diff --git a/integration/auth_web_flow_test.go b/integration/auth_web_flow_test.go index 89409c6e..df7c0d0b 100644 --- a/integration/auth_web_flow_test.go +++ b/integration/auth_web_flow_test.go @@ -7,6 +7,7 @@ import ( "io" "log" "net/http" + "net/netip" "net/url" "strings" "testing" @@ -59,7 +60,6 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) { } success := 0 - for _, client := range allClients { for _, ip := range allIps { err := client.Ping(ip.String()) @@ -79,6 +79,151 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) { } } +func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { + IntegrationSkip(t) + t.Parallel() + + baseScenario, err := NewScenario() + if err != nil { + t.Errorf("failed to create scenario: %s", err) + } + + scenario := AuthWebFlowScenario{ + Scenario: baseScenario, + } + + spec := map[string]int{ + "namespace1": len(TailscaleVersions), + "namespace2": len(TailscaleVersions), + } + + err = scenario.CreateHeadscaleEnv(spec, hsic.WithTestName("weblogout")) + if err != nil { + t.Errorf("failed to create headscale environment: %s", err) + } + + allClients, err := scenario.ListTailscaleClients() + if err != nil { + t.Errorf("failed to get clients: %s", err) + } + + allIps, err := scenario.ListTailscaleClientsIPs() + if err != nil { + t.Errorf("failed to get clients: %s", err) + } + + err = scenario.WaitForTailscaleSync() + if err != nil { + t.Errorf("failed wait for tailscale clients to be in sync: %s", err) + } + + success := 0 + for _, client := range allClients { + for _, ip := range allIps { + err := client.Ping(ip.String()) + if err != nil { + t.Errorf("failed to ping %s from %s: %s", ip, client.Hostname(), err) + } else { + success++ + } + } + } + + t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) + + clientIPs := make(map[TailscaleClient][]netip.Addr) + for _, client := range allClients { + ips, err := client.IPs() + if err != nil { + t.Errorf("failed to get IPs for client %s: %s", client.Hostname(), err) + } + clientIPs[client] = ips + } + + for _, client := range allClients { + _, _, err = client.Execute([]string{"tailscale", "logout"}) + if err != nil { + t.Errorf("failed to logout client %s: %s", client.Hostname(), err) + } + } + + scenario.waitForTailscaleLogout() + + t.Logf("all clients logged out") + + headscale, err := scenario.Headscale() + if err != nil { + t.Errorf("failed to get headscale server: %s", err) + } + + for namespaceName := range spec { + err = scenario.runTailscaleUp(namespaceName, headscale.GetEndpoint()) + if err != nil { + t.Errorf("failed to run tailscale up: %s", err) + } + } + + t.Logf("all clients logged in again") + + allClients, err = scenario.ListTailscaleClients() + if err != nil { + t.Errorf("failed to get clients: %s", err) + } + + allIps, err = scenario.ListTailscaleClientsIPs() + if err != nil { + t.Errorf("failed to get clients: %s", err) + } + + success = 0 + for _, client := range allClients { + for _, ip := range allIps { + err := client.Ping(ip.String()) + if err != nil { + t.Errorf("failed to ping %s from %s: %s", ip, client.Hostname(), err) + } else { + success++ + } + } + } + + t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) + + for _, client := range allClients { + ips, err := client.IPs() + if err != nil { + t.Errorf("failed to get IPs for client %s: %s", client.Hostname(), err) + } + + // lets check if the IPs are the same + if len(ips) != len(clientIPs[client]) { + t.Errorf("IPs changed for client %s", client.Hostname()) + } + + for _, ip := range ips { + found := false + for _, oldIP := range clientIPs[client] { + if ip == oldIP { + found = true + + break + } + } + + if !found { + t.Errorf("IPs changed for client %s. Used to be %v now %v", client.Hostname(), clientIPs[client], ips) + } + } + } + + t.Logf("all clients IPs are the same") + + err = scenario.Shutdown() + if err != nil { + t.Errorf("failed to tear down scenario: %s", err) + } +} + func (s *AuthWebFlowScenario) CreateHeadscaleEnv( namespaces map[string]int, opts ...hsic.Option, @@ -114,6 +259,22 @@ func (s *AuthWebFlowScenario) CreateHeadscaleEnv( return nil } +func (s *AuthWebFlowScenario) waitForTailscaleLogout() { + for _, namespace := range s.namespaces { + for _, client := range namespace.Clients { + namespace.syncWaitGroup.Add(1) + + go func(c TailscaleClient) { + defer namespace.syncWaitGroup.Done() + + // TODO(kradalby): error handle this + _ = c.WaitForLogout() + }(client) + } + namespace.syncWaitGroup.Wait() + } +} + func (s *AuthWebFlowScenario) runTailscaleUp( namespaceStr, loginServer string, ) error { diff --git a/integration/tailscale.go b/integration/tailscale.go index b69b217a..935fcf7c 100644 --- a/integration/tailscale.go +++ b/integration/tailscale.go @@ -7,7 +7,7 @@ import ( "tailscale.com/ipn/ipnstate" ) -//nolint +// nolint type TailscaleClient interface { Hostname() string Shutdown() error @@ -19,6 +19,7 @@ type TailscaleClient interface { FQDN() (string, error) Status() (*ipnstate.Status, error) WaitForReady() error + WaitForLogout() error WaitForPeers(expected int) error Ping(hostnameOrIP string) error ID() string diff --git a/integration/tsic/tsic.go b/integration/tsic/tsic.go index d656b1c0..971cf011 100644 --- a/integration/tsic/tsic.go +++ b/integration/tsic/tsic.go @@ -30,6 +30,7 @@ var ( errTailscaleWrongPeerCount = errors.New("wrong peer count") errTailscaleCannotUpWithoutAuthkey = errors.New("cannot up without authkey") errTailscaleNotConnected = errors.New("tailscale not connected") + errTailscaleNotLoggedOut = errors.New("tailscale not logged out") ) type TailscaleInContainer struct { @@ -350,6 +351,21 @@ func (t *TailscaleInContainer) WaitForReady() error { }) } +func (t *TailscaleInContainer) WaitForLogout() error { + return t.pool.Retry(func() error { + status, err := t.Status() + if err != nil { + return fmt.Errorf("failed to fetch tailscale status: %w", err) + } + + if status.CurrentTailnet == nil { + return nil + } + + return errTailscaleNotLoggedOut + }) +} + func (t *TailscaleInContainer) WaitForPeers(expected int) error { return t.pool.Retry(func() error { status, err := t.Status()