mirror of
https://github.com/juanfont/headscale.git
synced 2025-08-11 15:27:37 +00:00
all: use immutable node view in read path
This commit changes most of our (*)types.Node to types.NodeView, which is a readonly version of the underlying node ensuring that there is no mutations happening in the read path. Based on the migration, there didnt seem to be any, but the idea here is to prevent it in the future and simplify other new implementations. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:

committed by
Kristoffer Dalby

parent
5ba7120418
commit
73023c2ec3
@@ -27,6 +27,7 @@ import (
|
||||
"tailscale.com/smallzstd"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/dnstype"
|
||||
"tailscale.com/types/views"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -88,16 +89,18 @@ func (m *Mapper) String() string {
|
||||
}
|
||||
|
||||
func generateUserProfiles(
|
||||
node *types.Node,
|
||||
peers types.Nodes,
|
||||
node types.NodeView,
|
||||
peers views.Slice[types.NodeView],
|
||||
) []tailcfg.UserProfile {
|
||||
userMap := make(map[uint]*types.User)
|
||||
ids := make([]uint, 0, len(userMap))
|
||||
userMap[node.User.ID] = &node.User
|
||||
ids = append(ids, node.User.ID)
|
||||
for _, peer := range peers {
|
||||
userMap[peer.User.ID] = &peer.User
|
||||
ids = append(ids, peer.User.ID)
|
||||
ids := make([]uint, 0, peers.Len()+1)
|
||||
user := node.User()
|
||||
userMap[user.ID] = &user
|
||||
ids = append(ids, user.ID)
|
||||
for _, peer := range peers.All() {
|
||||
peerUser := peer.User()
|
||||
userMap[peerUser.ID] = &peerUser
|
||||
ids = append(ids, peerUser.ID)
|
||||
}
|
||||
|
||||
slices.Sort(ids)
|
||||
@@ -114,7 +117,7 @@ func generateUserProfiles(
|
||||
|
||||
func generateDNSConfig(
|
||||
cfg *types.Config,
|
||||
node *types.Node,
|
||||
node types.NodeView,
|
||||
) *tailcfg.DNSConfig {
|
||||
if cfg.TailcfgDNSConfig == nil {
|
||||
return nil
|
||||
@@ -134,16 +137,17 @@ func generateDNSConfig(
|
||||
//
|
||||
// This will produce a resolver like:
|
||||
// `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1`
|
||||
func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
|
||||
func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
|
||||
for _, resolver := range resolvers {
|
||||
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {
|
||||
attrs := url.Values{
|
||||
"device_name": []string{node.Hostname},
|
||||
"device_model": []string{node.Hostinfo.OS},
|
||||
"device_name": []string{node.Hostname()},
|
||||
"device_model": []string{node.Hostinfo().OS()},
|
||||
}
|
||||
|
||||
if len(node.IPs()) > 0 {
|
||||
attrs.Add("device_ip", node.IPs()[0].String())
|
||||
nodeIPs := node.IPs()
|
||||
if len(nodeIPs) > 0 {
|
||||
attrs.Add("device_ip", nodeIPs[0].String())
|
||||
}
|
||||
|
||||
resolver.Addr = fmt.Sprintf("%s?%s", resolver.Addr, attrs.Encode())
|
||||
@@ -154,8 +158,8 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
|
||||
// fullMapResponse creates a complete MapResponse for a node.
|
||||
// It is a separate function to make testing easier.
|
||||
func (m *Mapper) fullMapResponse(
|
||||
node *types.Node,
|
||||
peers types.Nodes,
|
||||
node types.NodeView,
|
||||
peers views.Slice[types.NodeView],
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
resp, err := m.baseWithConfigMapResponse(node, capVer)
|
||||
@@ -182,15 +186,15 @@ func (m *Mapper) fullMapResponse(
|
||||
// FullMapResponse returns a MapResponse for the given node.
|
||||
func (m *Mapper) FullMapResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
node *types.Node,
|
||||
node types.NodeView,
|
||||
messages ...string,
|
||||
) ([]byte, error) {
|
||||
peers, err := m.ListPeers(node.ID)
|
||||
peers, err := m.ListPeers(node.ID())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := m.fullMapResponse(node, peers, mapRequest.Version)
|
||||
resp, err := m.fullMapResponse(node, peers.ViewSlice(), mapRequest.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -203,7 +207,7 @@ func (m *Mapper) FullMapResponse(
|
||||
// to be used to answer MapRequests with OmitPeers set to true.
|
||||
func (m *Mapper) ReadOnlyMapResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
node *types.Node,
|
||||
node types.NodeView,
|
||||
messages ...string,
|
||||
) ([]byte, error) {
|
||||
resp, err := m.baseWithConfigMapResponse(node, mapRequest.Version)
|
||||
@@ -216,7 +220,7 @@ func (m *Mapper) ReadOnlyMapResponse(
|
||||
|
||||
func (m *Mapper) KeepAliveResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
node *types.Node,
|
||||
node types.NodeView,
|
||||
) ([]byte, error) {
|
||||
resp := m.baseMapResponse()
|
||||
resp.KeepAlive = true
|
||||
@@ -226,7 +230,7 @@ func (m *Mapper) KeepAliveResponse(
|
||||
|
||||
func (m *Mapper) DERPMapResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
node *types.Node,
|
||||
node types.NodeView,
|
||||
derpMap *tailcfg.DERPMap,
|
||||
) ([]byte, error) {
|
||||
resp := m.baseMapResponse()
|
||||
@@ -237,7 +241,7 @@ func (m *Mapper) DERPMapResponse(
|
||||
|
||||
func (m *Mapper) PeerChangedResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
node *types.Node,
|
||||
node types.NodeView,
|
||||
changed map[types.NodeID]bool,
|
||||
patches []*tailcfg.PeerChange,
|
||||
messages ...string,
|
||||
@@ -249,7 +253,7 @@ func (m *Mapper) PeerChangedResponse(
|
||||
var changedIDs []types.NodeID
|
||||
for nodeID, nodeChanged := range changed {
|
||||
if nodeChanged {
|
||||
if nodeID != node.ID {
|
||||
if nodeID != node.ID() {
|
||||
changedIDs = append(changedIDs, nodeID)
|
||||
}
|
||||
} else {
|
||||
@@ -270,7 +274,7 @@ func (m *Mapper) PeerChangedResponse(
|
||||
m.state,
|
||||
node,
|
||||
mapRequest.Version,
|
||||
changedNodes,
|
||||
changedNodes.ViewSlice(),
|
||||
m.cfg,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -315,7 +319,7 @@ func (m *Mapper) PeerChangedResponse(
|
||||
// incoming update from a state change.
|
||||
func (m *Mapper) PeerChangedPatchResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
node *types.Node,
|
||||
node types.NodeView,
|
||||
changed []*tailcfg.PeerChange,
|
||||
) ([]byte, error) {
|
||||
resp := m.baseMapResponse()
|
||||
@@ -327,7 +331,7 @@ func (m *Mapper) PeerChangedPatchResponse(
|
||||
func (m *Mapper) marshalMapResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
resp *tailcfg.MapResponse,
|
||||
node *types.Node,
|
||||
node types.NodeView,
|
||||
compression string,
|
||||
messages ...string,
|
||||
) ([]byte, error) {
|
||||
@@ -366,7 +370,7 @@ func (m *Mapper) marshalMapResponse(
|
||||
}
|
||||
|
||||
perms := fs.FileMode(debugMapResponsePerm)
|
||||
mPath := path.Join(debugDumpMapResponsePath, node.Hostname)
|
||||
mPath := path.Join(debugDumpMapResponsePath, node.Hostname())
|
||||
err = os.MkdirAll(mPath, perms)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@@ -444,7 +448,7 @@ func (m *Mapper) baseMapResponse() tailcfg.MapResponse {
|
||||
// It is used in for bigger updates, such as full and lite, not
|
||||
// incremental.
|
||||
func (m *Mapper) baseWithConfigMapResponse(
|
||||
node *types.Node,
|
||||
node types.NodeView,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
resp := m.baseMapResponse()
|
||||
@@ -523,9 +527,9 @@ func appendPeerChanges(
|
||||
|
||||
fullChange bool,
|
||||
state *state.State,
|
||||
node *types.Node,
|
||||
node types.NodeView,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
changed types.Nodes,
|
||||
changed views.Slice[types.NodeView],
|
||||
cfg *types.Config,
|
||||
) error {
|
||||
filter, matchers := state.Filter()
|
||||
@@ -537,16 +541,19 @@ func appendPeerChanges(
|
||||
|
||||
// If there are filter rules present, see if there are any nodes that cannot
|
||||
// access each-other at all and remove them from the peers.
|
||||
var reducedChanged views.Slice[types.NodeView]
|
||||
if len(filter) > 0 {
|
||||
changed = policy.ReduceNodes(node, changed, matchers)
|
||||
reducedChanged = policy.ReduceNodes(node, changed, matchers)
|
||||
} else {
|
||||
reducedChanged = changed
|
||||
}
|
||||
|
||||
profiles := generateUserProfiles(node, changed)
|
||||
profiles := generateUserProfiles(node, reducedChanged)
|
||||
|
||||
dnsConfig := generateDNSConfig(cfg, node)
|
||||
|
||||
tailPeers, err := tailNodes(
|
||||
changed, capVer, state,
|
||||
reducedChanged, capVer, state,
|
||||
func(id types.NodeID) []netip.Prefix {
|
||||
return policy.ReduceRoutes(node, state.GetNodePrimaryRoutes(id), matchers)
|
||||
},
|
||||
|
@@ -70,7 +70,7 @@ func TestDNSConfigMapResponse(t *testing.T) {
|
||||
&types.Config{
|
||||
TailcfgDNSConfig: &dnsConfigOrig,
|
||||
},
|
||||
nodeInShared1,
|
||||
nodeInShared1.View(),
|
||||
)
|
||||
|
||||
if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" {
|
||||
@@ -100,14 +100,14 @@ func (m *mockState) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
|
||||
return m.polMan.Filter()
|
||||
}
|
||||
|
||||
func (m *mockState) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) {
|
||||
func (m *mockState) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) {
|
||||
if m.polMan == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return m.polMan.SSHPolicy(node)
|
||||
}
|
||||
|
||||
func (m *mockState) NodeCanHaveTag(node *types.Node, tag string) bool {
|
||||
func (m *mockState) NodeCanHaveTag(node types.NodeView, tag string) bool {
|
||||
if m.polMan == nil {
|
||||
return false
|
||||
}
|
||||
|
@@ -8,24 +8,25 @@ import (
|
||||
"github.com/samber/lo"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/views"
|
||||
)
|
||||
|
||||
// NodeCanHaveTagChecker is an interface for checking if a node can have a tag
|
||||
type NodeCanHaveTagChecker interface {
|
||||
NodeCanHaveTag(node *types.Node, tag string) bool
|
||||
NodeCanHaveTag(node types.NodeView, tag string) bool
|
||||
}
|
||||
|
||||
func tailNodes(
|
||||
nodes types.Nodes,
|
||||
nodes views.Slice[types.NodeView],
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
checker NodeCanHaveTagChecker,
|
||||
primaryRouteFunc routeFilterFunc,
|
||||
cfg *types.Config,
|
||||
) ([]*tailcfg.Node, error) {
|
||||
tNodes := make([]*tailcfg.Node, len(nodes))
|
||||
tNodes := make([]*tailcfg.Node, 0, nodes.Len())
|
||||
|
||||
for index, node := range nodes {
|
||||
node, err := tailNode(
|
||||
for _, node := range nodes.All() {
|
||||
tNode, err := tailNode(
|
||||
node,
|
||||
capVer,
|
||||
checker,
|
||||
@@ -36,7 +37,7 @@ func tailNodes(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tNodes[index] = node
|
||||
tNodes = append(tNodes, tNode)
|
||||
}
|
||||
|
||||
return tNodes, nil
|
||||
@@ -44,7 +45,7 @@ func tailNodes(
|
||||
|
||||
// tailNode converts a Node into a Tailscale Node.
|
||||
func tailNode(
|
||||
node *types.Node,
|
||||
node types.NodeView,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
checker NodeCanHaveTagChecker,
|
||||
primaryRouteFunc routeFilterFunc,
|
||||
@@ -57,61 +58,64 @@ func tailNode(
|
||||
// TODO(kradalby): legacyDERP was removed in tailscale/tailscale@2fc4455e6dd9ab7f879d4e2f7cffc2be81f14077
|
||||
// and should be removed after 111 is the minimum capver.
|
||||
var legacyDERP string
|
||||
if node.Hostinfo != nil && node.Hostinfo.NetInfo != nil {
|
||||
legacyDERP = fmt.Sprintf("127.3.3.40:%d", node.Hostinfo.NetInfo.PreferredDERP)
|
||||
derp = node.Hostinfo.NetInfo.PreferredDERP
|
||||
if node.Hostinfo().Valid() && node.Hostinfo().NetInfo().Valid() {
|
||||
legacyDERP = fmt.Sprintf("127.3.3.40:%d", node.Hostinfo().NetInfo().PreferredDERP())
|
||||
derp = node.Hostinfo().NetInfo().PreferredDERP()
|
||||
} else {
|
||||
legacyDERP = "127.3.3.40:0" // Zero means disconnected or unknown.
|
||||
}
|
||||
|
||||
var keyExpiry time.Time
|
||||
if node.Expiry != nil {
|
||||
keyExpiry = *node.Expiry
|
||||
if node.Expiry().Valid() {
|
||||
keyExpiry = node.Expiry().Get()
|
||||
} else {
|
||||
keyExpiry = time.Time{}
|
||||
}
|
||||
|
||||
hostname, err := node.GetFQDN(cfg.BaseDomain)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var tags []string
|
||||
for _, tag := range node.RequestTags() {
|
||||
for _, tag := range node.RequestTagsSlice().All() {
|
||||
if checker.NodeCanHaveTag(node, tag) {
|
||||
tags = append(tags, tag)
|
||||
}
|
||||
}
|
||||
tags = lo.Uniq(append(tags, node.ForcedTags...))
|
||||
for _, tag := range node.ForcedTags().All() {
|
||||
tags = append(tags, tag)
|
||||
}
|
||||
tags = lo.Uniq(tags)
|
||||
|
||||
routes := primaryRouteFunc(node.ID)
|
||||
allowed := append(node.Prefixes(), routes...)
|
||||
routes := primaryRouteFunc(node.ID())
|
||||
allowed := append(addrs, routes...)
|
||||
allowed = append(allowed, node.ExitRoutes()...)
|
||||
tsaddr.SortPrefixes(allowed)
|
||||
|
||||
tNode := tailcfg.Node{
|
||||
ID: tailcfg.NodeID(node.ID), // this is the actual ID
|
||||
StableID: node.ID.StableID(),
|
||||
ID: tailcfg.NodeID(node.ID()), // this is the actual ID
|
||||
StableID: node.ID().StableID(),
|
||||
Name: hostname,
|
||||
Cap: capVer,
|
||||
|
||||
User: tailcfg.UserID(node.UserID),
|
||||
User: tailcfg.UserID(node.UserID()),
|
||||
|
||||
Key: node.NodeKey,
|
||||
Key: node.NodeKey(),
|
||||
KeyExpiry: keyExpiry.UTC(),
|
||||
|
||||
Machine: node.MachineKey,
|
||||
DiscoKey: node.DiscoKey,
|
||||
Machine: node.MachineKey(),
|
||||
DiscoKey: node.DiscoKey(),
|
||||
Addresses: addrs,
|
||||
PrimaryRoutes: routes,
|
||||
AllowedIPs: allowed,
|
||||
Endpoints: node.Endpoints,
|
||||
Endpoints: node.Endpoints().AsSlice(),
|
||||
HomeDERP: derp,
|
||||
LegacyDERPString: legacyDERP,
|
||||
Hostinfo: node.Hostinfo.View(),
|
||||
Created: node.CreatedAt.UTC(),
|
||||
Hostinfo: node.Hostinfo(),
|
||||
Created: node.CreatedAt().UTC(),
|
||||
|
||||
Online: node.IsOnline,
|
||||
Online: node.IsOnline().Clone(),
|
||||
|
||||
Tags: tags,
|
||||
|
||||
@@ -129,10 +133,13 @@ func tailNode(
|
||||
tNode.CapMap[tailcfg.NodeAttrRandomizeClientPort] = []tailcfg.RawMessage{}
|
||||
}
|
||||
|
||||
if node.IsOnline == nil || !*node.IsOnline {
|
||||
if !node.IsOnline().Valid() || !node.IsOnline().Get() {
|
||||
// LastSeen is only set when node is
|
||||
// not connected to the control server.
|
||||
tNode.LastSeen = node.LastSeen
|
||||
if node.LastSeen().Valid() {
|
||||
lastSeen := node.LastSeen().Get()
|
||||
tNode.LastSeen = &lastSeen
|
||||
}
|
||||
}
|
||||
|
||||
return &tNode, nil
|
||||
|
@@ -202,7 +202,7 @@ func TestTailNode(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
polMan, err := policy.NewPolicyManager(tt.pol, []types.User{}, types.Nodes{tt.node})
|
||||
polMan, err := policy.NewPolicyManager(tt.pol, []types.User{}, types.Nodes{tt.node}.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
primary := routes.New()
|
||||
cfg := &types.Config{
|
||||
@@ -216,7 +216,7 @@ func TestTailNode(t *testing.T) {
|
||||
// This should be baked into the test case proper if it is extended in the future.
|
||||
_ = primary.SetRoutes(2, netip.MustParsePrefix("192.168.0.0/24"))
|
||||
got, err := tailNode(
|
||||
tt.node,
|
||||
tt.node.View(),
|
||||
0,
|
||||
polMan,
|
||||
func(id types.NodeID) []netip.Prefix {
|
||||
@@ -272,11 +272,11 @@ func TestNodeExpiry(t *testing.T) {
|
||||
GivenName: "test",
|
||||
Expiry: tt.exp,
|
||||
}
|
||||
polMan, err := policy.NewPolicyManager(nil, nil, nil)
|
||||
polMan, err := policy.NewPolicyManager(nil, nil, types.Nodes{}.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
tn, err := tailNode(
|
||||
node,
|
||||
node.View(),
|
||||
0,
|
||||
polMan,
|
||||
func(id types.NodeID) []netip.Prefix {
|
||||
|
Reference in New Issue
Block a user