mirror of
https://github.com/juanfont/headscale.git
synced 2025-08-11 21:37:34 +00:00
all: use immutable node view in read path
This commit changes most of our (*)types.Node to types.NodeView, which is a readonly version of the underlying node ensuring that there is no mutations happening in the read path. Based on the migration, there didnt seem to be any, but the idea here is to prevent it in the future and simplify other new implementations. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:

committed by
Kristoffer Dalby

parent
5ba7120418
commit
73023c2ec3
@@ -8,27 +8,28 @@ import (
|
||||
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/views"
|
||||
)
|
||||
|
||||
type PolicyManager interface {
|
||||
// Filter returns the current filter rules for the entire tailnet and the associated matchers.
|
||||
Filter() ([]tailcfg.FilterRule, []matcher.Match)
|
||||
SSHPolicy(*types.Node) (*tailcfg.SSHPolicy, error)
|
||||
SSHPolicy(types.NodeView) (*tailcfg.SSHPolicy, error)
|
||||
SetPolicy([]byte) (bool, error)
|
||||
SetUsers(users []types.User) (bool, error)
|
||||
SetNodes(nodes types.Nodes) (bool, error)
|
||||
SetNodes(nodes views.Slice[types.NodeView]) (bool, error)
|
||||
// NodeCanHaveTag reports whether the given node can have the given tag.
|
||||
NodeCanHaveTag(*types.Node, string) bool
|
||||
NodeCanHaveTag(types.NodeView, string) bool
|
||||
|
||||
// NodeCanApproveRoute reports whether the given node can approve the given route.
|
||||
NodeCanApproveRoute(*types.Node, netip.Prefix) bool
|
||||
NodeCanApproveRoute(types.NodeView, netip.Prefix) bool
|
||||
|
||||
Version() int
|
||||
DebugString() string
|
||||
}
|
||||
|
||||
// NewPolicyManager returns a new policy manager.
|
||||
func NewPolicyManager(pol []byte, users []types.User, nodes types.Nodes) (PolicyManager, error) {
|
||||
func NewPolicyManager(pol []byte, users []types.User, nodes views.Slice[types.NodeView]) (PolicyManager, error) {
|
||||
var polMan PolicyManager
|
||||
var err error
|
||||
polMan, err = policyv2.NewPolicyManager(pol, users, nodes)
|
||||
@@ -42,7 +43,7 @@ func NewPolicyManager(pol []byte, users []types.User, nodes types.Nodes) (Policy
|
||||
// PolicyManagersForTest returns all available PostureManagers to be used
|
||||
// in tests to validate them in tests that try to determine that they
|
||||
// behave the same.
|
||||
func PolicyManagersForTest(pol []byte, users []types.User, nodes types.Nodes) ([]PolicyManager, error) {
|
||||
func PolicyManagersForTest(pol []byte, users []types.User, nodes views.Slice[types.NodeView]) ([]PolicyManager, error) {
|
||||
var polMans []PolicyManager
|
||||
|
||||
for _, pmf := range PolicyManagerFuncsForTest(pol) {
|
||||
@@ -56,10 +57,10 @@ func PolicyManagersForTest(pol []byte, users []types.User, nodes types.Nodes) ([
|
||||
return polMans, nil
|
||||
}
|
||||
|
||||
func PolicyManagerFuncsForTest(pol []byte) []func([]types.User, types.Nodes) (PolicyManager, error) {
|
||||
var polmanFuncs []func([]types.User, types.Nodes) (PolicyManager, error)
|
||||
func PolicyManagerFuncsForTest(pol []byte) []func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error) {
|
||||
var polmanFuncs []func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error)
|
||||
|
||||
polmanFuncs = append(polmanFuncs, func(u []types.User, n types.Nodes) (PolicyManager, error) {
|
||||
polmanFuncs = append(polmanFuncs, func(u []types.User, n views.Slice[types.NodeView]) (PolicyManager, error) {
|
||||
return policyv2.NewPolicyManager(pol, u, n)
|
||||
})
|
||||
|
||||
|
@@ -11,32 +11,33 @@ import (
|
||||
"github.com/samber/lo"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/views"
|
||||
)
|
||||
|
||||
// ReduceNodes returns the list of peers authorized to be accessed from a given node.
|
||||
func ReduceNodes(
|
||||
node *types.Node,
|
||||
nodes types.Nodes,
|
||||
node types.NodeView,
|
||||
nodes views.Slice[types.NodeView],
|
||||
matchers []matcher.Match,
|
||||
) types.Nodes {
|
||||
var result types.Nodes
|
||||
) views.Slice[types.NodeView] {
|
||||
var result []types.NodeView
|
||||
|
||||
for index, peer := range nodes {
|
||||
if peer.ID == node.ID {
|
||||
for _, peer := range nodes.All() {
|
||||
if peer.ID() == node.ID() {
|
||||
continue
|
||||
}
|
||||
|
||||
if node.CanAccess(matchers, nodes[index]) || peer.CanAccess(matchers, node) {
|
||||
if node.CanAccess(matchers, peer) || peer.CanAccess(matchers, node) {
|
||||
result = append(result, peer)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
return views.SliceOf(result)
|
||||
}
|
||||
|
||||
// ReduceRoutes returns a reduced list of routes for a given node that it can access.
|
||||
func ReduceRoutes(
|
||||
node *types.Node,
|
||||
node types.NodeView,
|
||||
routes []netip.Prefix,
|
||||
matchers []matcher.Match,
|
||||
) []netip.Prefix {
|
||||
@@ -51,9 +52,36 @@ func ReduceRoutes(
|
||||
return result
|
||||
}
|
||||
|
||||
// BuildPeerMap builds a map of all peers that can be accessed by each node.
|
||||
func BuildPeerMap(
|
||||
nodes views.Slice[types.NodeView],
|
||||
matchers []matcher.Match,
|
||||
) map[types.NodeID][]types.NodeView {
|
||||
ret := make(map[types.NodeID][]types.NodeView, nodes.Len())
|
||||
|
||||
// Build the map of all peers according to the matchers.
|
||||
// Compared to ReduceNodes, which builds the list per node, we end up with doing
|
||||
// the full work for every node (On^2), while this will reduce the list as we see
|
||||
// relationships while building the map, making it O(n^2/2) in the end, but with less work per node.
|
||||
for i := range nodes.Len() {
|
||||
for j := i + 1; j < nodes.Len(); j++ {
|
||||
if nodes.At(i).ID() == nodes.At(j).ID() {
|
||||
continue
|
||||
}
|
||||
|
||||
if nodes.At(i).CanAccess(matchers, nodes.At(j)) || nodes.At(j).CanAccess(matchers, nodes.At(i)) {
|
||||
ret[nodes.At(i).ID()] = append(ret[nodes.At(i).ID()], nodes.At(j))
|
||||
ret[nodes.At(j).ID()] = append(ret[nodes.At(j).ID()], nodes.At(i))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// ReduceFilterRules takes a node and a set of rules and removes all rules and destinations
|
||||
// that are not relevant to that particular node.
|
||||
func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.FilterRule {
|
||||
func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcfg.FilterRule {
|
||||
ret := []tailcfg.FilterRule{}
|
||||
|
||||
for _, rule := range rules {
|
||||
@@ -75,9 +103,10 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F
|
||||
|
||||
// If the node exposes routes, ensure they are note removed
|
||||
// when the filters are reduced.
|
||||
if node.Hostinfo != nil {
|
||||
if len(node.Hostinfo.RoutableIPs) > 0 {
|
||||
for _, routableIP := range node.Hostinfo.RoutableIPs {
|
||||
if node.Hostinfo().Valid() {
|
||||
routableIPs := node.Hostinfo().RoutableIPs()
|
||||
if routableIPs.Len() > 0 {
|
||||
for _, routableIP := range routableIPs.All() {
|
||||
if expanded.OverlapsPrefix(routableIP) {
|
||||
dests = append(dests, dest)
|
||||
continue DEST_LOOP
|
||||
@@ -102,13 +131,15 @@ func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.F
|
||||
// AutoApproveRoutes approves any route that can be autoapproved from
|
||||
// the nodes perspective according to the given policy.
|
||||
// It reports true if any routes were approved.
|
||||
// Note: This function now takes a pointer to the actual node to modify ApprovedRoutes
|
||||
func AutoApproveRoutes(pm PolicyManager, node *types.Node) bool {
|
||||
if pm == nil {
|
||||
return false
|
||||
}
|
||||
nodeView := node.View()
|
||||
var newApproved []netip.Prefix
|
||||
for _, route := range node.AnnouncedRoutes() {
|
||||
if pm.NodeCanApproveRoute(node, route) {
|
||||
for _, route := range nodeView.AnnouncedRoutes() {
|
||||
if pm.NodeCanApproveRoute(nodeView, route) {
|
||||
newApproved = append(newApproved, route)
|
||||
}
|
||||
}
|
||||
|
@@ -815,11 +815,11 @@ func TestReduceFilterRules(t *testing.T) {
|
||||
t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) {
|
||||
var pm PolicyManager
|
||||
var err error
|
||||
pm, err = pmf(users, append(tt.peers, tt.node))
|
||||
pm, err = pmf(users, append(tt.peers, tt.node).ViewSlice())
|
||||
require.NoError(t, err)
|
||||
got, _ := pm.Filter()
|
||||
t.Logf("full filter:\n%s", must.Get(json.MarshalIndent(got, "", " ")))
|
||||
got = ReduceFilterRules(tt.node, got)
|
||||
got = ReduceFilterRules(tt.node.View(), got)
|
||||
|
||||
if diff := cmp.Diff(tt.want, got); diff != "" {
|
||||
log.Trace().Interface("got", got).Msg("result")
|
||||
@@ -1576,11 +1576,16 @@ func TestReduceNodes(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
matchers := matcher.MatchesFromFilterRules(tt.args.rules)
|
||||
got := ReduceNodes(
|
||||
tt.args.node,
|
||||
tt.args.nodes,
|
||||
gotViews := ReduceNodes(
|
||||
tt.args.node.View(),
|
||||
tt.args.nodes.ViewSlice(),
|
||||
matchers,
|
||||
)
|
||||
// Convert views back to nodes for comparison in tests
|
||||
var got types.Nodes
|
||||
for _, v := range gotViews.All() {
|
||||
got = append(got, v.AsStruct())
|
||||
}
|
||||
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
|
||||
t.Errorf("FilterNodesByACL() unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
@@ -1949,7 +1954,7 @@ func TestSSHPolicyRules(t *testing.T) {
|
||||
t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) {
|
||||
var pm PolicyManager
|
||||
var err error
|
||||
pm, err = pmf(users, append(tt.peers, &tt.targetNode))
|
||||
pm, err = pmf(users, append(tt.peers, &tt.targetNode).ViewSlice())
|
||||
|
||||
if tt.expectErr {
|
||||
require.Error(t, err)
|
||||
@@ -1959,7 +1964,7 @@ func TestSSHPolicyRules(t *testing.T) {
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := pm.SSHPolicy(&tt.targetNode)
|
||||
got, err := pm.SSHPolicy(tt.targetNode.View())
|
||||
require.NoError(t, err)
|
||||
|
||||
if diff := cmp.Diff(tt.wantSSH, got); diff != "" {
|
||||
@@ -2426,7 +2431,7 @@ func TestReduceRoutes(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
matchers := matcher.MatchesFromFilterRules(tt.args.rules)
|
||||
got := ReduceRoutes(
|
||||
tt.args.node,
|
||||
tt.args.node.View(),
|
||||
tt.args.routes,
|
||||
matchers,
|
||||
)
|
||||
|
@@ -776,7 +776,7 @@ func TestNodeCanApproveRoute(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Initialize all policy manager implementations
|
||||
policyManagers, err := PolicyManagersForTest([]byte(tt.policy), users, types.Nodes{&tt.node})
|
||||
policyManagers, err := PolicyManagersForTest([]byte(tt.policy), users, types.Nodes{&tt.node}.ViewSlice())
|
||||
if tt.name == "empty policy" {
|
||||
// We expect this one to have a valid but empty policy
|
||||
require.NoError(t, err)
|
||||
@@ -789,7 +789,7 @@ func TestNodeCanApproveRoute(t *testing.T) {
|
||||
|
||||
for i, pm := range policyManagers {
|
||||
t.Run(fmt.Sprintf("policy-index%d", i), func(t *testing.T) {
|
||||
result := pm.NodeCanApproveRoute(&tt.node, tt.route)
|
||||
result := pm.NodeCanApproveRoute(tt.node.View(), tt.route)
|
||||
|
||||
if diff := cmp.Diff(tt.canApprove, result); diff != "" {
|
||||
t.Errorf("NodeCanApproveRoute() mismatch (-want +got):\n%s", diff)
|
||||
|
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
"go4.org/netipx"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/views"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -20,7 +21,7 @@ var (
|
||||
// set of Tailscale compatible FilterRules used to allow traffic on clients.
|
||||
func (pol *Policy) compileFilterRules(
|
||||
users types.Users,
|
||||
nodes types.Nodes,
|
||||
nodes views.Slice[types.NodeView],
|
||||
) ([]tailcfg.FilterRule, error) {
|
||||
if pol == nil {
|
||||
return tailcfg.FilterAllowAll, nil
|
||||
@@ -97,8 +98,8 @@ func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction {
|
||||
|
||||
func (pol *Policy) compileSSHPolicy(
|
||||
users types.Users,
|
||||
node *types.Node,
|
||||
nodes types.Nodes,
|
||||
node types.NodeView,
|
||||
nodes views.Slice[types.NodeView],
|
||||
) (*tailcfg.SSHPolicy, error) {
|
||||
if pol == nil || pol.SSHs == nil || len(pol.SSHs) == 0 {
|
||||
return nil, nil
|
||||
|
@@ -362,7 +362,7 @@ func TestParsing(t *testing.T) {
|
||||
User: users[0],
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
},
|
||||
})
|
||||
}.ViewSlice())
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
@@ -16,13 +16,14 @@ import (
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/deephash"
|
||||
"tailscale.com/types/views"
|
||||
)
|
||||
|
||||
type PolicyManager struct {
|
||||
mu sync.Mutex
|
||||
pol *Policy
|
||||
users []types.User
|
||||
nodes types.Nodes
|
||||
nodes views.Slice[types.NodeView]
|
||||
|
||||
filterHash deephash.Sum
|
||||
filter []tailcfg.FilterRule
|
||||
@@ -43,7 +44,7 @@ type PolicyManager struct {
|
||||
// NewPolicyManager creates a new PolicyManager from a policy file and a list of users and nodes.
|
||||
// It returns an error if the policy file is invalid.
|
||||
// The policy manager will update the filter rules based on the users and nodes.
|
||||
func NewPolicyManager(b []byte, users []types.User, nodes types.Nodes) (*PolicyManager, error) {
|
||||
func NewPolicyManager(b []byte, users []types.User, nodes views.Slice[types.NodeView]) (*PolicyManager, error) {
|
||||
policy, err := unmarshalPolicy(b)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing policy: %w", err)
|
||||
@@ -53,7 +54,7 @@ func NewPolicyManager(b []byte, users []types.User, nodes types.Nodes) (*PolicyM
|
||||
pol: policy,
|
||||
users: users,
|
||||
nodes: nodes,
|
||||
sshPolicyMap: make(map[types.NodeID]*tailcfg.SSHPolicy, len(nodes)),
|
||||
sshPolicyMap: make(map[types.NodeID]*tailcfg.SSHPolicy, nodes.Len()),
|
||||
}
|
||||
|
||||
_, err = pm.updateLocked()
|
||||
@@ -122,11 +123,11 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) {
|
||||
func (pm *PolicyManager) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
if sshPol, ok := pm.sshPolicyMap[node.ID]; ok {
|
||||
if sshPol, ok := pm.sshPolicyMap[node.ID()]; ok {
|
||||
return sshPol, nil
|
||||
}
|
||||
|
||||
@@ -134,7 +135,7 @@ func (pm *PolicyManager) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("compiling SSH policy: %w", err)
|
||||
}
|
||||
pm.sshPolicyMap[node.ID] = sshPol
|
||||
pm.sshPolicyMap[node.ID()] = sshPol
|
||||
|
||||
return sshPol, nil
|
||||
}
|
||||
@@ -181,7 +182,7 @@ func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) {
|
||||
}
|
||||
|
||||
// SetNodes updates the nodes in the policy manager and updates the filter rules.
|
||||
func (pm *PolicyManager) SetNodes(nodes types.Nodes) (bool, error) {
|
||||
func (pm *PolicyManager) SetNodes(nodes views.Slice[types.NodeView]) (bool, error) {
|
||||
if pm == nil {
|
||||
return false, nil
|
||||
}
|
||||
@@ -192,7 +193,7 @@ func (pm *PolicyManager) SetNodes(nodes types.Nodes) (bool, error) {
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) NodeCanHaveTag(node *types.Node, tag string) bool {
|
||||
func (pm *PolicyManager) NodeCanHaveTag(node types.NodeView, tag string) bool {
|
||||
if pm == nil {
|
||||
return false
|
||||
}
|
||||
@@ -209,7 +210,7 @@ func (pm *PolicyManager) NodeCanHaveTag(node *types.Node, tag string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) NodeCanApproveRoute(node *types.Node, route netip.Prefix) bool {
|
||||
func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Prefix) bool {
|
||||
if pm == nil {
|
||||
return false
|
||||
}
|
||||
@@ -322,7 +323,11 @@ func (pm *PolicyManager) DebugString() string {
|
||||
}
|
||||
|
||||
sb.WriteString("\n\n")
|
||||
sb.WriteString(pm.nodes.DebugString())
|
||||
sb.WriteString("Nodes:\n")
|
||||
for _, node := range pm.nodes.All() {
|
||||
sb.WriteString(node.String())
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
@@ -47,7 +47,7 @@ func TestPolicyManager(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes)
|
||||
pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
filter, matchers := pm.Filter()
|
||||
|
@@ -18,6 +18,7 @@ import (
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/ptr"
|
||||
"tailscale.com/types/views"
|
||||
"tailscale.com/util/multierr"
|
||||
)
|
||||
|
||||
@@ -91,7 +92,7 @@ func (a Asterix) UnmarshalJSON(b []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a Asterix) Resolve(_ *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
|
||||
func (a Asterix) Resolve(_ *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
|
||||
var ips netipx.IPSetBuilder
|
||||
|
||||
// TODO(kradalby):
|
||||
@@ -179,7 +180,7 @@ func (u Username) resolveUser(users types.Users) (types.User, error) {
|
||||
return potentialUsers[0], nil
|
||||
}
|
||||
|
||||
func (u Username) Resolve(_ *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
|
||||
func (u Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
|
||||
var ips netipx.IPSetBuilder
|
||||
var errs []error
|
||||
|
||||
@@ -188,12 +189,13 @@ func (u Username) Resolve(_ *Policy, users types.Users, nodes types.Nodes) (*net
|
||||
errs = append(errs, err)
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
for _, node := range nodes.All() {
|
||||
// Skip tagged nodes
|
||||
if node.IsTagged() {
|
||||
continue
|
||||
}
|
||||
|
||||
if node.User.ID == user.ID {
|
||||
if node.User().ID == user.ID {
|
||||
node.AppendToIPSet(&ips)
|
||||
}
|
||||
}
|
||||
@@ -246,7 +248,7 @@ func (g Group) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(string(g))
|
||||
}
|
||||
|
||||
func (g Group) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
|
||||
func (g Group) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
|
||||
var ips netipx.IPSetBuilder
|
||||
var errs []error
|
||||
|
||||
@@ -280,7 +282,7 @@ func (t *Tag) UnmarshalJSON(b []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t Tag) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
|
||||
func (t Tag) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
|
||||
var ips netipx.IPSetBuilder
|
||||
|
||||
// TODO(kradalby): This is currently resolved twice, and should be resolved once.
|
||||
@@ -295,17 +297,19 @@ func (t Tag) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.I
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
if node.HasTag(string(t)) {
|
||||
for _, node := range nodes.All() {
|
||||
// Check if node has this tag in all tags (ForcedTags + AuthKey.Tags)
|
||||
if slices.Contains(node.Tags(), string(t)) {
|
||||
node.AppendToIPSet(&ips)
|
||||
}
|
||||
|
||||
// TODO(kradalby): remove as part of #2417, see comment above
|
||||
if tagMap != nil {
|
||||
if tagips, ok := tagMap[t]; ok && node.InIPSet(tagips) && node.Hostinfo != nil {
|
||||
for _, tag := range node.Hostinfo.RequestTags {
|
||||
if tagips, ok := tagMap[t]; ok && node.InIPSet(tagips) && node.Hostinfo().Valid() {
|
||||
for _, tag := range node.RequestTagsSlice().All() {
|
||||
if tag == string(t) {
|
||||
node.AppendToIPSet(&ips)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -346,7 +350,7 @@ func (h *Host) UnmarshalJSON(b []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h Host) Resolve(p *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
|
||||
func (h Host) Resolve(p *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
|
||||
var ips netipx.IPSetBuilder
|
||||
var errs []error
|
||||
|
||||
@@ -371,7 +375,7 @@ func (h Host) Resolve(p *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSe
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
for _, node := range nodes {
|
||||
for _, node := range nodes.All() {
|
||||
if node.InIPSet(ipsTemp) {
|
||||
node.AppendToIPSet(&ips)
|
||||
}
|
||||
@@ -432,7 +436,7 @@ func (p *Prefix) UnmarshalJSON(b []byte) error {
|
||||
// of the Prefix and the Policy, Users, and Nodes.
|
||||
//
|
||||
// See [Policy], [types.Users], and [types.Nodes] for more details.
|
||||
func (p Prefix) Resolve(_ *Policy, _ types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
|
||||
func (p Prefix) Resolve(_ *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
|
||||
var ips netipx.IPSetBuilder
|
||||
var errs []error
|
||||
|
||||
@@ -446,12 +450,12 @@ func (p Prefix) Resolve(_ *Policy, _ types.Users, nodes types.Nodes) (*netipx.IP
|
||||
|
||||
// appendIfNodeHasIP appends the IPs of the nodes to the IPSet if the node has the
|
||||
// IP address in the prefix.
|
||||
func appendIfNodeHasIP(nodes types.Nodes, ips *netipx.IPSetBuilder, pref netip.Prefix) {
|
||||
func appendIfNodeHasIP(nodes views.Slice[types.NodeView], ips *netipx.IPSetBuilder, pref netip.Prefix) {
|
||||
if !pref.IsSingleIP() && !tsaddr.IsTailscaleIP(pref.Addr()) {
|
||||
return
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
for _, node := range nodes.All() {
|
||||
if node.HasIP(pref.Addr()) {
|
||||
node.AppendToIPSet(ips)
|
||||
}
|
||||
@@ -499,7 +503,7 @@ func (ag AutoGroup) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(string(ag))
|
||||
}
|
||||
|
||||
func (ag AutoGroup) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
|
||||
func (ag AutoGroup) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
|
||||
var build netipx.IPSetBuilder
|
||||
|
||||
switch ag {
|
||||
@@ -513,17 +517,17 @@ func (ag AutoGroup) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*n
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
// Skip if node has forced tags
|
||||
if len(node.ForcedTags) != 0 {
|
||||
for _, node := range nodes.All() {
|
||||
// Skip if node is tagged
|
||||
if node.IsTagged() {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip if node has any allowed requested tags
|
||||
hasAllowedTag := false
|
||||
if node.Hostinfo != nil && len(node.Hostinfo.RequestTags) != 0 {
|
||||
for _, tag := range node.Hostinfo.RequestTags {
|
||||
if tagips, ok := tagMap[Tag(tag)]; ok && node.InIPSet(tagips) {
|
||||
if node.RequestTagsSlice().Len() != 0 {
|
||||
for _, tag := range node.RequestTagsSlice().All() {
|
||||
if _, ok := tagMap[Tag(tag)]; ok {
|
||||
hasAllowedTag = true
|
||||
break
|
||||
}
|
||||
@@ -546,16 +550,16 @@ func (ag AutoGroup) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*n
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
// Include if node has forced tags
|
||||
if len(node.ForcedTags) != 0 {
|
||||
for _, node := range nodes.All() {
|
||||
// Include if node is tagged
|
||||
if node.IsTagged() {
|
||||
node.AppendToIPSet(&build)
|
||||
continue
|
||||
}
|
||||
|
||||
// Include if node has any allowed requested tags
|
||||
if node.Hostinfo != nil && len(node.Hostinfo.RequestTags) != 0 {
|
||||
for _, tag := range node.Hostinfo.RequestTags {
|
||||
if node.RequestTagsSlice().Len() != 0 {
|
||||
for _, tag := range node.RequestTagsSlice().All() {
|
||||
if _, ok := tagMap[Tag(tag)]; ok {
|
||||
node.AppendToIPSet(&build)
|
||||
break
|
||||
@@ -588,7 +592,7 @@ type Alias interface {
|
||||
// of the Alias and the Policy, Users and Nodes.
|
||||
// This is an interface definition and the implementation is independent of
|
||||
// the Alias type.
|
||||
Resolve(*Policy, types.Users, types.Nodes) (*netipx.IPSet, error)
|
||||
Resolve(*Policy, types.Users, views.Slice[types.NodeView]) (*netipx.IPSet, error)
|
||||
}
|
||||
|
||||
type AliasWithPorts struct {
|
||||
@@ -759,7 +763,7 @@ func (a Aliases) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(aliases)
|
||||
}
|
||||
|
||||
func (a Aliases) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
|
||||
func (a Aliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
|
||||
var ips netipx.IPSetBuilder
|
||||
var errs []error
|
||||
|
||||
@@ -1094,7 +1098,7 @@ func (to TagOwners) Contains(tagOwner *Tag) error {
|
||||
// resolveTagOwners resolves the TagOwners to a map of Tag to netipx.IPSet.
|
||||
// The resulting map can be used to quickly look up the IPSet for a given Tag.
|
||||
// It is intended for internal use in a PolicyManager.
|
||||
func resolveTagOwners(p *Policy, users types.Users, nodes types.Nodes) (map[Tag]*netipx.IPSet, error) {
|
||||
func resolveTagOwners(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (map[Tag]*netipx.IPSet, error) {
|
||||
if p == nil {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -1158,7 +1162,7 @@ func (ap AutoApproverPolicy) MarshalJSON() ([]byte, error) {
|
||||
// resolveAutoApprovers resolves the AutoApprovers to a map of netip.Prefix to netipx.IPSet.
|
||||
// The resulting map can be used to quickly look up if a node can self-approve a route.
|
||||
// It is intended for internal use in a PolicyManager.
|
||||
func resolveAutoApprovers(p *Policy, users types.Users, nodes types.Nodes) (map[netip.Prefix]*netipx.IPSet, *netipx.IPSet, error) {
|
||||
func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (map[netip.Prefix]*netipx.IPSet, *netipx.IPSet, error) {
|
||||
if p == nil {
|
||||
return nil, nil, nil
|
||||
}
|
||||
@@ -1671,7 +1675,7 @@ func (a SSHSrcAliases) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(aliases)
|
||||
}
|
||||
|
||||
func (a SSHSrcAliases) Resolve(p *Policy, users types.Users, nodes types.Nodes) (*netipx.IPSet, error) {
|
||||
func (a SSHSrcAliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
|
||||
var ips netipx.IPSetBuilder
|
||||
var errs []error
|
||||
|
||||
|
@@ -1377,7 +1377,7 @@ func TestResolvePolicy(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ips, err := tt.toResolve.Resolve(tt.pol,
|
||||
xmaps.Values(users),
|
||||
tt.nodes)
|
||||
tt.nodes.ViewSlice())
|
||||
if tt.wantErr == "" {
|
||||
if err != nil {
|
||||
t.Fatalf("got %v; want no error", err)
|
||||
@@ -1557,7 +1557,7 @@ func TestResolveAutoApprovers(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, gotAllIPRoutes, err := resolveAutoApprovers(tt.policy, users, nodes)
|
||||
got, gotAllIPRoutes, err := resolveAutoApprovers(tt.policy, users, nodes.ViewSlice())
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("resolveAutoApprovers() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
@@ -1716,10 +1716,10 @@ func TestNodeCanApproveRoute(t *testing.T) {
|
||||
b, err := json.Marshal(tt.policy)
|
||||
require.NoError(t, err)
|
||||
|
||||
pm, err := NewPolicyManager(b, users, nodes)
|
||||
pm, err := NewPolicyManager(b, users, nodes.ViewSlice())
|
||||
require.NoErrorf(t, err, "NewPolicyManager() error = %v", err)
|
||||
|
||||
got := pm.NodeCanApproveRoute(tt.node, tt.route)
|
||||
got := pm.NodeCanApproveRoute(tt.node.View(), tt.route)
|
||||
if got != tt.want {
|
||||
t.Errorf("NodeCanApproveRoute() = %v, want %v", got, tt.want)
|
||||
}
|
||||
@@ -1800,7 +1800,7 @@ func TestResolveTagOwners(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := resolveTagOwners(tt.policy, users, nodes)
|
||||
got, err := resolveTagOwners(tt.policy, users, nodes.ViewSlice())
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("resolveTagOwners() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
@@ -1911,14 +1911,14 @@ func TestNodeCanHaveTag(t *testing.T) {
|
||||
b, err := json.Marshal(tt.policy)
|
||||
require.NoError(t, err)
|
||||
|
||||
pm, err := NewPolicyManager(b, users, nodes)
|
||||
pm, err := NewPolicyManager(b, users, nodes.ViewSlice())
|
||||
if tt.wantErr != "" {
|
||||
require.ErrorContains(t, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
got := pm.NodeCanHaveTag(tt.node, tt.tag)
|
||||
got := pm.NodeCanHaveTag(tt.node.View(), tt.tag)
|
||||
if got != tt.want {
|
||||
t.Errorf("NodeCanHaveTag() = %v, want %v", got, tt.want)
|
||||
}
|
||||
|
Reference in New Issue
Block a user