all: use immutable node view in read path

This commit changes most of our (*)types.Node to
types.NodeView, which is a readonly version of the
underlying node ensuring that there is no mutations
happening in the read path.

Based on the migration, there didnt seem to be any, but the
idea here is to prevent it in the future and simplify other
new implementations.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby
2025-07-05 23:31:13 +02:00
committed by Kristoffer Dalby
parent 5ba7120418
commit 73023c2ec3
24 changed files with 866 additions and 196 deletions

View File

@@ -10,6 +10,7 @@ import (
"github.com/rs/zerolog/log"
"go4.org/netipx"
"tailscale.com/tailcfg"
"tailscale.com/types/views"
)
var (
@@ -20,7 +21,7 @@ var (
// set of Tailscale compatible FilterRules used to allow traffic on clients.
func (pol *Policy) compileFilterRules(
users types.Users,
nodes types.Nodes,
nodes views.Slice[types.NodeView],
) ([]tailcfg.FilterRule, error) {
if pol == nil {
return tailcfg.FilterAllowAll, nil
@@ -97,8 +98,8 @@ func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction {
func (pol *Policy) compileSSHPolicy(
users types.Users,
node *types.Node,
nodes types.Nodes,
node types.NodeView,
nodes views.Slice[types.NodeView],
) (*tailcfg.SSHPolicy, error) {
if pol == nil || pol.SSHs == nil || len(pol.SSHs) == 0 {
return nil, nil

View File

@@ -362,7 +362,7 @@ func TestParsing(t *testing.T) {
User: users[0],
Hostinfo: &tailcfg.Hostinfo{},
},
})
}.ViewSlice())
if (err != nil) != tt.wantErr {
t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr)

View File

@@ -16,13 +16,14 @@ import (
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/util/deephash"
"tailscale.com/types/views"
)
type PolicyManager struct {
mu sync.Mutex
pol *Policy
users []types.User
nodes types.Nodes
nodes views.Slice[types.NodeView]
filterHash deephash.Sum
filter []tailcfg.FilterRule
@@ -43,7 +44,7 @@ type PolicyManager struct {
// NewPolicyManager creates a new PolicyManager from a policy file and a list of users and nodes.
// It returns an error if the policy file is invalid.
// The policy manager will update the filter rules based on the users and nodes.
func NewPolicyManager(b []byte, users []types.User, nodes types.Nodes) (*PolicyManager, error) {
func NewPolicyManager(b []byte, users []types.User, nodes views.Slice[types.NodeView]) (*PolicyManager, error) {
policy, err := unmarshalPolicy(b)
if err != nil {
return nil, fmt.Errorf("parsing policy: %w", err)
@@ -53,7 +54,7 @@ func NewPolicyManager(b []byte, users []types.User, nodes types.Nodes) (*PolicyM
pol: policy,
users: users,
nodes: nodes,
sshPolicyMap: make(map[types.NodeID]*tailcfg.SSHPolicy, len(nodes)),
sshPolicyMap: make(map[types.NodeID]*tailcfg.SSHPolicy, nodes.Len()),
}
_, err = pm.updateLocked()
@@ -122,11 +123,11 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
return true, nil
}
func (pm *PolicyManager) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) {
func (pm *PolicyManager) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) {
pm.mu.Lock()
defer pm.mu.Unlock()
if sshPol, ok := pm.sshPolicyMap[node.ID]; ok {
if sshPol, ok := pm.sshPolicyMap[node.ID()]; ok {
return sshPol, nil
}
@@ -134,7 +135,7 @@ func (pm *PolicyManager) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error)
if err != nil {
return nil, fmt.Errorf("compiling SSH policy: %w", err)
}
pm.sshPolicyMap[node.ID] = sshPol
pm.sshPolicyMap[node.ID()] = sshPol
return sshPol, nil
}
@@ -181,7 +182,7 @@ func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) {
}
// SetNodes updates the nodes in the policy manager and updates the filter rules.
func (pm *PolicyManager) SetNodes(nodes types.Nodes) (bool, error) {
func (pm *PolicyManager) SetNodes(nodes views.Slice[types.NodeView]) (bool, error) {
if pm == nil {
return false, nil
}
@@ -192,7 +193,7 @@ func (pm *PolicyManager) SetNodes(nodes types.Nodes) (bool, error) {
return pm.updateLocked()
}
func (pm *PolicyManager) NodeCanHaveTag(node *types.Node, tag string) bool {
func (pm *PolicyManager) NodeCanHaveTag(node types.NodeView, tag string) bool {
if pm == nil {
return false
}
@@ -209,7 +210,7 @@ func (pm *PolicyManager) NodeCanHaveTag(node *types.Node, tag string) bool {
return false
}
func (pm *PolicyManager) NodeCanApproveRoute(node *types.Node, route netip.Prefix) bool {
func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Prefix) bool {
if pm == nil {
return false
}
@@ -322,7 +323,11 @@ func (pm *PolicyManager) DebugString() string {
}
sb.WriteString("\n\n")
sb.WriteString(pm.nodes.DebugString())
sb.WriteString("Nodes:\n")
for _, node := range pm.nodes.All() {
sb.WriteString(node.String())
sb.WriteString("\n")
}
return sb.String()
}

View File

@@ -47,7 +47,7 @@ func TestPolicyManager(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes)
pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes.ViewSlice())
require.NoError(t, err)
filter, matchers := pm.Filter()

View File

@@ -18,6 +18,7 @@ import (
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/ptr"
"tailscale.com/types/views"
"tailscale.com/util/multierr"
)
@@ -91,7 +92,7 @@ func (a Asterix) UnmarshalJSON(b []byte) error {
return nil
}
func (a Asterix) Resolve(_ *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
func (a Asterix) Resolve(_ *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder
// TODO(kradalby):
@@ -179,7 +180,7 @@ func (u Username) resolveUser(users types.Users) (types.User, error) {
return potentialUsers[0], nil
}
func (u Username) Resolve(_ *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
func (u Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder
var errs []error
@@ -188,12 +189,13 @@ func (u Username) Resolve(_ *Policy, users types.Users, nodes types.Nodes) (*net
errs = append(errs, err)
}
for _, node := range nodes {
for _, node := range nodes.All() {
// Skip tagged nodes
if node.IsTagged() {
continue
}
if node.User.ID == user.ID {
if node.User().ID == user.ID {
node.AppendToIPSet(&ips)
}
}
@@ -246,7 +248,7 @@ func (g Group) MarshalJSON() ([]byte, error) {
return json.Marshal(string(g))
}
func (g Group) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
func (g Group) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder
var errs []error
@@ -280,7 +282,7 @@ func (t *Tag) UnmarshalJSON(b []byte) error {
return nil
}
func (t Tag) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
func (t Tag) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder
// TODO(kradalby): This is currently resolved twice, and should be resolved once.
@@ -295,17 +297,19 @@ func (t Tag) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.I
return nil, err
}
for _, node := range nodes {
if node.HasTag(string(t)) {
for _, node := range nodes.All() {
// Check if node has this tag in all tags (ForcedTags + AuthKey.Tags)
if slices.Contains(node.Tags(), string(t)) {
node.AppendToIPSet(&ips)
}
// TODO(kradalby): remove as part of #2417, see comment above
if tagMap != nil {
if tagips, ok := tagMap[t]; ok && node.InIPSet(tagips) && node.Hostinfo != nil {
for _, tag := range node.Hostinfo.RequestTags {
if tagips, ok := tagMap[t]; ok && node.InIPSet(tagips) && node.Hostinfo().Valid() {
for _, tag := range node.RequestTagsSlice().All() {
if tag == string(t) {
node.AppendToIPSet(&ips)
break
}
}
}
@@ -346,7 +350,7 @@ func (h *Host) UnmarshalJSON(b []byte) error {
return nil
}
func (h Host) Resolve(p *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
func (h Host) Resolve(p *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder
var errs []error
@@ -371,7 +375,7 @@ func (h Host) Resolve(p *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSe
if err != nil {
errs = append(errs, err)
}
for _, node := range nodes {
for _, node := range nodes.All() {
if node.InIPSet(ipsTemp) {
node.AppendToIPSet(&ips)
}
@@ -432,7 +436,7 @@ func (p *Prefix) UnmarshalJSON(b []byte) error {
// of the Prefix and the Policy, Users, and Nodes.
//
// See [Policy], [types.Users], and [types.Nodes] for more details.
func (p Prefix) Resolve(_ *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
func (p Prefix) Resolve(_ *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder
var errs []error
@@ -446,12 +450,12 @@ func (p Prefix) Resolve(_ *Policy, _ types.Users, nodes types.Nodes) (*netipx.IP
// appendIfNodeHasIP appends the IPs of the nodes to the IPSet if the node has the
// IP address in the prefix.
func appendIfNodeHasIP(nodes types.Nodes, ips *netipx.IPSetBuilder, pref netip.Prefix) {
func appendIfNodeHasIP(nodes views.Slice[types.NodeView], ips *netipx.IPSetBuilder, pref netip.Prefix) {
if !pref.IsSingleIP() && !tsaddr.IsTailscaleIP(pref.Addr()) {
return
}
for _, node := range nodes {
for _, node := range nodes.All() {
if node.HasIP(pref.Addr()) {
node.AppendToIPSet(ips)
}
@@ -499,7 +503,7 @@ func (ag AutoGroup) MarshalJSON() ([]byte, error) {
return json.Marshal(string(ag))
}
func (ag AutoGroup) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
func (ag AutoGroup) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
var build netipx.IPSetBuilder
switch ag {
@@ -513,17 +517,17 @@ func (ag AutoGroup) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*n
return nil, err
}
for _, node := range nodes {
// Skip if node has forced tags
if len(node.ForcedTags) != 0 {
for _, node := range nodes.All() {
// Skip if node is tagged
if node.IsTagged() {
continue
}
// Skip if node has any allowed requested tags
hasAllowedTag := false
if node.Hostinfo != nil && len(node.Hostinfo.RequestTags) != 0 {
for _, tag := range node.Hostinfo.RequestTags {
if tagips, ok := tagMap[Tag(tag)]; ok && node.InIPSet(tagips) {
if node.RequestTagsSlice().Len() != 0 {
for _, tag := range node.RequestTagsSlice().All() {
if _, ok := tagMap[Tag(tag)]; ok {
hasAllowedTag = true
break
}
@@ -546,16 +550,16 @@ func (ag AutoGroup) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*n
return nil, err
}
for _, node := range nodes {
// Include if node has forced tags
if len(node.ForcedTags) != 0 {
for _, node := range nodes.All() {
// Include if node is tagged
if node.IsTagged() {
node.AppendToIPSet(&build)
continue
}
// Include if node has any allowed requested tags
if node.Hostinfo != nil && len(node.Hostinfo.RequestTags) != 0 {
for _, tag := range node.Hostinfo.RequestTags {
if node.RequestTagsSlice().Len() != 0 {
for _, tag := range node.RequestTagsSlice().All() {
if _, ok := tagMap[Tag(tag)]; ok {
node.AppendToIPSet(&build)
break
@@ -588,7 +592,7 @@ type Alias interface {
// of the Alias and the Policy, Users and Nodes.
// This is an interface definition and the implementation is independent of
// the Alias type.
Resolve(*Policy, types.Users, types.Nodes) (*netipx.IPSet, error)
Resolve(*Policy, types.Users, views.Slice[types.NodeView]) (*netipx.IPSet, error)
}
type AliasWithPorts struct {
@@ -759,7 +763,7 @@ func (a Aliases) MarshalJSON() ([]byte, error) {
return json.Marshal(aliases)
}
func (a Aliases) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
func (a Aliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder
var errs []error
@@ -1094,7 +1098,7 @@ func (to TagOwners) Contains(tagOwner *Tag) error {
// resolveTagOwners resolves the TagOwners to a map of Tag to netipx.IPSet.
// The resulting map can be used to quickly look up the IPSet for a given Tag.
// It is intended for internal use in a PolicyManager.
func resolveTagOwners(p *Policy, users types.Users, nodes types.Nodes) (map[Tag]*netipx.IPSet, error) {
func resolveTagOwners(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (map[Tag]*netipx.IPSet, error) {
if p == nil {
return nil, nil
}
@@ -1158,7 +1162,7 @@ func (ap AutoApproverPolicy) MarshalJSON() ([]byte, error) {
// resolveAutoApprovers resolves the AutoApprovers to a map of netip.Prefix to netipx.IPSet.
// The resulting map can be used to quickly look up if a node can self-approve a route.
// It is intended for internal use in a PolicyManager.
func resolveAutoApprovers(p *Policy, users types.Users, nodes types.Nodes) (map[netip.Prefix]*netipx.IPSet, *netipx.IPSet, error) {
func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (map[netip.Prefix]*netipx.IPSet, *netipx.IPSet, error) {
if p == nil {
return nil, nil, nil
}
@@ -1671,7 +1675,7 @@ func (a SSHSrcAliases) MarshalJSON() ([]byte, error) {
return json.Marshal(aliases)
}
func (a SSHSrcAliases) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
func (a SSHSrcAliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder
var errs []error

View File

@@ -1377,7 +1377,7 @@ func TestResolvePolicy(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
ips, err := tt.toResolve.Resolve(tt.pol,
xmaps.Values(users),
tt.nodes)
tt.nodes.ViewSlice())
if tt.wantErr == "" {
if err != nil {
t.Fatalf("got %v; want no error", err)
@@ -1557,7 +1557,7 @@ func TestResolveAutoApprovers(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, gotAllIPRoutes, err := resolveAutoApprovers(tt.policy, users, nodes)
got, gotAllIPRoutes, err := resolveAutoApprovers(tt.policy, users, nodes.ViewSlice())
if (err != nil) != tt.wantErr {
t.Errorf("resolveAutoApprovers() error = %v, wantErr %v", err, tt.wantErr)
return
@@ -1716,10 +1716,10 @@ func TestNodeCanApproveRoute(t *testing.T) {
b, err := json.Marshal(tt.policy)
require.NoError(t, err)
pm, err := NewPolicyManager(b, users, nodes)
pm, err := NewPolicyManager(b, users, nodes.ViewSlice())
require.NoErrorf(t, err, "NewPolicyManager() error = %v", err)
got := pm.NodeCanApproveRoute(tt.node, tt.route)
got := pm.NodeCanApproveRoute(tt.node.View(), tt.route)
if got != tt.want {
t.Errorf("NodeCanApproveRoute() = %v, want %v", got, tt.want)
}
@@ -1800,7 +1800,7 @@ func TestResolveTagOwners(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := resolveTagOwners(tt.policy, users, nodes)
got, err := resolveTagOwners(tt.policy, users, nodes.ViewSlice())
if (err != nil) != tt.wantErr {
t.Errorf("resolveTagOwners() error = %v, wantErr %v", err, tt.wantErr)
return
@@ -1911,14 +1911,14 @@ func TestNodeCanHaveTag(t *testing.T) {
b, err := json.Marshal(tt.policy)
require.NoError(t, err)
pm, err := NewPolicyManager(b, users, nodes)
pm, err := NewPolicyManager(b, users, nodes.ViewSlice())
if tt.wantErr != "" {
require.ErrorContains(t, err, tt.wantErr)
return
}
require.NoError(t, err)
got := pm.NodeCanHaveTag(tt.node, tt.tag)
got := pm.NodeCanHaveTag(tt.node.View(), tt.tag)
if got != tt.want {
t.Errorf("NodeCanHaveTag() = %v, want %v", got, tt.want)
}