cmd/hi: fixes and qol (#2649)

This commit is contained in:
Kristoffer Dalby
2025-06-23 13:43:14 +02:00
committed by GitHub
parent ea7376f522
commit afc11e1f0c
31 changed files with 1097 additions and 311 deletions

View File

@@ -1,44 +1,65 @@
package dockertestutil
import (
"fmt"
"os"
"strings"
"time"
"github.com/ory/dockertest/v3/docker"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/ory/dockertest/v3"
)
// GetIntegrationRunID returns the run ID for the current integration test session.
// This is set by the hi tool and passed through environment variables.
func GetIntegrationRunID() string {
return os.Getenv("HEADSCALE_INTEGRATION_RUN_ID")
}
// DockerAddIntegrationLabels adds integration test labels to Docker RunOptions.
// This allows the hi tool to identify containers belonging to specific test runs.
// This function should be called before passing RunOptions to dockertest functions.
func DockerAddIntegrationLabels(opts *dockertest.RunOptions, testType string) {
runID := GetIntegrationRunID()
if runID == "" {
panic("HEADSCALE_INTEGRATION_RUN_ID environment variable is required")
}
if opts.Labels == nil {
opts.Labels = make(map[string]string)
}
opts.Labels["hi.run-id"] = runID
opts.Labels["hi.test-type"] = testType
}
// GenerateRunID creates a unique run identifier with timestamp and random hash.
// Format: YYYYMMDD-HHMMSS-HASH (e.g., 20250619-143052-a1b2c3)
func GenerateRunID() string {
now := time.Now()
timestamp := now.Format("20060102-150405")
// Add a short random hash to ensure uniqueness
randomHash := util.MustGenerateRandomStringDNSSafe(6)
return fmt.Sprintf("%s-%s", timestamp, randomHash)
}
// ExtractRunIDFromContainerName extracts the run ID from container name.
// Expects format: "prefix-YYYYMMDD-HHMMSS-HASH"
func ExtractRunIDFromContainerName(containerName string) string {
parts := strings.Split(containerName, "-")
if len(parts) >= 3 {
// Return the last three parts as the run ID (YYYYMMDD-HHMMSS-HASH)
return strings.Join(parts[len(parts)-3:], "-")
}
panic(fmt.Sprintf("unexpected container name format: %s", containerName))
}
// IsRunningInContainer checks if the current process is running inside a Docker container.
// This is used by tests to determine if they should run integration tests.
func IsRunningInContainer() bool {
if _, err := os.Stat("/.dockerenv"); err != nil {
return false
}
return true
}
func DockerRestartPolicy(config *docker.HostConfig) {
// set AutoRemove to true so that stopped container goes away by itself on error *immediately*.
// when set to false, containers remain until the end of the integration test.
config.AutoRemove = false
config.RestartPolicy = docker.RestartPolicy{
Name: "no",
}
}
func DockerAllowLocalIPv6(config *docker.HostConfig) {
if config.Sysctls == nil {
config.Sysctls = make(map[string]string, 1)
}
config.Sysctls["net.ipv6.conf.all.disable_ipv6"] = "0"
}
func DockerAllowNetworkAdministration(config *docker.HostConfig) {
// Needed since containerd (1.7.24)
// https://github.com/tailscale/tailscale/issues/14256
// https://github.com/opencontainers/runc/commit/2ce40b6ad72b4bd4391380cafc5ef1bad1fa0b31
config.CapAdd = append(config.CapAdd, "NET_ADMIN")
config.CapAdd = append(config.CapAdd, "NET_RAW")
config.Devices = append(config.Devices, docker.Device{
PathOnHost: "/dev/net/tun",
PathInContainer: "/dev/net/tun",
CgroupPermissions: "rwm",
})
}
// Check for the common indicator that we're in a container
// This could be improved with more robust detection if needed
_, err := os.Stat("/.dockerenv")
return err == nil
}

View File

