integration: replace time.Sleep with assert.EventuallyWithT (#2680)

This commit is contained in:
Kristoffer Dalby
2025-07-10 23:38:55 +02:00
committed by GitHub
parent b904276f2b
commit c6d7b512bd
73 changed files with 584 additions and 573 deletions

View File

@@ -13,9 +13,7 @@ import (
"tailscale.com/types/views"
)
var (
ErrInvalidAction = errors.New("invalid action")
)
var ErrInvalidAction = errors.New("invalid action")
// compileFilterRules takes a set of nodes and an ACLPolicy and generates a
// set of Tailscale compatible FilterRules used to allow traffic on clients.
@@ -52,7 +50,7 @@ func (pol *Policy) compileFilterRules(
var destPorts []tailcfg.NetPortRange
for _, dest := range acl.Destinations {
ips, err := dest.Alias.Resolve(pol, users, nodes)
ips, err := dest.Resolve(pol, users, nodes)
if err != nil {
log.Trace().Err(err).Msgf("resolving destination ips")
}
@@ -174,5 +172,6 @@ func ipSetToPrefixStringList(ips *netipx.IPSet) []string {
for _, pref := range ips.Prefixes() {
out = append(out, pref.String())
}
return out
}

View File

@@ -4,19 +4,17 @@ import (
"encoding/json"
"fmt"
"net/netip"
"slices"
"strings"
"sync"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"slices"
"github.com/juanfont/headscale/hscontrol/types"
"go4.org/netipx"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/util/deephash"
"tailscale.com/types/views"
"tailscale.com/util/deephash"
)
type PolicyManager struct {
@@ -166,6 +164,7 @@ func (pm *PolicyManager) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
pm.mu.Lock()
defer pm.mu.Unlock()
return pm.filter, pm.matchers
}
@@ -178,6 +177,7 @@ func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) {
pm.mu.Lock()
defer pm.mu.Unlock()
pm.users = users
return pm.updateLocked()
}
@@ -190,6 +190,7 @@ func (pm *PolicyManager) SetNodes(nodes views.Slice[types.NodeView]) (bool, erro
pm.mu.Lock()
defer pm.mu.Unlock()
pm.nodes = nodes
return pm.updateLocked()
}
@@ -249,7 +250,6 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr
// cannot just lookup in the prefix map and have to check
// if there is a "parent" prefix available.
for prefix, approveAddrs := range pm.autoApproveMap {
// Check if prefix is larger (so containing) and then overlaps
// the route to see if the node can approve a subset of an autoapprover
if prefix.Bits() <= route.Bits() && prefix.Overlaps(route) {

View File

@@ -1,10 +1,10 @@
package v2
import (
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/require"
"gorm.io/gorm"

View File

@@ -6,9 +6,9 @@ import (
"errors"
"fmt"
"net/netip"
"strings"
"slices"
"strconv"
"strings"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
@@ -72,14 +72,14 @@ func (a AliasWithPorts) MarshalJSON() ([]byte, error) {
// Check if it's the wildcard port range
if len(a.Ports) == 1 && a.Ports[0].First == 0 && a.Ports[0].Last == 65535 {
return json.Marshal(fmt.Sprintf("%s:*", alias))
return json.Marshal(alias + ":*")
}
// Otherwise, format as "alias:ports"
var ports []string
for _, port := range a.Ports {
if port.First == port.Last {
ports = append(ports, fmt.Sprintf("%d", port.First))
ports = append(ports, strconv.FormatUint(uint64(port.First), 10))
} else {
ports = append(ports, fmt.Sprintf("%d-%d", port.First, port.Last))
}
@@ -133,6 +133,7 @@ func (u *Username) UnmarshalJSON(b []byte) error {
if err := u.Validate(); err != nil {
return err
}
return nil
}
@@ -203,7 +204,7 @@ func (u Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types.
return buildIPSetMultiErr(&ips, errs)
}
// Group is a special string which is always prefixed with `group:`
// Group is a special string which is always prefixed with `group:`.
type Group string
func (g Group) Validate() error {
@@ -218,6 +219,7 @@ func (g *Group) UnmarshalJSON(b []byte) error {
if err := g.Validate(); err != nil {
return err
}
return nil
}
@@ -264,7 +266,7 @@ func (g Group) Resolve(p *Policy, users types.Users, nodes views.Slice[types.Nod
return buildIPSetMultiErr(&ips, errs)
}
// Tag is a special string which is always prefixed with `tag:`
// Tag is a special string which is always prefixed with `tag:`.
type Tag string
func (t Tag) Validate() error {
@@ -279,6 +281,7 @@ func (t *Tag) UnmarshalJSON(b []byte) error {
if err := t.Validate(); err != nil {
return err
}
return nil
}
@@ -347,6 +350,7 @@ func (h *Host) UnmarshalJSON(b []byte) error {
if err := h.Validate(); err != nil {
return err
}
return nil
}
@@ -409,6 +413,7 @@ func (p *Prefix) parseString(addr string) error {
}
*p = Prefix(addrPref)
return nil
}
@@ -417,6 +422,7 @@ func (p *Prefix) parseString(addr string) error {
return err
}
*p = Prefix(pref)
return nil
}
@@ -428,6 +434,7 @@ func (p *Prefix) UnmarshalJSON(b []byte) error {
if err := p.Validate(); err != nil {
return err
}
return nil
}
@@ -462,7 +469,7 @@ func appendIfNodeHasIP(nodes views.Slice[types.NodeView], ips *netipx.IPSetBuild
}
}
// AutoGroup is a special string which is always prefixed with `autogroup:`
// AutoGroup is a special string which is always prefixed with `autogroup:`.
type AutoGroup string
const (
@@ -495,6 +502,7 @@ func (ag *AutoGroup) UnmarshalJSON(b []byte) error {
if err := ag.Validate(); err != nil {
return err
}
return nil
}
@@ -632,13 +640,14 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error {
if err != nil {
return err
}
if err := ve.Alias.Validate(); err != nil {
if err := ve.Validate(); err != nil {
return err
}
default:
return fmt.Errorf("type %T not supported", vs)
}
return nil
}
@@ -713,6 +722,7 @@ func (ve *AliasEnc) UnmarshalJSON(b []byte) error {
return err
}
ve.Alias = ptr
return nil
}
@@ -729,6 +739,7 @@ func (a *Aliases) UnmarshalJSON(b []byte) error {
for i, alias := range aliases {
(*a)[i] = alias.Alias
}
return nil
}
@@ -784,7 +795,7 @@ func buildIPSetMultiErr(ipBuilder *netipx.IPSetBuilder, errs []error) (*netipx.I
return ips, multierr.New(append(errs, err)...)
}
// Helper function to unmarshal a JSON string into either an AutoApprover or Owner pointer
// Helper function to unmarshal a JSON string into either an AutoApprover or Owner pointer.
func unmarshalPointer[T any](
b []byte,
parseFunc func(string) (T, error),
@@ -818,6 +829,7 @@ func (aa *AutoApprovers) UnmarshalJSON(b []byte) error {
for i, autoApprover := range autoApprovers {
(*aa)[i] = autoApprover.AutoApprover
}
return nil
}
@@ -874,6 +886,7 @@ func (ve *AutoApproverEnc) UnmarshalJSON(b []byte) error {
return err
}
ve.AutoApprover = ptr
return nil
}
@@ -894,6 +907,7 @@ func (ve *OwnerEnc) UnmarshalJSON(b []byte) error {
return err
}
ve.Owner = ptr
return nil
}
@@ -910,6 +924,7 @@ func (o *Owners) UnmarshalJSON(b []byte) error {
for i, owner := range owners {
(*o)[i] = owner.Owner
}
return nil
}
@@ -941,6 +956,7 @@ func parseOwner(s string) (Owner, error) {
case isGroup(s):
return ptr.To(Group(s)), nil
}
return nil, fmt.Errorf(`Invalid Owner %q. An alias must be one of the following types:
- user (containing an "@")
- group (starting with "group:")
@@ -1001,6 +1017,7 @@ func (g *Groups) UnmarshalJSON(b []byte) error {
(*g)[group] = usernames
}
return nil
}
@@ -1252,7 +1269,7 @@ type Policy struct {
// We use the default JSON marshalling behavior provided by the Go runtime.
var (
// TODO(kradalby): Add these checks for tagOwners and autoApprovers
// TODO(kradalby): Add these checks for tagOwners and autoApprovers.
autogroupForSrc = []AutoGroup{AutoGroupMember, AutoGroupTagged}
autogroupForDst = []AutoGroup{AutoGroupInternet, AutoGroupMember, AutoGroupTagged}
autogroupForSSHSrc = []AutoGroup{AutoGroupMember, AutoGroupTagged}
@@ -1279,7 +1296,7 @@ func validateAutogroupForSrc(src *AutoGroup) error {
}
if src.Is(AutoGroupInternet) {
return fmt.Errorf(`"autogroup:internet" used in source, it can only be used in ACL destinations`)
return errors.New(`"autogroup:internet" used in source, it can only be used in ACL destinations`)
}
if !slices.Contains(autogroupForSrc, *src) {
@@ -1307,7 +1324,7 @@ func validateAutogroupForSSHSrc(src *AutoGroup) error {
}
if src.Is(AutoGroupInternet) {
return fmt.Errorf(`"autogroup:internet" used in SSH source, it can only be used in ACL destinations`)
return errors.New(`"autogroup:internet" used in SSH source, it can only be used in ACL destinations`)
}
if !slices.Contains(autogroupForSSHSrc, *src) {
@@ -1323,7 +1340,7 @@ func validateAutogroupForSSHDst(dst *AutoGroup) error {
}
if dst.Is(AutoGroupInternet) {
return fmt.Errorf(`"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`)
return errors.New(`"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`)
}
if !slices.Contains(autogroupForSSHDst, *dst) {
@@ -1360,14 +1377,14 @@ func (p *Policy) validate() error {
for _, acl := range p.ACLs {
for _, src := range acl.Sources {
switch src.(type) {
switch src := src.(type) {
case *Host:
h := src.(*Host)
h := src
if !p.Hosts.exist(*h) {
errs = append(errs, fmt.Errorf(`Host %q is not defined in the Policy, please define or remove the reference to it`, *h))
}
case *AutoGroup:
ag := src.(*AutoGroup)
ag := src
if err := validateAutogroupSupported(ag); err != nil {
errs = append(errs, err)
@@ -1379,12 +1396,12 @@ func (p *Policy) validate() error {
continue
}
case *Group:
g := src.(*Group)
g := src
if err := p.Groups.Contains(g); err != nil {
errs = append(errs, err)
}
case *Tag:
tagOwner := src.(*Tag)
tagOwner := src
if err := p.TagOwners.Contains(tagOwner); err != nil {
errs = append(errs, err)
}
@@ -1440,9 +1457,9 @@ func (p *Policy) validate() error {
}
for _, src := range ssh.Sources {
switch src.(type) {
switch src := src.(type) {
case *AutoGroup:
ag := src.(*AutoGroup)
ag := src
if err := validateAutogroupSupported(ag); err != nil {
errs = append(errs, err)
@@ -1454,21 +1471,21 @@ func (p *Policy) validate() error {
continue
}
case *Group:
g := src.(*Group)
g := src
if err := p.Groups.Contains(g); err != nil {
errs = append(errs, err)
}
case *Tag:
tagOwner := src.(*Tag)
tagOwner := src
if err := p.TagOwners.Contains(tagOwner); err != nil {
errs = append(errs, err)
}
}
}
for _, dst := range ssh.Destinations {
switch dst.(type) {
switch dst := dst.(type) {
case *AutoGroup:
ag := dst.(*AutoGroup)
ag := dst
if err := validateAutogroupSupported(ag); err != nil {
errs = append(errs, err)
continue
@@ -1479,7 +1496,7 @@ func (p *Policy) validate() error {
continue
}
case *Tag:
tagOwner := dst.(*Tag)
tagOwner := dst
if err := p.TagOwners.Contains(tagOwner); err != nil {
errs = append(errs, err)
}
@@ -1489,9 +1506,9 @@ func (p *Policy) validate() error {
for _, tagOwners := range p.TagOwners {
for _, tagOwner := range tagOwners {
switch tagOwner.(type) {
switch tagOwner := tagOwner.(type) {
case *Group:
g := tagOwner.(*Group)
g := tagOwner
if err := p.Groups.Contains(g); err != nil {
errs = append(errs, err)
}
@@ -1501,14 +1518,14 @@ func (p *Policy) validate() error {
for _, approvers := range p.AutoApprovers.Routes {
for _, approver := range approvers {
switch approver.(type) {
switch approver := approver.(type) {
case *Group:
g := approver.(*Group)
g := approver
if err := p.Groups.Contains(g); err != nil {
errs = append(errs, err)
}
case *Tag:
tagOwner := approver.(*Tag)
tagOwner := approver
if err := p.TagOwners.Contains(tagOwner); err != nil {
errs = append(errs, err)
}
@@ -1517,14 +1534,14 @@ func (p *Policy) validate() error {
}
for _, approver := range p.AutoApprovers.ExitNode {
switch approver.(type) {
switch approver := approver.(type) {
case *Group:
g := approver.(*Group)
g := approver
if err := p.Groups.Contains(g); err != nil {
errs = append(errs, err)
}
case *Tag:
tagOwner := approver.(*Tag)
tagOwner := approver
if err := p.TagOwners.Contains(tagOwner); err != nil {
errs = append(errs, err)
}
@@ -1536,6 +1553,7 @@ func (p *Policy) validate() error {
}
p.validated = true
return nil
}
@@ -1589,6 +1607,7 @@ func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error {
)
}
}
return nil
}
@@ -1618,6 +1637,7 @@ func (a *SSHDstAliases) UnmarshalJSON(b []byte) error {
)
}
}
return nil
}

View File

@@ -5,13 +5,13 @@ import (
"net/netip"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/prometheus/common/model"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go4.org/netipx"
@@ -68,7 +68,7 @@ func TestMarshalJSON(t *testing.T) {
// Marshal the policy to JSON
marshalled, err := json.MarshalIndent(policy, "", " ")
require.NoError(t, err)
// Make sure all expected fields are present in the JSON
jsonString := string(marshalled)
assert.Contains(t, jsonString, "group:example")
@@ -79,21 +79,21 @@ func TestMarshalJSON(t *testing.T) {
assert.Contains(t, jsonString, "accept")
assert.Contains(t, jsonString, "tcp")
assert.Contains(t, jsonString, "80")
// Unmarshal back to verify round trip
var roundTripped Policy
err = json.Unmarshal(marshalled, &roundTripped)
require.NoError(t, err)
// Compare the original and round-tripped policies
cmps := append(util.Comparers,
cmps := append(util.Comparers,
cmp.Comparer(func(x, y Prefix) bool {
return x == y
}),
cmpopts.IgnoreUnexported(Policy{}),
cmpopts.EquateEmpty(),
)
if diff := cmp.Diff(policy, &roundTripped, cmps...); diff != "" {
t.Fatalf("round trip policy (-original +roundtripped):\n%s", diff)
}
@@ -958,13 +958,13 @@ func TestUnmarshalPolicy(t *testing.T) {
},
}
cmps := append(util.Comparers,
cmps := append(util.Comparers,
cmp.Comparer(func(x, y Prefix) bool {
return x == y
}),
cmpopts.IgnoreUnexported(Policy{}),
)
// For round-trip testing, we'll normalize the policies before comparing
for _, tt := range tests {
@@ -981,6 +981,7 @@ func TestUnmarshalPolicy(t *testing.T) {
} else if !strings.Contains(err.Error(), tt.wantErr) {
t.Fatalf("unmarshalling: got err %v; want error %q", err, tt.wantErr)
}
return // Skip the rest of the test if we expected an error
}
@@ -1001,9 +1002,9 @@ func TestUnmarshalPolicy(t *testing.T) {
if err != nil {
t.Fatalf("round-trip unmarshalling: %v", err)
}
// Add EquateEmpty to handle nil vs empty maps/slices
roundTripCmps := append(cmps,
roundTripCmps := append(cmps,
cmpopts.EquateEmpty(),
cmpopts.IgnoreUnexported(Policy{}),
)
@@ -1584,6 +1585,7 @@ func mustIPSet(prefixes ...string) *netipx.IPSet {
builder.AddPrefix(mp(p))
}
ipSet, _ := builder.IPSet()
return ipSet
}

View File

@@ -73,10 +73,10 @@ func TestParsePortRange(t *testing.T) {
expected []tailcfg.PortRange
err string
}{
{"80", []tailcfg.PortRange{{80, 80}}, ""},
{"80-90", []tailcfg.PortRange{{80, 90}}, ""},
{"80,90", []tailcfg.PortRange{{80, 80}, {90, 90}}, ""},
{"80-91,92,93-95", []tailcfg.PortRange{{80, 91}, {92, 92}, {93, 95}}, ""},
{"80", []tailcfg.PortRange{{First: 80, Last: 80}}, ""},
{"80-90", []tailcfg.PortRange{{First: 80, Last: 90}}, ""},
{"80,90", []tailcfg.PortRange{{First: 80, Last: 80}, {First: 90, Last: 90}}, ""},
{"80-91,92,93-95", []tailcfg.PortRange{{First: 80, Last: 91}, {First: 92, Last: 92}, {First: 93, Last: 95}}, ""},
{"*", []tailcfg.PortRange{tailcfg.PortRangeAny}, ""},
{"80-", nil, "invalid port range format"},
{"-90", nil, "invalid port range format"},