From 93d56362af687fb7d0b7797fe622c619abe86ab3 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Mon, 14 Nov 2022 14:27:02 +0100 Subject: [PATCH] Lock and unify headscale start/get method Signed-off-by: Kristoffer Dalby --- integration/auth_web_flow_test.go | 15 +++++--- integration/cli_test.go | 36 ++++++++++++------ integration/control.go | 3 ++ integration/scenario.go | 62 +++++++++++++++++++------------ integration/scenario_test.go | 18 ++++++--- 5 files changed, 89 insertions(+), 45 deletions(-) diff --git a/integration/auth_web_flow_test.go b/integration/auth_web_flow_test.go index d522d9e2..a51272ce 100644 --- a/integration/auth_web_flow_test.go +++ b/integration/auth_web_flow_test.go @@ -77,12 +77,12 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) { } func (s *AuthWebFlowScenario) CreateHeadscaleEnv(namespaces map[string]int) error { - err := s.StartHeadscale() + headscale, err := s.Headscale() if err != nil { return err } - err = s.MustHeadscale().WaitForReady() + err = headscale.WaitForReady() if err != nil { return err } @@ -99,7 +99,7 @@ func (s *AuthWebFlowScenario) CreateHeadscaleEnv(namespaces map[string]int) erro return err } - err = s.runTailscaleUp(namespaceName, s.MustHeadscale().GetEndpoint()) + err = s.runTailscaleUp(namespaceName, headscale.GetEndpoint()) if err != nil { return err } @@ -145,8 +145,13 @@ func (s *AuthWebFlowScenario) runTailscaleUp( } func (s *AuthWebFlowScenario) runHeadscaleRegister(namespaceStr string, loginURL *url.URL) error { + headscale, err := s.Headscale() + if err != nil { + return err + } + log.Printf("loginURL: %s", loginURL) - loginURL.Host = fmt.Sprintf("%s:8080", s.MustHeadscale().GetIP()) + loginURL.Host = fmt.Sprintf("%s:8080", headscale.GetIP()) loginURL.Scheme = "http" httpClient := &http.Client{} @@ -177,7 +182,7 @@ func (s *AuthWebFlowScenario) runHeadscaleRegister(namespaceStr string, loginURL key := keySep[1] log.Printf("registering node %s", key) - if headscale, ok := s.controlServers["headscale"]; ok { + if headscale, err := s.Headscale(); err == nil { _, err = headscale.Execute( []string{"headscale", "-n", namespaceStr, "nodes", "register", "--key", key}, ) diff --git a/integration/cli_test.go b/integration/cli_test.go index 5045f2d2..329c8247 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -39,8 +39,11 @@ func TestNamespaceCommand(t *testing.T) { err = scenario.CreateHeadscaleEnv(spec) assert.NoError(t, err) + headscale, err := scenario.Headscale() + assert.NoError(t, err) + var listNamespaces []v1.Namespace - err = executeAndUnmarshal(scenario.MustHeadscale(), + err = executeAndUnmarshal(headscale, []string{ "headscale", "namespaces", @@ -61,7 +64,7 @@ func TestNamespaceCommand(t *testing.T) { result, ) - _, err = scenario.MustHeadscale().Execute( + _, err = headscale.Execute( []string{ "headscale", "namespaces", @@ -75,7 +78,7 @@ func TestNamespaceCommand(t *testing.T) { assert.NoError(t, err) var listAfterRenameNamespaces []v1.Namespace - err = executeAndUnmarshal(scenario.MustHeadscale(), + err = executeAndUnmarshal(headscale, []string{ "headscale", "namespaces", @@ -117,13 +120,16 @@ func TestPreAuthKeyCommand(t *testing.T) { err = scenario.CreateHeadscaleEnv(spec) assert.NoError(t, err) + headscale, err := scenario.Headscale() + assert.NoError(t, err) + keys := make([]*v1.PreAuthKey, count) assert.NoError(t, err) for index := 0; index < count; index++ { var preAuthKey v1.PreAuthKey err := executeAndUnmarshal( - scenario.MustHeadscale(), + headscale, []string{ "headscale", "preauthkeys", @@ -149,7 +155,7 @@ func TestPreAuthKeyCommand(t *testing.T) { var listedPreAuthKeys []v1.PreAuthKey err = executeAndUnmarshal( - scenario.MustHeadscale(), + headscale, []string{ "headscale", "preauthkeys", @@ -202,7 +208,7 @@ func TestPreAuthKeyCommand(t *testing.T) { } // Test key expiry - _, err = scenario.MustHeadscale().Execute( + _, err = headscale.Execute( []string{ "headscale", "preauthkeys", @@ -216,7 +222,7 @@ func TestPreAuthKeyCommand(t *testing.T) { var listedPreAuthKeysAfterExpire []v1.PreAuthKey err = executeAndUnmarshal( - scenario.MustHeadscale(), + headscale, []string{ "headscale", "preauthkeys", @@ -254,9 +260,12 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) { err = scenario.CreateHeadscaleEnv(spec) assert.NoError(t, err) + headscale, err := scenario.Headscale() + assert.NoError(t, err) + var preAuthKey v1.PreAuthKey err = executeAndUnmarshal( - scenario.MustHeadscale(), + headscale, []string{ "headscale", "preauthkeys", @@ -273,7 +282,7 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) { var listedPreAuthKeys []v1.PreAuthKey err = executeAndUnmarshal( - scenario.MustHeadscale(), + headscale, []string{ "headscale", "preauthkeys", @@ -316,9 +325,12 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) { err = scenario.CreateHeadscaleEnv(spec) assert.NoError(t, err) + headscale, err := scenario.Headscale() + assert.NoError(t, err) + var preAuthReusableKey v1.PreAuthKey err = executeAndUnmarshal( - scenario.MustHeadscale(), + headscale, []string{ "headscale", "preauthkeys", @@ -335,7 +347,7 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) { var preAuthEphemeralKey v1.PreAuthKey err = executeAndUnmarshal( - scenario.MustHeadscale(), + headscale, []string{ "headscale", "preauthkeys", @@ -355,7 +367,7 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) { var listedPreAuthKeys []v1.PreAuthKey err = executeAndUnmarshal( - scenario.MustHeadscale(), + headscale, []string{ "headscale", "preauthkeys", diff --git a/integration/control.go b/integration/control.go index 33a687c0..d05f5305 100644 --- a/integration/control.go +++ b/integration/control.go @@ -13,4 +13,7 @@ type ControlServer interface { CreateNamespace(namespace string) error CreateAuthKey(namespace string) (*v1.PreAuthKey, error) ListMachinesInNamespace(namespace string) ([]*v1.Machine, error) + GetCert() []byte + GetHostname() string + GetIP() string } diff --git a/integration/scenario.go b/integration/scenario.go index 29203309..1c94bce6 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -15,6 +15,7 @@ import ( "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" "github.com/ory/dockertest/v3" + "github.com/puzpuzpuz/xsync/v2" ) const ( @@ -69,12 +70,14 @@ type Namespace struct { type Scenario struct { // TODO(kradalby): support multiple headcales for later, currently only // use one. - controlServers map[string]ControlServer + controlServers *xsync.MapOf[string, ControlServer] namespaces map[string]*Namespace pool *dockertest.Pool network *dockertest.Network + + headscaleLock sync.Mutex } func NewScenario() (*Scenario, error) { @@ -109,7 +112,7 @@ func NewScenario() (*Scenario, error) { } return &Scenario{ - controlServers: make(map[string]ControlServer), + controlServers: xsync.NewMapOf[ControlServer](), namespaces: make(map[string]*Namespace), pool: pool, @@ -118,12 +121,17 @@ func NewScenario() (*Scenario, error) { } func (s *Scenario) Shutdown() error { - for _, control := range s.controlServers { + s.controlServers.Range(func(_ string, control ControlServer) bool { err := control.Shutdown() if err != nil { - return fmt.Errorf("failed to tear down control: %w", err) + log.Printf( + "Failed to shut down control: %s", + fmt.Errorf("failed to tear down control: %w", err), + ) } - } + + return true + }) for namespaceName, namespace := range s.namespaces { for _, client := range namespace.Clients { @@ -160,31 +168,31 @@ func (s *Scenario) Namespaces() []string { // Note: These functions assume that there is a _single_ headscale instance for now // TODO(kradalby): make port and headscale configurable, multiple instances support? -func (s *Scenario) StartHeadscale(opts ...hsic.Option) error { +func (s *Scenario) Headscale(opts ...hsic.Option) (ControlServer, error) { + s.headscaleLock.Lock() + defer s.headscaleLock.Unlock() + + if headscale, ok := s.controlServers.Load("headscale"); ok { + return headscale, nil + } + headscale, err := hsic.New(s.pool, s.network, opts...) if err != nil { - return fmt.Errorf("failed to create headscale container: %w", err) + return nil, fmt.Errorf("failed to create headscale container: %w", err) } err = headscale.WaitForReady() if err != nil { - return err + return nil, fmt.Errorf("failed reach headscale container: %w", err) } - s.controlServers["headscale"] = headscale + s.controlServers.Store("headscale", headscale) - return nil -} - -// MustHeadscale returns the headscale unit of a scenario, it will crash if it -// is not available. -func (s *Scenario) MustHeadscale() *hsic.HeadscaleInContainer { - //nolint - return s.controlServers["headscale"].(*hsic.HeadscaleInContainer) + return headscale, nil } func (s *Scenario) CreatePreAuthKey(namespace string) (*v1.PreAuthKey, error) { - if headscale, ok := s.controlServers["headscale"]; ok { + if headscale, err := s.Headscale(); err == nil { key, err := headscale.CreateAuthKey(namespace) if err != nil { return nil, fmt.Errorf("failed to create namespace: %w", err) @@ -197,7 +205,7 @@ func (s *Scenario) CreatePreAuthKey(namespace string) (*v1.PreAuthKey, error) { } func (s *Scenario) CreateNamespace(namespace string) error { - if headscale, ok := s.controlServers["headscale"]; ok { + if headscale, err := s.Headscale(); err == nil { err := headscale.CreateNamespace(namespace) if err != nil { return fmt.Errorf("failed to create namespace: %w", err) @@ -227,6 +235,14 @@ func (s *Scenario) CreateTailscaleNodesInNamespace( version = TailscaleVersions[i%len(TailscaleVersions)] } + headscale, err := s.Headscale() + if err != nil { + return fmt.Errorf("failed to create tailscale node: %w", err) + } + + cert := headscale.GetCert() + hostname := headscale.GetHostname() + namespace.createWaitGroup.Add(1) go func() { @@ -237,8 +253,8 @@ func (s *Scenario) CreateTailscaleNodesInNamespace( s.pool, version, s.network, - tsic.WithHeadscaleTLS(s.MustHeadscale().GetCert()), - tsic.WithHeadscaleName(s.MustHeadscale().GetHostname()), + tsic.WithHeadscaleTLS(cert), + tsic.WithHeadscaleName(hostname), ) if err != nil { // return fmt.Errorf("failed to add tailscale node: %w", err) @@ -324,7 +340,7 @@ func (s *Scenario) WaitForTailscaleSync() error { // test environment with nodes of all versions, joined to the server with X // namespaces. func (s *Scenario) CreateHeadscaleEnv(namespaces map[string]int, opts ...hsic.Option) error { - err := s.StartHeadscale(opts...) + headscale, err := s.Headscale(opts...) if err != nil { return err } @@ -345,7 +361,7 @@ func (s *Scenario) CreateHeadscaleEnv(namespaces map[string]int, opts ...hsic.Op return err } - err = s.RunTailscaleUp(namespaceName, s.MustHeadscale().GetEndpoint(), key.GetKey()) + err = s.RunTailscaleUp(namespaceName, headscale.GetEndpoint(), key.GetKey()) if err != nil { return err } diff --git a/integration/scenario_test.go b/integration/scenario_test.go index d762d4c8..faa50d65 100644 --- a/integration/scenario_test.go +++ b/integration/scenario_test.go @@ -34,12 +34,12 @@ func TestHeadscale(t *testing.T) { } t.Run("start-headscale", func(t *testing.T) { - err = scenario.StartHeadscale() + headscale, err := scenario.Headscale() if err != nil { t.Errorf("failed to create start headcale: %s", err) } - err = scenario.MustHeadscale().WaitForReady() + err = headscale.WaitForReady() if err != nil { t.Errorf("headscale failed to become ready: %s", err) } @@ -117,12 +117,11 @@ func TestTailscaleNodesJoiningHeadcale(t *testing.T) { } t.Run("start-headscale", func(t *testing.T) { - err = scenario.StartHeadscale() + headscale, err := scenario.Headscale() if err != nil { t.Errorf("failed to create start headcale: %s", err) } - headscale := scenario.MustHeadscale() err = headscale.WaitForReady() if err != nil { t.Errorf("headscale failed to become ready: %s", err) @@ -157,7 +156,16 @@ func TestTailscaleNodesJoiningHeadcale(t *testing.T) { t.Errorf("failed to create preauthkey: %s", err) } - err = scenario.RunTailscaleUp(namespace, scenario.MustHeadscale().GetEndpoint(), key.GetKey()) + headscale, err := scenario.Headscale() + if err != nil { + t.Errorf("failed to create start headcale: %s", err) + } + + err = scenario.RunTailscaleUp( + namespace, + headscale.GetEndpoint(), + key.GetKey(), + ) if err != nil { t.Errorf("failed to login: %s", err) }