@@ -126,3 +126,24 @@ func CleanImagesInCI(pool *dockertest.Pool) error {
return nil
}
// DockerRestartPolicy sets the restart policy for containers.
func DockerRestartPolicy(config *docker.HostConfig) {
config.RestartPolicy = docker.RestartPolicy{
Name: "unless-stopped",
}
}
// DockerAllowLocalIPv6 allows IPv6 traffic within the container.
func DockerAllowLocalIPv6(config *docker.HostConfig) {
config.NetworkMode = "default"
config.Sysctls = map[string]string{
"net.ipv6.conf.all.disable_ipv6": "0",
}
}
// DockerAllowNetworkAdministration gives the container network administration capabilities.
func DockerAllowNetworkAdministration(config *docker.HostConfig) {
config.CapAdd = append(config.CapAdd, "NET_ADMIN")
config.Privileged = true
}

View File

@@ -159,6 +159,7 @@ func New(
},
}
if dsic.workdir != "" {
runOptions.WorkingDir = dsic.workdir
}
@@ -189,6 +190,9 @@ func New(
Value: "v" + version,
})
}
// Add integration test labels if running under hi tool
dockertestutil.DockerAddIntegrationLabels(runOptions, "derp")
container, err = pool.BuildAndRunWithBuildOptions(
buildOptions,
runOptions,

View File

@@ -1,6 +1,8 @@
package hsic
import (
"archive/tar"
"bytes"
"cmp"
"crypto/tls"
"encoding/json"
@@ -12,6 +14,7 @@ import (
"net/netip"
"os"
"path"
"path/filepath"
"sort"
"strconv"
"strings"
@@ -311,18 +314,22 @@ func New(
hsic.env["HEADSCALE_DATABASE_POSTGRES_NAME"] = "headscale"
delete(hsic.env, "HEADSCALE_DATABASE_SQLITE_PATH")
pg, err := pool.RunWithOptions(
&dockertest.RunOptions{
Name: fmt.Sprintf("postgres-%s", hash),
Repository: "postgres",
Tag: "latest",
Networks: networks,
Env: []string{
"POSTGRES_USER=headscale",
"POSTGRES_PASSWORD=headscale",
"POSTGRES_DB=headscale",
},
})
pgRunOptions := &dockertest.RunOptions{
Name: fmt.Sprintf("postgres-%s", hash),
Repository: "postgres",
Tag: "latest",
Networks: networks,
Env: []string{
"POSTGRES_USER=headscale",
"POSTGRES_PASSWORD=headscale",
"POSTGRES_DB=headscale",
},
}
// Add integration test labels if running under hi tool
dockertestutil.DockerAddIntegrationLabels(pgRunOptions, "postgres")
pg, err := pool.RunWithOptions(pgRunOptions)
if err != nil {
return nil, fmt.Errorf("starting postgres container: %w", err)
}
@@ -366,6 +373,7 @@ func New(
Env: env,
}
if len(hsic.hostPortBindings) > 0 {
runOptions.PortBindings = map[docker.Port][]docker.PortBinding{}
for port, hostPorts := range hsic.hostPortBindings {
@@ -386,6 +394,9 @@ func New(
return nil, err
}
// Add integration test labels if running under hi tool
dockertestutil.DockerAddIntegrationLabels(runOptions, "headscale")
container, err := pool.BuildAndRunWithBuildOptions(
headscaleBuildOptions,
runOptions,
@@ -553,22 +564,67 @@ func (t *HeadscaleInContainer) SaveMetrics(savePath string) error {
return nil
}
// extractTarToDirectory extracts a tar archive to a directory.
func extractTarToDirectory(tarData []byte, targetDir string) error {
if err := os.MkdirAll(targetDir, 0755); err != nil {
return fmt.Errorf("failed to create directory %s: %w", targetDir, err)
}
tarReader := tar.NewReader(bytes.NewReader(tarData))
for {
header, err := tarReader.Next()
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("failed to read tar header: %w", err)
}
// Clean the path to prevent directory traversal
cleanName := filepath.Clean(header.Name)
if strings.Contains(cleanName, "..") {
continue // Skip potentially dangerous paths
}
targetPath := filepath.Join(targetDir, filepath.Base(cleanName))
switch header.Typeflag {
case tar.TypeDir:
// Create directory
if err := os.MkdirAll(targetPath, os.FileMode(header.Mode)); err != nil {
return fmt.Errorf("failed to create directory %s: %w", targetPath, err)
}
case tar.TypeReg:
// Create file
outFile, err := os.Create(targetPath)
if err != nil {
return fmt.Errorf("failed to create file %s: %w", targetPath, err)
}
if _, err := io.Copy(outFile, tarReader); err != nil {
outFile.Close()
return fmt.Errorf("failed to copy file contents: %w", err)
}
outFile.Close()
// Set file permissions
if err := os.Chmod(targetPath, os.FileMode(header.Mode)); err != nil {
return fmt.Errorf("failed to set file permissions: %w", err)
}
}
}
return nil
}
func (t *HeadscaleInContainer) SaveProfile(savePath string) error {
tarFile, err := t.FetchPath("/tmp/profile")
if err != nil {
return err
}
err = os.WriteFile(
path.Join(savePath, t.hostname+".pprof.tar"),
tarFile,
os.ModePerm,
)
if err != nil {
return err
}
return nil
targetDir := path.Join(savePath, t.hostname+"-pprof")
return extractTarToDirectory(tarFile, targetDir)
}
func (t *HeadscaleInContainer) SaveMapResponses(savePath string) error {
@@ -577,34 +633,101 @@ func (t *HeadscaleInContainer) SaveMapResponses(savePath string) error {
return err
}
err = os.WriteFile(
path.Join(savePath, t.hostname+".maps.tar"),
tarFile,
os.ModePerm,
)
if err != nil {
return err
}
return nil
targetDir := path.Join(savePath, t.hostname+"-mapresponses")
return extractTarToDirectory(tarFile, targetDir)
}
func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
// If using PostgreSQL, skip database file extraction
if t.postgres {
return nil
}
// First, let's see what files are actually in /tmp
tmpListing, err := t.Execute([]string{"ls", "-la", "/tmp/"})
if err != nil {
log.Printf("Warning: could not list /tmp directory: %v", err)
} else {
log.Printf("Contents of /tmp in container %s:\n%s", t.hostname, tmpListing)
}
// Also check for any .sqlite files
sqliteFiles, err := t.Execute([]string{"find", "/tmp", "-name", "*.sqlite*", "-type", "f"})
if err != nil {
log.Printf("Warning: could not find sqlite files: %v", err)
} else {
log.Printf("SQLite files found in %s:\n%s", t.hostname, sqliteFiles)
}
// Check if the database file exists and has a schema
dbPath := "/tmp/integration_test_db.sqlite3"
fileInfo, err := t.Execute([]string{"ls", "-la", dbPath})
if err != nil {
return fmt.Errorf("database file does not exist at %s: %w", dbPath, err)
}
log.Printf("Database file info: %s", fileInfo)
// Check if the database has any tables (schema)
schemaCheck, err := t.Execute([]string{"sqlite3", dbPath, ".schema"})
if err != nil {
return fmt.Errorf("failed to check database schema (sqlite3 command failed): %w", err)
}
if strings.TrimSpace(schemaCheck) == "" {
return fmt.Errorf("database file exists but has no schema (empty database)")
}
// Show a preview of the schema (first 500 chars)
schemaPreview := schemaCheck
if len(schemaPreview) > 500 {
schemaPreview = schemaPreview[:500] + "..."
}
log.Printf("Database schema preview:\n%s", schemaPreview)
tarFile, err := t.FetchPath("/tmp/integration_test_db.sqlite3")
if err != nil {
return err
return fmt.Errorf("failed to fetch database file: %w", err)
}
err = os.WriteFile(
path.Join(savePath, t.hostname+".db.tar"),
tarFile,
os.ModePerm,
)
if err != nil {
return err
// For database, extract the first regular file (should be the SQLite file)
tarReader := tar.NewReader(bytes.NewReader(tarFile))
for {
header, err := tarReader.Next()
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("failed to read tar header: %w", err)
}
log.Printf("Found file in tar: %s (type: %d, size: %d)", header.Name, header.Typeflag, header.Size)
// Extract the first regular file we find
if header.Typeflag == tar.TypeReg {
dbPath := path.Join(savePath, t.hostname+".db")
outFile, err := os.Create(dbPath)
if err != nil {
return fmt.Errorf("failed to create database file: %w", err)
}
written, err := io.Copy(outFile, tarReader)
outFile.Close()
if err != nil {
return fmt.Errorf("failed to copy database file: %w", err)
}
log.Printf("Extracted database file: %s (%d bytes written, header claimed %d bytes)", dbPath, written, header.Size)
// Check if we actually wrote something
if written == 0 {
return fmt.Errorf("database file is empty (size: %d, header size: %d)", written, header.Size)
}
return nil
}
}
return nil
return fmt.Errorf("no regular file found in database tar archive")
}
// Execute runs a command inside the Headscale container and returns the

View File

@@ -32,7 +32,7 @@ import (
"github.com/oauth2-proxy/mockoidc"
"github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
"github.com/puzpuzpuz/xsync/v3"
"github.com/puzpuzpuz/xsync/v4"
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -1102,6 +1102,7 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse
},
}
headscaleBuildOptions := &dockertest.BuildOptions{
Dockerfile: hsic.IntegrationTestDockerFileName,
ContextDir: dockerContextPath,
@@ -1114,6 +1115,9 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse
s.mockOIDC = scenarioOIDC{}
// Add integration test labels if running under hi tool
dockertestutil.DockerAddIntegrationLabels(mockOidcOptions, "oidc")
if pmockoidc, err := s.pool.BuildAndRunWithBuildOptions(
headscaleBuildOptions,
mockOidcOptions,
@@ -1198,6 +1202,9 @@ func Webservice(s *Scenario, networkName string) (*dockertest.Resource, error) {
Env: []string{},
}
// Add integration test labels if running under hi tool
dockertestutil.DockerAddIntegrationLabels(webOpts, "web")
webBOpts := &dockertest.BuildOptions{
Dockerfile: hsic.IntegrationTestDockerFileName,
ContextDir: dockerContextPath,

View File

@@ -17,7 +17,9 @@ import (
func isSSHNoAccessStdError(stderr string) bool {
return strings.Contains(stderr, "Permission denied (tailscale)") ||
// Since https://github.com/tailscale/tailscale/pull/14853
strings.Contains(stderr, "failed to evaluate SSH policy")
strings.Contains(stderr, "failed to evaluate SSH policy") ||
// Since https://github.com/tailscale/tailscale/pull/16127
strings.Contains(stderr, "tailnet policy does not permit you to SSH to this node")
}
var retry = func(times int, sleepInterval time.Duration,

View File

@@ -251,6 +251,7 @@ func New(
Env: []string{},
}
if tsic.withWebsocketDERP {
if version != VersionHead {
return tsic, errInvalidClientConfig
@@ -279,6 +280,9 @@ func New(
return nil, err
}
// Add integration test labels if running under hi tool
dockertestutil.DockerAddIntegrationLabels(tailscaleOptions, "tailscale")
var container *dockertest.Resource
if version != VersionHead {

View File

@@ -3,6 +3,7 @@ package integration
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net/netip"
@@ -11,7 +12,7 @@ import (
"testing"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/cenkalti/backoff/v5"
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/integration/tsic"
@@ -310,20 +311,18 @@ func assertValidNetcheck(t *testing.T, client TailscaleClient) {
func assertCommandOutputContains(t *testing.T, c TailscaleClient, command []string, contains string) {
t.Helper()
err := backoff.Retry(func() error {
_, err := backoff.Retry(context.Background(), func() (struct{}, error) {
stdout, stderr, err := c.Execute(command)
if err != nil {
return fmt.Errorf("executing command, stdout: %q stderr: %q, err: %w", stdout, stderr, err)
return struct{}{}, fmt.Errorf("executing command, stdout: %q stderr: %q, err: %w", stdout, stderr, err)
}
if !strings.Contains(stdout, contains) {
return fmt.Errorf("executing command, expected string %q not found in %q", contains, stdout)
return struct{}{}, fmt.Errorf("executing command, expected string %q not found in %q", contains, stdout)
}
return nil
}, backoff.NewExponentialBackOff(
backoff.WithMaxElapsedTime(10*time.Second)),
)
return struct{}{}, nil
}, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(10*time.Second))
assert.NoError(t, err)
}