mapper: produce map before poll (#2628)

This commit is contained in:
Kristoffer Dalby 2025-07-28 11:15:53 +02:00 committed by GitHub
parent b2a18830ed
commit a058bf3cd3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
70 changed files with 5771 additions and 2475 deletions

View File

@ -17,3 +17,7 @@ LICENSE
.vscode
*.sock
node_modules/
package-lock.json
package.json

55
.github/workflows/check-generated.yml vendored Normal file
View File

@ -0,0 +1,55 @@
name: Check Generated Files
on:
push:
branches:
- main
pull_request:
branches:
- main
concurrency:
group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
check-generated:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
fetch-depth: 2
- name: Get changed files
id: changed-files
uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2
with:
filters: |
files:
- '*.nix'
- 'go.*'
- '**/*.go'
- '**/*.proto'
- 'buf.gen.yaml'
- 'tools/**'
- uses: nixbuild/nix-quick-install-action@889f3180bb5f064ee9e3201428d04ae9e41d54ad # v31
if: steps.changed-files.outputs.files == 'true'
- uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3
if: steps.changed-files.outputs.files == 'true'
with:
primary-key: nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix', '**/flake.lock') }}
restore-prefixes-first-match: nix-${{ runner.os }}-${{ runner.arch }}
- name: Run make generate
if: steps.changed-files.outputs.files == 'true'
run: nix develop --command -- make generate
- name: Check for uncommitted changes
if: steps.changed-files.outputs.files == 'true'
run: |
if ! git diff --exit-code; then
echo "❌ Generated files are not up to date!"
echo "Please run 'make generate' and commit the changes."
exit 1
else
echo "✅ All generated files are up to date."
fi

View File

@ -77,7 +77,7 @@ jobs:
attempt_delay: 300000 # 5 min
attempt_limit: 2
command: |
nix develop --command -- hi run "^${{ inputs.test }}$" \
nix develop --command -- hi run --stats --ts-memory-limit=300 --hs-memory-limit=500 "^${{ inputs.test }}$" \
--timeout=120m \
${{ inputs.postgres_flag }}
- uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2

7
.gitignore vendored
View File

@ -1,6 +1,9 @@
ignored/
tailscale/
.vscode/
.claude/
*.prof
# Binaries for programs and plugins
*.exe
@ -46,3 +49,7 @@ integration_test/etc/config.dump.yaml
/site
__debug_bin
node_modules/
package-lock.json
package.json

View File

@ -2,6 +2,8 @@
## Next
**Minimum supported Tailscale client version: v1.64.0**
### Database integrity improvements
This release includes a significant database migration that addresses longstanding

395
CLAUDE.md Normal file
View File

@ -0,0 +1,395 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Overview
Headscale is an open-source implementation of the Tailscale control server written in Go. It provides self-hosted coordination for Tailscale networks (tailnets), managing node registration, IP allocation, policy enforcement, and DERP routing.
## Development Commands
### Quick Setup
```bash
# Recommended: Use Nix for dependency management
nix develop
# Full development workflow
make dev # runs fmt + lint + test + build
```
### Essential Commands
```bash
# Build headscale binary
make build
# Run tests
make test
go test ./... # All unit tests
go test -race ./... # With race detection
# Run specific integration test
go run ./cmd/hi run "TestName" --postgres
# Code formatting and linting
make fmt # Format all code (Go, docs, proto)
make lint # Lint all code (Go, proto)
make fmt-go # Format Go code only
make lint-go # Lint Go code only
# Protocol buffer generation (after modifying proto/)
make generate
# Clean build artifacts
make clean
```
### Integration Testing
```bash
# Use the hi (Headscale Integration) test runner
go run ./cmd/hi doctor # Check system requirements
go run ./cmd/hi run "TestPattern" # Run specific test
go run ./cmd/hi run "TestPattern" --postgres # With PostgreSQL backend
# Test artifacts are saved to control_logs/ with logs and debug data
```
## Project Structure & Architecture
### Top-Level Organization
```
headscale/
├── cmd/ # Command-line applications
│ ├── headscale/ # Main headscale server binary
│ └── hi/ # Headscale Integration test runner
├── hscontrol/ # Core control plane logic
├── integration/ # End-to-end Docker-based tests
├── proto/ # Protocol buffer definitions
├── gen/ # Generated code (protobuf)
├── docs/ # Documentation
└── packaging/ # Distribution packaging
```
### Core Packages (`hscontrol/`)
**Main Server (`hscontrol/`)**
- `app.go`: Application setup, dependency injection, server lifecycle
- `handlers.go`: HTTP/gRPC API endpoints for management operations
- `grpcv1.go`: gRPC service implementation for headscale API
- `poll.go`: **Critical** - Handles Tailscale MapRequest/MapResponse protocol
- `noise.go`: Noise protocol implementation for secure client communication
- `auth.go`: Authentication flows (web, OIDC, command-line)
- `oidc.go`: OpenID Connect integration for user authentication
**State Management (`hscontrol/state/`)**
- `state.go`: Central coordinator for all subsystems (database, policy, IP allocation, DERP)
- `node_store.go`: **Performance-critical** - In-memory cache with copy-on-write semantics
- Thread-safe operations with deadlock detection
- Coordinates between database persistence and real-time operations
**Database Layer (`hscontrol/db/`)**
- `db.go`: Database abstraction, GORM setup, migration management
- `node.go`: Node lifecycle, registration, expiration, IP assignment
- `users.go`: User management, namespace isolation
- `api_key.go`: API authentication tokens
- `preauth_keys.go`: Pre-authentication keys for automated node registration
- `ip.go`: IP address allocation and management
- `policy.go`: Policy storage and retrieval
- Schema migrations in `schema.sql` with extensive test data coverage
**Policy Engine (`hscontrol/policy/`)**
- `policy.go`: Core ACL evaluation logic, HuJSON parsing
- `v2/`: Next-generation policy system with improved filtering
- `matcher/`: ACL rule matching and evaluation engine
- Determines peer visibility, route approval, and network access rules
- Supports both file-based and database-stored policies
**Network Management (`hscontrol/`)**
- `derp/`: DERP (Designated Encrypted Relay for Packets) server implementation
- NAT traversal when direct connections fail
- Fallback relay for firewall-restricted environments
- `mapper/`: Converts internal Headscale state to Tailscale's wire protocol format
- `tail.go`: Tailscale-specific data structure generation
- `routes/`: Subnet route management and primary route selection
- `dns/`: DNS record management and MagicDNS implementation
**Utilities & Support (`hscontrol/`)**
- `types/`: Core data structures, configuration, validation
- `util/`: Helper functions for networking, DNS, key management
- `templates/`: Client configuration templates (Apple, Windows, etc.)
- `notifier/`: Event notification system for real-time updates
- `metrics.go`: Prometheus metrics collection
- `capver/`: Tailscale capability version management
### Key Subsystem Interactions
**Node Registration Flow**
1. **Client Connection**: `noise.go` handles secure protocol handshake
2. **Authentication**: `auth.go` validates credentials (web/OIDC/preauth)
3. **State Creation**: `state.go` coordinates IP allocation via `db/ip.go`
4. **Storage**: `db/node.go` persists node, `NodeStore` caches in memory
5. **Network Setup**: `mapper/` generates initial Tailscale network map
**Ongoing Operations**
1. **Poll Requests**: `poll.go` receives periodic client updates
2. **State Updates**: `NodeStore` maintains real-time node information
3. **Policy Application**: `policy/` evaluates ACL rules for peer relationships
4. **Map Distribution**: `mapper/` sends network topology to all affected clients
**Route Management**
1. **Advertisement**: Clients announce routes via `poll.go` Hostinfo updates
2. **Storage**: `db/` persists routes, `NodeStore` caches for performance
3. **Approval**: `policy/` auto-approves routes based on ACL rules
4. **Distribution**: `routes/` selects primary routes, `mapper/` distributes to peers
### Command-Line Tools (`cmd/`)
**Main Server (`cmd/headscale/`)**
- `headscale.go`: CLI parsing, configuration loading, server startup
- Supports daemon mode, CLI operations (user/node management), database operations
**Integration Test Runner (`cmd/hi/`)**
- `main.go`: Test execution framework with Docker orchestration
- `run.go`: Individual test execution with artifact collection
- `doctor.go`: System requirements validation
- `docker.go`: Container lifecycle management
- Essential for validating changes against real Tailscale clients
### Generated & External Code
**Protocol Buffers (`proto/``gen/`)**
- Defines gRPC API for headscale management operations
- Client libraries can generate from these definitions
- Run `make generate` after modifying `.proto` files
**Integration Testing (`integration/`)**
- `scenario.go`: Docker test environment setup
- `tailscale.go`: Tailscale client container management
- Individual test files for specific functionality areas
- Real end-to-end validation with network isolation
### Critical Performance Paths
**High-Frequency Operations**
1. **MapRequest Processing** (`poll.go`): Every 15-60 seconds per client
2. **NodeStore Reads** (`node_store.go`): Every operation requiring node data
3. **Policy Evaluation** (`policy/`): On every peer relationship calculation
4. **Route Lookups** (`routes/`): During network map generation
**Database Write Patterns**
- **Frequent**: Node heartbeats, endpoint updates, route changes
- **Moderate**: User operations, policy updates, API key management
- **Rare**: Schema migrations, bulk operations
### Configuration & Deployment
**Configuration** (`hscontrol/types/config.go`)**
- Database connection settings (SQLite/PostgreSQL)
- Network configuration (IP ranges, DNS settings)
- Policy mode (file vs database)
- DERP relay configuration
- OIDC provider settings
**Key Dependencies**
- **GORM**: Database ORM with migration support
- **Tailscale Libraries**: Core networking and protocol code
- **Zerolog**: Structured logging throughout the application
- **Buf**: Protocol buffer toolchain for code generation
### Development Workflow Integration
The architecture supports incremental development:
- **Unit Tests**: Focus on individual packages (`*_test.go` files)
- **Integration Tests**: Validate cross-component interactions
- **Database Tests**: Extensive migration and data integrity validation
- **Policy Tests**: ACL rule evaluation and edge cases
- **Performance Tests**: NodeStore and high-frequency operation validation
## Integration Test System
### Overview
Integration tests use Docker containers running real Tailscale clients against a Headscale server. Tests validate end-to-end functionality including routing, ACLs, node lifecycle, and network coordination.
### Running Integration Tests
**System Requirements**
```bash
# Check if your system is ready
go run ./cmd/hi doctor
```
This verifies Docker, Go, required images, and disk space.
**Test Execution Patterns**
```bash
# Run a single test (recommended for development)
go run ./cmd/hi run "TestSubnetRouterMultiNetwork"
# Run with PostgreSQL backend (for database-heavy tests)
go run ./cmd/hi run "TestExpireNode" --postgres
# Run multiple tests with pattern matching
go run ./cmd/hi run "TestSubnet*"
# Run all integration tests (CI/full validation)
go test ./integration -timeout 30m
```
**Test Categories & Timing**
- **Fast tests** (< 2 min): Basic functionality, CLI operations
- **Medium tests** (2-5 min): Route management, ACL validation
- **Slow tests** (5+ min): Node expiration, HA failover
- **Long-running tests** (10+ min): `TestNodeOnlineStatus` (12 min duration)
### Test Infrastructure
**Docker Setup**
- Headscale server container with configurable database backend
- Multiple Tailscale client containers with different versions
- Isolated networks per test scenario
- Automatic cleanup after test completion
**Test Artifacts**
All test runs save artifacts to `control_logs/TIMESTAMP-ID/`:
```
control_logs/20250713-213106-iajsux/
├── hs-testname-abc123.stderr.log # Headscale server logs
├── hs-testname-abc123.stdout.log
├── hs-testname-abc123.db # Database snapshot
├── hs-testname-abc123_metrics.txt # Prometheus metrics
├── hs-testname-abc123-mapresponses/ # Protocol debug data
├── ts-client-xyz789.stderr.log # Tailscale client logs
├── ts-client-xyz789.stdout.log
└── ts-client-xyz789_status.json # Client status dump
```
### Test Development Guidelines
**Timing Considerations**
Integration tests involve real network operations and Docker container lifecycle:
```go
// ❌ Wrong: Immediate assertions after async operations
client.Execute([]string{"tailscale", "set", "--advertise-routes=10.0.0.0/24"})
nodes, _ := headscale.ListNodes()
require.Len(t, nodes[0].GetAvailableRoutes(), 1) // May fail due to timing
// ✅ Correct: Wait for async operations to complete
client.Execute([]string{"tailscale", "set", "--advertise-routes=10.0.0.0/24"})
require.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err := headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, nodes[0].GetAvailableRoutes(), 1)
}, 10*time.Second, 100*time.Millisecond, "route should be advertised")
```
**Common Test Patterns**
- **Route Advertisement**: Use `EventuallyWithT` for route propagation
- **Node State Changes**: Wait for NodeStore synchronization
- **ACL Policy Changes**: Allow time for policy recalculation
- **Network Connectivity**: Use ping tests with retries
**Test Data Management**
```go
// Node identification: Don't assume array ordering
expectedRoutes := map[string]string{"1": "10.33.0.0/16"}
for _, node := range nodes {
nodeIDStr := fmt.Sprintf("%d", node.GetId())
if route, shouldHaveRoute := expectedRoutes[nodeIDStr]; shouldHaveRoute {
// Test the node that should have the route
}
}
```
### Troubleshooting Integration Tests
**Common Failure Patterns**
1. **Timing Issues**: Test assertions run before async operations complete
- **Solution**: Use `EventuallyWithT` with appropriate timeouts
- **Timeout Guidelines**: 3-5s for route operations, 10s for complex scenarios
2. **Infrastructure Problems**: Disk space, Docker issues, network conflicts
- **Check**: `go run ./cmd/hi doctor` for system health
- **Clean**: Remove old test containers and networks
3. **NodeStore Synchronization**: Tests expecting immediate data availability
- **Key Points**: Route advertisements must propagate through poll requests
- **Fix**: Wait for NodeStore updates after Hostinfo changes
4. **Database Backend Differences**: SQLite vs PostgreSQL behavior differences
- **Use**: `--postgres` flag for database-intensive tests
- **Note**: Some timing characteristics differ between backends
**Debugging Failed Tests**
1. **Check test artifacts** in `control_logs/` for detailed logs
2. **Examine MapResponse JSON** files for protocol-level debugging
3. **Review Headscale stderr logs** for server-side error messages
4. **Check Tailscale client status** for network-level issues
**Resource Management**
- Tests require significant disk space (each run ~100MB of logs)
- Docker containers are cleaned up automatically on success
- Failed tests may leave containers running - clean manually if needed
- Use `docker system prune` periodically to reclaim space
### Best Practices for Test Modifications
1. **Always test locally** before committing integration test changes
2. **Use appropriate timeouts** - too short causes flaky tests, too long slows CI
3. **Clean up properly** - ensure tests don't leave persistent state
4. **Handle both success and failure paths** in test scenarios
5. **Document timing requirements** for complex test scenarios
## NodeStore Implementation Details
**Key Insight from Recent Work**: The NodeStore is a critical performance optimization that caches node data in memory while ensuring consistency with the database. When working with route advertisements or node state changes:
1. **Timing Considerations**: Route advertisements need time to propagate from clients to server. Use `require.EventuallyWithT()` patterns in tests instead of immediate assertions.
2. **Synchronization Points**: NodeStore updates happen at specific points like `poll.go:420` after Hostinfo changes. Ensure these are maintained when modifying the polling logic.
3. **Peer Visibility**: The NodeStore's `peersFunc` determines which nodes are visible to each other. Policy-based filtering is separate from monitoring visibility - expired nodes should remain visible for debugging but marked as expired.
## Testing Guidelines
### Integration Test Patterns
```go
// Use EventuallyWithT for async operations
require.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err := headscale.ListNodes()
assert.NoError(c, err)
// Check expected state
}, 10*time.Second, 100*time.Millisecond, "description")
// Node route checking by actual node properties, not array position
var routeNode *v1.Node
for _, node := range nodes {
if nodeIDStr := fmt.Sprintf("%d", node.GetId()); expectedRoutes[nodeIDStr] != "" {
routeNode = node
break
}
}
```
### Running Problematic Tests
- Some tests require significant time (e.g., `TestNodeOnlineStatus` runs for 12 minutes)
- Infrastructure issues like disk space can cause test failures unrelated to code changes
- Use `--postgres` flag when testing database-heavy scenarios
## Important Notes
- **Dependencies**: Use `nix develop` for consistent toolchain (Go, buf, protobuf tools, linting)
- **Protocol Buffers**: Changes to `proto/` require `make generate` and should be committed separately
- **Code Style**: Enforced via golangci-lint with golines (width 88) and gofumpt formatting
- **Database**: Supports both SQLite (development) and PostgreSQL (production/testing)
- **Integration Tests**: Require Docker and can consume significant disk space
- **Performance**: NodeStore optimizations are critical for scale - be careful with changes to state management
## Debugging Integration Tests
Test artifacts are preserved in `control_logs/TIMESTAMP-ID/` including:
- Headscale server logs (stderr/stdout)
- Tailscale client logs and status
- Database dumps and network captures
- MapResponse JSON files for protocol debugging
When tests fail, check these artifacts first before assuming code issues.

View File

@ -87,10 +87,9 @@ lint-proto: check-deps $(PROTO_SOURCES)
# Code generation
.PHONY: generate
generate: check-deps $(PROTO_SOURCES)
@echo "Generating code from Protocol Buffers..."
rm -rf gen
buf generate proto
generate: check-deps
@echo "Generating code..."
go generate ./...
# Clean targets
.PHONY: clean

View File

@ -212,13 +212,10 @@ var listUsersCmd = &cobra.Command{
switch {
case id > 0:
request.Id = uint64(id)
break
case username != "":
request.Name = username
break
case email != "":
request.Email = email
break
}
response, err := client.ListUsers(ctx, request)

View File

@ -90,6 +90,32 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
log.Printf("Starting test: %s", config.TestPattern)
// Start stats collection for container resource monitoring (if enabled)
var statsCollector *StatsCollector
if config.Stats {
var err error
statsCollector, err = NewStatsCollector()
if err != nil {
if config.Verbose {
log.Printf("Warning: failed to create stats collector: %v", err)
}
statsCollector = nil
}
if statsCollector != nil {
defer statsCollector.Close()
// Start stats collection immediately - no need for complex retry logic
// The new implementation monitors Docker events and will catch containers as they start
if err := statsCollector.StartCollection(ctx, runID, config.Verbose); err != nil {
if config.Verbose {
log.Printf("Warning: failed to start stats collection: %v", err)
}
}
defer statsCollector.StopCollection()
}
}
exitCode, err := streamAndWait(ctx, cli, resp.ID)
// Ensure all containers have finished and logs are flushed before extracting artifacts
@ -105,6 +131,20 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
// Always list control files regardless of test outcome
listControlFiles(logsDir)
// Print stats summary and check memory limits if enabled
if config.Stats && statsCollector != nil {
violations := statsCollector.PrintSummaryAndCheckLimits(config.HSMemoryLimit, config.TSMemoryLimit)
if len(violations) > 0 {
log.Printf("MEMORY LIMIT VIOLATIONS DETECTED:")
log.Printf("=================================")
for _, violation := range violations {
log.Printf("Container %s exceeded memory limit: %.1f MB > %.1f MB",
violation.ContainerName, violation.MaxMemoryMB, violation.LimitMB)
}
return fmt.Errorf("test failed: %d container(s) exceeded memory limits", len(violations))
}
}
shouldCleanup := config.CleanAfter && (!config.KeepOnFailure || exitCode == 0)
if shouldCleanup {
if config.Verbose {
@ -379,10 +419,37 @@ func getDockerSocketPath() string {
return "/var/run/docker.sock"
}
// ensureImageAvailable pulls the specified Docker image to ensure it's available.
// checkImageAvailableLocally checks if the specified Docker image is available locally.
func checkImageAvailableLocally(ctx context.Context, cli *client.Client, imageName string) (bool, error) {
_, _, err := cli.ImageInspectWithRaw(ctx, imageName)
if err != nil {
if client.IsErrNotFound(err) {
return false, nil
}
return false, fmt.Errorf("failed to inspect image %s: %w", imageName, err)
}
return true, nil
}
// ensureImageAvailable checks if the image is available locally first, then pulls if needed.
func ensureImageAvailable(ctx context.Context, cli *client.Client, imageName string, verbose bool) error {
// First check if image is available locally
available, err := checkImageAvailableLocally(ctx, cli, imageName)
if err != nil {
return fmt.Errorf("failed to check local image availability: %w", err)
}
if available {
if verbose {
log.Printf("Pulling image %s...", imageName)
log.Printf("Image %s is available locally", imageName)
}
return nil
}
// Image not available locally, try to pull it
if verbose {
log.Printf("Image %s not found locally, pulling...", imageName)
}
reader, err := cli.ImagePull(ctx, imageName, image.PullOptions{})

View File

@ -190,7 +190,7 @@ func checkDockerSocket(ctx context.Context) DoctorResult {
}
}
// checkGolangImage verifies we can access the golang Docker image.
// checkGolangImage verifies the golang Docker image is available locally or can be pulled.
func checkGolangImage(ctx context.Context) DoctorResult {
cli, err := createDockerClient()
if err != nil {
@ -205,17 +205,40 @@ func checkGolangImage(ctx context.Context) DoctorResult {
goVersion := detectGoVersion()
imageName := "golang:" + goVersion
// Check if we can pull the image
// First check if image is available locally
available, err := checkImageAvailableLocally(ctx, cli, imageName)
if err != nil {
return DoctorResult{
Name: "Golang Image",
Status: "FAIL",
Message: fmt.Sprintf("Cannot check golang image %s: %v", imageName, err),
Suggestions: []string{
"Check Docker daemon status",
"Try: docker images | grep golang",
},
}
}
if available {
return DoctorResult{
Name: "Golang Image",
Status: "PASS",
Message: fmt.Sprintf("Golang image %s is available locally", imageName),
}
}
// Image not available locally, try to pull it
err = ensureImageAvailable(ctx, cli, imageName, false)
if err != nil {
return DoctorResult{
Name: "Golang Image",
Status: "FAIL",
Message: fmt.Sprintf("Cannot pull golang image %s: %v", imageName, err),
Message: fmt.Sprintf("Golang image %s not available locally and cannot pull: %v", imageName, err),
Suggestions: []string{
"Check internet connectivity",
"Verify Docker Hub access",
"Try: docker pull " + imageName,
"Or run tests offline if image was pulled previously",
},
}
}
@ -223,7 +246,7 @@ func checkGolangImage(ctx context.Context) DoctorResult {
return DoctorResult{
Name: "Golang Image",
Status: "PASS",
Message: fmt.Sprintf("Golang image %s is available", imageName),
Message: fmt.Sprintf("Golang image %s is now available", imageName),
}
}

View File

@ -24,6 +24,9 @@ type RunConfig struct {
KeepOnFailure bool `flag:"keep-on-failure,default=false,Keep containers on test failure"`
LogsDir string `flag:"logs-dir,default=control_logs,Control logs directory"`
Verbose bool `flag:"verbose,default=false,Verbose output"`
Stats bool `flag:"stats,default=false,Collect and display container resource usage statistics"`
HSMemoryLimit float64 `flag:"hs-memory-limit,default=0,Fail test if any Headscale container exceeds this memory limit in MB (0 = disabled)"`
TSMemoryLimit float64 `flag:"ts-memory-limit,default=0,Fail test if any Tailscale container exceeds this memory limit in MB (0 = disabled)"`
}
// runIntegrationTest executes the integration test workflow.

468
cmd/hi/stats.go Normal file
View File

@ -0,0 +1,468 @@
package main
import (
"context"
"encoding/json"
"fmt"
"log"
"sort"
"strings"
"sync"
"time"
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/events"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/client"
)
// ContainerStats represents statistics for a single container
type ContainerStats struct {
ContainerID string
ContainerName string
Stats []StatsSample
mutex sync.RWMutex
}
// StatsSample represents a single stats measurement
type StatsSample struct {
Timestamp time.Time
CPUUsage float64 // CPU usage percentage
MemoryMB float64 // Memory usage in MB
}
// StatsCollector manages collection of container statistics
type StatsCollector struct {
client *client.Client
containers map[string]*ContainerStats
stopChan chan struct{}
wg sync.WaitGroup
mutex sync.RWMutex
collectionStarted bool
}
// NewStatsCollector creates a new stats collector instance
func NewStatsCollector() (*StatsCollector, error) {
cli, err := createDockerClient()
if err != nil {
return nil, fmt.Errorf("failed to create Docker client: %w", err)
}
return &StatsCollector{
client: cli,
containers: make(map[string]*ContainerStats),
stopChan: make(chan struct{}),
}, nil
}
// StartCollection begins monitoring all containers and collecting stats for hs- and ts- containers with matching run ID
func (sc *StatsCollector) StartCollection(ctx context.Context, runID string, verbose bool) error {
sc.mutex.Lock()
defer sc.mutex.Unlock()
if sc.collectionStarted {
return fmt.Errorf("stats collection already started")
}
sc.collectionStarted = true
// Start monitoring existing containers
sc.wg.Add(1)
go sc.monitorExistingContainers(ctx, runID, verbose)
// Start Docker events monitoring for new containers
sc.wg.Add(1)
go sc.monitorDockerEvents(ctx, runID, verbose)
if verbose {
log.Printf("Started container monitoring for run ID %s", runID)
}
return nil
}
// StopCollection stops all stats collection
func (sc *StatsCollector) StopCollection() {
// Check if already stopped without holding lock
sc.mutex.RLock()
if !sc.collectionStarted {
sc.mutex.RUnlock()
return
}
sc.mutex.RUnlock()
// Signal stop to all goroutines
close(sc.stopChan)
// Wait for all goroutines to finish
sc.wg.Wait()
// Mark as stopped
sc.mutex.Lock()
sc.collectionStarted = false
sc.mutex.Unlock()
}
// monitorExistingContainers checks for existing containers that match our criteria
func (sc *StatsCollector) monitorExistingContainers(ctx context.Context, runID string, verbose bool) {
defer sc.wg.Done()
containers, err := sc.client.ContainerList(ctx, container.ListOptions{})
if err != nil {
if verbose {
log.Printf("Failed to list existing containers: %v", err)
}
return
}
for _, cont := range containers {
if sc.shouldMonitorContainer(cont, runID) {
sc.startStatsForContainer(ctx, cont.ID, cont.Names[0], verbose)
}
}
}
// monitorDockerEvents listens for container start events and begins monitoring relevant containers
func (sc *StatsCollector) monitorDockerEvents(ctx context.Context, runID string, verbose bool) {
defer sc.wg.Done()
filter := filters.NewArgs()
filter.Add("type", "container")
filter.Add("event", "start")
eventOptions := events.ListOptions{
Filters: filter,
}
events, errs := sc.client.Events(ctx, eventOptions)
for {
select {
case <-sc.stopChan:
return
case <-ctx.Done():
return
case event := <-events:
if event.Type == "container" && event.Action == "start" {
// Get container details
containerInfo, err := sc.client.ContainerInspect(ctx, event.ID)
if err != nil {
continue
}
// Convert to types.Container format for consistency
cont := types.Container{
ID: containerInfo.ID,
Names: []string{containerInfo.Name},
Labels: containerInfo.Config.Labels,
}
if sc.shouldMonitorContainer(cont, runID) {
sc.startStatsForContainer(ctx, cont.ID, cont.Names[0], verbose)
}
}
case err := <-errs:
if verbose {
log.Printf("Error in Docker events stream: %v", err)
}
return
}
}
}
// shouldMonitorContainer determines if a container should be monitored
func (sc *StatsCollector) shouldMonitorContainer(cont types.Container, runID string) bool {
// Check if it has the correct run ID label
if cont.Labels == nil || cont.Labels["hi.run-id"] != runID {
return false
}
// Check if it's an hs- or ts- container
for _, name := range cont.Names {
containerName := strings.TrimPrefix(name, "/")
if strings.HasPrefix(containerName, "hs-") || strings.HasPrefix(containerName, "ts-") {
return true
}
}
return false
}
// startStatsForContainer begins stats collection for a specific container
func (sc *StatsCollector) startStatsForContainer(ctx context.Context, containerID, containerName string, verbose bool) {
containerName = strings.TrimPrefix(containerName, "/")
sc.mutex.Lock()
// Check if we're already monitoring this container
if _, exists := sc.containers[containerID]; exists {
sc.mutex.Unlock()
return
}
sc.containers[containerID] = &ContainerStats{
ContainerID: containerID,
ContainerName: containerName,
Stats: make([]StatsSample, 0),
}
sc.mutex.Unlock()
if verbose {
log.Printf("Starting stats collection for container %s (%s)", containerName, containerID[:12])
}
sc.wg.Add(1)
go sc.collectStatsForContainer(ctx, containerID, verbose)
}
// collectStatsForContainer collects stats for a specific container using Docker API streaming
func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containerID string, verbose bool) {
defer sc.wg.Done()
// Use Docker API streaming stats - much more efficient than CLI
statsResponse, err := sc.client.ContainerStats(ctx, containerID, true)
if err != nil {
if verbose {
log.Printf("Failed to get stats stream for container %s: %v", containerID[:12], err)
}
return
}
defer statsResponse.Body.Close()
decoder := json.NewDecoder(statsResponse.Body)
var prevStats *container.Stats
for {
select {
case <-sc.stopChan:
return
case <-ctx.Done():
return
default:
var stats container.Stats
if err := decoder.Decode(&stats); err != nil {
// EOF is expected when container stops or stream ends
if err.Error() != "EOF" && verbose {
log.Printf("Failed to decode stats for container %s: %v", containerID[:12], err)
}
return
}
// Calculate CPU percentage (only if we have previous stats)
var cpuPercent float64
if prevStats != nil {
cpuPercent = calculateCPUPercent(prevStats, &stats)
}
// Calculate memory usage in MB
memoryMB := float64(stats.MemoryStats.Usage) / (1024 * 1024)
// Store the sample (skip first sample since CPU calculation needs previous stats)
if prevStats != nil {
// Get container stats reference without holding the main mutex
var containerStats *ContainerStats
var exists bool
sc.mutex.RLock()
containerStats, exists = sc.containers[containerID]
sc.mutex.RUnlock()
if exists && containerStats != nil {
containerStats.mutex.Lock()
containerStats.Stats = append(containerStats.Stats, StatsSample{
Timestamp: time.Now(),
CPUUsage: cpuPercent,
MemoryMB: memoryMB,
})
containerStats.mutex.Unlock()
}
}
// Save current stats for next iteration
prevStats = &stats
}
}
}
// calculateCPUPercent calculates CPU usage percentage from Docker stats
func calculateCPUPercent(prevStats, stats *container.Stats) float64 {
// CPU calculation based on Docker's implementation
cpuDelta := float64(stats.CPUStats.CPUUsage.TotalUsage) - float64(prevStats.CPUStats.CPUUsage.TotalUsage)
systemDelta := float64(stats.CPUStats.SystemUsage) - float64(prevStats.CPUStats.SystemUsage)
if systemDelta > 0 && cpuDelta >= 0 {
// Calculate CPU percentage: (container CPU delta / system CPU delta) * number of CPUs * 100
numCPUs := float64(len(stats.CPUStats.CPUUsage.PercpuUsage))
if numCPUs == 0 {
// Fallback: if PercpuUsage is not available, assume 1 CPU
numCPUs = 1.0
}
return (cpuDelta / systemDelta) * numCPUs * 100.0
}
return 0.0
}
// ContainerStatsSummary represents summary statistics for a container
type ContainerStatsSummary struct {
ContainerName string
SampleCount int
CPU StatsSummary
Memory StatsSummary
}
// MemoryViolation represents a container that exceeded the memory limit
type MemoryViolation struct {
ContainerName string
MaxMemoryMB float64
LimitMB float64
}
// StatsSummary represents min, max, and average for a metric
type StatsSummary struct {
Min float64
Max float64
Average float64
}
// GetSummary returns a summary of collected statistics
func (sc *StatsCollector) GetSummary() []ContainerStatsSummary {
// Take snapshot of container references without holding main lock long
sc.mutex.RLock()
containerRefs := make([]*ContainerStats, 0, len(sc.containers))
for _, containerStats := range sc.containers {
containerRefs = append(containerRefs, containerStats)
}
sc.mutex.RUnlock()
summaries := make([]ContainerStatsSummary, 0, len(containerRefs))
for _, containerStats := range containerRefs {
containerStats.mutex.RLock()
stats := make([]StatsSample, len(containerStats.Stats))
copy(stats, containerStats.Stats)
containerName := containerStats.ContainerName
containerStats.mutex.RUnlock()
if len(stats) == 0 {
continue
}
summary := ContainerStatsSummary{
ContainerName: containerName,
SampleCount: len(stats),
}
// Calculate CPU stats
cpuValues := make([]float64, len(stats))
memoryValues := make([]float64, len(stats))
for i, sample := range stats {
cpuValues[i] = sample.CPUUsage
memoryValues[i] = sample.MemoryMB
}
summary.CPU = calculateStatsSummary(cpuValues)
summary.Memory = calculateStatsSummary(memoryValues)
summaries = append(summaries, summary)
}
// Sort by container name for consistent output
sort.Slice(summaries, func(i, j int) bool {
return summaries[i].ContainerName < summaries[j].ContainerName
})
return summaries
}
// calculateStatsSummary calculates min, max, and average for a slice of values
func calculateStatsSummary(values []float64) StatsSummary {
if len(values) == 0 {
return StatsSummary{}
}
min := values[0]
max := values[0]
sum := 0.0
for _, value := range values {
if value < min {
min = value
}
if value > max {
max = value
}
sum += value
}
return StatsSummary{
Min: min,
Max: max,
Average: sum / float64(len(values)),
}
}
// PrintSummary prints the statistics summary to the console
func (sc *StatsCollector) PrintSummary() {
summaries := sc.GetSummary()
if len(summaries) == 0 {
log.Printf("No container statistics collected")
return
}
log.Printf("Container Resource Usage Summary:")
log.Printf("================================")
for _, summary := range summaries {
log.Printf("Container: %s (%d samples)", summary.ContainerName, summary.SampleCount)
log.Printf(" CPU Usage: Min: %6.2f%% Max: %6.2f%% Avg: %6.2f%%",
summary.CPU.Min, summary.CPU.Max, summary.CPU.Average)
log.Printf(" Memory Usage: Min: %6.1f MB Max: %6.1f MB Avg: %6.1f MB",
summary.Memory.Min, summary.Memory.Max, summary.Memory.Average)
log.Printf("")
}
}
// CheckMemoryLimits checks if any containers exceeded their memory limits
func (sc *StatsCollector) CheckMemoryLimits(hsLimitMB, tsLimitMB float64) []MemoryViolation {
if hsLimitMB <= 0 && tsLimitMB <= 0 {
return nil
}
summaries := sc.GetSummary()
var violations []MemoryViolation
for _, summary := range summaries {
var limitMB float64
if strings.HasPrefix(summary.ContainerName, "hs-") {
limitMB = hsLimitMB
} else if strings.HasPrefix(summary.ContainerName, "ts-") {
limitMB = tsLimitMB
} else {
continue // Skip containers that don't match our patterns
}
if limitMB > 0 && summary.Memory.Max > limitMB {
violations = append(violations, MemoryViolation{
ContainerName: summary.ContainerName,
MaxMemoryMB: summary.Memory.Max,
LimitMB: limitMB,
})
}
}
return violations
}
// PrintSummaryAndCheckLimits prints the statistics summary and returns memory violations if any
func (sc *StatsCollector) PrintSummaryAndCheckLimits(hsLimitMB, tsLimitMB float64) []MemoryViolation {
sc.PrintSummary()
return sc.CheckMemoryLimits(hsLimitMB, tsLimitMB)
}
// Close closes the stats collector and cleans up resources
func (sc *StatsCollector) Close() error {
sc.StopCollection()
return sc.client.Close()
}

View File

@ -19,7 +19,7 @@
overlay = _: prev: let
pkgs = nixpkgs.legacyPackages.${prev.system};
buildGo = pkgs.buildGo124Module;
vendorHash = "sha256-S2GnCg2dyfjIyi5gXhVEuRs5Bop2JAhZcnhg1fu4/Gg=";
vendorHash = "sha256-83L2NMyOwKCHWqcowStJ7Ze/U9CJYhzleDRLrJNhX2g=";
in {
headscale = buildGo {
pname = "headscale";

27
go.mod
View File

@ -23,7 +23,6 @@ require (
github.com/gorilla/mux v1.8.1
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.0
github.com/jagottsicher/termcolor v1.0.2
github.com/klauspost/compress v1.18.0
github.com/oauth2-proxy/mockoidc v0.0.0-20240214162133-caebfff84d25
github.com/ory/dockertest/v3 v3.12.0
github.com/philip-bui/grpc-zerolog v1.0.1
@ -43,11 +42,11 @@ require (
github.com/tailscale/tailsql v0.0.0-20250421235516-02f85f087b97
github.com/tcnksm/go-latest v0.0.0-20170313132115-e3007ae9052e
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba
golang.org/x/crypto v0.39.0
golang.org/x/crypto v0.40.0
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0
golang.org/x/net v0.41.0
golang.org/x/net v0.42.0
golang.org/x/oauth2 v0.30.0
golang.org/x/sync v0.15.0
golang.org/x/sync v0.16.0
google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822
google.golang.org/grpc v1.73.0
google.golang.org/protobuf v1.36.6
@ -55,7 +54,7 @@ require (
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/postgres v1.6.0
gorm.io/gorm v1.30.0
tailscale.com v1.84.2
tailscale.com v1.84.3
zgo.at/zcache/v2 v2.2.0
zombiezen.com/go/postgrestest v1.0.1
)
@ -81,7 +80,7 @@ require (
modernc.org/libc v1.62.1 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.10.0 // indirect
modernc.org/sqlite v1.37.0 // indirect
modernc.org/sqlite v1.37.0
)
require (
@ -166,6 +165,7 @@ require (
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/jsimonetti/rtnetlink v1.4.1 // indirect
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/lib/pq v1.10.9 // indirect
@ -231,14 +231,19 @@ require (
go.opentelemetry.io/otel/trace v1.36.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
go4.org/mem v0.0.0-20240501181205-ae6ca9944745 // indirect
golang.org/x/mod v0.25.0 // indirect
golang.org/x/sys v0.33.0 // indirect
golang.org/x/term v0.32.0 // indirect
golang.org/x/text v0.26.0 // indirect
golang.org/x/mod v0.26.0 // indirect
golang.org/x/sys v0.34.0 // indirect
golang.org/x/term v0.33.0 // indirect
golang.org/x/text v0.27.0 // indirect
golang.org/x/time v0.10.0 // indirect
golang.org/x/tools v0.33.0 // indirect
golang.org/x/tools v0.35.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 // indirect
gvisor.dev/gvisor v0.0.0-20250205023644-9414b50a5633 // indirect
)
tool (
golang.org/x/tools/cmd/stringer
tailscale.com/cmd/viewer
)

34
go.sum
View File

@ -555,8 +555,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 h1:R84qjqJb5nVJMxqWYb3np9L5ZsaDtB+a39EqjV0JSUM=
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0/go.mod h1:S9Xr4PYopiDyqSyp5NjCrhFrqg6A5zA2E/iPHPhqnS8=
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f h1:phY1HzDcf18Aq9A8KkmRtY9WvOFIxN8wgfvy6Zm1DV8=
@ -567,8 +567,8 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w=
golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
@ -577,8 +577,8 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs=
golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8=
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@ -587,8 +587,8 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8=
golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@ -615,8 +615,8 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210615171337-6886f2dfbf5b/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
@ -624,8 +624,8 @@ golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuX
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg=
golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ=
golang.org/x/term v0.33.0 h1:NuFncQrRcaRvVmgRkvM3j/F00gWIAlcmlB8ACEKmGIg=
golang.org/x/term v0.33.0/go.mod h1:s18+ql9tYWp1IfpV9DmCtQDDSRBUjKaw9M1eAv5UeF0=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
@ -633,8 +633,8 @@ golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M=
golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4=
golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@ -643,8 +643,8 @@ golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roY
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc=
golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=
golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0=
golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@ -714,6 +714,8 @@ software.sslmate.com/src/go-pkcs12 v0.4.0 h1:H2g08FrTvSFKUj+D309j1DPfk5APnIdAQAB
software.sslmate.com/src/go-pkcs12 v0.4.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=
tailscale.com v1.84.2 h1:v6aM4RWUgYiV52LRAx6ET+dlGnvO/5lnqPXb7/pMnR0=
tailscale.com v1.84.2/go.mod h1:6/S63NMAhmncYT/1zIPDJkvCuZwMw+JnUuOfSPNazpo=
tailscale.com v1.84.3 h1:Ur9LMedSgicwbqpy5xn7t49G8490/s6rqAJOk5Q5AYE=
tailscale.com v1.84.3/go.mod h1:6/S63NMAhmncYT/1zIPDJkvCuZwMw+JnUuOfSPNazpo=
zgo.at/zcache/v2 v2.2.0 h1:K29/IPjMniZfveYE+IRXfrl11tMzHkIPuyGrfVZ2fGo=
zgo.at/zcache/v2 v2.2.0/go.mod h1:gyCeoLVo01QjDZynjime8xUGHHMbsLiPyUTBpDGd4Gk=
zombiezen.com/go/postgrestest v1.0.1 h1:aXoADQAJmZDU3+xilYVut0pHhgc0sF8ZspPW9gFNwP4=

View File

@ -28,14 +28,15 @@ import (
derpServer "github.com/juanfont/headscale/hscontrol/derp/server"
"github.com/juanfont/headscale/hscontrol/dns"
"github.com/juanfont/headscale/hscontrol/mapper"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/state"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"github.com/juanfont/headscale/hscontrol/util"
zerolog "github.com/philip-bui/grpc-zerolog"
"github.com/pkg/profile"
zl "github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/sasha-s/go-deadlock"
"golang.org/x/crypto/acme"
"golang.org/x/crypto/acme/autocert"
"golang.org/x/sync/errgroup"
@ -64,6 +65,19 @@ var (
)
)
var (
debugDeadlock = envknob.Bool("HEADSCALE_DEBUG_DEADLOCK")
debugDeadlockTimeout = envknob.RegisterDuration("HEADSCALE_DEBUG_DEADLOCK_TIMEOUT")
)
func init() {
deadlock.Opts.Disable = !debugDeadlock
if debugDeadlock {
deadlock.Opts.DeadlockTimeout = debugDeadlockTimeout()
deadlock.Opts.PrintAllCurrentGoroutines = true
}
}
const (
AuthPrefix = "Bearer "
updateInterval = 5 * time.Second
@ -82,9 +96,8 @@ type Headscale struct {
// Things that generate changes
extraRecordMan *dns.ExtraRecordsMan
mapper *mapper.Mapper
nodeNotifier *notifier.Notifier
authProvider AuthProvider
mapBatcher mapper.Batcher
pollNetMapStreamWG sync.WaitGroup
}
@ -118,7 +131,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
cfg: cfg,
noisePrivateKey: noisePrivateKey,
pollNetMapStreamWG: sync.WaitGroup{},
nodeNotifier: notifier.NewNotifier(cfg),
state: s,
}
@ -136,12 +148,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
return
}
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "ephemeral-gc-policy", node.Hostname)
app.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
app.Change(policyChanged)
log.Debug().Uint64("node.id", ni.Uint64()).Msgf("deleted ephemeral node")
})
app.ephemeralGC = ephemeralGC
@ -153,10 +160,9 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
defer cancel()
oidcProvider, err := NewAuthProviderOIDC(
ctx,
&app,
cfg.ServerURL,
&cfg.OIDC,
app.state,
app.nodeNotifier,
)
if err != nil {
if cfg.OIDC.OnlyStartIfOIDCIsAvailable {
@ -262,16 +268,18 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
return
case <-expireTicker.C:
var update types.StateUpdate
var expiredNodeChanges []change.ChangeSet
var changed bool
lastExpiryCheck, update, changed = h.state.ExpireExpiredNodes(lastExpiryCheck)
lastExpiryCheck, expiredNodeChanges, changed = h.state.ExpireExpiredNodes(lastExpiryCheck)
if changed {
log.Trace().Interface("nodes", update.ChangePatches).Msgf("expiring nodes")
log.Trace().Interface("changes", expiredNodeChanges).Msgf("expiring nodes")
ctx := types.NotifyCtx(context.Background(), "expire-expired", "na")
h.nodeNotifier.NotifyAll(ctx, update)
// Send the changes directly since they're already in the new format
for _, nodeChange := range expiredNodeChanges {
h.Change(nodeChange)
}
}
case <-derpTickerChan:
@ -282,11 +290,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
derpMap.Regions[region.RegionID] = &region
}
ctx := types.NotifyCtx(context.Background(), "derpmap-update", "na")
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StateDERPUpdated,
DERPMap: derpMap,
})
h.Change(change.DERPSet)
case records, ok := <-extraRecordsUpdate:
if !ok {
@ -294,19 +298,16 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
}
h.cfg.TailcfgDNSConfig.ExtraRecords = records
ctx := types.NotifyCtx(context.Background(), "dns-extrarecord", "all")
// TODO(kradalby): We can probably do better than sending a full update here,
// but for now this will ensure that all of the nodes get the new records.
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
h.Change(change.ExtraRecordsSet)
}
}
}
func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
req interface{},
req any,
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
) (any, error) {
// Check if the request is coming from the on-server client.
// This is not secure, but it is to maintain maintainability
// with the "legacy" database-based client
@ -484,58 +485,6 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
return router
}
// // TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
// // Maybe we should attempt a new in memory state and not go via the DB?
// // Maybe this should be implemented as an event bus?
// // A bool is returned indicating if a full update was sent to all nodes
// func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error {
// users, err := db.ListUsers()
// if err != nil {
// return err
// }
// changed, err := polMan.SetUsers(users)
// if err != nil {
// return err
// }
// if changed {
// ctx := types.NotifyCtx(context.Background(), "acl-users-change", "all")
// notif.NotifyAll(ctx, types.UpdateFull())
// }
// return nil
// }
// // TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
// // Maybe we should attempt a new in memory state and not go via the DB?
// // Maybe this should be implemented as an event bus?
// // A bool is returned indicating if a full update was sent to all nodes
// func nodesChangedHook(
// db *db.HSDatabase,
// polMan policy.PolicyManager,
// notif *notifier.Notifier,
// ) (bool, error) {
// nodes, err := db.ListNodes()
// if err != nil {
// return false, err
// }
// filterChanged, err := polMan.SetNodes(nodes)
// if err != nil {
// return false, err
// }
// if filterChanged {
// ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all")
// notif.NotifyAll(ctx, types.UpdateFull())
// return true, nil
// }
// return false, nil
// }
// Serve launches the HTTP and gRPC server service Headscale and the API.
func (h *Headscale) Serve() error {
capver.CanOldCodeBeCleanedUp()
@ -562,8 +511,9 @@ func (h *Headscale) Serve() error {
Str("minimum_version", capver.TailscaleVersion(capver.MinSupportedCapabilityVersion)).
Msg("Clients with a lower minimum version will be rejected")
// Fetch an initial DERP Map before we start serving
h.mapper = mapper.NewMapper(h.state, h.cfg, h.nodeNotifier)
h.mapBatcher = mapper.NewBatcherAndMapper(h.cfg, h.state)
h.mapBatcher.Start()
defer h.mapBatcher.Close()
// TODO(kradalby): fix state part.
if h.cfg.DERP.ServerEnabled {
@ -838,8 +788,12 @@ func (h *Headscale) Serve() error {
log.Info().
Msg("ACL policy successfully reloaded, notifying nodes of change")
ctx := types.NotifyCtx(context.Background(), "acl-sighup", "na")
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
err = h.state.AutoApproveNodes()
if err != nil {
log.Error().Err(err).Msg("failed to approve routes after new policy")
}
h.Change(change.PolicySet)
}
default:
info := func(msg string) { log.Info().Msg(msg) }
@ -865,7 +819,6 @@ func (h *Headscale) Serve() error {
}
info("closing node notifier")
h.nodeNotifier.Close()
info("waiting for netmap stream to close")
h.pollNetMapStreamWG.Wait()
@ -1047,3 +1000,10 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
return &machineKey, nil
}
// Change is used to send changes to nodes.
// All change should be enqueued here and empty will be automatically
// ignored.
func (h *Headscale) Change(c change.ChangeSet) {
h.mapBatcher.AddWork(c)
}

View File

@ -10,6 +10,8 @@ import (
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
@ -32,6 +34,21 @@ func (h *Headscale) handleRegister(
}
if node != nil {
// If an existing node is trying to register with an auth key,
// we need to validate the auth key even for existing nodes
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
resp, err := h.handleRegisterWithAuthKey(regReq, machineKey)
if err != nil {
// Preserve HTTPError types so they can be handled properly by the HTTP layer
var httpErr HTTPError
if errors.As(err, &httpErr) {
return nil, httpErr
}
return nil, fmt.Errorf("handling register with auth key for existing node: %w", err)
}
return resp, nil
}
resp, err := h.handleExistingNode(node, regReq, machineKey)
if err != nil {
return nil, fmt.Errorf("handling existing node: %w", err)
@ -47,6 +64,11 @@ func (h *Headscale) handleRegister(
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
resp, err := h.handleRegisterWithAuthKey(regReq, machineKey)
if err != nil {
// Preserve HTTPError types so they can be handled properly by the HTTP layer
var httpErr HTTPError
if errors.As(err, &httpErr) {
return nil, httpErr
}
return nil, fmt.Errorf("handling register with auth key: %w", err)
}
@ -66,11 +88,13 @@ func (h *Headscale) handleExistingNode(
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
if node.MachineKey != machineKey {
return nil, NewHTTPError(http.StatusUnauthorized, "node exist with different machine key", nil)
}
expired := node.IsExpired()
if !expired && !regReq.Expiry.IsZero() {
requestExpiry := regReq.Expiry
@ -82,39 +106,23 @@ func (h *Headscale) handleExistingNode(
// If the request expiry is in the past, we consider it a logout.
if requestExpiry.Before(time.Now()) {
if node.IsEphemeral() {
policyChanged, err := h.state.DeleteNode(node)
c, err := h.state.DeleteNode(node)
if err != nil {
return nil, fmt.Errorf("deleting ephemeral node: %w", err)
}
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "auth-logout-ephemeral-policy", "na")
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
} else {
ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na")
h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerRemoved(node.ID))
}
h.Change(c)
return nil, nil
}
}
n, policyChanged, err := h.state.SetNodeExpiry(node.ID, requestExpiry)
_, c, err := h.state.SetNodeExpiry(node.ID, requestExpiry)
if err != nil {
return nil, fmt.Errorf("setting node expiry: %w", err)
}
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "auth-expiry-policy", "na")
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
} else {
ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na")
h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdateExpire(node.ID, requestExpiry), node.ID)
}
return nodeToRegisterResponse(n), nil
h.Change(c)
}
return nodeToRegisterResponse(node), nil
@ -168,7 +176,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
node, changed, err := h.state.HandleNodeFromPreAuthKey(
node, changed, policyChanged, err := h.state.HandleNodeFromPreAuthKey(
regReq,
machineKey,
)
@ -184,6 +192,12 @@ func (h *Headscale) handleRegisterWithAuthKey(
return nil, err
}
// If node is nil, it means an ephemeral node was deleted during logout
if node == nil {
h.Change(changed)
return nil, nil
}
// This is a bit of a back and forth, but we have a bit of a chicken and egg
// dependency here.
// Because the way the policy manager works, we need to have the node
@ -195,23 +209,22 @@ func (h *Headscale) handleRegisterWithAuthKey(
// ensure we send an update.
// This works, but might be another good candidate for doing some sort of
// eventbus.
routesChanged := h.state.AutoApproveRoutes(node)
// TODO(kradalby): This needs to be ran as part of the batcher maybe?
// now since we dont update the node/pol here anymore
routeChange := h.state.AutoApproveRoutes(node)
if _, _, err := h.state.SaveNode(node); err != nil {
return nil, fmt.Errorf("saving auto approved routes to node: %w", err)
}
if routesChanged {
ctx := types.NotifyCtx(context.Background(), "node updated", node.Hostname)
h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(node.ID))
} else if changed {
ctx := types.NotifyCtx(context.Background(), "node created", node.Hostname)
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
} else {
// Existing node re-registering without route changes
// Still need to notify peers about the node being active again
// Use UpdateFull to ensure all peers get complete peer maps
ctx := types.NotifyCtx(context.Background(), "node re-registered", node.Hostname)
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
if routeChange && changed.Empty() {
changed = change.NodeAdded(node.ID)
}
h.Change(changed)
// If policy changed due to node registration, send a separate policy change
if policyChanged {
policyChange := change.PolicyChange()
h.Change(policyChange)
}
return &tailcfg.RegisterResponse{

View File

@ -1,5 +1,7 @@
package capver
//go:generate go run ../../tools/capver/main.go
import (
"slices"
"sort"
@ -10,7 +12,7 @@ import (
"tailscale.com/util/set"
)
const MinSupportedCapabilityVersion tailcfg.CapabilityVersion = 88
const MinSupportedCapabilityVersion tailcfg.CapabilityVersion = 90
// CanOldCodeBeCleanedUp is intended to be called on startup to see if
// there are old code that can ble cleaned up, entries should contain

View File

@ -1,14 +1,10 @@
package capver
// Generated DO NOT EDIT
//Generated DO NOT EDIT
import "tailscale.com/tailcfg"
var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{
"v1.60.0": 87,
"v1.60.1": 87,
"v1.62.0": 88,
"v1.62.1": 88,
"v1.64.0": 90,
"v1.64.1": 90,
"v1.64.2": 90,
@ -36,11 +32,13 @@ var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{
"v1.80.3": 113,
"v1.82.0": 115,
"v1.82.5": 115,
"v1.84.0": 116,
"v1.84.1": 116,
"v1.84.2": 116,
}
var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{
87: "v1.60.0",
88: "v1.62.0",
90: "v1.64.0",
95: "v1.66.0",
97: "v1.68.0",
@ -50,4 +48,5 @@ var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{
109: "v1.78.0",
113: "v1.80.0",
115: "v1.82.0",
116: "v1.84.0",
}

View File

@ -13,11 +13,10 @@ func TestTailscaleLatestMajorMinor(t *testing.T) {
stripV bool
expected []string
}{
{3, false, []string{"v1.78", "v1.80", "v1.82"}},
{2, true, []string{"1.80", "1.82"}},
{3, false, []string{"v1.80", "v1.82", "v1.84"}},
{2, true, []string{"1.82", "1.84"}},
// Lazy way to see all supported versions
{10, true, []string{
"1.64",
"1.66",
"1.68",
"1.70",
@ -27,6 +26,7 @@ func TestTailscaleLatestMajorMinor(t *testing.T) {
"1.78",
"1.80",
"1.82",
"1.84",
}},
{0, false, nil},
}
@ -46,7 +46,6 @@ func TestCapVerMinimumTailscaleVersion(t *testing.T) {
input tailcfg.CapabilityVersion
expected string
}{
{88, "v1.62.0"},
{90, "v1.64.0"},
{95, "v1.66.0"},
{106, "v1.74.0"},

View File

@ -7,7 +7,6 @@ import (
"os/exec"
"path/filepath"
"slices"
"sort"
"strings"
"testing"
"time"
@ -362,8 +361,8 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) {
}
if diff := cmp.Diff(expectedKeys, keys, cmp.Comparer(func(a, b []string) bool {
sort.Sort(sort.StringSlice(a))
sort.Sort(sort.StringSlice(b))
slices.Sort(a)
slices.Sort(b)
return slices.Equal(a, b)
}), cmpopts.IgnoreFields(types.PreAuthKey{}, "User", "CreatedAt", "Reusable", "Ephemeral", "Used", "Expiration")); diff != "" {
t.Errorf("TestSQLiteMigrationAndDataValidation() pre-auth key tags migration mismatch (-want +got):\n%s", diff)

View File

@ -7,15 +7,19 @@ import (
"net/netip"
"slices"
"sort"
"strconv"
"sync"
"testing"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
)
const (
@ -39,9 +43,7 @@ var (
// If no peer IDs are given, all peers are returned.
// If at least one peer ID is given, only these peer nodes will be returned.
func (hsdb *HSDatabase) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
return ListPeers(rx, nodeID, peerIDs...)
})
return ListPeers(hsdb.DB, nodeID, peerIDs...)
}
// ListPeers returns peers of node, regardless of any Policy or if the node is expired.
@ -66,9 +68,7 @@ func ListPeers(tx *gorm.DB, nodeID types.NodeID, peerIDs ...types.NodeID) (types
// ListNodes queries the database for either all nodes if no parameters are given
// or for the given nodes if at least one node ID is given as parameter.
func (hsdb *HSDatabase) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
return ListNodes(rx, nodeIDs...)
})
return ListNodes(hsdb.DB, nodeIDs...)
}
// ListNodes queries the database for either all nodes if no parameters are given
@ -120,9 +120,7 @@ func getNode(tx *gorm.DB, uid types.UserID, name string) (*types.Node, error) {
}
func (hsdb *HSDatabase) GetNodeByID(id types.NodeID) (*types.Node, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
return GetNodeByID(rx, id)
})
return GetNodeByID(hsdb.DB, id)
}
// GetNodeByID finds a Node by ID and returns the Node struct.
@ -140,9 +138,7 @@ func GetNodeByID(tx *gorm.DB, id types.NodeID) (*types.Node, error) {
}
func (hsdb *HSDatabase) GetNodeByMachineKey(machineKey key.MachinePublic) (*types.Node, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
return GetNodeByMachineKey(rx, machineKey)
})
return GetNodeByMachineKey(hsdb.DB, machineKey)
}
// GetNodeByMachineKey finds a Node by its MachineKey and returns the Node struct.
@ -163,9 +159,7 @@ func GetNodeByMachineKey(
}
func (hsdb *HSDatabase) GetNodeByNodeKey(nodeKey key.NodePublic) (*types.Node, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
return GetNodeByNodeKey(rx, nodeKey)
})
return GetNodeByNodeKey(hsdb.DB, nodeKey)
}
// GetNodeByNodeKey finds a Node by its NodeKey and returns the Node struct.
@ -352,8 +346,8 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
registrationMethod string,
ipv4 *netip.Addr,
ipv6 *netip.Addr,
) (*types.Node, bool, error) {
var newNode bool
) (*types.Node, change.ChangeSet, error) {
var nodeChange change.ChangeSet
node, err := Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
if reg, ok := hsdb.regCache.Get(registrationID); ok {
if node, _ := GetNodeByNodeKey(tx, reg.Node.NodeKey); node == nil {
@ -405,7 +399,7 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
}
close(reg.Registered)
newNode = true
nodeChange = change.NodeAdded(node.ID)
return node, err
} else {
@ -415,6 +409,8 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
return nil, err
}
nodeChange = change.KeyExpiry(node.ID)
return node, nil
}
}
@ -422,7 +418,7 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
return nil, ErrNodeNotFoundRegistrationCache
})
return node, newNode, err
return node, nodeChange, err
}
func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {
@ -448,6 +444,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
if oldNode != nil && oldNode.UserID == node.UserID {
node.ID = oldNode.ID
node.GivenName = oldNode.GivenName
node.ApprovedRoutes = oldNode.ApprovedRoutes
ipv4 = oldNode.IPv4
ipv6 = oldNode.IPv6
}
@ -594,17 +591,18 @@ func ensureUniqueGivenName(
// containing the expired nodes, and a boolean indicating if any nodes were found.
func ExpireExpiredNodes(tx *gorm.DB,
lastCheck time.Time,
) (time.Time, types.StateUpdate, bool) {
) (time.Time, []change.ChangeSet, bool) {
// use the time of the start of the function to ensure we
// dont miss some nodes by returning it _after_ we have
// checked everything.
started := time.Now()
expired := make([]*tailcfg.PeerChange, 0)
var updates []change.ChangeSet
nodes, err := ListNodes(tx)
if err != nil {
return time.Unix(0, 0), types.StateUpdate{}, false
return time.Unix(0, 0), nil, false
}
for _, node := range nodes {
if node.IsExpired() && node.Expiry.After(lastCheck) {
@ -612,14 +610,15 @@ func ExpireExpiredNodes(tx *gorm.DB,
NodeID: tailcfg.NodeID(node.ID),
KeyExpiry: node.Expiry,
})
updates = append(updates, change.KeyExpiry(node.ID))
}
}
if len(expired) > 0 {
return started, types.UpdatePeerPatch(expired...), true
return started, updates, true
}
return started, types.StateUpdate{}, false
return started, nil, false
}
// EphemeralGarbageCollector is a garbage collector that will delete nodes after
@ -732,3 +731,114 @@ func (e *EphemeralGarbageCollector) Start() {
}
}
}
func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string) *types.Node {
if !testing.Testing() {
panic("CreateNodeForTest can only be called during tests")
}
if user == nil {
panic("CreateNodeForTest requires a valid user")
}
nodeName := "testnode"
if len(hostname) > 0 && hostname[0] != "" {
nodeName = hostname[0]
}
// Create a preauth key for the node
pak, err := hsdb.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
if err != nil {
panic(fmt.Sprintf("failed to create preauth key for test node: %v", err))
}
nodeKey := key.NewNode()
machineKey := key.NewMachine()
discoKey := key.NewDisco()
node := &types.Node{
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
DiscoKey: discoKey.Public(),
Hostname: nodeName,
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
}
err = hsdb.DB.Save(node).Error
if err != nil {
panic(fmt.Sprintf("failed to create test node: %v", err))
}
return node
}
func (hsdb *HSDatabase) CreateRegisteredNodeForTest(user *types.User, hostname ...string) *types.Node {
if !testing.Testing() {
panic("CreateRegisteredNodeForTest can only be called during tests")
}
node := hsdb.CreateNodeForTest(user, hostname...)
err := hsdb.DB.Transaction(func(tx *gorm.DB) error {
_, err := RegisterNode(tx, *node, nil, nil)
return err
})
if err != nil {
panic(fmt.Sprintf("failed to register test node: %v", err))
}
registeredNode, err := hsdb.GetNodeByID(node.ID)
if err != nil {
panic(fmt.Sprintf("failed to get registered test node: %v", err))
}
return registeredNode
}
func (hsdb *HSDatabase) CreateNodesForTest(user *types.User, count int, hostnamePrefix ...string) []*types.Node {
if !testing.Testing() {
panic("CreateNodesForTest can only be called during tests")
}
if user == nil {
panic("CreateNodesForTest requires a valid user")
}
prefix := "testnode"
if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" {
prefix = hostnamePrefix[0]
}
nodes := make([]*types.Node, count)
for i := range count {
hostname := prefix + "-" + strconv.Itoa(i)
nodes[i] = hsdb.CreateNodeForTest(user, hostname)
}
return nodes
}
func (hsdb *HSDatabase) CreateRegisteredNodesForTest(user *types.User, count int, hostnamePrefix ...string) []*types.Node {
if !testing.Testing() {
panic("CreateRegisteredNodesForTest can only be called during tests")
}
if user == nil {
panic("CreateRegisteredNodesForTest requires a valid user")
}
prefix := "testnode"
if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" {
prefix = hostnamePrefix[0]
}
nodes := make([]*types.Node, count)
for i := range count {
hostname := prefix + "-" + strconv.Itoa(i)
nodes[i] = hsdb.CreateRegisteredNodeForTest(user, hostname)
}
return nodes
}

View File

@ -6,7 +6,6 @@ import (
"math/big"
"net/netip"
"regexp"
"strconv"
"sync"
"testing"
"time"
@ -26,82 +25,36 @@ import (
)
func (s *Suite) TestGetNode(c *check.C) {
user, err := db.CreateUser(types.User{Name: "test"})
c.Assert(err, check.IsNil)
user := db.CreateUserForTest("test")
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.getNode(types.UserID(user.ID), "testnode")
_, err := db.getNode(types.UserID(user.ID), "testnode")
c.Assert(err, check.NotNil)
nodeKey := key.NewNode()
machineKey := key.NewMachine()
node := &types.Node{
ID: 0,
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
}
trx := db.DB.Save(node)
c.Assert(trx.Error, check.IsNil)
node := db.CreateNodeForTest(user, "testnode")
_, err = db.getNode(types.UserID(user.ID), "testnode")
c.Assert(err, check.IsNil)
c.Assert(node.Hostname, check.Equals, "testnode")
}
func (s *Suite) TestGetNodeByID(c *check.C) {
user, err := db.CreateUser(types.User{Name: "test"})
c.Assert(err, check.IsNil)
user := db.CreateUserForTest("test")
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNodeByID(0)
_, err := db.GetNodeByID(0)
c.Assert(err, check.NotNil)
nodeKey := key.NewNode()
machineKey := key.NewMachine()
node := db.CreateNodeForTest(user, "testnode")
node := types.Node{
ID: 0,
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
}
trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)
_, err = db.GetNodeByID(0)
retrievedNode, err := db.GetNodeByID(node.ID)
c.Assert(err, check.IsNil)
c.Assert(retrievedNode.Hostname, check.Equals, "testnode")
}
func (s *Suite) TestHardDeleteNode(c *check.C) {
user, err := db.CreateUser(types.User{Name: "test"})
c.Assert(err, check.IsNil)
user := db.CreateUserForTest("test")
node := db.CreateNodeForTest(user, "testnode3")
nodeKey := key.NewNode()
machineKey := key.NewMachine()
node := types.Node{
ID: 0,
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "testnode3",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
}
trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)
err = db.DeleteNode(&node)
err := db.DeleteNode(node)
c.Assert(err, check.IsNil)
_, err = db.getNode(types.UserID(user.ID), "testnode3")
@ -109,42 +62,21 @@ func (s *Suite) TestHardDeleteNode(c *check.C) {
}
func (s *Suite) TestListPeers(c *check.C) {
user, err := db.CreateUser(types.User{Name: "test"})
c.Assert(err, check.IsNil)
user := db.CreateUserForTest("test")
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNodeByID(0)
_, err := db.GetNodeByID(0)
c.Assert(err, check.NotNil)
for index := range 11 {
nodeKey := key.NewNode()
machineKey := key.NewMachine()
nodes := db.CreateNodesForTest(user, 11, "testnode")
node := types.Node{
ID: types.NodeID(index),
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "testnode" + strconv.Itoa(index),
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
}
trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)
}
node0ByID, err := db.GetNodeByID(0)
firstNode := nodes[0]
peersOfFirstNode, err := db.ListPeers(firstNode.ID)
c.Assert(err, check.IsNil)
peersOfNode0, err := db.ListPeers(node0ByID.ID)
c.Assert(err, check.IsNil)
c.Assert(len(peersOfNode0), check.Equals, 9)
c.Assert(peersOfNode0[0].Hostname, check.Equals, "testnode2")
c.Assert(peersOfNode0[5].Hostname, check.Equals, "testnode7")
c.Assert(peersOfNode0[8].Hostname, check.Equals, "testnode10")
c.Assert(len(peersOfFirstNode), check.Equals, 10)
c.Assert(peersOfFirstNode[0].Hostname, check.Equals, "testnode-1")
c.Assert(peersOfFirstNode[5].Hostname, check.Equals, "testnode-6")
c.Assert(peersOfFirstNode[9].Hostname, check.Equals, "testnode-10")
}
func (s *Suite) TestExpireNode(c *check.C) {
@ -807,13 +739,13 @@ func TestListPeers(t *testing.T) {
// No parameter means no filter, should return all peers
nodes, err = db.ListPeers(1)
require.NoError(t, err)
assert.Len(t, nodes, 1)
assert.Equal(t, 1, len(nodes))
assert.Equal(t, "test2", nodes[0].Hostname)
// Empty node list should return all peers
nodes, err = db.ListPeers(1, types.NodeIDs{}...)
require.NoError(t, err)
assert.Len(t, nodes, 1)
assert.Equal(t, 1, len(nodes))
assert.Equal(t, "test2", nodes[0].Hostname)
// No match in IDs should return empty list and no error
@ -824,13 +756,13 @@ func TestListPeers(t *testing.T) {
// Partial match in IDs
nodes, err = db.ListPeers(1, types.NodeIDs{2, 3}...)
require.NoError(t, err)
assert.Len(t, nodes, 1)
assert.Equal(t, 1, len(nodes))
assert.Equal(t, "test2", nodes[0].Hostname)
// Several matched IDs, but node ID is still filtered out
nodes, err = db.ListPeers(1, types.NodeIDs{1, 2, 3}...)
require.NoError(t, err)
assert.Len(t, nodes, 1)
assert.Equal(t, 1, len(nodes))
assert.Equal(t, "test2", nodes[0].Hostname)
}
@ -892,14 +824,14 @@ func TestListNodes(t *testing.T) {
// No parameter means no filter, should return all nodes
nodes, err = db.ListNodes()
require.NoError(t, err)
assert.Len(t, nodes, 2)
assert.Equal(t, 2, len(nodes))
assert.Equal(t, "test1", nodes[0].Hostname)
assert.Equal(t, "test2", nodes[1].Hostname)
// Empty node list should return all nodes
nodes, err = db.ListNodes(types.NodeIDs{}...)
require.NoError(t, err)
assert.Len(t, nodes, 2)
assert.Equal(t, 2, len(nodes))
assert.Equal(t, "test1", nodes[0].Hostname)
assert.Equal(t, "test2", nodes[1].Hostname)
@ -911,13 +843,13 @@ func TestListNodes(t *testing.T) {
// Partial match in IDs
nodes, err = db.ListNodes(types.NodeIDs{2, 3}...)
require.NoError(t, err)
assert.Len(t, nodes, 1)
assert.Equal(t, 1, len(nodes))
assert.Equal(t, "test2", nodes[0].Hostname)
// Several matched IDs
nodes, err = db.ListNodes(types.NodeIDs{1, 2, 3}...)
require.NoError(t, err)
assert.Len(t, nodes, 2)
assert.Equal(t, 2, len(nodes))
assert.Equal(t, "test1", nodes[0].Hostname)
assert.Equal(t, "test2", nodes[1].Hostname)
}

View File

@ -109,9 +109,7 @@ func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, e
}
func (hsdb *HSDatabase) GetPreAuthKey(key string) (*types.PreAuthKey, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.PreAuthKey, error) {
return GetPreAuthKey(rx, key)
})
return GetPreAuthKey(hsdb.DB, key)
}
// GetPreAuthKey returns a PreAuthKey for a given key. The caller is responsible
@ -155,11 +153,8 @@ func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
if err := tx.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
return err
}
return nil
now := time.Now()
return tx.Model(&types.PreAuthKey{}).Where("id = ?", k.ID).Update("expiration", now).Error
}
func generateKey() (string, error) {

View File

@ -1,7 +1,7 @@
package db
import (
"sort"
"slices"
"testing"
"github.com/juanfont/headscale/hscontrol/types"
@ -57,7 +57,7 @@ func (*Suite) TestPreAuthKeyACLTags(c *check.C) {
listedPaks, err := db.ListPreAuthKeys(types.UserID(user.ID))
c.Assert(err, check.IsNil)
gotTags := listedPaks[0].Proto().GetAclTags()
sort.Sort(sort.StringSlice(gotTags))
slices.Sort(gotTags)
c.Assert(gotTags, check.DeepEquals, tags)
}

View File

@ -3,6 +3,8 @@ package db
import (
"errors"
"fmt"
"strconv"
"testing"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
@ -110,9 +112,7 @@ func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error {
}
func (hsdb *HSDatabase) GetUserByID(uid types.UserID) (*types.User, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
return GetUserByID(rx, uid)
})
return GetUserByID(hsdb.DB, uid)
}
func GetUserByID(tx *gorm.DB, uid types.UserID) (*types.User, error) {
@ -146,9 +146,7 @@ func GetUserByOIDCIdentifier(tx *gorm.DB, id string) (*types.User, error) {
}
func (hsdb *HSDatabase) ListUsers(where ...*types.User) ([]types.User, error) {
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) {
return ListUsers(rx, where...)
})
return ListUsers(hsdb.DB, where...)
}
// ListUsers gets all the existing users.
@ -217,3 +215,40 @@ func AssignNodeToUser(tx *gorm.DB, nodeID types.NodeID, uid types.UserID) error
return nil
}
func (hsdb *HSDatabase) CreateUserForTest(name ...string) *types.User {
if !testing.Testing() {
panic("CreateUserForTest can only be called during tests")
}
userName := "testuser"
if len(name) > 0 && name[0] != "" {
userName = name[0]
}
user, err := hsdb.CreateUser(types.User{Name: userName})
if err != nil {
panic(fmt.Sprintf("failed to create test user: %v", err))
}
return user
}
func (hsdb *HSDatabase) CreateUsersForTest(count int, namePrefix ...string) []*types.User {
if !testing.Testing() {
panic("CreateUsersForTest can only be called during tests")
}
prefix := "testuser"
if len(namePrefix) > 0 && namePrefix[0] != "" {
prefix = namePrefix[0]
}
users := make([]*types.User, count)
for i := range count {
name := prefix + "-" + strconv.Itoa(i)
users[i] = hsdb.CreateUserForTest(name)
}
return users
}

View File

@ -11,8 +11,7 @@ import (
)
func (s *Suite) TestCreateAndDestroyUser(c *check.C) {
user, err := db.CreateUser(types.User{Name: "test"})
c.Assert(err, check.IsNil)
user := db.CreateUserForTest("test")
c.Assert(user.Name, check.Equals, "test")
users, err := db.ListUsers()
@ -30,8 +29,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
err := db.DestroyUser(9998)
c.Assert(err, check.Equals, ErrUserNotFound)
user, err := db.CreateUser(types.User{Name: "test"})
c.Assert(err, check.IsNil)
user := db.CreateUserForTest("test")
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
c.Assert(err, check.IsNil)
@ -64,8 +62,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
}
func (s *Suite) TestRenameUser(c *check.C) {
userTest, err := db.CreateUser(types.User{Name: "test"})
c.Assert(err, check.IsNil)
userTest := db.CreateUserForTest("test")
c.Assert(userTest.Name, check.Equals, "test")
users, err := db.ListUsers()
@ -86,8 +83,7 @@ func (s *Suite) TestRenameUser(c *check.C) {
err = db.RenameUser(99988, "test")
c.Assert(err, check.Equals, ErrUserNotFound)
userTest2, err := db.CreateUser(types.User{Name: "test2"})
c.Assert(err, check.IsNil)
userTest2 := db.CreateUserForTest("test2")
c.Assert(userTest2.Name, check.Equals, "test2")
want := "UNIQUE constraint failed"
@ -98,11 +94,8 @@ func (s *Suite) TestRenameUser(c *check.C) {
}
func (s *Suite) TestSetMachineUser(c *check.C) {
oldUser, err := db.CreateUser(types.User{Name: "old"})
c.Assert(err, check.IsNil)
newUser, err := db.CreateUser(types.User{Name: "new"})
c.Assert(err, check.IsNil)
oldUser := db.CreateUserForTest("old")
newUser := db.CreateUserForTest("new")
pak, err := db.CreatePreAuthKey(types.UserID(oldUser.ID), false, false, nil, nil)
c.Assert(err, check.IsNil)

View File

@ -17,10 +17,6 @@ import (
func (h *Headscale) debugHTTPServer() *http.Server {
debugMux := http.NewServeMux()
debug := tsweb.Debugger(debugMux)
debug.Handle("notifier", "Connected nodes in notifier", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(h.nodeNotifier.String()))
}))
debug.Handle("config", "Current configuration", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
config, err := json.MarshalIndent(h.cfg, "", " ")
if err != nil {

View File

@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"io"
"maps"
"net/http"
"net/url"
"os"
@ -72,9 +73,7 @@ func mergeDERPMaps(derpMaps []*tailcfg.DERPMap) *tailcfg.DERPMap {
}
for _, derpMap := range derpMaps {
for id, region := range derpMap.Regions {
result.Regions[id] = region
}
maps.Copy(result.Regions, derpMap.Regions)
}
return &result

View File

@ -1,3 +1,5 @@
//go:generate buf generate --template ../buf.gen.yaml -o .. ../proto
// nolint
package hscontrol
@ -27,6 +29,7 @@ import (
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/state"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"github.com/juanfont/headscale/hscontrol/util"
)
@ -56,12 +59,14 @@ func (api headscaleV1APIServer) CreateUser(
return nil, status.Errorf(codes.Internal, "failed to create user: %s", err)
}
// Send policy update notifications if needed
c := change.UserAdded(types.UserID(user.ID))
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "grpc-user-created", user.Name)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
c.Change = change.Policy
}
api.h.Change(c)
return &v1.CreateUserResponse{User: user.Proto()}, nil
}
@ -81,8 +86,7 @@ func (api headscaleV1APIServer) RenameUser(
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "grpc-user-renamed", request.GetNewName())
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
api.h.Change(change.PolicyChange())
}
newUser, err := api.h.state.GetUserByName(request.GetNewName())
@ -107,6 +111,8 @@ func (api headscaleV1APIServer) DeleteUser(
return nil, err
}
api.h.Change(change.UserRemoved(types.UserID(user.ID)))
return &v1.DeleteUserResponse{}, nil
}
@ -246,7 +252,7 @@ func (api headscaleV1APIServer) RegisterNode(
return nil, fmt.Errorf("looking up user: %w", err)
}
node, _, err := api.h.state.HandleNodeFromAuthPath(
node, nodeChange, err := api.h.state.HandleNodeFromAuthPath(
registrationId,
types.UserID(user.ID),
nil,
@ -267,22 +273,13 @@ func (api headscaleV1APIServer) RegisterNode(
// ensure we send an update.
// This works, but might be another good candidate for doing some sort of
// eventbus.
routesChanged := api.h.state.AutoApproveRoutes(node)
_, policyChanged, err := api.h.state.SaveNode(node)
_ = api.h.state.AutoApproveRoutes(node)
_, _, err = api.h.state.SaveNode(node)
if err != nil {
return nil, fmt.Errorf("saving auto approved routes to node: %w", err)
}
// Send policy update notifications if needed (from SaveNode or route changes)
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "grpc-nodes-change", "all")
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
if routesChanged {
ctx = types.NotifyCtx(context.Background(), "web-node-login", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(node.ID))
}
api.h.Change(nodeChange)
return &v1.RegisterNodeResponse{Node: node.Proto()}, nil
}
@ -300,7 +297,7 @@ func (api headscaleV1APIServer) GetNode(
// Populate the online field based on
// currently connected nodes.
resp.Online = api.h.nodeNotifier.IsConnected(node.ID)
resp.Online = api.h.mapBatcher.IsConnected(node.ID)
return &v1.GetNodeResponse{Node: resp}, nil
}
@ -316,21 +313,14 @@ func (api headscaleV1APIServer) SetTags(
}
}
node, policyChanged, err := api.h.state.SetNodeTags(types.NodeID(request.GetNodeId()), request.GetTags())
node, nodeChange, err := api.h.state.SetNodeTags(types.NodeID(request.GetNodeId()), request.GetTags())
if err != nil {
return &v1.SetTagsResponse{
Node: nil,
}, status.Error(codes.InvalidArgument, err.Error())
}
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "grpc-node-tags", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
ctx = types.NotifyCtx(ctx, "cli-settags", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
api.h.Change(nodeChange)
log.Trace().
Str("node", node.Hostname).
@ -362,23 +352,19 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
tsaddr.SortPrefixes(routes)
routes = slices.Compact(routes)
node, policyChanged, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), routes)
node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), routes)
if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "grpc-routes-approved", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
routeChange := api.h.state.SetNodeRoutes(node.ID, node.SubnetRoutes()...)
if api.h.state.SetNodeRoutes(node.ID, node.SubnetRoutes()...) {
ctx := types.NotifyCtx(ctx, "poll-primary-change", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
} else {
ctx = types.NotifyCtx(ctx, "cli-approveroutes", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
// Always propagate node changes from SetApprovedRoutes
api.h.Change(nodeChange)
// If routes changed, propagate those changes too
if !routeChange.Empty() {
api.h.Change(routeChange)
}
proto := node.Proto()
@ -409,19 +395,12 @@ func (api headscaleV1APIServer) DeleteNode(
return nil, err
}
policyChanged, err := api.h.state.DeleteNode(node)
nodeChange, err := api.h.state.DeleteNode(node)
if err != nil {
return nil, err
}
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "grpc-node-deleted", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
ctx = types.NotifyCtx(ctx, "cli-deletenode", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerRemoved(node.ID))
api.h.Change(nodeChange)
return &v1.DeleteNodeResponse{}, nil
}
@ -432,25 +411,13 @@ func (api headscaleV1APIServer) ExpireNode(
) (*v1.ExpireNodeResponse, error) {
now := time.Now()
node, policyChanged, err := api.h.state.SetNodeExpiry(types.NodeID(request.GetNodeId()), now)
node, nodeChange, err := api.h.state.SetNodeExpiry(types.NodeID(request.GetNodeId()), now)
if err != nil {
return nil, err
}
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "grpc-node-expired", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
ctx = types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname)
api.h.nodeNotifier.NotifyByNodeID(
ctx,
types.UpdateSelf(node.ID),
node.ID)
ctx = types.NotifyCtx(ctx, "cli-expirenode-peers", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdateExpire(node.ID, now), node.ID)
// TODO(kradalby): Ensure that both the selfupdate and peer updates are sent
api.h.Change(nodeChange)
log.Trace().
Str("node", node.Hostname).
@ -464,22 +431,13 @@ func (api headscaleV1APIServer) RenameNode(
ctx context.Context,
request *v1.RenameNodeRequest,
) (*v1.RenameNodeResponse, error) {
node, policyChanged, err := api.h.state.RenameNode(types.NodeID(request.GetNodeId()), request.GetNewName())
node, nodeChange, err := api.h.state.RenameNode(types.NodeID(request.GetNodeId()), request.GetNewName())
if err != nil {
return nil, err
}
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "grpc-node-renamed", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
ctx = types.NotifyCtx(ctx, "cli-renamenode-self", node.Hostname)
api.h.nodeNotifier.NotifyByNodeID(ctx, types.UpdateSelf(node.ID), node.ID)
ctx = types.NotifyCtx(ctx, "cli-renamenode-peers", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
// TODO(kradalby): investigate if we need selfupdate
api.h.Change(nodeChange)
log.Trace().
Str("node", node.Hostname).
@ -498,7 +456,7 @@ func (api headscaleV1APIServer) ListNodes(
// probably be done once.
// TODO(kradalby): This should be done in one tx.
isLikelyConnected := api.h.nodeNotifier.LikelyConnectedMap()
IsConnected := api.h.mapBatcher.ConnectedMap()
if request.GetUser() != "" {
user, err := api.h.state.GetUserByName(request.GetUser())
if err != nil {
@ -510,7 +468,7 @@ func (api headscaleV1APIServer) ListNodes(
return nil, err
}
response := nodesToProto(api.h.state, isLikelyConnected, nodes)
response := nodesToProto(api.h.state, IsConnected, nodes)
return &v1.ListNodesResponse{Nodes: response}, nil
}
@ -523,18 +481,18 @@ func (api headscaleV1APIServer) ListNodes(
return nodes[i].ID < nodes[j].ID
})
response := nodesToProto(api.h.state, isLikelyConnected, nodes)
response := nodesToProto(api.h.state, IsConnected, nodes)
return &v1.ListNodesResponse{Nodes: response}, nil
}
func nodesToProto(state *state.State, isLikelyConnected *xsync.MapOf[types.NodeID, bool], nodes types.Nodes) []*v1.Node {
func nodesToProto(state *state.State, IsConnected *xsync.MapOf[types.NodeID, bool], nodes types.Nodes) []*v1.Node {
response := make([]*v1.Node, len(nodes))
for index, node := range nodes {
resp := node.Proto()
// Populate the online field based on
// currently connected nodes.
if val, ok := isLikelyConnected.Load(node.ID); ok && val {
if val, ok := IsConnected.Load(node.ID); ok && val {
resp.Online = true
}
@ -556,24 +514,14 @@ func (api headscaleV1APIServer) MoveNode(
ctx context.Context,
request *v1.MoveNodeRequest,
) (*v1.MoveNodeResponse, error) {
node, policyChanged, err := api.h.state.AssignNodeToUser(types.NodeID(request.GetNodeId()), types.UserID(request.GetUser()))
node, nodeChange, err := api.h.state.AssignNodeToUser(types.NodeID(request.GetNodeId()), types.UserID(request.GetUser()))
if err != nil {
return nil, err
}
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "grpc-node-moved", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
ctx = types.NotifyCtx(ctx, "cli-movenode-self", node.Hostname)
api.h.nodeNotifier.NotifyByNodeID(
ctx,
types.UpdateSelf(node.ID),
node.ID)
ctx = types.NotifyCtx(ctx, "cli-movenode", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
// TODO(kradalby): Ensure the policy is also sent
// TODO(kradalby): ensure that both the selfupdate and peer updates are sent
api.h.Change(nodeChange)
return &v1.MoveNodeResponse{Node: node.Proto()}, nil
}
@ -754,8 +702,7 @@ func (api headscaleV1APIServer) SetPolicy(
return nil, err
}
ctx := types.NotifyCtx(context.Background(), "acl-update", "na")
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
api.h.Change(change.PolicyChange())
}
response := &v1.SetPolicyResponse{

155
hscontrol/mapper/batcher.go Normal file
View File

@ -0,0 +1,155 @@
package mapper
import (
"fmt"
"time"
"github.com/juanfont/headscale/hscontrol/state"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"github.com/puzpuzpuz/xsync/v4"
"tailscale.com/tailcfg"
"tailscale.com/types/ptr"
)
type batcherFunc func(cfg *types.Config, state *state.State) Batcher
// Batcher defines the common interface for all batcher implementations.
type Batcher interface {
Start()
Close()
AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error
RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool)
IsConnected(id types.NodeID) bool
ConnectedMap() *xsync.Map[types.NodeID, bool]
AddWork(c change.ChangeSet)
MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error)
}
func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeBatcher {
return &LockFreeBatcher{
mapper: mapper,
workers: workers,
tick: time.NewTicker(batchTime),
// The size of this channel is arbitrary chosen, the sizing should be revisited.
workCh: make(chan work, workers*200),
nodes: xsync.NewMap[types.NodeID, *nodeConn](),
connected: xsync.NewMap[types.NodeID, *time.Time](),
pendingChanges: xsync.NewMap[types.NodeID, []change.ChangeSet](),
}
}
// NewBatcherAndMapper creates a Batcher implementation.
func NewBatcherAndMapper(cfg *types.Config, state *state.State) Batcher {
m := newMapper(cfg, state)
b := NewBatcher(cfg.Tuning.BatchChangeDelay, cfg.Tuning.BatcherWorkers, m)
m.batcher = b
return b
}
// nodeConnection interface for different connection implementations.
type nodeConnection interface {
nodeID() types.NodeID
version() tailcfg.CapabilityVersion
send(data *tailcfg.MapResponse) error
}
// generateMapResponse generates a [tailcfg.MapResponse] for the given NodeID that is based on the provided [change.ChangeSet].
func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion, mapper *mapper, c change.ChangeSet) (*tailcfg.MapResponse, error) {
if c.Empty() {
return nil, nil
}
// Validate inputs before processing
if nodeID == 0 {
return nil, fmt.Errorf("invalid nodeID: %d", nodeID)
}
if mapper == nil {
return nil, fmt.Errorf("mapper is nil for nodeID %d", nodeID)
}
var mapResp *tailcfg.MapResponse
var err error
switch c.Change {
case change.DERP:
mapResp, err = mapper.derpMapResponse(nodeID)
case change.NodeCameOnline, change.NodeWentOffline:
if c.IsSubnetRouter {
// TODO(kradalby): This can potentially be a peer update of the old and new subnet router.
mapResp, err = mapper.fullMapResponse(nodeID, version)
} else {
mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{
{
NodeID: c.NodeID.NodeID(),
Online: ptr.To(c.Change == change.NodeCameOnline),
},
})
}
case change.NodeNewOrUpdate:
mapResp, err = mapper.fullMapResponse(nodeID, version)
case change.NodeRemove:
mapResp, err = mapper.peerRemovedResponse(nodeID, c.NodeID)
default:
// The following will always hit this:
// change.Full, change.Policy
mapResp, err = mapper.fullMapResponse(nodeID, version)
}
if err != nil {
return nil, fmt.Errorf("generating map response for nodeID %d: %w", nodeID, err)
}
// TODO(kradalby): Is this necessary?
// Validate the generated map response - only check for nil response
// Note: mapResp.Node can be nil for peer updates, which is valid
if mapResp == nil && c.Change != change.DERP && c.Change != change.NodeRemove {
return nil, fmt.Errorf("generated nil map response for nodeID %d change %s", nodeID, c.Change.String())
}
return mapResp, nil
}
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.ChangeSet].
func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) error {
if nc == nil {
return fmt.Errorf("nodeConnection is nil")
}
nodeID := nc.nodeID()
data, err := generateMapResponse(nodeID, nc.version(), mapper, c)
if err != nil {
return fmt.Errorf("generating map response for node %d: %w", nodeID, err)
}
if data == nil {
// No data to send is valid for some change types
return nil
}
// Send the map response
if err := nc.send(data); err != nil {
return fmt.Errorf("sending map response to node %d: %w", nodeID, err)
}
return nil
}
// workResult represents the result of processing a change.
type workResult struct {
mapResponse *tailcfg.MapResponse
err error
}
// work represents a unit of work to be processed by workers.
type work struct {
c change.ChangeSet
nodeID types.NodeID
resultCh chan<- workResult // optional channel for synchronous operations
}

View File

@ -0,0 +1,491 @@
package mapper
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"github.com/puzpuzpuz/xsync/v4"
"github.com/rs/zerolog/log"
"tailscale.com/tailcfg"
"tailscale.com/types/ptr"
)
// LockFreeBatcher uses atomic operations and concurrent maps to eliminate mutex contention.
type LockFreeBatcher struct {
tick *time.Ticker
mapper *mapper
workers int
// Lock-free concurrent maps
nodes *xsync.Map[types.NodeID, *nodeConn]
connected *xsync.Map[types.NodeID, *time.Time]
// Work queue channel
workCh chan work
ctx context.Context
cancel context.CancelFunc
// Batching state
pendingChanges *xsync.Map[types.NodeID, []change.ChangeSet]
batchMutex sync.RWMutex
// Metrics
totalNodes atomic.Int64
totalUpdates atomic.Int64
workQueuedCount atomic.Int64
workProcessed atomic.Int64
workErrors atomic.Int64
}
// AddNode registers a new node connection with the batcher and sends an initial map response.
// It creates or updates the node's connection data, validates the initial map generation,
// and notifies other nodes that this node has come online.
// TODO(kradalby): See if we can move the isRouter argument somewhere else.
func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error {
// First validate that we can generate initial map before doing anything else
fullSelfChange := change.FullSelf(id)
// TODO(kradalby): This should not be generated here, but rather in MapResponseFromChange.
// This currently means that the goroutine for the node connection will do the processing
// which means that we might have uncontrolled concurrency.
// When we use MapResponseFromChange, it will be processed by the same worker pool, causing
// it to be processed in a more controlled manner.
initialMap, err := generateMapResponse(id, version, b.mapper, fullSelfChange)
if err != nil {
return fmt.Errorf("failed to generate initial map for node %d: %w", id, err)
}
// Only after validation succeeds, create or update node connection
newConn := newNodeConn(id, c, version, b.mapper)
var conn *nodeConn
if existing, loaded := b.nodes.LoadOrStore(id, newConn); loaded {
// Update existing connection
existing.updateConnection(c, version)
conn = existing
} else {
b.totalNodes.Add(1)
conn = newConn
}
// Mark as connected only after validation succeeds
b.connected.Store(id, nil) // nil = connected
log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node connected to batcher")
// Send the validated initial map
if initialMap != nil {
if err := conn.send(initialMap); err != nil {
// Clean up the connection state on send failure
b.nodes.Delete(id)
b.connected.Delete(id)
return fmt.Errorf("failed to send initial map to node %d: %w", id, err)
}
// Notify other nodes that this node came online
b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeCameOnline, IsSubnetRouter: isRouter})
}
return nil
}
// RemoveNode disconnects a node from the batcher, marking it as offline and cleaning up its state.
// It validates the connection channel matches the current one, closes the connection,
// and notifies other nodes that this node has gone offline.
func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool) {
// Check if this is the current connection and mark it as closed
if existing, ok := b.nodes.Load(id); ok {
if !existing.matchesChannel(c) {
log.Debug().Uint64("node.id", id.Uint64()).Msg("RemoveNode called for non-current connection, ignoring")
return // Not the current connection, not an error
}
// Mark the connection as closed to prevent further sends
if connData := existing.connData.Load(); connData != nil {
connData.closed.Store(true)
}
}
log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node disconnected from batcher, marking as offline")
// Remove node and mark disconnected atomically
b.nodes.Delete(id)
b.connected.Store(id, ptr.To(time.Now()))
b.totalNodes.Add(-1)
// Notify other nodes that this node went offline
b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeWentOffline, IsSubnetRouter: isRouter})
}
// AddWork queues a change to be processed by the batcher.
// Critical changes are processed immediately, while others are batched for efficiency.
func (b *LockFreeBatcher) AddWork(c change.ChangeSet) {
b.addWork(c)
}
func (b *LockFreeBatcher) Start() {
b.ctx, b.cancel = context.WithCancel(context.Background())
go b.doWork()
}
func (b *LockFreeBatcher) Close() {
if b.cancel != nil {
b.cancel()
}
close(b.workCh)
}
func (b *LockFreeBatcher) doWork() {
log.Debug().Msg("batcher doWork loop started")
defer log.Debug().Msg("batcher doWork loop stopped")
for i := range b.workers {
go b.worker(i + 1)
}
for {
select {
case <-b.tick.C:
// Process batched changes
b.processBatchedChanges()
case <-b.ctx.Done():
return
}
}
}
func (b *LockFreeBatcher) worker(workerID int) {
log.Debug().Int("workerID", workerID).Msg("batcher worker started")
defer log.Debug().Int("workerID", workerID).Msg("batcher worker stopped")
for {
select {
case w, ok := <-b.workCh:
if !ok {
return
}
startTime := time.Now()
b.workProcessed.Add(1)
// If the resultCh is set, it means that this is a work request
// where there is a blocking function waiting for the map that
// is being generated.
// This is used for synchronous map generation.
if w.resultCh != nil {
var result workResult
if nc, exists := b.nodes.Load(w.nodeID); exists {
result.mapResponse, result.err = generateMapResponse(nc.nodeID(), nc.version(), b.mapper, w.c)
if result.err != nil {
b.workErrors.Add(1)
log.Error().Err(result.err).
Int("workerID", workerID).
Uint64("node.id", w.nodeID.Uint64()).
Str("change", w.c.Change.String()).
Msg("failed to generate map response for synchronous work")
}
} else {
result.err = fmt.Errorf("node %d not found", w.nodeID)
b.workErrors.Add(1)
log.Error().Err(result.err).
Int("workerID", workerID).
Uint64("node.id", w.nodeID.Uint64()).
Msg("node not found for synchronous work")
}
// Send result
select {
case w.resultCh <- result:
case <-b.ctx.Done():
return
}
duration := time.Since(startTime)
if duration > 100*time.Millisecond {
log.Warn().
Int("workerID", workerID).
Uint64("node.id", w.nodeID.Uint64()).
Str("change", w.c.Change.String()).
Dur("duration", duration).
Msg("slow synchronous work processing")
}
continue
}
// If resultCh is nil, this is an asynchronous work request
// that should be processed and sent to the node instead of
// returned to the caller.
if nc, exists := b.nodes.Load(w.nodeID); exists {
// Check if this connection is still active before processing
if connData := nc.connData.Load(); connData != nil && connData.closed.Load() {
log.Debug().
Int("workerID", workerID).
Uint64("node.id", w.nodeID.Uint64()).
Str("change", w.c.Change.String()).
Msg("skipping work for closed connection")
continue
}
err := nc.change(w.c)
if err != nil {
b.workErrors.Add(1)
log.Error().Err(err).
Int("workerID", workerID).
Uint64("node.id", w.c.NodeID.Uint64()).
Str("change", w.c.Change.String()).
Msg("failed to apply change")
}
} else {
log.Debug().
Int("workerID", workerID).
Uint64("node.id", w.nodeID.Uint64()).
Str("change", w.c.Change.String()).
Msg("node not found for asynchronous work - node may have disconnected")
}
duration := time.Since(startTime)
if duration > 100*time.Millisecond {
log.Warn().
Int("workerID", workerID).
Uint64("node.id", w.nodeID.Uint64()).
Str("change", w.c.Change.String()).
Dur("duration", duration).
Msg("slow asynchronous work processing")
}
case <-b.ctx.Done():
return
}
}
}
func (b *LockFreeBatcher) addWork(c change.ChangeSet) {
// For critical changes that need immediate processing, send directly
if b.shouldProcessImmediately(c) {
if c.SelfUpdateOnly {
b.queueWork(work{c: c, nodeID: c.NodeID, resultCh: nil})
return
}
b.nodes.Range(func(nodeID types.NodeID, _ *nodeConn) bool {
if c.NodeID == nodeID && !c.AlsoSelf() {
return true
}
b.queueWork(work{c: c, nodeID: nodeID, resultCh: nil})
return true
})
return
}
// For non-critical changes, add to batch
b.addToBatch(c)
}
// queueWork safely queues work
func (b *LockFreeBatcher) queueWork(w work) {
b.workQueuedCount.Add(1)
select {
case b.workCh <- w:
// Successfully queued
case <-b.ctx.Done():
// Batcher is shutting down
return
}
}
// shouldProcessImmediately determines if a change should bypass batching
func (b *LockFreeBatcher) shouldProcessImmediately(c change.ChangeSet) bool {
// Process these changes immediately to avoid delaying critical functionality
switch c.Change {
case change.Full, change.NodeRemove, change.NodeCameOnline, change.NodeWentOffline, change.Policy:
return true
default:
return false
}
}
// addToBatch adds a change to the pending batch
func (b *LockFreeBatcher) addToBatch(c change.ChangeSet) {
b.batchMutex.Lock()
defer b.batchMutex.Unlock()
if c.SelfUpdateOnly {
changes, _ := b.pendingChanges.LoadOrStore(c.NodeID, []change.ChangeSet{})
changes = append(changes, c)
b.pendingChanges.Store(c.NodeID, changes)
return
}
b.nodes.Range(func(nodeID types.NodeID, _ *nodeConn) bool {
if c.NodeID == nodeID && !c.AlsoSelf() {
return true
}
changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.ChangeSet{})
changes = append(changes, c)
b.pendingChanges.Store(nodeID, changes)
return true
})
}
// processBatchedChanges processes all pending batched changes
func (b *LockFreeBatcher) processBatchedChanges() {
b.batchMutex.Lock()
defer b.batchMutex.Unlock()
if b.pendingChanges == nil {
return
}
// Process all pending changes
b.pendingChanges.Range(func(nodeID types.NodeID, changes []change.ChangeSet) bool {
if len(changes) == 0 {
return true
}
// Send all batched changes for this node
for _, c := range changes {
b.queueWork(work{c: c, nodeID: nodeID, resultCh: nil})
}
// Clear the pending changes for this node
b.pendingChanges.Delete(nodeID)
return true
})
}
// IsConnected is lock-free read.
func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool {
if val, ok := b.connected.Load(id); ok {
// nil means connected
return val == nil
}
return false
}
// ConnectedMap returns a lock-free map of all connected nodes.
func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
ret := xsync.NewMap[types.NodeID, bool]()
b.connected.Range(func(id types.NodeID, val *time.Time) bool {
// nil means connected
ret.Store(id, val == nil)
return true
})
return ret
}
// MapResponseFromChange queues work to generate a map response and waits for the result.
// This allows synchronous map generation using the same worker pool.
func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error) {
resultCh := make(chan workResult, 1)
// Queue the work with a result channel using the safe queueing method
b.queueWork(work{c: c, nodeID: id, resultCh: resultCh})
// Wait for the result
select {
case result := <-resultCh:
return result.mapResponse, result.err
case <-b.ctx.Done():
return nil, fmt.Errorf("batcher shutting down while generating map response for node %d", id)
}
}
// connectionData holds the channel and connection parameters.
type connectionData struct {
c chan<- *tailcfg.MapResponse
version tailcfg.CapabilityVersion
closed atomic.Bool // Track if this connection has been closed
}
// nodeConn described the node connection and its associated data.
type nodeConn struct {
id types.NodeID
mapper *mapper
// Atomic pointer to connection data - allows lock-free updates
connData atomic.Pointer[connectionData]
updateCount atomic.Int64
}
func newNodeConn(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion, mapper *mapper) *nodeConn {
nc := &nodeConn{
id: id,
mapper: mapper,
}
// Initialize connection data
data := &connectionData{
c: c,
version: version,
}
nc.connData.Store(data)
return nc
}
// updateConnection atomically updates connection parameters.
func (nc *nodeConn) updateConnection(c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) {
newData := &connectionData{
c: c,
version: version,
}
nc.connData.Store(newData)
}
// matchesChannel checks if the given channel matches current connection.
func (nc *nodeConn) matchesChannel(c chan<- *tailcfg.MapResponse) bool {
data := nc.connData.Load()
if data == nil {
return false
}
// Compare channel pointers directly
return data.c == c
}
// compressAndVersion atomically reads connection settings.
func (nc *nodeConn) version() tailcfg.CapabilityVersion {
data := nc.connData.Load()
if data == nil {
return 0
}
return data.version
}
func (nc *nodeConn) nodeID() types.NodeID {
return nc.id
}
func (nc *nodeConn) change(c change.ChangeSet) error {
return handleNodeChange(nc, nc.mapper, c)
}
// send sends data to the node's channel.
// The node will pick it up and send it to the HTTP handler.
func (nc *nodeConn) send(data *tailcfg.MapResponse) error {
connData := nc.connData.Load()
if connData == nil {
return fmt.Errorf("node %d: no connection data", nc.id)
}
// Check if connection has been closed
if connData.closed.Load() {
return fmt.Errorf("node %d: connection closed", nc.id)
}
// TODO(kradalby): We might need some sort of timeout here if the client is not reading
// the channel. That might mean that we are sending to a node that has gone offline, but
// the channel is still open.
connData.c <- data
nc.updateCount.Add(1)
return nil
}

File diff suppressed because it is too large Load Diff

259
hscontrol/mapper/builder.go Normal file
View File

@ -0,0 +1,259 @@
package mapper
import (
"net/netip"
"sort"
"time"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"tailscale.com/tailcfg"
"tailscale.com/types/views"
"tailscale.com/util/multierr"
)
// MapResponseBuilder provides a fluent interface for building tailcfg.MapResponse
type MapResponseBuilder struct {
resp *tailcfg.MapResponse
mapper *mapper
nodeID types.NodeID
capVer tailcfg.CapabilityVersion
errs []error
}
// NewMapResponseBuilder creates a new builder with basic fields set
func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder {
now := time.Now()
return &MapResponseBuilder{
resp: &tailcfg.MapResponse{
KeepAlive: false,
ControlTime: &now,
},
mapper: m,
nodeID: nodeID,
errs: nil,
}
}
// addError adds an error to the builder's error list
func (b *MapResponseBuilder) addError(err error) {
if err != nil {
b.errs = append(b.errs, err)
}
}
// hasErrors returns true if the builder has accumulated any errors
func (b *MapResponseBuilder) hasErrors() bool {
return len(b.errs) > 0
}
// WithCapabilityVersion sets the capability version for the response
func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVersion) *MapResponseBuilder {
b.capVer = capVer
return b
}
// WithSelfNode adds the requesting node to the response
func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
node, err := b.mapper.state.GetNodeByID(b.nodeID)
if err != nil {
b.addError(err)
return b
}
_, matchers := b.mapper.state.Filter()
tailnode, err := tailNode(
node.View(), b.capVer, b.mapper.state,
func(id types.NodeID) []netip.Prefix {
return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers)
},
b.mapper.cfg)
if err != nil {
b.addError(err)
return b
}
b.resp.Node = tailnode
return b
}
// WithDERPMap adds the DERP map to the response
func (b *MapResponseBuilder) WithDERPMap() *MapResponseBuilder {
b.resp.DERPMap = b.mapper.state.DERPMap()
return b
}
// WithDomain adds the domain configuration
func (b *MapResponseBuilder) WithDomain() *MapResponseBuilder {
b.resp.Domain = b.mapper.cfg.Domain()
return b
}
// WithCollectServicesDisabled sets the collect services flag to false
func (b *MapResponseBuilder) WithCollectServicesDisabled() *MapResponseBuilder {
b.resp.CollectServices.Set(false)
return b
}
// WithDebugConfig adds debug configuration
// It disables log tailing if the mapper's LogTail is not enabled
func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
b.resp.Debug = &tailcfg.Debug{
DisableLogTail: !b.mapper.cfg.LogTail.Enabled,
}
return b
}
// WithSSHPolicy adds SSH policy configuration for the requesting node
func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
node, err := b.mapper.state.GetNodeByID(b.nodeID)
if err != nil {
b.addError(err)
return b
}
sshPolicy, err := b.mapper.state.SSHPolicy(node.View())
if err != nil {
b.addError(err)
return b
}
b.resp.SSHPolicy = sshPolicy
return b
}
// WithDNSConfig adds DNS configuration for the requesting node
func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
node, err := b.mapper.state.GetNodeByID(b.nodeID)
if err != nil {
b.addError(err)
return b
}
b.resp.DNSConfig = generateDNSConfig(b.mapper.cfg, node)
return b
}
// WithUserProfiles adds user profiles for the requesting node and given peers
func (b *MapResponseBuilder) WithUserProfiles(peers types.Nodes) *MapResponseBuilder {
node, err := b.mapper.state.GetNodeByID(b.nodeID)
if err != nil {
b.addError(err)
return b
}
b.resp.UserProfiles = generateUserProfiles(node, peers)
return b
}
// WithPacketFilters adds packet filter rules based on policy
func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
node, err := b.mapper.state.GetNodeByID(b.nodeID)
if err != nil {
b.addError(err)
return b
}
filter, _ := b.mapper.state.Filter()
// CapVer 81: 2023-11-17: MapResponse.PacketFilters (incremental packet filter updates)
// Currently, we do not send incremental package filters, however using the
// new PacketFilters field and "base" allows us to send a full update when we
// have to send an empty list, avoiding the hack in the else block.
b.resp.PacketFilters = map[string][]tailcfg.FilterRule{
"base": policy.ReduceFilterRules(node.View(), filter),
}
return b
}
// WithPeers adds full peer list with policy filtering (for full map response)
func (b *MapResponseBuilder) WithPeers(peers types.Nodes) *MapResponseBuilder {
tailPeers, err := b.buildTailPeers(peers)
if err != nil {
b.addError(err)
return b
}
b.resp.Peers = tailPeers
return b
}
// WithPeerChanges adds changed peers with policy filtering (for incremental updates)
func (b *MapResponseBuilder) WithPeerChanges(peers types.Nodes) *MapResponseBuilder {
tailPeers, err := b.buildTailPeers(peers)
if err != nil {
b.addError(err)
return b
}
b.resp.PeersChanged = tailPeers
return b
}
// buildTailPeers converts types.Nodes to []tailcfg.Node with policy filtering and sorting
func (b *MapResponseBuilder) buildTailPeers(peers types.Nodes) ([]*tailcfg.Node, error) {
node, err := b.mapper.state.GetNodeByID(b.nodeID)
if err != nil {
return nil, err
}
filter, matchers := b.mapper.state.Filter()
// If there are filter rules present, see if there are any nodes that cannot
// access each-other at all and remove them from the peers.
var changedViews views.Slice[types.NodeView]
if len(filter) > 0 {
changedViews = policy.ReduceNodes(node.View(), peers.ViewSlice(), matchers)
} else {
changedViews = peers.ViewSlice()
}
tailPeers, err := tailNodes(
changedViews, b.capVer, b.mapper.state,
func(id types.NodeID) []netip.Prefix {
return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers)
},
b.mapper.cfg)
if err != nil {
return nil, err
}
// Peers is always returned sorted by Node.ID.
sort.SliceStable(tailPeers, func(x, y int) bool {
return tailPeers[x].ID < tailPeers[y].ID
})
return tailPeers, nil
}
// WithPeerChangedPatch adds peer change patches
func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange) *MapResponseBuilder {
b.resp.PeersChangedPatch = changes
return b
}
// WithPeersRemoved adds removed peer IDs
func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder {
var tailscaleIDs []tailcfg.NodeID
for _, id := range removedIDs {
tailscaleIDs = append(tailscaleIDs, id.NodeID())
}
b.resp.PeersRemoved = tailscaleIDs
return b
}
// Build finalizes the response and returns marshaled bytes
func (b *MapResponseBuilder) Build(messages ...string) (*tailcfg.MapResponse, error) {
if len(b.errs) > 0 {
return nil, multierr.New(b.errs...)
}
if debugDumpMapResponsePath != "" {
writeDebugMapResponse(b.resp, b.nodeID)
}
return b.resp, nil
}

View File

@ -0,0 +1,347 @@
package mapper
import (
"testing"
"time"
"github.com/juanfont/headscale/hscontrol/state"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tailscale.com/tailcfg"
)
func TestMapResponseBuilder_Basic(t *testing.T) {
cfg := &types.Config{
BaseDomain: "example.com",
LogTail: types.LogTailConfig{
Enabled: true,
},
}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
builder := m.NewMapResponseBuilder(nodeID)
// Test basic builder creation
assert.NotNil(t, builder)
assert.Equal(t, nodeID, builder.nodeID)
assert.NotNil(t, builder.resp)
assert.False(t, builder.resp.KeepAlive)
assert.NotNil(t, builder.resp.ControlTime)
assert.WithinDuration(t, time.Now(), *builder.resp.ControlTime, time.Second)
}
func TestMapResponseBuilder_WithCapabilityVersion(t *testing.T) {
cfg := &types.Config{}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
capVer := tailcfg.CapabilityVersion(42)
builder := m.NewMapResponseBuilder(nodeID).
WithCapabilityVersion(capVer)
assert.Equal(t, capVer, builder.capVer)
assert.False(t, builder.hasErrors())
}
func TestMapResponseBuilder_WithDomain(t *testing.T) {
domain := "test.example.com"
cfg := &types.Config{
ServerURL: "https://test.example.com",
BaseDomain: domain,
}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
builder := m.NewMapResponseBuilder(nodeID).
WithDomain()
assert.Equal(t, domain, builder.resp.Domain)
assert.False(t, builder.hasErrors())
}
func TestMapResponseBuilder_WithCollectServicesDisabled(t *testing.T) {
cfg := &types.Config{}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
builder := m.NewMapResponseBuilder(nodeID).
WithCollectServicesDisabled()
value, isSet := builder.resp.CollectServices.Get()
assert.True(t, isSet)
assert.False(t, value)
assert.False(t, builder.hasErrors())
}
func TestMapResponseBuilder_WithDebugConfig(t *testing.T) {
tests := []struct {
name string
logTailEnabled bool
expected bool
}{
{
name: "LogTail enabled",
logTailEnabled: true,
expected: false, // DisableLogTail should be false when LogTail is enabled
},
{
name: "LogTail disabled",
logTailEnabled: false,
expected: true, // DisableLogTail should be true when LogTail is disabled
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &types.Config{
LogTail: types.LogTailConfig{
Enabled: tt.logTailEnabled,
},
}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
builder := m.NewMapResponseBuilder(nodeID).
WithDebugConfig()
require.NotNil(t, builder.resp.Debug)
assert.Equal(t, tt.expected, builder.resp.Debug.DisableLogTail)
assert.False(t, builder.hasErrors())
})
}
}
func TestMapResponseBuilder_WithPeerChangedPatch(t *testing.T) {
cfg := &types.Config{}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
changes := []*tailcfg.PeerChange{
{
NodeID: 123,
DERPRegion: 1,
},
{
NodeID: 456,
DERPRegion: 2,
},
}
builder := m.NewMapResponseBuilder(nodeID).
WithPeerChangedPatch(changes)
assert.Equal(t, changes, builder.resp.PeersChangedPatch)
assert.False(t, builder.hasErrors())
}
func TestMapResponseBuilder_WithPeersRemoved(t *testing.T) {
cfg := &types.Config{}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
removedID1 := types.NodeID(123)
removedID2 := types.NodeID(456)
builder := m.NewMapResponseBuilder(nodeID).
WithPeersRemoved(removedID1, removedID2)
expected := []tailcfg.NodeID{
removedID1.NodeID(),
removedID2.NodeID(),
}
assert.Equal(t, expected, builder.resp.PeersRemoved)
assert.False(t, builder.hasErrors())
}
func TestMapResponseBuilder_ErrorHandling(t *testing.T) {
cfg := &types.Config{}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
// Simulate an error in the builder
builder := m.NewMapResponseBuilder(nodeID)
builder.addError(assert.AnError)
// All subsequent calls should continue to work and accumulate errors
result := builder.
WithDomain().
WithCollectServicesDisabled().
WithDebugConfig()
assert.True(t, result.hasErrors())
assert.Len(t, result.errs, 1)
assert.Equal(t, assert.AnError, result.errs[0])
// Build should return the error
data, err := result.Build("none")
assert.Nil(t, data)
assert.Error(t, err)
}
func TestMapResponseBuilder_ChainedCalls(t *testing.T) {
domain := "chained.example.com"
cfg := &types.Config{
ServerURL: "https://chained.example.com",
BaseDomain: domain,
LogTail: types.LogTailConfig{
Enabled: false,
},
}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
capVer := tailcfg.CapabilityVersion(99)
builder := m.NewMapResponseBuilder(nodeID).
WithCapabilityVersion(capVer).
WithDomain().
WithCollectServicesDisabled().
WithDebugConfig()
// Verify all fields are set correctly
assert.Equal(t, capVer, builder.capVer)
assert.Equal(t, domain, builder.resp.Domain)
value, isSet := builder.resp.CollectServices.Get()
assert.True(t, isSet)
assert.False(t, value)
assert.NotNil(t, builder.resp.Debug)
assert.True(t, builder.resp.Debug.DisableLogTail)
assert.False(t, builder.hasErrors())
}
func TestMapResponseBuilder_MultipleWithPeersRemoved(t *testing.T) {
cfg := &types.Config{}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
removedID1 := types.NodeID(100)
removedID2 := types.NodeID(200)
// Test calling WithPeersRemoved multiple times
builder := m.NewMapResponseBuilder(nodeID).
WithPeersRemoved(removedID1).
WithPeersRemoved(removedID2)
// Second call should overwrite the first
expected := []tailcfg.NodeID{removedID2.NodeID()}
assert.Equal(t, expected, builder.resp.PeersRemoved)
assert.False(t, builder.hasErrors())
}
func TestMapResponseBuilder_EmptyPeerChangedPatch(t *testing.T) {
cfg := &types.Config{}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
builder := m.NewMapResponseBuilder(nodeID).
WithPeerChangedPatch([]*tailcfg.PeerChange{})
assert.Empty(t, builder.resp.PeersChangedPatch)
assert.False(t, builder.hasErrors())
}
func TestMapResponseBuilder_NilPeerChangedPatch(t *testing.T) {
cfg := &types.Config{}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
builder := m.NewMapResponseBuilder(nodeID).
WithPeerChangedPatch(nil)
assert.Nil(t, builder.resp.PeersChangedPatch)
assert.False(t, builder.hasErrors())
}
func TestMapResponseBuilder_MultipleErrors(t *testing.T) {
cfg := &types.Config{}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
// Create a builder and add multiple errors
builder := m.NewMapResponseBuilder(nodeID)
builder.addError(assert.AnError)
builder.addError(assert.AnError)
builder.addError(nil) // This should be ignored
// All subsequent calls should continue to work
result := builder.
WithDomain().
WithCollectServicesDisabled()
assert.True(t, result.hasErrors())
assert.Len(t, result.errs, 2) // nil error should be ignored
// Build should return a multierr
data, err := result.Build("none")
assert.Nil(t, data)
assert.Error(t, err)
// The error should contain information about multiple errors
assert.Contains(t, err.Error(), "multiple errors")
}

View File

@ -1,7 +1,6 @@
package mapper
import (
"encoding/binary"
"encoding/json"
"fmt"
"io/fs"
@ -10,29 +9,19 @@ import (
"os"
"path"
"slices"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/state"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/klauspost/compress/zstd"
"github.com/rs/zerolog/log"
"tailscale.com/envknob"
"tailscale.com/smallzstd"
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
"tailscale.com/types/views"
)
const (
nextDNSDoHPrefix = "https://dns.nextdns.io"
reservedResponseHeaderSize = 4
mapperIDLength = 8
debugMapResponsePerm = 0o755
)
@ -50,15 +39,13 @@ var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_
// - Create a "minifier" that removes info not needed for the node
// - some sort of batching, wait for 5 or 60 seconds before sending
type Mapper struct {
type mapper struct {
// Configuration
state *state.State
cfg *types.Config
notif *notifier.Notifier
batcher Batcher
uid string
created time.Time
seq uint64
}
type patch struct {
@ -66,41 +53,31 @@ type patch struct {
change *tailcfg.PeerChange
}
func NewMapper(
state *state.State,
func newMapper(
cfg *types.Config,
notif *notifier.Notifier,
) *Mapper {
uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
state *state.State,
) *mapper {
// uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
return &Mapper{
return &mapper{
state: state,
cfg: cfg,
notif: notif,
uid: uid,
created: time.Now(),
seq: 0,
}
}
func (m *Mapper) String() string {
return fmt.Sprintf("Mapper: { seq: %d, uid: %s, created: %s }", m.seq, m.uid, m.created)
}
func generateUserProfiles(
node types.NodeView,
peers views.Slice[types.NodeView],
node *types.Node,
peers types.Nodes,
) []tailcfg.UserProfile {
userMap := make(map[uint]*types.User)
ids := make([]uint, 0, peers.Len()+1)
user := node.User()
userMap[user.ID] = &user
ids = append(ids, user.ID)
for _, peer := range peers.All() {
peerUser := peer.User()
userMap[peerUser.ID] = &peerUser
ids = append(ids, peerUser.ID)
ids := make([]uint, 0, len(userMap))
userMap[node.User.ID] = &node.User
ids = append(ids, node.User.ID)
for _, peer := range peers {
userMap[peer.User.ID] = &peer.User
ids = append(ids, peer.User.ID)
}
slices.Sort(ids)
@ -117,7 +94,7 @@ func generateUserProfiles(
func generateDNSConfig(
cfg *types.Config,
node types.NodeView,
node *types.Node,
) *tailcfg.DNSConfig {
if cfg.TailcfgDNSConfig == nil {
return nil
@ -137,17 +114,16 @@ func generateDNSConfig(
//
// This will produce a resolver like:
// `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1`
func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
for _, resolver := range resolvers {
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {
attrs := url.Values{
"device_name": []string{node.Hostname()},
"device_model": []string{node.Hostinfo().OS()},
"device_name": []string{node.Hostname},
"device_model": []string{node.Hostinfo.OS},
}
nodeIPs := node.IPs()
if len(nodeIPs) > 0 {
attrs.Add("device_ip", nodeIPs[0].String())
if len(node.IPs()) > 0 {
attrs.Add("device_ip", node.IPs()[0].String())
}
resolver.Addr = fmt.Sprintf("%s?%s", resolver.Addr, attrs.Encode())
@ -155,222 +131,112 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
}
}
// fullMapResponse creates a complete MapResponse for a node.
// It is a separate function to make testing easier.
func (m *Mapper) fullMapResponse(
node types.NodeView,
peers views.Slice[types.NodeView],
// fullMapResponse returns a MapResponse for the given node.
func (m *mapper) fullMapResponse(
nodeID types.NodeID,
capVer tailcfg.CapabilityVersion,
messages ...string,
) (*tailcfg.MapResponse, error) {
resp, err := m.baseWithConfigMapResponse(node, capVer)
peers, err := m.listPeers(nodeID)
if err != nil {
return nil, err
}
err = appendPeerChanges(
resp,
true, // full change
m.state,
node,
capVer,
peers,
m.cfg,
)
if err != nil {
return nil, err
}
return resp, nil
return m.NewMapResponseBuilder(nodeID).
WithCapabilityVersion(capVer).
WithSelfNode().
WithDERPMap().
WithDomain().
WithCollectServicesDisabled().
WithDebugConfig().
WithSSHPolicy().
WithDNSConfig().
WithUserProfiles(peers).
WithPacketFilters().
WithPeers(peers).
Build(messages...)
}
// FullMapResponse returns a MapResponse for the given node.
func (m *Mapper) FullMapResponse(
mapRequest tailcfg.MapRequest,
node types.NodeView,
messages ...string,
) ([]byte, error) {
peers, err := m.ListPeers(node.ID())
if err != nil {
return nil, err
}
resp, err := m.fullMapResponse(node, peers.ViewSlice(), mapRequest.Version)
if err != nil {
return nil, err
}
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
}
// ReadOnlyMapResponse returns a MapResponse for the given node.
// Lite means that the peers has been omitted, this is intended
// to be used to answer MapRequests with OmitPeers set to true.
func (m *Mapper) ReadOnlyMapResponse(
mapRequest tailcfg.MapRequest,
node types.NodeView,
messages ...string,
) ([]byte, error) {
resp, err := m.baseWithConfigMapResponse(node, mapRequest.Version)
if err != nil {
return nil, err
}
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
}
func (m *Mapper) KeepAliveResponse(
mapRequest tailcfg.MapRequest,
node types.NodeView,
) ([]byte, error) {
resp := m.baseMapResponse()
resp.KeepAlive = true
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
}
func (m *Mapper) DERPMapResponse(
mapRequest tailcfg.MapRequest,
node types.NodeView,
derpMap *tailcfg.DERPMap,
) ([]byte, error) {
resp := m.baseMapResponse()
resp.DERPMap = derpMap
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
}
func (m *Mapper) PeerChangedResponse(
mapRequest tailcfg.MapRequest,
node types.NodeView,
changed map[types.NodeID]bool,
patches []*tailcfg.PeerChange,
messages ...string,
) ([]byte, error) {
var err error
resp := m.baseMapResponse()
var removedIDs []tailcfg.NodeID
var changedIDs []types.NodeID
for nodeID, nodeChanged := range changed {
if nodeChanged {
if nodeID != node.ID() {
changedIDs = append(changedIDs, nodeID)
}
} else {
removedIDs = append(removedIDs, nodeID.NodeID())
}
}
changedNodes := types.Nodes{}
if len(changedIDs) > 0 {
changedNodes, err = m.ListNodes(changedIDs...)
if err != nil {
return nil, err
}
}
err = appendPeerChanges(
&resp,
false, // partial change
m.state,
node,
mapRequest.Version,
changedNodes.ViewSlice(),
m.cfg,
)
if err != nil {
return nil, err
}
resp.PeersRemoved = removedIDs
// Sending patches as a part of a PeersChanged response
// is technically not suppose to be done, but they are
// applied after the PeersChanged. The patch list
// should _only_ contain Nodes that are not in the
// PeersChanged or PeersRemoved list and the caller
// should filter them out.
//
// From tailcfg docs:
// These are applied after Peers* above, but in practice the
// control server should only send these on their own, without
// the Peers* fields also set.
if patches != nil {
resp.PeersChangedPatch = patches
}
_, matchers := m.state.Filter()
// Add the node itself, it might have changed, and particularly
// if there are no patches or changes, this is a self update.
tailnode, err := tailNode(
node, mapRequest.Version, m.state,
func(id types.NodeID) []netip.Prefix {
return policy.ReduceRoutes(node, m.state.GetNodePrimaryRoutes(id), matchers)
},
m.cfg)
if err != nil {
return nil, err
}
resp.Node = tailnode
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress, messages...)
func (m *mapper) derpMapResponse(
nodeID types.NodeID,
) (*tailcfg.MapResponse, error) {
return m.NewMapResponseBuilder(nodeID).
WithDERPMap().
Build()
}
// PeerChangedPatchResponse creates a patch MapResponse with
// incoming update from a state change.
func (m *Mapper) PeerChangedPatchResponse(
mapRequest tailcfg.MapRequest,
node types.NodeView,
func (m *mapper) peerChangedPatchResponse(
nodeID types.NodeID,
changed []*tailcfg.PeerChange,
) ([]byte, error) {
resp := m.baseMapResponse()
resp.PeersChangedPatch = changed
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
) (*tailcfg.MapResponse, error) {
return m.NewMapResponseBuilder(nodeID).
WithPeerChangedPatch(changed).
Build()
}
func (m *Mapper) marshalMapResponse(
mapRequest tailcfg.MapRequest,
resp *tailcfg.MapResponse,
node types.NodeView,
compression string,
messages ...string,
) ([]byte, error) {
atomic.AddUint64(&m.seq, 1)
jsonBody, err := json.Marshal(resp)
// peerChangeResponse returns a MapResponse with changed or added nodes.
func (m *mapper) peerChangeResponse(
nodeID types.NodeID,
capVer tailcfg.CapabilityVersion,
changedNodeID types.NodeID,
) (*tailcfg.MapResponse, error) {
peers, err := m.listPeers(nodeID, changedNodeID)
if err != nil {
return nil, fmt.Errorf("marshalling map response: %w", err)
return nil, err
}
if debugDumpMapResponsePath != "" {
return m.NewMapResponseBuilder(nodeID).
WithCapabilityVersion(capVer).
WithSelfNode().
WithUserProfiles(peers).
WithPeerChanges(peers).
Build()
}
// peerRemovedResponse creates a MapResponse indicating that a peer has been removed.
func (m *mapper) peerRemovedResponse(
nodeID types.NodeID,
removedNodeID types.NodeID,
) (*tailcfg.MapResponse, error) {
return m.NewMapResponseBuilder(nodeID).
WithPeersRemoved(removedNodeID).
Build()
}
func writeDebugMapResponse(
resp *tailcfg.MapResponse,
nodeID types.NodeID,
messages ...string,
) {
data := map[string]any{
"Messages": messages,
"MapRequest": mapRequest,
"MapResponse": resp,
}
responseType := "keepalive"
switch {
case resp.Peers != nil && len(resp.Peers) > 0:
case len(resp.Peers) > 0:
responseType = "full"
case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil && resp.DERPMap == nil && !resp.KeepAlive:
responseType = "self"
case resp.PeersChanged != nil && len(resp.PeersChanged) > 0:
case len(resp.PeersChanged) > 0:
responseType = "changed"
case resp.PeersChangedPatch != nil && len(resp.PeersChangedPatch) > 0:
case len(resp.PeersChangedPatch) > 0:
responseType = "patch"
case resp.PeersRemoved != nil && len(resp.PeersRemoved) > 0:
case len(resp.PeersRemoved) > 0:
responseType = "removed"
}
body, err := json.MarshalIndent(data, "", " ")
if err != nil {
return nil, fmt.Errorf("marshalling map response: %w", err)
panic(err)
}
perms := fs.FileMode(debugMapResponsePerm)
mPath := path.Join(debugDumpMapResponsePath, node.Hostname())
mPath := path.Join(debugDumpMapResponsePath, nodeID.String())
err = os.MkdirAll(mPath, perms)
if err != nil {
panic(err)
@ -380,7 +246,7 @@ func (m *Mapper) marshalMapResponse(
mapResponsePath := path.Join(
mPath,
fmt.Sprintf("%s-%s-%d-%s.json", now, m.uid, atomic.LoadUint64(&m.seq), responseType),
fmt.Sprintf("%s-%s.json", now, responseType),
)
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
@ -388,201 +254,28 @@ func (m *Mapper) marshalMapResponse(
if err != nil {
panic(err)
}
}
var respBody []byte
if compression == util.ZstdCompression {
respBody = zstdEncode(jsonBody)
} else {
respBody = jsonBody
}
data := make([]byte, reservedResponseHeaderSize)
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
data = append(data, respBody...)
return data, nil
}
func zstdEncode(in []byte) []byte {
encoder, ok := zstdEncoderPool.Get().(*zstd.Encoder)
if !ok {
panic("invalid type in sync pool")
}
out := encoder.EncodeAll(in, nil)
_ = encoder.Close()
zstdEncoderPool.Put(encoder)
return out
}
var zstdEncoderPool = &sync.Pool{
New: func() any {
encoder, err := smallzstd.NewEncoder(
nil,
zstd.WithEncoderLevel(zstd.SpeedFastest))
if err != nil {
panic(err)
}
return encoder
},
}
// baseMapResponse returns a tailcfg.MapResponse with
// KeepAlive false and ControlTime set to now.
func (m *Mapper) baseMapResponse() tailcfg.MapResponse {
now := time.Now()
resp := tailcfg.MapResponse{
KeepAlive: false,
ControlTime: &now,
// TODO(kradalby): Implement PingRequest?
}
return resp
}
// baseWithConfigMapResponse returns a tailcfg.MapResponse struct
// with the basic configuration from headscale set.
// It is used in for bigger updates, such as full and lite, not
// incremental.
func (m *Mapper) baseWithConfigMapResponse(
node types.NodeView,
capVer tailcfg.CapabilityVersion,
) (*tailcfg.MapResponse, error) {
resp := m.baseMapResponse()
_, matchers := m.state.Filter()
tailnode, err := tailNode(
node, capVer, m.state,
func(id types.NodeID) []netip.Prefix {
return policy.ReduceRoutes(node, m.state.GetNodePrimaryRoutes(id), matchers)
},
m.cfg)
if err != nil {
return nil, err
}
resp.Node = tailnode
resp.DERPMap = m.state.DERPMap()
resp.Domain = m.cfg.Domain()
// Do not instruct clients to collect services we do not
// support or do anything with them
resp.CollectServices = "false"
resp.KeepAlive = false
resp.Debug = &tailcfg.Debug{
DisableLogTail: !m.cfg.LogTail.Enabled,
}
return &resp, nil
}
// ListPeers returns peers of node, regardless of any Policy or if the node is expired.
// listPeers returns peers of node, regardless of any Policy or if the node is expired.
// If no peer IDs are given, all peers are returned.
// If at least one peer ID is given, only these peer nodes will be returned.
func (m *Mapper) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
func (m *mapper) listPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
peers, err := m.state.ListPeers(nodeID, peerIDs...)
if err != nil {
return nil, err
}
// TODO(kradalby): Add back online via batcher. This was removed
// to avoid a circular dependency between the mapper and the notification.
for _, peer := range peers {
online := m.notif.IsLikelyConnected(peer.ID)
online := m.batcher.IsConnected(peer.ID)
peer.IsOnline = &online
}
return peers, nil
}
// ListNodes queries the database for either all nodes if no parameters are given
// or for the given nodes if at least one node ID is given as parameter.
func (m *Mapper) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
nodes, err := m.state.ListNodes(nodeIDs...)
if err != nil {
return nil, err
}
for _, node := range nodes {
online := m.notif.IsLikelyConnected(node.ID)
node.IsOnline = &online
}
return nodes, nil
}
// routeFilterFunc is a function that takes a node ID and returns a list of
// netip.Prefixes that are allowed for that node. It is used to filter routes
// from the primary route manager to the node.
type routeFilterFunc func(id types.NodeID) []netip.Prefix
// appendPeerChanges mutates a tailcfg.MapResponse with all the
// necessary changes when peers have changed.
func appendPeerChanges(
resp *tailcfg.MapResponse,
fullChange bool,
state *state.State,
node types.NodeView,
capVer tailcfg.CapabilityVersion,
changed views.Slice[types.NodeView],
cfg *types.Config,
) error {
filter, matchers := state.Filter()
sshPolicy, err := state.SSHPolicy(node)
if err != nil {
return err
}
// If there are filter rules present, see if there are any nodes that cannot
// access each-other at all and remove them from the peers.
var reducedChanged views.Slice[types.NodeView]
if len(filter) > 0 {
reducedChanged = policy.ReduceNodes(node, changed, matchers)
} else {
reducedChanged = changed
}
profiles := generateUserProfiles(node, reducedChanged)
dnsConfig := generateDNSConfig(cfg, node)
tailPeers, err := tailNodes(
reducedChanged, capVer, state,
func(id types.NodeID) []netip.Prefix {
return policy.ReduceRoutes(node, state.GetNodePrimaryRoutes(id), matchers)
},
cfg)
if err != nil {
return err
}
// Peers is always returned sorted by Node.ID.
sort.SliceStable(tailPeers, func(x, y int) bool {
return tailPeers[x].ID < tailPeers[y].ID
})
if fullChange {
resp.Peers = tailPeers
} else {
resp.PeersChanged = tailPeers
}
resp.DNSConfig = dnsConfig
resp.UserProfiles = profiles
resp.SSHPolicy = sshPolicy
// CapVer 81: 2023-11-17: MapResponse.PacketFilters (incremental packet filter updates)
// Currently, we do not send incremental package filters, however using the
// new PacketFilters field and "base" allows us to send a full update when we
// have to send an empty list, avoiding the hack in the else block.
resp.PacketFilters = map[string][]tailcfg.FilterRule{
"base": policy.ReduceFilterRules(node, filter),
}
return nil
}

View File

@ -3,6 +3,7 @@ package mapper
import (
"fmt"
"net/netip"
"slices"
"testing"
"github.com/google/go-cmp/cmp"
@ -70,7 +71,7 @@ func TestDNSConfigMapResponse(t *testing.T) {
&types.Config{
TailcfgDNSConfig: &dnsConfigOrig,
},
nodeInShared1.View(),
nodeInShared1,
)
if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" {
@ -126,11 +127,8 @@ func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (typ
// Filter peers by the provided IDs
var filtered types.Nodes
for _, peer := range m.peers {
for _, id := range peerIDs {
if peer.ID == id {
if slices.Contains(peerIDs, peer.ID) {
filtered = append(filtered, peer)
break
}
}
}
@ -152,11 +150,8 @@ func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
// Filter nodes by the provided IDs
var filtered types.Nodes
for _, node := range m.nodes {
for _, id := range nodeIDs {
if node.ID == id {
if slices.Contains(nodeIDs, node.ID) {
filtered = append(filtered, node)
break
}
}
}

47
hscontrol/mapper/utils.go Normal file
View File

@ -0,0 +1,47 @@
package mapper
import "tailscale.com/tailcfg"
// mergePatch takes the current patch and a newer patch
// and override any field that has changed.
func mergePatch(currPatch, newPatch *tailcfg.PeerChange) {
if newPatch.DERPRegion != 0 {
currPatch.DERPRegion = newPatch.DERPRegion
}
if newPatch.Cap != 0 {
currPatch.Cap = newPatch.Cap
}
if newPatch.CapMap != nil {
currPatch.CapMap = newPatch.CapMap
}
if newPatch.Endpoints != nil {
currPatch.Endpoints = newPatch.Endpoints
}
if newPatch.Key != nil {
currPatch.Key = newPatch.Key
}
if newPatch.KeySignature != nil {
currPatch.KeySignature = newPatch.KeySignature
}
if newPatch.DiscoKey != nil {
currPatch.DiscoKey = newPatch.DiscoKey
}
if newPatch.Online != nil {
currPatch.Online = newPatch.Online
}
if newPatch.LastSeen != nil {
currPatch.LastSeen = newPatch.LastSeen
}
if newPatch.KeyExpiry != nil {
currPatch.KeyExpiry = newPatch.KeyExpiry
}
}

View File

@ -221,7 +221,7 @@ func (ns *noiseServer) NoisePollNetMapHandler(
ns.nodeKey = nv.NodeKey()
sess := ns.headscale.newMapSession(req.Context(), mapRequest, writer, nv)
sess := ns.headscale.newMapSession(req.Context(), mapRequest, writer, nv.AsStruct())
sess.tracef("a node sending a MapRequest with Noise protocol")
if !sess.isStreaming() {
sess.serve()
@ -279,28 +279,33 @@ func (ns *noiseServer) NoiseRegistrationHandler(
return
}
respBody, err := json.Marshal(registerResponse)
if err != nil {
httpError(writer, err)
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
if err := json.NewEncoder(writer).Encode(registerResponse); err != nil {
log.Error().Err(err).Msg("NoiseRegistrationHandler: failed to encode RegisterResponse")
return
}
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
writer.Write(respBody)
// Ensure response is flushed to client
if flusher, ok := writer.(http.Flusher); ok {
flusher.Flush()
}
}
// getAndValidateNode retrieves the node from the database using the NodeKey
// and validates that it matches the MachineKey from the Noise session.
func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types.NodeView, error) {
nv, err := ns.headscale.state.GetNodeViewByNodeKey(mapRequest.NodeKey)
node, err := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node not found", nil)
}
return types.NodeView{}, err
return types.NodeView{}, NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("lookup node: %s", err), nil)
}
nv := node.View()
// Validate that the MachineKey in the Noise session matches the one associated with the NodeKey.
if ns.machineKey != nv.MachineKey() {
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node key in request does not match the one associated with this machine key", nil)

View File

@ -1,68 +0,0 @@
package notifier
import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"tailscale.com/envknob"
)
const prometheusNamespace = "headscale"
var debugHighCardinalityMetrics = envknob.Bool("HEADSCALE_DEBUG_HIGH_CARDINALITY_METRICS")
var notifierUpdateSent *prometheus.CounterVec
func init() {
if debugHighCardinalityMetrics {
notifierUpdateSent = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace,
Name: "notifier_update_sent_total",
Help: "total count of update sent on nodes channel",
}, []string{"status", "type", "trigger", "id"})
} else {
notifierUpdateSent = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace,
Name: "notifier_update_sent_total",
Help: "total count of update sent on nodes channel",
}, []string{"status", "type", "trigger"})
}
}
var (
notifierWaitersForLock = promauto.NewGaugeVec(prometheus.GaugeOpts{
Namespace: prometheusNamespace,
Name: "notifier_waiters_for_lock",
Help: "gauge of waiters for the notifier lock",
}, []string{"type", "action"})
notifierWaitForLock = promauto.NewHistogramVec(prometheus.HistogramOpts{
Namespace: prometheusNamespace,
Name: "notifier_wait_for_lock_seconds",
Help: "histogram of time spent waiting for the notifier lock",
Buckets: []float64{0.001, 0.01, 0.1, 0.3, 0.5, 1, 3, 5, 10},
}, []string{"action"})
notifierUpdateReceived = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace,
Name: "notifier_update_received_total",
Help: "total count of updates received by notifier",
}, []string{"type", "trigger"})
notifierNodeUpdateChans = promauto.NewGauge(prometheus.GaugeOpts{
Namespace: prometheusNamespace,
Name: "notifier_open_channels_total",
Help: "total count open channels in notifier",
})
notifierBatcherWaitersForLock = promauto.NewGaugeVec(prometheus.GaugeOpts{
Namespace: prometheusNamespace,
Name: "notifier_batcher_waiters_for_lock",
Help: "gauge of waiters for the notifier batcher lock",
}, []string{"type", "action"})
notifierBatcherChanges = promauto.NewGaugeVec(prometheus.GaugeOpts{
Namespace: prometheusNamespace,
Name: "notifier_batcher_changes_pending",
Help: "gauge of full changes pending in the notifier batcher",
}, []string{})
notifierBatcherPatches = promauto.NewGaugeVec(prometheus.GaugeOpts{
Namespace: prometheusNamespace,
Name: "notifier_batcher_patches_pending",
Help: "gauge of patches pending in the notifier batcher",
}, []string{})
)

View File

@ -1,488 +0,0 @@
package notifier
import (
"context"
"fmt"
"sort"
"strings"
"sync"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/puzpuzpuz/xsync/v4"
"github.com/rs/zerolog/log"
"github.com/sasha-s/go-deadlock"
"tailscale.com/envknob"
"tailscale.com/tailcfg"
"tailscale.com/util/set"
)
var (
debugDeadlock = envknob.Bool("HEADSCALE_DEBUG_DEADLOCK")
debugDeadlockTimeout = envknob.RegisterDuration("HEADSCALE_DEBUG_DEADLOCK_TIMEOUT")
)
func init() {
deadlock.Opts.Disable = !debugDeadlock
if debugDeadlock {
deadlock.Opts.DeadlockTimeout = debugDeadlockTimeout()
deadlock.Opts.PrintAllCurrentGoroutines = true
}
}
type Notifier struct {
l deadlock.Mutex
nodes map[types.NodeID]chan<- types.StateUpdate
connected *xsync.MapOf[types.NodeID, bool]
b *batcher
cfg *types.Config
closed bool
}
func NewNotifier(cfg *types.Config) *Notifier {
n := &Notifier{
nodes: make(map[types.NodeID]chan<- types.StateUpdate),
connected: xsync.NewMapOf[types.NodeID, bool](),
cfg: cfg,
closed: false,
}
b := newBatcher(cfg.Tuning.BatchChangeDelay, n)
n.b = b
go b.doWork()
return n
}
// Close stops the batcher and closes all channels.
func (n *Notifier) Close() {
notifierWaitersForLock.WithLabelValues("lock", "close").Inc()
n.l.Lock()
defer n.l.Unlock()
notifierWaitersForLock.WithLabelValues("lock", "close").Dec()
n.closed = true
n.b.close()
// Close channels safely using the helper method
for nodeID, c := range n.nodes {
n.safeCloseChannel(nodeID, c)
}
// Clear node map after closing channels
n.nodes = make(map[types.NodeID]chan<- types.StateUpdate)
}
// safeCloseChannel closes a channel and panic recovers if already closed.
func (n *Notifier) safeCloseChannel(nodeID types.NodeID, c chan<- types.StateUpdate) {
defer func() {
if r := recover(); r != nil {
log.Error().
Uint64("node.id", nodeID.Uint64()).
Any("recover", r).
Msg("recovered from panic when closing channel in Close()")
}
}()
close(c)
}
func (n *Notifier) tracef(nID types.NodeID, msg string, args ...any) {
log.Trace().
Uint64("node.id", nID.Uint64()).
Int("open_chans", len(n.nodes)).Msgf(msg, args...)
}
func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) {
start := time.Now()
notifierWaitersForLock.WithLabelValues("lock", "add").Inc()
n.l.Lock()
defer n.l.Unlock()
notifierWaitersForLock.WithLabelValues("lock", "add").Dec()
notifierWaitForLock.WithLabelValues("add").Observe(time.Since(start).Seconds())
if n.closed {
return
}
// If a channel exists, it means the node has opened a new
// connection. Close the old channel and replace it.
if curr, ok := n.nodes[nodeID]; ok {
n.tracef(nodeID, "channel present, closing and replacing")
// Use the safeCloseChannel helper in a goroutine to avoid deadlocks
// if/when someone is waiting to send on this channel
go func(ch chan<- types.StateUpdate) {
n.safeCloseChannel(nodeID, ch)
}(curr)
}
n.nodes[nodeID] = c
n.connected.Store(nodeID, true)
n.tracef(nodeID, "added new channel")
notifierNodeUpdateChans.Inc()
}
// RemoveNode removes a node and a given channel from the notifier.
// It checks that the channel is the same as currently being updated
// and ignores the removal if it is not.
// RemoveNode reports if the node/chan was removed.
func (n *Notifier) RemoveNode(nodeID types.NodeID, c chan<- types.StateUpdate) bool {
start := time.Now()
notifierWaitersForLock.WithLabelValues("lock", "remove").Inc()
n.l.Lock()
defer n.l.Unlock()
notifierWaitersForLock.WithLabelValues("lock", "remove").Dec()
notifierWaitForLock.WithLabelValues("remove").Observe(time.Since(start).Seconds())
if n.closed {
return true
}
if len(n.nodes) == 0 {
return true
}
// If the channel exist, but it does not belong
// to the caller, ignore.
if curr, ok := n.nodes[nodeID]; ok {
if curr != c {
n.tracef(nodeID, "channel has been replaced, not removing")
return false
}
}
delete(n.nodes, nodeID)
n.connected.Store(nodeID, false)
n.tracef(nodeID, "removed channel")
notifierNodeUpdateChans.Dec()
return true
}
// IsConnected reports if a node is connected to headscale and has a
// poll session open.
func (n *Notifier) IsConnected(nodeID types.NodeID) bool {
notifierWaitersForLock.WithLabelValues("lock", "conncheck").Inc()
n.l.Lock()
defer n.l.Unlock()
notifierWaitersForLock.WithLabelValues("lock", "conncheck").Dec()
if val, ok := n.connected.Load(nodeID); ok {
return val
}
return false
}
// IsLikelyConnected reports if a node is connected to headscale and has a
// poll session open, but doesn't lock, so might be wrong.
func (n *Notifier) IsLikelyConnected(nodeID types.NodeID) bool {
if val, ok := n.connected.Load(nodeID); ok {
return val
}
return false
}
// LikelyConnectedMap returns a thread safe map of connected nodes.
func (n *Notifier) LikelyConnectedMap() *xsync.MapOf[types.NodeID, bool] {
return n.connected
}
func (n *Notifier) NotifyAll(ctx context.Context, update types.StateUpdate) {
n.NotifyWithIgnore(ctx, update)
}
func (n *Notifier) NotifyWithIgnore(
ctx context.Context,
update types.StateUpdate,
ignoreNodeIDs ...types.NodeID,
) {
if n.closed {
return
}
notifierUpdateReceived.WithLabelValues(update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
n.b.addOrPassthrough(update)
}
func (n *Notifier) NotifyByNodeID(
ctx context.Context,
update types.StateUpdate,
nodeID types.NodeID,
) {
start := time.Now()
notifierWaitersForLock.WithLabelValues("lock", "notify").Inc()
n.l.Lock()
defer n.l.Unlock()
notifierWaitersForLock.WithLabelValues("lock", "notify").Dec()
notifierWaitForLock.WithLabelValues("notify").Observe(time.Since(start).Seconds())
if n.closed {
return
}
if c, ok := n.nodes[nodeID]; ok {
select {
case <-ctx.Done():
log.Error().
Err(ctx.Err()).
Uint64("node.id", nodeID.Uint64()).
Any("origin", types.NotifyOriginKey.Value(ctx)).
Any("origin-hostname", types.NotifyHostnameKey.Value(ctx)).
Msgf("update not sent, context cancelled")
if debugHighCardinalityMetrics {
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), types.NotifyOriginKey.Value(ctx), nodeID.String()).Inc()
} else {
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
}
return
case c <- update:
n.tracef(nodeID, "update successfully sent on chan, origin: %s, origin-hostname: %s", ctx.Value("origin"), ctx.Value("hostname"))
if debugHighCardinalityMetrics {
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), types.NotifyOriginKey.Value(ctx), nodeID.String()).Inc()
} else {
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
}
}
}
}
func (n *Notifier) sendAll(update types.StateUpdate) {
start := time.Now()
notifierWaitersForLock.WithLabelValues("lock", "send-all").Inc()
n.l.Lock()
defer n.l.Unlock()
notifierWaitersForLock.WithLabelValues("lock", "send-all").Dec()
notifierWaitForLock.WithLabelValues("send-all").Observe(time.Since(start).Seconds())
if n.closed {
return
}
for id, c := range n.nodes {
// Whenever an update is sent to all nodes, there is a chance that the node
// has disconnected and the goroutine that was supposed to consume the update
// has shut down the channel and is waiting for the lock held here in RemoveNode.
// This means that there is potential for a deadlock which would stop all updates
// going out to clients. This timeout prevents that from happening by moving on to the
// next node if the context is cancelled. After sendAll releases the lock, the add/remove
// call will succeed and the update will go to the correct nodes on the next call.
ctx, cancel := context.WithTimeout(context.Background(), n.cfg.Tuning.NotifierSendTimeout)
defer cancel()
select {
case <-ctx.Done():
log.Error().
Err(ctx.Err()).
Uint64("node.id", id.Uint64()).
Msgf("update not sent, context cancelled")
if debugHighCardinalityMetrics {
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), "send-all", id.String()).Inc()
} else {
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), "send-all").Inc()
}
return
case c <- update:
if debugHighCardinalityMetrics {
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), "send-all", id.String()).Inc()
} else {
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), "send-all").Inc()
}
}
}
}
func (n *Notifier) String() string {
notifierWaitersForLock.WithLabelValues("lock", "string").Inc()
n.l.Lock()
defer n.l.Unlock()
notifierWaitersForLock.WithLabelValues("lock", "string").Dec()
var b strings.Builder
fmt.Fprintf(&b, "chans (%d):\n", len(n.nodes))
var keys []types.NodeID
n.connected.Range(func(key types.NodeID, value bool) bool {
keys = append(keys, key)
return true
})
sort.Slice(keys, func(i, j int) bool {
return keys[i] < keys[j]
})
for _, key := range keys {
fmt.Fprintf(&b, "\t%d: %p\n", key, n.nodes[key])
}
b.WriteString("\n")
fmt.Fprintf(&b, "connected (%d):\n", len(n.nodes))
for _, key := range keys {
val, _ := n.connected.Load(key)
fmt.Fprintf(&b, "\t%d: %t\n", key, val)
}
return b.String()
}
type batcher struct {
tick *time.Ticker
mu sync.Mutex
cancelCh chan struct{}
changedNodeIDs set.Slice[types.NodeID]
nodesChanged bool
patches map[types.NodeID]tailcfg.PeerChange
patchesChanged bool
n *Notifier
}
func newBatcher(batchTime time.Duration, n *Notifier) *batcher {
return &batcher{
tick: time.NewTicker(batchTime),
cancelCh: make(chan struct{}),
patches: make(map[types.NodeID]tailcfg.PeerChange),
n: n,
}
}
func (b *batcher) close() {
b.cancelCh <- struct{}{}
}
// addOrPassthrough adds the update to the batcher, if it is not a
// type that is currently batched, it will be sent immediately.
func (b *batcher) addOrPassthrough(update types.StateUpdate) {
notifierBatcherWaitersForLock.WithLabelValues("lock", "add").Inc()
b.mu.Lock()
defer b.mu.Unlock()
notifierBatcherWaitersForLock.WithLabelValues("lock", "add").Dec()
switch update.Type {
case types.StatePeerChanged:
b.changedNodeIDs.Add(update.ChangeNodes...)
b.nodesChanged = true
notifierBatcherChanges.WithLabelValues().Set(float64(b.changedNodeIDs.Len()))
case types.StatePeerChangedPatch:
for _, newPatch := range update.ChangePatches {
if curr, ok := b.patches[types.NodeID(newPatch.NodeID)]; ok {
overwritePatch(&curr, newPatch)
b.patches[types.NodeID(newPatch.NodeID)] = curr
} else {
b.patches[types.NodeID(newPatch.NodeID)] = *newPatch
}
}
b.patchesChanged = true
notifierBatcherPatches.WithLabelValues().Set(float64(len(b.patches)))
default:
b.n.sendAll(update)
}
}
// flush sends all the accumulated patches to all
// nodes in the notifier.
func (b *batcher) flush() {
notifierBatcherWaitersForLock.WithLabelValues("lock", "flush").Inc()
b.mu.Lock()
defer b.mu.Unlock()
notifierBatcherWaitersForLock.WithLabelValues("lock", "flush").Dec()
if b.nodesChanged || b.patchesChanged {
var patches []*tailcfg.PeerChange
// If a node is getting a full update from a change
// node update, then the patch can be dropped.
for nodeID, patch := range b.patches {
if b.changedNodeIDs.Contains(nodeID) {
delete(b.patches, nodeID)
} else {
patches = append(patches, &patch)
}
}
changedNodes := b.changedNodeIDs.Slice().AsSlice()
sort.Slice(changedNodes, func(i, j int) bool {
return changedNodes[i] < changedNodes[j]
})
if b.changedNodeIDs.Slice().Len() > 0 {
update := types.UpdatePeerChanged(changedNodes...)
b.n.sendAll(update)
}
if len(patches) > 0 {
patchUpdate := types.UpdatePeerPatch(patches...)
b.n.sendAll(patchUpdate)
}
b.changedNodeIDs = set.Slice[types.NodeID]{}
notifierBatcherChanges.WithLabelValues().Set(0)
b.nodesChanged = false
b.patches = make(map[types.NodeID]tailcfg.PeerChange, len(b.patches))
notifierBatcherPatches.WithLabelValues().Set(0)
b.patchesChanged = false
}
}
func (b *batcher) doWork() {
for {
select {
case <-b.cancelCh:
return
case <-b.tick.C:
b.flush()
}
}
}
// overwritePatch takes the current patch and a newer patch
// and override any field that has changed.
func overwritePatch(currPatch, newPatch *tailcfg.PeerChange) {
if newPatch.DERPRegion != 0 {
currPatch.DERPRegion = newPatch.DERPRegion
}
if newPatch.Cap != 0 {
currPatch.Cap = newPatch.Cap
}
if newPatch.CapMap != nil {
currPatch.CapMap = newPatch.CapMap
}
if newPatch.Endpoints != nil {
currPatch.Endpoints = newPatch.Endpoints
}
if newPatch.Key != nil {
currPatch.Key = newPatch.Key
}
if newPatch.KeySignature != nil {
currPatch.KeySignature = newPatch.KeySignature
}
if newPatch.DiscoKey != nil {
currPatch.DiscoKey = newPatch.DiscoKey
}
if newPatch.Online != nil {
currPatch.Online = newPatch.Online
}
if newPatch.LastSeen != nil {
currPatch.LastSeen = newPatch.LastSeen
}
if newPatch.KeyExpiry != nil {
currPatch.KeyExpiry = newPatch.KeyExpiry
}
}

View File

@ -1,342 +0,0 @@
package notifier
import (
"fmt"
"math/rand"
"net/netip"
"slices"
"sort"
"sync"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"tailscale.com/tailcfg"
)
func TestBatcher(t *testing.T) {
tests := []struct {
name string
updates []types.StateUpdate
want []types.StateUpdate
}{
{
name: "full-passthrough",
updates: []types.StateUpdate{
{
Type: types.StateFullUpdate,
},
},
want: []types.StateUpdate{
{
Type: types.StateFullUpdate,
},
},
},
{
name: "derp-passthrough",
updates: []types.StateUpdate{
{
Type: types.StateDERPUpdated,
},
},
want: []types.StateUpdate{
{
Type: types.StateDERPUpdated,
},
},
},
{
name: "single-node-update",
updates: []types.StateUpdate{
{
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{
2,
},
},
},
want: []types.StateUpdate{
{
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{
2,
},
},
},
},
{
name: "merge-node-update",
updates: []types.StateUpdate{
{
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{
2, 4,
},
},
{
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{
2, 3,
},
},
},
want: []types.StateUpdate{
{
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{
2, 3, 4,
},
},
},
},
{
name: "single-patch-update",
updates: []types.StateUpdate{
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 2,
DERPRegion: 5,
},
},
},
},
want: []types.StateUpdate{
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 2,
DERPRegion: 5,
},
},
},
},
},
{
name: "merge-patch-to-same-node-update",
updates: []types.StateUpdate{
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 2,
DERPRegion: 5,
},
},
},
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 2,
DERPRegion: 6,
},
},
},
},
want: []types.StateUpdate{
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 2,
DERPRegion: 6,
},
},
},
},
},
{
name: "merge-patch-to-multiple-node-update",
updates: []types.StateUpdate{
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 3,
Endpoints: []netip.AddrPort{
netip.MustParseAddrPort("1.1.1.1:9090"),
},
},
},
},
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 3,
Endpoints: []netip.AddrPort{
netip.MustParseAddrPort("1.1.1.1:9090"),
netip.MustParseAddrPort("2.2.2.2:8080"),
},
},
},
},
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 4,
DERPRegion: 6,
},
},
},
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 4,
Cap: tailcfg.CapabilityVersion(54),
},
},
},
},
want: []types.StateUpdate{
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 3,
Endpoints: []netip.AddrPort{
netip.MustParseAddrPort("1.1.1.1:9090"),
netip.MustParseAddrPort("2.2.2.2:8080"),
},
},
{
NodeID: 4,
DERPRegion: 6,
Cap: tailcfg.CapabilityVersion(54),
},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
n := NewNotifier(&types.Config{
Tuning: types.Tuning{
// We will call flush manually for the tests,
// so do not run the worker.
BatchChangeDelay: time.Hour,
// Since we do not load the config, we won't get the
// default, so set it manually so we dont time out
// and have flakes.
NotifierSendTimeout: time.Second,
},
})
ch := make(chan types.StateUpdate, 30)
defer close(ch)
n.AddNode(1, ch)
defer n.RemoveNode(1, ch)
for _, u := range tt.updates {
n.NotifyAll(t.Context(), u)
}
n.b.flush()
var got []types.StateUpdate
for len(ch) > 0 {
out := <-ch
got = append(got, out)
}
// Make the inner order stable for comparison.
for _, u := range got {
slices.Sort(u.ChangeNodes)
sort.Slice(u.ChangePatches, func(i, j int) bool {
return u.ChangePatches[i].NodeID < u.ChangePatches[j].NodeID
})
}
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
t.Errorf("batcher() unexpected result (-want +got):\n%s", diff)
}
})
}
}
// TestIsLikelyConnectedRaceCondition tests for a race condition in IsLikelyConnected
// Multiple goroutines calling AddNode and RemoveNode cause panics when trying to
// close a channel that was already closed, which can happen when a node changes
// network transport quickly (eg mobile->wifi) and reconnects whilst also disconnecting.
func TestIsLikelyConnectedRaceCondition(t *testing.T) {
// mock config for the notifier
cfg := &types.Config{
Tuning: types.Tuning{
NotifierSendTimeout: 1 * time.Second,
BatchChangeDelay: 1 * time.Second,
NodeMapSessionBufferedChanSize: 30,
},
}
notifier := NewNotifier(cfg)
defer notifier.Close()
nodeID := types.NodeID(1)
updateChan := make(chan types.StateUpdate, 10)
var wg sync.WaitGroup
// Number of goroutines to spawn for concurrent access
concurrentAccessors := 100
iterations := 100
// Add node to notifier
notifier.AddNode(nodeID, updateChan)
// Track errors
errChan := make(chan string, concurrentAccessors*iterations)
// Start goroutines to cause a race
wg.Add(concurrentAccessors)
for i := range concurrentAccessors {
go func(routineID int) {
defer wg.Done()
for range iterations {
// Simulate race by having some goroutines check IsLikelyConnected
// while others add/remove the node
switch routineID % 3 {
case 0:
// This goroutine checks connection status
isConnected := notifier.IsLikelyConnected(nodeID)
if isConnected != true && isConnected != false {
errChan <- fmt.Sprintf("Invalid connection status: %v", isConnected)
}
case 1:
// This goroutine removes the node
notifier.RemoveNode(nodeID, updateChan)
default:
// This goroutine adds the node back
notifier.AddNode(nodeID, updateChan)
}
// Small random delay to increase chance of races
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
}
}(i)
}
wg.Wait()
close(errChan)
// Collate errors
var errors []string
for err := range errChan {
errors = append(errors, err)
}
if len(errors) > 0 {
t.Errorf("Detected %d race condition errors: %v", len(errors), errors)
}
}

View File

@ -16,9 +16,8 @@ import (
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/mux"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/state"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"golang.org/x/oauth2"
@ -56,11 +55,10 @@ type RegistrationInfo struct {
}
type AuthProviderOIDC struct {
h *Headscale
serverURL string
cfg *types.OIDCConfig
state *state.State
registrationCache *zcache.Cache[string, RegistrationInfo]
notifier *notifier.Notifier
oidcProvider *oidc.Provider
oauth2Config *oauth2.Config
@ -68,10 +66,9 @@ type AuthProviderOIDC struct {
func NewAuthProviderOIDC(
ctx context.Context,
h *Headscale,
serverURL string,
cfg *types.OIDCConfig,
state *state.State,
notif *notifier.Notifier,
) (*AuthProviderOIDC, error) {
var err error
// grab oidc config if it hasn't been already
@ -94,11 +91,10 @@ func NewAuthProviderOIDC(
)
return &AuthProviderOIDC{
h: h,
serverURL: serverURL,
cfg: cfg,
state: state,
registrationCache: registrationCache,
notifier: notif,
oidcProvider: oidcProvider,
oauth2Config: oauth2Config,
@ -318,8 +314,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "oidc-user-created", user.Name)
a.notifier.NotifyAll(ctx, types.UpdateFull())
a.h.Change(change.PolicyChange())
}
// TODO(kradalby): Is this comment right?
@ -360,8 +355,6 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
// Neither node nor machine key was found in the state cache meaning
// that we could not reauth nor register the node.
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
return
}
func extractCodeAndStateParamFromRequest(
@ -490,12 +483,14 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
var err error
var newUser bool
var policyChanged bool
user, err = a.state.GetUserByOIDCIdentifier(claims.Identifier())
user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier())
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
return nil, false, fmt.Errorf("creating or updating user: %w", err)
}
// if the user is still not found, create a new empty user.
// TODO(kradalby): This might cause us to not have an ID below which
// is a problem.
if user == nil {
newUser = true
user = &types.User{}
@ -504,12 +499,12 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
user.FromClaim(claims)
if newUser {
user, policyChanged, err = a.state.CreateUser(*user)
user, policyChanged, err = a.h.state.CreateUser(*user)
if err != nil {
return nil, false, fmt.Errorf("creating user: %w", err)
}
} else {
_, policyChanged, err = a.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error {
_, policyChanged, err = a.h.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error {
*u = *user
return nil
})
@ -526,7 +521,7 @@ func (a *AuthProviderOIDC) handleRegistration(
registrationID types.RegistrationID,
expiry time.Time,
) (bool, error) {
node, newNode, err := a.state.HandleNodeFromAuthPath(
node, nodeChange, err := a.h.state.HandleNodeFromAuthPath(
registrationID,
types.UserID(user.ID),
&expiry,
@ -547,31 +542,20 @@ func (a *AuthProviderOIDC) handleRegistration(
// ensure we send an update.
// This works, but might be another good candidate for doing some sort of
// eventbus.
routesChanged := a.state.AutoApproveRoutes(node)
_, policyChanged, err := a.state.SaveNode(node)
_ = a.h.state.AutoApproveRoutes(node)
_, policyChange, err := a.h.state.SaveNode(node)
if err != nil {
return false, fmt.Errorf("saving auto approved routes to node: %w", err)
}
// Send policy update notifications if needed (from SaveNode or route changes)
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "oidc-nodes-change", "all")
a.notifier.NotifyAll(ctx, types.UpdateFull())
// Policy updates are full and take precedence over node changes.
if !policyChange.Empty() {
a.h.Change(policyChange)
} else {
a.h.Change(nodeChange)
}
if routesChanged {
ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname)
a.notifier.NotifyByNodeID(
ctx,
types.UpdateSelf(node.ID),
node.ID,
)
ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname)
a.notifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
}
return newNode, nil
return !nodeChange.Empty(), nil
}
// TODO(kradalby):

View File

@ -113,6 +113,17 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf
}
}
}
// Also check approved subnet routes - nodes should have access
// to subnets they're approved to route traffic for.
subnetRoutes := node.SubnetRoutes()
for _, subnetRoute := range subnetRoutes {
if expanded.OverlapsPrefix(subnetRoute) {
dests = append(dests, dest)
continue DEST_LOOP
}
}
}
if len(dests) > 0 {
@ -142,17 +153,24 @@ func AutoApproveRoutes(pm PolicyManager, node *types.Node) bool {
newApproved = append(newApproved, route)
}
}
if newApproved != nil {
newApproved = append(newApproved, node.ApprovedRoutes...)
tsaddr.SortPrefixes(newApproved)
newApproved = slices.Compact(newApproved)
newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool {
// Only modify ApprovedRoutes if we have new routes to approve.
// This prevents clearing existing approved routes when nodes
// temporarily don't have announced routes during policy changes.
if len(newApproved) > 0 {
combined := append(newApproved, node.ApprovedRoutes...)
tsaddr.SortPrefixes(combined)
combined = slices.Compact(combined)
combined = lo.Filter(combined, func(route netip.Prefix, index int) bool {
return route.IsValid()
})
node.ApprovedRoutes = newApproved
// Only update if the routes actually changed
if !slices.Equal(node.ApprovedRoutes, combined) {
node.ApprovedRoutes = combined
return true
}
}
return false
}

View File

@ -56,10 +56,13 @@ func (pol *Policy) compileFilterRules(
}
if ips == nil {
log.Debug().Msgf("destination resolved to nil ips: %v", dest)
continue
}
for _, pref := range ips.Prefixes() {
prefixes := ips.Prefixes()
for _, pref := range prefixes {
for _, port := range dest.Ports {
pr := tailcfg.NetPortRange{
IP: pref.String(),
@ -103,6 +106,8 @@ func (pol *Policy) compileSSHPolicy(
return nil, nil
}
log.Trace().Msgf("compiling SSH policy for node %q", node.Hostname())
var rules []*tailcfg.SSHRule
for index, rule := range pol.SSHs {
@ -137,7 +142,8 @@ func (pol *Policy) compileSSHPolicy(
var principals []*tailcfg.SSHPrincipal
srcIPs, err := rule.Sources.Resolve(pol, users, nodes)
if err != nil {
log.Trace().Err(err).Msgf("resolving source ips")
log.Trace().Err(err).Msgf("SSH policy compilation failed resolving source ips for rule %+v", rule)
continue // Skip this rule if we can't resolve sources
}
for addr := range util.IPSetAddrIter(srcIPs) {

View File

@ -70,7 +70,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
// TODO(kradalby): This could potentially be optimized by only clearing the
// policies for nodes that have changed. Particularly if the only difference is
// that nodes has been added or removed.
defer clear(pm.sshPolicyMap)
clear(pm.sshPolicyMap)
filter, err := pm.pol.compileFilterRules(pm.users, pm.nodes)
if err != nil {

View File

@ -1730,7 +1730,7 @@ func (u SSHUser) MarshalJSON() ([]byte, error) {
// In addition to unmarshalling, it will also validate the policy.
// This is the only entrypoint of reading a policy from a file or other source.
func unmarshalPolicy(b []byte) (*Policy, error) {
if b == nil || len(b) == 0 {
if len(b) == 0 {
return nil, nil
}

View File

@ -2,20 +2,20 @@ package hscontrol
import (
"context"
"encoding/binary"
"encoding/json"
"fmt"
"math/rand/v2"
"net/http"
"net/netip"
"slices"
"time"
"github.com/juanfont/headscale/hscontrol/mapper"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"github.com/sasha-s/go-deadlock"
xslices "golang.org/x/exp/slices"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/util/zstdframe"
)
const (
@ -31,18 +31,17 @@ type mapSession struct {
req tailcfg.MapRequest
ctx context.Context
capVer tailcfg.CapabilityVersion
mapper *mapper.Mapper
cancelChMu deadlock.Mutex
ch chan types.StateUpdate
ch chan *tailcfg.MapResponse
cancelCh chan struct{}
cancelChOpen bool
keepAlive time.Duration
keepAliveTicker *time.Ticker
node types.NodeView
node *types.Node
w http.ResponseWriter
warnf func(string, ...any)
@ -55,18 +54,9 @@ func (h *Headscale) newMapSession(
ctx context.Context,
req tailcfg.MapRequest,
w http.ResponseWriter,
nv types.NodeView,
node *types.Node,
) *mapSession {
warnf, infof, tracef, errf := logPollFuncView(req, nv)
var updateChan chan types.StateUpdate
if req.Stream {
// Use a buffered channel in case a node is not fully ready
// to receive a message to make sure we dont block the entire
// notifier.
updateChan = make(chan types.StateUpdate, h.cfg.Tuning.NodeMapSessionBufferedChanSize)
updateChan <- types.UpdateFull()
}
warnf, infof, tracef, errf := logPollFunc(req, node)
ka := keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond)
@ -75,11 +65,10 @@ func (h *Headscale) newMapSession(
ctx: ctx,
req: req,
w: w,
node: nv,
node: node,
capVer: req.Version,
mapper: h.mapper,
ch: updateChan,
ch: make(chan *tailcfg.MapResponse, h.cfg.Tuning.NodeMapSessionBufferedChanSize),
cancelCh: make(chan struct{}),
cancelChOpen: true,
@ -95,15 +84,11 @@ func (h *Headscale) newMapSession(
}
func (m *mapSession) isStreaming() bool {
return m.req.Stream && !m.req.ReadOnly
return m.req.Stream
}
func (m *mapSession) isEndpointUpdate() bool {
return !m.req.Stream && !m.req.ReadOnly && m.req.OmitPeers
}
func (m *mapSession) isReadOnlyUpdate() bool {
return !m.req.Stream && m.req.OmitPeers && m.req.ReadOnly
return !m.req.Stream && m.req.OmitPeers
}
func (m *mapSession) resetKeepAlive() {
@ -112,25 +97,22 @@ func (m *mapSession) resetKeepAlive() {
func (m *mapSession) beforeServeLongPoll() {
if m.node.IsEphemeral() {
m.h.ephemeralGC.Cancel(m.node.ID())
m.h.ephemeralGC.Cancel(m.node.ID)
}
}
func (m *mapSession) afterServeLongPoll() {
if m.node.IsEphemeral() {
m.h.ephemeralGC.Schedule(m.node.ID(), m.h.cfg.EphemeralNodeInactivityTimeout)
m.h.ephemeralGC.Schedule(m.node.ID, m.h.cfg.EphemeralNodeInactivityTimeout)
}
}
// serve handles non-streaming requests.
func (m *mapSession) serve() {
// TODO(kradalby): A set todos to harden:
// - func to tell the stream to die, readonly -> false, !stream && omitpeers -> false, true
// This is the mechanism where the node gives us information about its
// current configuration.
//
// If OmitPeers is true, Stream is false, and ReadOnly is false,
// If OmitPeers is true and Stream is false
// then the server will let clients update their endpoints without
// breaking existing long-polling (Stream == true) connections.
// In this case, the server can omit the entire response; the client
@ -138,26 +120,18 @@ func (m *mapSession) serve() {
//
// This is what Tailscale calls a Lite update, the client ignores
// the response and just wants a 200.
// !req.stream && !req.ReadOnly && req.OmitPeers
//
// TODO(kradalby): remove ReadOnly when we only support capVer 68+
// !req.stream && req.OmitPeers
if m.isEndpointUpdate() {
m.handleEndpointUpdate()
c, err := m.h.state.UpdateNodeFromMapRequest(m.node, m.req)
if err != nil {
httpError(m.w, err)
return
}
// ReadOnly is whether the client just wants to fetch the
// MapResponse, without updating their Endpoints. The
// Endpoints field will be ignored and LastSeen will not be
// updated and peers will not be notified of changes.
//
// The intended use is for clients to discover the DERP map at
// start-up before their first real endpoint update.
if m.isReadOnlyUpdate() {
m.handleReadOnlyRequest()
m.h.Change(c)
return
m.w.WriteHeader(http.StatusOK)
mapResponseEndpointUpdates.WithLabelValues("ok").Inc()
}
}
@ -175,23 +149,15 @@ func (m *mapSession) serveLongPoll() {
close(m.cancelCh)
m.cancelChMu.Unlock()
// only update node status if the node channel was removed.
// in principal, it will be removed, but the client rapidly
// reconnects, the channel might be of another connection.
// In that case, it is not closed and the node is still online.
if m.h.nodeNotifier.RemoveNode(m.node.ID(), m.ch) {
// TODO(kradalby): This can likely be made more effective, but likely most
// nodes has access to the same routes, so it might not be a big deal.
change, err := m.h.state.Disconnect(m.node.ID())
disconnectChange, err := m.h.state.Disconnect(m.node)
if err != nil {
m.errf(err, "Failed to disconnect node %s", m.node.Hostname())
m.errf(err, "Failed to disconnect node %s", m.node.Hostname)
}
m.h.Change(disconnectChange)
if change {
ctx := types.NotifyCtx(context.Background(), "poll-primary-change", m.node.Hostname())
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
}
m.h.mapBatcher.RemoveNode(m.node.ID, m.ch, m.node.IsSubnetRouter())
m.afterServeLongPoll()
m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch)
@ -201,21 +167,30 @@ func (m *mapSession) serveLongPoll() {
m.h.pollNetMapStreamWG.Add(1)
defer m.h.pollNetMapStreamWG.Done()
m.h.state.Connect(m.node.ID())
// Upgrade the writer to a ResponseController
rc := http.NewResponseController(m.w)
// Longpolling will break if there is a write timeout,
// so it needs to be disabled.
rc.SetWriteDeadline(time.Time{})
ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname()))
ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname))
defer cancel()
m.keepAliveTicker = time.NewTicker(m.keepAlive)
m.h.nodeNotifier.AddNode(m.node.ID(), m.ch)
// Add node to batcher BEFORE sending Connect change to prevent race condition
// where the change is sent before the node is in the batcher's node map
if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.node.IsSubnetRouter(), m.capVer); err != nil {
m.errf(err, "failed to add node to batcher")
// Send empty response to client to fail fast for invalid/non-existent nodes
select {
case m.ch <- &tailcfg.MapResponse{}:
default:
// Channel might be closed
}
return
}
// Now send the Connect change - the batcher handles NodeCameOnline internally
// but we still need to update routes and other state-level changes
connectChange := m.h.state.Connect(m.node)
if !connectChange.Empty() && connectChange.Change != change.NodeCameOnline {
m.h.Change(connectChange)
}
m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch)
@ -236,290 +211,94 @@ func (m *mapSession) serveLongPoll() {
// Consume updates sent to node
case update, ok := <-m.ch:
m.tracef("received update from channel, ok: %t", ok)
if !ok {
m.tracef("update channel closed, streaming session is likely being replaced")
return
}
// If the node has been removed from headscale, close the stream
if slices.Contains(update.Removed, m.node.ID()) {
m.tracef("node removed, closing stream")
if err := m.writeMap(update); err != nil {
m.errf(err, "cannot write update to client")
return
}
m.tracef("received stream update: %s %s", update.Type.String(), update.Message)
mapResponseUpdateReceived.WithLabelValues(update.Type.String()).Inc()
var data []byte
var err error
var lastMessage string
// Ensure the node view is updated, for example, there
// might have been a hostinfo update in a sidechannel
// which contains data needed to generate a map response.
m.node, err = m.h.state.GetNodeViewByID(m.node.ID())
if err != nil {
m.errf(err, "Could not get machine from db")
return
}
updateType := "full"
switch update.Type {
case types.StateFullUpdate:
m.tracef("Sending Full MapResponse")
data, err = m.mapper.FullMapResponse(m.req, m.node, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming()))
case types.StatePeerChanged:
changed := make(map[types.NodeID]bool, len(update.ChangeNodes))
for _, nodeID := range update.ChangeNodes {
changed[nodeID] = true
}
lastMessage = update.Message
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage)
updateType = "change"
case types.StatePeerChangedPatch:
m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage))
data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches)
updateType = "patch"
case types.StatePeerRemoved:
changed := make(map[types.NodeID]bool, len(update.Removed))
for _, nodeID := range update.Removed {
changed[nodeID] = false
}
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage)
updateType = "remove"
case types.StateSelfUpdate:
lastMessage = update.Message
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
// create the map so an empty (self) update is sent
data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, lastMessage)
updateType = "remove"
case types.StateDERPUpdated:
m.tracef("Sending DERPUpdate MapResponse")
data, err = m.mapper.DERPMapResponse(m.req, m.node, m.h.state.DERPMap())
updateType = "derp"
}
if err != nil {
m.errf(err, "Could not get the create map update")
return
}
// Only send update if there is change
if data != nil {
startWrite := time.Now()
_, err = m.w.Write(data)
if err != nil {
mapResponseSent.WithLabelValues("error", updateType).Inc()
m.errf(err, "could not write the map response(%s), for mapSession: %p", update.Type.String(), m)
return
}
err = rc.Flush()
if err != nil {
mapResponseSent.WithLabelValues("error", updateType).Inc()
m.errf(err, "flushing the map response to client, for mapSession: %p", m)
return
}
log.Trace().Str("node", m.node.Hostname()).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey().String()).Msg("finished writing mapresp to node")
if debugHighCardinalityMetrics {
mapResponseLastSentSeconds.WithLabelValues(updateType, m.node.ID().String()).Set(float64(time.Now().Unix()))
}
mapResponseSent.WithLabelValues("ok", updateType).Inc()
m.tracef("update sent")
m.resetKeepAlive()
}
case <-m.keepAliveTicker.C:
data, err := m.mapper.KeepAliveResponse(m.req, m.node)
if err != nil {
m.errf(err, "Error generating the keep alive msg")
mapResponseSent.WithLabelValues("error", "keepalive").Inc()
return
}
_, err = m.w.Write(data)
if err != nil {
m.errf(err, "Cannot write keep alive message")
mapResponseSent.WithLabelValues("error", "keepalive").Inc()
return
}
err = rc.Flush()
if err != nil {
m.errf(err, "flushing keep alive to client, for mapSession: %p", m)
mapResponseSent.WithLabelValues("error", "keepalive").Inc()
if err := m.writeMap(&keepAlive); err != nil {
m.errf(err, "cannot write keep alive")
return
}
if debugHighCardinalityMetrics {
mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID().String()).Set(float64(time.Now().Unix()))
mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID.String()).Set(float64(time.Now().Unix()))
}
mapResponseSent.WithLabelValues("ok", "keepalive").Inc()
}
}
}
func (m *mapSession) handleEndpointUpdate() {
m.tracef("received endpoint update")
// Get fresh node state from database for accurate route calculations
node, err := m.h.state.GetNodeByID(m.node.ID())
// writeMap writes the map response to the client.
// It handles compression if requested and any headers that need to be set.
// It also handles flushing the response if the ResponseWriter
// implements http.Flusher.
func (m *mapSession) writeMap(msg *tailcfg.MapResponse) error {
jsonBody, err := json.Marshal(msg)
if err != nil {
m.errf(err, "Failed to get fresh node from database for endpoint update")
http.Error(m.w, "", http.StatusInternalServerError)
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
return
return fmt.Errorf("marshalling map response: %w", err)
}
change := m.node.PeerChangeFromMapRequest(m.req)
online := m.h.nodeNotifier.IsLikelyConnected(m.node.ID())
change.Online = &online
node.ApplyPeerChange(&change)
sendUpdate, routesChanged := hostInfoChanged(node.Hostinfo, m.req.Hostinfo)
// The node might not set NetInfo if it has not changed and if
// the full HostInfo object is overwritten, the information is lost.
// If there is no NetInfo, keep the previous one.
// From 1.66 the client only sends it if changed:
// https://github.com/tailscale/tailscale/commit/e1011f138737286ecf5123ff887a7a5800d129a2
// TODO(kradalby): evaluate if we need better comparing of hostinfo
// before we take the changes.
if m.req.Hostinfo.NetInfo == nil && node.Hostinfo != nil {
m.req.Hostinfo.NetInfo = node.Hostinfo.NetInfo
}
node.Hostinfo = m.req.Hostinfo
logTracePeerChange(node.Hostname, sendUpdate, &change)
// If there is no changes and nothing to save,
// return early.
if peerChangeEmpty(change) && !sendUpdate {
mapResponseEndpointUpdates.WithLabelValues("noop").Inc()
return
if m.req.Compress == util.ZstdCompression {
jsonBody = zstdframe.AppendEncode(nil, jsonBody, zstdframe.FastestCompression)
}
// Auto approve any routes that have been defined in policy as
// auto approved. Check if this actually changed the node.
routesAutoApproved := m.h.state.AutoApproveRoutes(node)
data := make([]byte, reservedResponseHeaderSize)
binary.LittleEndian.PutUint32(data, uint32(len(jsonBody)))
data = append(data, jsonBody...)
// Always update routes for connected nodes to handle reconnection scenarios
// where routes need to be restored to the primary routes system
routesToSet := node.SubnetRoutes()
startWrite := time.Now()
if m.h.state.SetNodeRoutes(node.ID, routesToSet...) {
ctx := types.NotifyCtx(m.ctx, "poll-primary-change", node.Hostname)
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
} else if routesChanged {
// Only send peer changed notification if routes actually changed
ctx := types.NotifyCtx(m.ctx, "cli-approveroutes", node.Hostname)
m.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
// TODO(kradalby): I am not sure if we need this?
// Send an update to the node itself with to ensure it
// has an updated packetfilter allowing the new route
// if it is defined in the ACL.
ctx = types.NotifyCtx(m.ctx, "poll-nodeupdate-self-hostinfochange", node.Hostname)
m.h.nodeNotifier.NotifyByNodeID(
ctx,
types.UpdateSelf(node.ID),
node.ID)
}
// If routes were auto-approved, we need to save the node to persist the changes
if routesAutoApproved {
if _, _, err := m.h.state.SaveNode(node); err != nil {
m.errf(err, "Failed to save auto-approved routes to node")
http.Error(m.w, "", http.StatusInternalServerError)
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
return
}
}
// Check if there has been a change to Hostname and update them
// in the database. Then send a Changed update
// (containing the whole node object) to peers to inform about
// the hostname change.
node.ApplyHostnameFromHostInfo(m.req.Hostinfo)
_, policyChanged, err := m.h.state.SaveNode(node)
_, err = m.w.Write(data)
if err != nil {
m.errf(err, "Failed to persist/update node in the database")
http.Error(m.w, "", http.StatusInternalServerError)
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
return
return err
}
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-policy", node.Hostname)
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
if m.isStreaming() {
if f, ok := m.w.(http.Flusher); ok {
f.Flush()
} else {
m.errf(nil, "ResponseWriter does not implement http.Flusher, cannot flush")
}
}
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-patch", node.Hostname)
m.h.nodeNotifier.NotifyWithIgnore(
ctx,
types.UpdatePeerChanged(node.ID),
node.ID,
)
log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey.String()).Msg("finished writing mapresp to node")
m.w.WriteHeader(http.StatusOK)
mapResponseEndpointUpdates.WithLabelValues("ok").Inc()
return nil
}
func (m *mapSession) handleReadOnlyRequest() {
m.tracef("Client asked for a lite update, responding without peers")
mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node)
if err != nil {
m.errf(err, "Failed to create MapResponse")
http.Error(m.w, "", http.StatusInternalServerError)
mapResponseReadOnly.WithLabelValues("error").Inc()
return
}
m.w.Header().Set("Content-Type", "application/json; charset=utf-8")
m.w.WriteHeader(http.StatusOK)
_, err = m.w.Write(mapResp)
if err != nil {
m.errf(err, "Failed to write response")
mapResponseReadOnly.WithLabelValues("error").Inc()
return
}
m.w.WriteHeader(http.StatusOK)
mapResponseReadOnly.WithLabelValues("ok").Inc()
var keepAlive = tailcfg.MapResponse{
KeepAlive: true,
}
func logTracePeerChange(hostname string, hostinfoChange bool, change *tailcfg.PeerChange) {
trace := log.Trace().Uint64("node.id", uint64(change.NodeID)).Str("hostname", hostname)
func logTracePeerChange(hostname string, hostinfoChange bool, peerChange *tailcfg.PeerChange) {
trace := log.Trace().Uint64("node.id", uint64(peerChange.NodeID)).Str("hostname", hostname)
if change.Key != nil {
trace = trace.Str("node_key", change.Key.ShortString())
if peerChange.Key != nil {
trace = trace.Str("node_key", peerChange.Key.ShortString())
}
if change.DiscoKey != nil {
trace = trace.Str("disco_key", change.DiscoKey.ShortString())
if peerChange.DiscoKey != nil {
trace = trace.Str("disco_key", peerChange.DiscoKey.ShortString())
}
if change.Online != nil {
trace = trace.Bool("online", *change.Online)
if peerChange.Online != nil {
trace = trace.Bool("online", *peerChange.Online)
}
if change.Endpoints != nil {
eps := make([]string, len(change.Endpoints))
for idx, ep := range change.Endpoints {
if peerChange.Endpoints != nil {
eps := make([]string, len(peerChange.Endpoints))
for idx, ep := range peerChange.Endpoints {
eps[idx] = ep.String()
}
@ -530,21 +309,11 @@ func logTracePeerChange(hostname string, hostinfoChange bool, change *tailcfg.Pe
trace = trace.Bool("hostinfo_changed", hostinfoChange)
}
if change.DERPRegion != 0 {
trace = trace.Int("derp_region", change.DERPRegion)
if peerChange.DERPRegion != 0 {
trace = trace.Int("derp_region", peerChange.DERPRegion)
}
trace.Time("last_seen", *change.LastSeen).Msg("PeerChange received")
}
func peerChangeEmpty(chng tailcfg.PeerChange) bool {
return chng.Key == nil &&
chng.DiscoKey == nil &&
chng.Online == nil &&
chng.Endpoints == nil &&
chng.DERPRegion == 0 &&
chng.LastSeen == nil &&
chng.KeyExpiry == nil
trace.Time("last_seen", *peerChange.LastSeen).Msg("PeerChange received")
}
func logPollFunc(
@ -554,7 +323,6 @@ func logPollFunc(
return func(msg string, a ...any) {
log.Warn().
Caller().
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", node.ID.Uint64()).
@ -564,7 +332,6 @@ func logPollFunc(
func(msg string, a ...any) {
log.Info().
Caller().
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", node.ID.Uint64()).
@ -574,7 +341,6 @@ func logPollFunc(
func(msg string, a ...any) {
log.Trace().
Caller().
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", node.ID.Uint64()).
@ -584,7 +350,6 @@ func logPollFunc(
func(err error, msg string, a ...any) {
log.Error().
Caller().
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", node.ID.Uint64()).
@ -593,91 +358,3 @@ func logPollFunc(
Msgf(msg, a...)
}
}
func logPollFuncView(
mapRequest tailcfg.MapRequest,
nodeView types.NodeView,
) (func(string, ...any), func(string, ...any), func(string, ...any), func(error, string, ...any)) {
return func(msg string, a ...any) {
log.Warn().
Caller().
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", nodeView.ID().Uint64()).
Str("node", nodeView.Hostname()).
Msgf(msg, a...)
},
func(msg string, a ...any) {
log.Info().
Caller().
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", nodeView.ID().Uint64()).
Str("node", nodeView.Hostname()).
Msgf(msg, a...)
},
func(msg string, a ...any) {
log.Trace().
Caller().
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", nodeView.ID().Uint64()).
Str("node", nodeView.Hostname()).
Msgf(msg, a...)
},
func(err error, msg string, a ...any) {
log.Error().
Caller().
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", nodeView.ID().Uint64()).
Str("node", nodeView.Hostname()).
Err(err).
Msgf(msg, a...)
}
}
// hostInfoChanged reports if hostInfo has changed in two ways,
// - first bool reports if an update needs to be sent to nodes
// - second reports if there has been changes to routes
// the caller can then use this info to save and update nodes
// and routes as needed.
func hostInfoChanged(old, new *tailcfg.Hostinfo) (bool, bool) {
if old.Equal(new) {
return false, false
}
if old == nil && new != nil {
return true, true
}
// Routes
oldRoutes := make([]netip.Prefix, 0)
if old != nil {
oldRoutes = old.RoutableIPs
}
newRoutes := new.RoutableIPs
tsaddr.SortPrefixes(oldRoutes)
tsaddr.SortPrefixes(newRoutes)
if !xslices.Equal(oldRoutes, newRoutes) {
return true, true
}
// Services is mostly useful for discovery and not critical,
// except for peerapi, which is how nodes talk to each other.
// If peerapi was not part of the initial mapresponse, we
// need to make sure its sent out later as it is needed for
// Taildrop.
// TODO(kradalby): Length comparison is a bit naive, replace.
if len(old.Services) != len(new.Services) {
return true, false
}
return false, false
}

View File

@ -17,10 +17,13 @@ import (
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/routes"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"github.com/sasha-s/go-deadlock"
xslices "golang.org/x/exp/slices"
"gorm.io/gorm"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
@ -46,12 +49,6 @@ type State struct {
// cfg holds the current Headscale configuration
cfg *types.Config
// in-memory data, protected by mu
// nodes contains the current set of registered nodes
nodes types.Nodes
// users contains the current set of users/namespaces
users types.Users
// subsystem keeping state
// db provides persistent storage and database operations
db *hsdb.HSDatabase
@ -113,9 +110,6 @@ func NewState(cfg *types.Config) (*State, error) {
return &State{
cfg: cfg,
nodes: nodes,
users: users,
db: db,
ipAlloc: ipAlloc,
// TODO(kradalby): Update DERPMap
@ -215,6 +209,7 @@ func (s *State) CreateUser(user types.User) (*types.User, bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
if err := s.db.DB.Save(&user).Error; err != nil {
return nil, false, fmt.Errorf("creating user: %w", err)
}
@ -226,6 +221,18 @@ func (s *State) CreateUser(user types.User) (*types.User, bool, error) {
return &user, false, fmt.Errorf("failed to update policy manager after user creation: %w", err)
}
// Even if the policy manager doesn't detect a filter change, SSH policies
// might now be resolvable when they weren't before. If there are existing
// nodes, we should send a policy change to ensure they get updated SSH policies.
if !policyChanged {
nodes, err := s.ListNodes()
if err == nil && len(nodes) > 0 {
policyChanged = true
}
}
log.Info().Str("user", user.Name).Bool("policyChanged", policyChanged).Msg("User created, policy manager updated")
// TODO(kradalby): implement the user in-memory cache
return &user, policyChanged, nil
@ -329,7 +336,7 @@ func (s *State) CreateNode(node *types.Node) (*types.Node, bool, error) {
}
// updateNodeTx performs a database transaction to update a node and refresh the policy manager.
func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) error) (*types.Node, bool, error) {
func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) error) (*types.Node, change.ChangeSet, error) {
s.mu.Lock()
defer s.mu.Unlock()
@ -350,72 +357,100 @@ func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) err
return node, nil
})
if err != nil {
return nil, false, err
return nil, change.EmptySet, err
}
// Check if policy manager needs updating
policyChanged, err := s.updatePolicyManagerNodes()
if err != nil {
return node, false, fmt.Errorf("failed to update policy manager after node update: %w", err)
return node, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err)
}
// TODO(kradalby): implement the node in-memory cache
return node, policyChanged, nil
var c change.ChangeSet
if policyChanged {
c = change.PolicyChange()
} else {
// Basic node change without specific details since this is a generic update
c = change.NodeAdded(node.ID)
}
return node, c, nil
}
// SaveNode persists an existing node to the database and updates the policy manager.
func (s *State) SaveNode(node *types.Node) (*types.Node, bool, error) {
func (s *State) SaveNode(node *types.Node) (*types.Node, change.ChangeSet, error) {
s.mu.Lock()
defer s.mu.Unlock()
if err := s.db.DB.Save(node).Error; err != nil {
return nil, false, fmt.Errorf("saving node: %w", err)
return nil, change.EmptySet, fmt.Errorf("saving node: %w", err)
}
// Check if policy manager needs updating
policyChanged, err := s.updatePolicyManagerNodes()
if err != nil {
return node, false, fmt.Errorf("failed to update policy manager after node save: %w", err)
return node, change.EmptySet, fmt.Errorf("failed to update policy manager after node save: %w", err)
}
// TODO(kradalby): implement the node in-memory cache
return node, policyChanged, nil
if policyChanged {
return node, change.PolicyChange(), nil
}
return node, change.EmptySet, nil
}
// DeleteNode permanently removes a node and cleans up associated resources.
// Returns whether policies changed and any error. This operation is irreversible.
func (s *State) DeleteNode(node *types.Node) (bool, error) {
func (s *State) DeleteNode(node *types.Node) (change.ChangeSet, error) {
err := s.db.DeleteNode(node)
if err != nil {
return false, err
return change.EmptySet, err
}
c := change.NodeRemoved(node.ID)
// Check if policy manager needs updating after node deletion
policyChanged, err := s.updatePolicyManagerNodes()
if err != nil {
return false, fmt.Errorf("failed to update policy manager after node deletion: %w", err)
return change.EmptySet, fmt.Errorf("failed to update policy manager after node deletion: %w", err)
}
return policyChanged, nil
if policyChanged {
c = change.PolicyChange()
}
return c, nil
}
func (s *State) Connect(id types.NodeID) {
func (s *State) Connect(node *types.Node) change.ChangeSet {
c := change.NodeOnline(node.ID)
routeChange := s.primaryRoutes.SetRoutes(node.ID, node.SubnetRoutes()...)
if routeChange {
c = change.NodeAdded(node.ID)
}
return c
}
func (s *State) Disconnect(id types.NodeID) (bool, error) {
// TODO(kradalby): This node should update the in memory state
_, polChanged, err := s.SetLastSeen(id, time.Now())
func (s *State) Disconnect(node *types.Node) (change.ChangeSet, error) {
c := change.NodeOffline(node.ID)
_, _, err := s.SetLastSeen(node.ID, time.Now())
if err != nil {
return false, fmt.Errorf("disconnecting node: %w", err)
return c, fmt.Errorf("disconnecting node: %w", err)
}
changed := s.primaryRoutes.SetRoutes(id)
if routeChange := s.primaryRoutes.SetRoutes(node.ID); routeChange {
c = change.PolicyChange()
}
// TODO(kradalby): the returned change should be more nuanced allowing us to
// send more directed updates.
return changed || polChanged, nil
// TODO(kradalby): This node should update the in memory state
return c, nil
}
// GetNodeByID retrieves a node by ID.
@ -475,45 +510,93 @@ func (s *State) ListEphemeralNodes() (types.Nodes, error) {
}
// SetNodeExpiry updates the expiration time for a node.
func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (*types.Node, bool, error) {
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (*types.Node, change.ChangeSet, error) {
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
return hsdb.NodeSetExpiry(tx, nodeID, expiry)
})
if err != nil {
return nil, change.EmptySet, fmt.Errorf("setting node expiry: %w", err)
}
if !c.IsFull() {
c = change.KeyExpiry(nodeID)
}
return n, c, nil
}
// SetNodeTags assigns tags to a node for use in access control policies.
func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (*types.Node, bool, error) {
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (*types.Node, change.ChangeSet, error) {
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
return hsdb.SetTags(tx, nodeID, tags)
})
if err != nil {
return nil, change.EmptySet, fmt.Errorf("setting node tags: %w", err)
}
if !c.IsFull() {
c = change.NodeAdded(nodeID)
}
return n, c, nil
}
// SetApprovedRoutes sets the network routes that a node is approved to advertise.
func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (*types.Node, bool, error) {
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (*types.Node, change.ChangeSet, error) {
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
return hsdb.SetApprovedRoutes(tx, nodeID, routes)
})
if err != nil {
return nil, change.EmptySet, fmt.Errorf("setting approved routes: %w", err)
}
// Update primary routes after changing approved routes
routeChange := s.primaryRoutes.SetRoutes(nodeID, n.SubnetRoutes()...)
if routeChange || !c.IsFull() {
c = change.PolicyChange()
}
return n, c, nil
}
// RenameNode changes the display name of a node.
func (s *State) RenameNode(nodeID types.NodeID, newName string) (*types.Node, bool, error) {
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
func (s *State) RenameNode(nodeID types.NodeID, newName string) (*types.Node, change.ChangeSet, error) {
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
return hsdb.RenameNode(tx, nodeID, newName)
})
if err != nil {
return nil, change.EmptySet, fmt.Errorf("renaming node: %w", err)
}
if !c.IsFull() {
c = change.NodeAdded(nodeID)
}
return n, c, nil
}
// SetLastSeen updates when a node was last seen, used for connectivity monitoring.
func (s *State) SetLastSeen(nodeID types.NodeID, lastSeen time.Time) (*types.Node, bool, error) {
func (s *State) SetLastSeen(nodeID types.NodeID, lastSeen time.Time) (*types.Node, change.ChangeSet, error) {
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
return hsdb.SetLastSeen(tx, nodeID, lastSeen)
})
}
// AssignNodeToUser transfers a node to a different user.
func (s *State) AssignNodeToUser(nodeID types.NodeID, userID types.UserID) (*types.Node, bool, error) {
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
func (s *State) AssignNodeToUser(nodeID types.NodeID, userID types.UserID) (*types.Node, change.ChangeSet, error) {
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
return hsdb.AssignNodeToUser(tx, nodeID, userID)
})
if err != nil {
return nil, change.EmptySet, fmt.Errorf("assigning node to user: %w", err)
}
if !c.IsFull() {
c = change.NodeAdded(nodeID)
}
return n, c, nil
}
// BackfillNodeIPs assigns IP addresses to nodes that don't have them.
@ -523,7 +606,7 @@ func (s *State) BackfillNodeIPs() ([]string, error) {
// ExpireExpiredNodes finds and processes expired nodes since the last check.
// Returns next check time, state update with expired nodes, and whether any were found.
func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, types.StateUpdate, bool) {
func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.ChangeSet, bool) {
return hsdb.ExpireExpiredNodes(s.db.DB, lastCheck)
}
@ -568,8 +651,14 @@ func (s *State) SetPolicyInDB(data string) (*types.Policy, error) {
}
// SetNodeRoutes sets the primary routes for a node.
func (s *State) SetNodeRoutes(nodeID types.NodeID, routes ...netip.Prefix) bool {
return s.primaryRoutes.SetRoutes(nodeID, routes...)
func (s *State) SetNodeRoutes(nodeID types.NodeID, routes ...netip.Prefix) change.ChangeSet {
if s.primaryRoutes.SetRoutes(nodeID, routes...) {
// Route changes affect packet filters for all nodes, so trigger a policy change
// to ensure filters are regenerated across the entire network
return change.PolicyChange()
}
return change.EmptySet
}
// GetNodePrimaryRoutes returns the primary routes for a node.
@ -653,10 +742,10 @@ func (s *State) HandleNodeFromAuthPath(
userID types.UserID,
expiry *time.Time,
registrationMethod string,
) (*types.Node, bool, error) {
) (*types.Node, change.ChangeSet, error) {
ipv4, ipv6, err := s.ipAlloc.Next()
if err != nil {
return nil, false, err
return nil, change.EmptySet, err
}
return s.db.HandleNodeFromAuthPath(
@ -672,12 +761,15 @@ func (s *State) HandleNodeFromAuthPath(
func (s *State) HandleNodeFromPreAuthKey(
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*types.Node, bool, error) {
) (*types.Node, change.ChangeSet, bool, error) {
pak, err := s.GetPreAuthKey(regReq.Auth.AuthKey)
if err != nil {
return nil, change.EmptySet, false, err
}
err = pak.Validate()
if err != nil {
return nil, false, err
return nil, change.EmptySet, false, err
}
nodeToRegister := types.Node{
@ -698,22 +790,13 @@ func (s *State) HandleNodeFromPreAuthKey(
AuthKeyID: &pak.ID,
}
// For auth key registration, ensure we don't keep an expired node
// This is especially important for re-registration after logout
if !regReq.Expiry.IsZero() && regReq.Expiry.After(time.Now()) {
if !regReq.Expiry.IsZero() {
nodeToRegister.Expiry = &regReq.Expiry
} else if !regReq.Expiry.IsZero() {
// If client is sending an expired time (e.g., after logout),
// don't set expiry so the node won't be considered expired
log.Debug().
Time("requested_expiry", regReq.Expiry).
Str("node", regReq.Hostinfo.Hostname).
Msg("Ignoring expired expiry time from auth key registration")
}
ipv4, ipv6, err := s.ipAlloc.Next()
if err != nil {
return nil, false, fmt.Errorf("allocating IPs: %w", err)
return nil, change.EmptySet, false, fmt.Errorf("allocating IPs: %w", err)
}
node, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
@ -735,18 +818,38 @@ func (s *State) HandleNodeFromPreAuthKey(
return node, nil
})
if err != nil {
return nil, false, fmt.Errorf("writing node to database: %w", err)
return nil, change.EmptySet, false, fmt.Errorf("writing node to database: %w", err)
}
// Check if this is a logout request for an ephemeral node
if !regReq.Expiry.IsZero() && regReq.Expiry.Before(time.Now()) && pak.Ephemeral {
// This is a logout request for an ephemeral node, delete it immediately
c, err := s.DeleteNode(node)
if err != nil {
return nil, change.EmptySet, false, fmt.Errorf("deleting ephemeral node during logout: %w", err)
}
return nil, c, false, nil
}
// Check if policy manager needs updating
// This is necessary because we just created a new node.
// We need to ensure that the policy manager is aware of this new node.
policyChanged, err := s.updatePolicyManagerNodes()
// Also update users to ensure all users are known when evaluating policies.
usersChanged, err := s.updatePolicyManagerUsers()
if err != nil {
return nil, false, fmt.Errorf("failed to update policy manager after node registration: %w", err)
return nil, change.EmptySet, false, fmt.Errorf("failed to update policy manager users after node registration: %w", err)
}
return node, policyChanged, nil
nodesChanged, err := s.updatePolicyManagerNodes()
if err != nil {
return nil, change.EmptySet, false, fmt.Errorf("failed to update policy manager nodes after node registration: %w", err)
}
policyChanged := usersChanged || nodesChanged
c := change.NodeAdded(node.ID)
return node, c, policyChanged, nil
}
// AllocateNextIPs allocates the next available IPv4 and IPv6 addresses.
@ -766,11 +869,15 @@ func (s *State) updatePolicyManagerUsers() (bool, error) {
return false, fmt.Errorf("listing users for policy update: %w", err)
}
log.Debug().Int("userCount", len(users)).Msg("Updating policy manager with users")
changed, err := s.polMan.SetUsers(users)
if err != nil {
return false, fmt.Errorf("updating policy manager users: %w", err)
}
log.Debug().Bool("changed", changed).Msg("Policy manager users updated")
return changed, nil
}
@ -835,3 +942,125 @@ func (s *State) autoApproveNodes() error {
return nil
}
// TODO(kradalby): This should just take the node ID?
func (s *State) UpdateNodeFromMapRequest(node *types.Node, req tailcfg.MapRequest) (change.ChangeSet, error) {
// TODO(kradalby): This is essentially a patch update that could be sent directly to nodes,
// which means we could shortcut the whole change thing if there are no other important updates.
peerChange := node.PeerChangeFromMapRequest(req)
node.ApplyPeerChange(&peerChange)
sendUpdate, routesChanged := hostInfoChanged(node.Hostinfo, req.Hostinfo)
// The node might not set NetInfo if it has not changed and if
// the full HostInfo object is overwritten, the information is lost.
// If there is no NetInfo, keep the previous one.
// From 1.66 the client only sends it if changed:
// https://github.com/tailscale/tailscale/commit/e1011f138737286ecf5123ff887a7a5800d129a2
// TODO(kradalby): evaluate if we need better comparing of hostinfo
// before we take the changes.
if req.Hostinfo.NetInfo == nil && node.Hostinfo != nil {
req.Hostinfo.NetInfo = node.Hostinfo.NetInfo
}
node.Hostinfo = req.Hostinfo
// If there is no changes and nothing to save,
// return early.
if peerChangeEmpty(peerChange) && !sendUpdate {
// mapResponseEndpointUpdates.WithLabelValues("noop").Inc()
return change.EmptySet, nil
}
c := change.EmptySet
// Check if the Hostinfo of the node has changed.
// If it has changed, check if there has been a change to
// the routable IPs of the host and update them in
// the database. Then send a Changed update
// (containing the whole node object) to peers to inform about
// the route change.
// If the hostinfo has changed, but not the routes, just update
// hostinfo and let the function continue.
if routesChanged {
// Auto approve any routes that have been defined in policy as
// auto approved. Check if this actually changed the node.
_ = s.AutoApproveRoutes(node)
// Update the routes of the given node in the route manager to
// see if an update needs to be sent.
c = s.SetNodeRoutes(node.ID, node.SubnetRoutes()...)
}
// Check if there has been a change to Hostname and update them
// in the database. Then send a Changed update
// (containing the whole node object) to peers to inform about
// the hostname change.
node.ApplyHostnameFromHostInfo(req.Hostinfo)
_, policyChange, err := s.SaveNode(node)
if err != nil {
return change.EmptySet, err
}
if policyChange.IsFull() {
c = policyChange
}
if c.Empty() {
c = change.NodeAdded(node.ID)
}
return c, nil
}
// hostInfoChanged reports if hostInfo has changed in two ways,
// - first bool reports if an update needs to be sent to nodes
// - second reports if there has been changes to routes
// the caller can then use this info to save and update nodes
// and routes as needed.
func hostInfoChanged(old, new *tailcfg.Hostinfo) (bool, bool) {
if old.Equal(new) {
return false, false
}
if old == nil && new != nil {
return true, true
}
// Routes
oldRoutes := make([]netip.Prefix, 0)
if old != nil {
oldRoutes = old.RoutableIPs
}
newRoutes := new.RoutableIPs
tsaddr.SortPrefixes(oldRoutes)
tsaddr.SortPrefixes(newRoutes)
if !xslices.Equal(oldRoutes, newRoutes) {
return true, true
}
// Services is mostly useful for discovery and not critical,
// except for peerapi, which is how nodes talk to each other.
// If peerapi was not part of the initial mapresponse, we
// need to make sure its sent out later as it is needed for
// Taildrop.
// TODO(kradalby): Length comparison is a bit naive, replace.
if len(old.Services) != len(new.Services) {
return true, false
}
return false, false
}
func peerChangeEmpty(peerChange tailcfg.PeerChange) bool {
return peerChange.Key == nil &&
peerChange.DiscoKey == nil &&
peerChange.Online == nil &&
peerChange.Endpoints == nil &&
peerChange.DERPRegion == 0 &&
peerChange.LastSeen == nil &&
peerChange.KeyExpiry == nil
}

View File

@ -0,0 +1,183 @@
//go:generate go tool stringer -type=Change
package change
import (
"errors"
"github.com/juanfont/headscale/hscontrol/types"
)
type (
NodeID = types.NodeID
UserID = types.UserID
)
type Change int
const (
ChangeUnknown Change = 0
// Deprecated: Use specific change instead
// Full is a legacy change to ensure places where we
// have not yet determined the specific update, can send.
Full Change = 9
// Server changes.
Policy Change = 11
DERP Change = 12
ExtraRecords Change = 13
// Node changes.
NodeCameOnline Change = 21
NodeWentOffline Change = 22
NodeRemove Change = 23
NodeKeyExpiry Change = 24
NodeNewOrUpdate Change = 25
// User changes.
UserNewOrUpdate Change = 51
UserRemove Change = 52
)
// AlsoSelf reports whether this change should also be sent to the node itself.
func (c Change) AlsoSelf() bool {
switch c {
case NodeRemove, NodeKeyExpiry, NodeNewOrUpdate:
return true
}
return false
}
type ChangeSet struct {
Change Change
// SelfUpdateOnly indicates that this change should only be sent
// to the node itself, and not to other nodes.
// This is used for changes that are not relevant to other nodes.
// NodeID must be set if this is true.
SelfUpdateOnly bool
// NodeID if set, is the ID of the node that is being changed.
// It must be set if this is a node change.
NodeID types.NodeID
// UserID if set, is the ID of the user that is being changed.
// It must be set if this is a user change.
UserID types.UserID
// IsSubnetRouter indicates whether the node is a subnet router.
IsSubnetRouter bool
}
func (c *ChangeSet) Validate() error {
if c.Change >= NodeCameOnline || c.Change <= NodeNewOrUpdate {
if c.NodeID == 0 {
return errors.New("ChangeSet.NodeID must be set for node updates")
}
}
if c.Change >= UserNewOrUpdate || c.Change <= UserRemove {
if c.UserID == 0 {
return errors.New("ChangeSet.UserID must be set for user updates")
}
}
return nil
}
// Empty reports whether the ChangeSet is empty, meaning it does not
// represent any change.
func (c ChangeSet) Empty() bool {
return c.Change == ChangeUnknown && c.NodeID == 0 && c.UserID == 0
}
// IsFull reports whether the ChangeSet represents a full update.
func (c ChangeSet) IsFull() bool {
return c.Change == Full || c.Change == Policy
}
func (c ChangeSet) AlsoSelf() bool {
// If NodeID is 0, it means this ChangeSet is not related to a specific node,
// so we consider it as a change that should be sent to all nodes.
if c.NodeID == 0 {
return true
}
return c.Change.AlsoSelf() || c.SelfUpdateOnly
}
var (
EmptySet = ChangeSet{Change: ChangeUnknown}
FullSet = ChangeSet{Change: Full}
DERPSet = ChangeSet{Change: DERP}
PolicySet = ChangeSet{Change: Policy}
ExtraRecordsSet = ChangeSet{Change: ExtraRecords}
)
func FullSelf(id types.NodeID) ChangeSet {
return ChangeSet{
Change: Full,
SelfUpdateOnly: true,
NodeID: id,
}
}
func NodeAdded(id types.NodeID) ChangeSet {
return ChangeSet{
Change: NodeNewOrUpdate,
NodeID: id,
}
}
func NodeRemoved(id types.NodeID) ChangeSet {
return ChangeSet{
Change: NodeRemove,
NodeID: id,
}
}
func NodeOnline(id types.NodeID) ChangeSet {
return ChangeSet{
Change: NodeCameOnline,
NodeID: id,
}
}
func NodeOffline(id types.NodeID) ChangeSet {
return ChangeSet{
Change: NodeWentOffline,
NodeID: id,
}
}
func KeyExpiry(id types.NodeID) ChangeSet {
return ChangeSet{
Change: NodeKeyExpiry,
NodeID: id,
}
}
func UserAdded(id types.UserID) ChangeSet {
return ChangeSet{
Change: UserNewOrUpdate,
UserID: id,
}
}
func UserRemoved(id types.UserID) ChangeSet {
return ChangeSet{
Change: UserRemove,
UserID: id,
}
}
func PolicyChange() ChangeSet {
return ChangeSet{
Change: Policy,
}
}
func DERPChange() ChangeSet {
return ChangeSet{
Change: DERP,
}
}

View File

@ -0,0 +1,57 @@
// Code generated by "stringer -type=Change"; DO NOT EDIT.
package change
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[ChangeUnknown-0]
_ = x[Full-9]
_ = x[Policy-11]
_ = x[DERP-12]
_ = x[ExtraRecords-13]
_ = x[NodeCameOnline-21]
_ = x[NodeWentOffline-22]
_ = x[NodeRemove-23]
_ = x[NodeKeyExpiry-24]
_ = x[NodeNewOrUpdate-25]
_ = x[UserNewOrUpdate-51]
_ = x[UserRemove-52]
}
const (
_Change_name_0 = "ChangeUnknown"
_Change_name_1 = "Full"
_Change_name_2 = "PolicyDERPExtraRecords"
_Change_name_3 = "NodeCameOnlineNodeWentOfflineNodeRemoveNodeKeyExpiryNodeNewOrUpdate"
_Change_name_4 = "UserNewOrUpdateUserRemove"
)
var (
_Change_index_2 = [...]uint8{0, 6, 10, 22}
_Change_index_3 = [...]uint8{0, 14, 29, 39, 52, 67}
_Change_index_4 = [...]uint8{0, 15, 25}
)
func (i Change) String() string {
switch {
case i == 0:
return _Change_name_0
case i == 9:
return _Change_name_1
case 11 <= i && i <= 13:
i -= 11
return _Change_name_2[_Change_index_2[i]:_Change_index_2[i+1]]
case 21 <= i && i <= 25:
i -= 21
return _Change_name_3[_Change_index_3[i]:_Change_index_3[i+1]]
case 51 <= i && i <= 52:
i -= 51
return _Change_name_4[_Change_index_4[i]:_Change_index_4[i+1]]
default:
return "Change(" + strconv.FormatInt(int64(i), 10) + ")"
}
}

View File

@ -1,16 +1,16 @@
//go:generate go run tailscale.com/cmd/viewer --type=User,Node,PreAuthKey
//go:generate go tool viewer --type=User,Node,PreAuthKey
package types
//go:generate go run tailscale.com/cmd/viewer --type=User,Node,PreAuthKey
import (
"context"
"errors"
"fmt"
"runtime"
"time"
"github.com/juanfont/headscale/hscontrol/util"
"tailscale.com/tailcfg"
"tailscale.com/util/ctxkey"
)
const (
@ -150,18 +150,6 @@ func UpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate {
}
}
var (
NotifyOriginKey = ctxkey.New("notify.origin", "")
NotifyHostnameKey = ctxkey.New("notify.hostname", "")
)
func NotifyCtx(ctx context.Context, origin, hostname string) context.Context {
ctx2, _ := context.WithTimeout(ctx, 3*time.Second)
ctx2 = NotifyOriginKey.WithValue(ctx2, origin)
ctx2 = NotifyHostnameKey.WithValue(ctx2, hostname)
return ctx2
}
const RegistrationIDLength = 24
type RegistrationID string
@ -199,3 +187,20 @@ type RegisterNode struct {
Node Node
Registered chan *Node
}
// DefaultBatcherWorkers returns the default number of batcher workers.
// Default to 3/4 of CPU cores, minimum 1, no maximum.
func DefaultBatcherWorkers() int {
return DefaultBatcherWorkersFor(runtime.NumCPU())
}
// DefaultBatcherWorkersFor returns the default number of batcher workers for a given CPU count.
// Default to 3/4 of CPU cores, minimum 1, no maximum.
func DefaultBatcherWorkersFor(cpuCount int) int {
defaultWorkers := (cpuCount * 3) / 4
if defaultWorkers < 1 {
defaultWorkers = 1
}
return defaultWorkers
}

View File

@ -0,0 +1,36 @@
package types
import (
"testing"
)
func TestDefaultBatcherWorkersFor(t *testing.T) {
tests := []struct {
cpuCount int
expected int
}{
{1, 1}, // (1*3)/4 = 0, should be minimum 1
{2, 1}, // (2*3)/4 = 1
{4, 3}, // (4*3)/4 = 3
{8, 6}, // (8*3)/4 = 6
{12, 9}, // (12*3)/4 = 9
{16, 12}, // (16*3)/4 = 12
{20, 15}, // (20*3)/4 = 15
{24, 18}, // (24*3)/4 = 18
}
for _, test := range tests {
result := DefaultBatcherWorkersFor(test.cpuCount)
if result != test.expected {
t.Errorf("DefaultBatcherWorkersFor(%d) = %d, expected %d", test.cpuCount, result, test.expected)
}
}
}
func TestDefaultBatcherWorkers(t *testing.T) {
// Just verify it returns a valid value (>= 1)
result := DefaultBatcherWorkers()
if result < 1 {
t.Errorf("DefaultBatcherWorkers() = %d, expected value >= 1", result)
}
}

View File

@ -234,6 +234,7 @@ type Tuning struct {
NotifierSendTimeout time.Duration
BatchChangeDelay time.Duration
NodeMapSessionBufferedChanSize int
BatcherWorkers int
}
func validatePKCEMethod(method string) error {
@ -991,6 +992,12 @@ func LoadServerConfig() (*Config, error) {
NodeMapSessionBufferedChanSize: viper.GetInt(
"tuning.node_mapsession_buffered_chan_size",
),
BatcherWorkers: func() int {
if workers := viper.GetInt("tuning.batcher_workers"); workers > 0 {
return workers
}
return DefaultBatcherWorkers()
}(),
},
}, nil
}

View File

@ -431,6 +431,11 @@ func (node *Node) SubnetRoutes() []netip.Prefix {
return routes
}
// IsSubnetRouter reports if the node has any subnet routes.
func (node *Node) IsSubnetRouter() bool {
return len(node.SubnetRoutes()) > 0
}
func (node *Node) String() string {
return node.Hostname
}
@ -669,6 +674,13 @@ func (v NodeView) SubnetRoutes() []netip.Prefix {
return v.ж.SubnetRoutes()
}
func (v NodeView) IsSubnetRouter() bool {
if !v.Valid() {
return false
}
return v.ж.IsSubnetRouter()
}
func (v NodeView) AppendToIPSet(build *netipx.IPSetBuilder) {
if !v.Valid() {
return

View File

@ -1,17 +1,16 @@
package types
import (
"fmt"
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/rs/zerolog/log"
"google.golang.org/protobuf/types/known/timestamppb"
)
type PAKError string
func (e PAKError) Error() string { return string(e) }
func (e PAKError) Unwrap() error { return fmt.Errorf("preauth key error: %w", e) }
// PreAuthKey describes a pre-authorization key usable in a particular user.
type PreAuthKey struct {
@ -60,6 +59,21 @@ func (pak *PreAuthKey) Validate() error {
if pak == nil {
return PAKError("invalid authkey")
}
log.Debug().
Str("key", pak.Key).
Bool("hasExpiration", pak.Expiration != nil).
Time("expiration", func() time.Time {
if pak.Expiration != nil {
return *pak.Expiration
}
return time.Time{}
}()).
Time("now", time.Now()).
Bool("reusable", pak.Reusable).
Bool("used", pak.Used).
Msg("PreAuthKey.Validate: checking key")
if pak.Expiration != nil && pak.Expiration.Before(time.Now()) {
return PAKError("authkey expired")
}

View File

@ -5,6 +5,8 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"tailscale.com/util/dnsname"
"tailscale.com/util/must"
)
func TestCheckForFQDNRules(t *testing.T) {
@ -102,59 +104,16 @@ func TestConvertWithFQDNRules(t *testing.T) {
func TestMagicDNSRootDomains100(t *testing.T) {
domains := GenerateIPv4DNSRootDomain(netip.MustParsePrefix("100.64.0.0/10"))
found := false
for _, domain := range domains {
if domain == "64.100.in-addr.arpa." {
found = true
break
}
}
assert.True(t, found)
found = false
for _, domain := range domains {
if domain == "100.100.in-addr.arpa." {
found = true
break
}
}
assert.True(t, found)
found = false
for _, domain := range domains {
if domain == "127.100.in-addr.arpa." {
found = true
break
}
}
assert.True(t, found)
assert.Contains(t, domains, must.Get(dnsname.ToFQDN("64.100.in-addr.arpa.")))
assert.Contains(t, domains, must.Get(dnsname.ToFQDN("100.100.in-addr.arpa.")))
assert.Contains(t, domains, must.Get(dnsname.ToFQDN("127.100.in-addr.arpa.")))
}
func TestMagicDNSRootDomains172(t *testing.T) {
domains := GenerateIPv4DNSRootDomain(netip.MustParsePrefix("172.16.0.0/16"))
found := false
for _, domain := range domains {
if domain == "0.16.172.in-addr.arpa." {
found = true
break
}
}
assert.True(t, found)
found = false
for _, domain := range domains {
if domain == "255.16.172.in-addr.arpa." {
found = true
break
}
}
assert.True(t, found)
assert.Contains(t, domains, must.Get(dnsname.ToFQDN("0.16.172.in-addr.arpa.")))
assert.Contains(t, domains, must.Get(dnsname.ToFQDN("255.16.172.in-addr.arpa.")))
}
// Happens when netmask is a multiple of 4 bits (sounds likely).

View File

@ -143,7 +143,7 @@ func ParseTraceroute(output string) (Traceroute, error) {
// Parse latencies
for j := 5; j <= 7; j++ {
if matches[j] != "" {
if j < len(matches) && matches[j] != "" {
ms, err := strconv.ParseFloat(matches[j], 64)
if err != nil {
return Traceroute{}, fmt.Errorf("parsing latency: %w", err)

View File

@ -88,7 +88,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
var err error
listNodes, err = headscale.ListNodes()
assert.NoError(ct, err)
assert.Equal(ct, nodeCountBeforeLogout, len(listNodes), "Node count should match before logout count")
assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match before logout count")
}, 20*time.Second, 1*time.Second)
for _, node := range listNodes {
@ -123,7 +123,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
var err error
listNodes, err = headscale.ListNodes()
assert.NoError(ct, err)
assert.Equal(ct, nodeCountBeforeLogout, len(listNodes), "Node count should match after HTTPS reconnection")
assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match after HTTPS reconnection")
}, 30*time.Second, 2*time.Second)
for _, node := range listNodes {
@ -161,7 +161,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
}
listNodes, err = headscale.ListNodes()
require.Equal(t, nodeCountBeforeLogout, len(listNodes))
require.Len(t, listNodes, nodeCountBeforeLogout)
for _, node := range listNodes {
assertLastSeenSet(t, node)
}
@ -355,7 +355,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
"--user",
strconv.FormatUint(userMap[userName].GetId(), 10),
"expire",
key.Key,
key.GetKey(),
})
assertNoErr(t, err)

View File

@ -147,3 +147,9 @@ func DockerAllowNetworkAdministration(config *docker.HostConfig) {
config.CapAdd = append(config.CapAdd, "NET_ADMIN")
config.Privileged = true
}
// DockerMemoryLimit sets memory limit and disables OOM kill for containers.
func DockerMemoryLimit(config *docker.HostConfig) {
config.Memory = 2 * 1024 * 1024 * 1024 // 2GB in bytes
config.OOMKillDisable = true
}

View File

@ -883,6 +883,10 @@ func TestNodeOnlineStatus(t *testing.T) {
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
status, err := client.Status()
assert.NoError(ct, err)
if status == nil {
assert.Fail(ct, "status is nil")
return
}
for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey]
@ -984,16 +988,11 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
}
// Wait for sync and successful pings after nodes come back up
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
err = scenario.WaitForTailscaleSync()
assert.NoError(ct, err)
assert.NoError(t, err)
success := pingAllHelper(t, allClients, allAddrs)
assert.Greater(ct, success, 0, "Nodes should be able to ping after coming back up")
}, 30*time.Second, 2*time.Second)
success := pingAllHelper(t, allClients, allAddrs)
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
assert.Equalf(t, len(allClients)*len(allIps), success, "%d successful pings out of %d", success, len(allClients)*len(allIps))
}
}

View File

@ -260,7 +260,9 @@ func WithDERPConfig(derpMap tailcfg.DERPMap) Option {
func WithTuning(batchTimeout time.Duration, mapSessionChanSize int) Option {
return func(hsic *HeadscaleInContainer) {
hsic.env["HEADSCALE_TUNING_BATCH_CHANGE_DELAY"] = batchTimeout.String()
hsic.env["HEADSCALE_TUNING_NODE_MAPSESSION_BUFFERED_CHAN_SIZE"] = strconv.Itoa(mapSessionChanSize)
hsic.env["HEADSCALE_TUNING_NODE_MAPSESSION_BUFFERED_CHAN_SIZE"] = strconv.Itoa(
mapSessionChanSize,
)
}
}
@ -279,9 +281,15 @@ func WithDebugPort(port int) Option {
// buildEntrypoint builds the container entrypoint command based on configuration.
func (hsic *HeadscaleInContainer) buildEntrypoint() []string {
debugCmd := fmt.Sprintf("/go/bin/dlv --listen=0.0.0.0:%d --headless=true --api-version=2 --accept-multiclient --allow-non-terminal-interactive=true exec /go/bin/headscale --continue -- serve", hsic.debugPort)
debugCmd := fmt.Sprintf(
"/go/bin/dlv --listen=0.0.0.0:%d --headless=true --api-version=2 --accept-multiclient --allow-non-terminal-interactive=true exec /go/bin/headscale --continue -- serve",
hsic.debugPort,
)
entrypoint := fmt.Sprintf("/bin/sleep 3 ; update-ca-certificates ; %s ; /bin/sleep 30", debugCmd)
entrypoint := fmt.Sprintf(
"/bin/sleep 3 ; update-ca-certificates ; %s ; /bin/sleep 30",
debugCmd,
)
return []string{"/bin/bash", "-c", entrypoint}
}
@ -448,7 +456,11 @@ func New(
hsic.container = container
log.Printf("Debug ports for %s: delve=%s, metrics/pprof=49090\n", hsic.hostname, hsic.GetHostDebugPort())
log.Printf(
"Debug ports for %s: delve=%s, metrics/pprof=49090\n",
hsic.hostname,
hsic.GetHostDebugPort(),
)
// Write the CA certificates to the container
for i, cert := range hsic.caCerts {
@ -684,14 +696,6 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
return nil
}
// First, let's see what files are actually in /tmp
tmpListing, err := t.Execute([]string{"ls", "-la", "/tmp/"})
if err != nil {
log.Printf("Warning: could not list /tmp directory: %v", err)
} else {
log.Printf("Contents of /tmp in container %s:\n%s", t.hostname, tmpListing)
}
// Also check for any .sqlite files
sqliteFiles, err := t.Execute([]string{"find", "/tmp", "-name", "*.sqlite*", "-type", "f"})
if err != nil {
@ -718,12 +722,6 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
return errors.New("database file exists but has no schema (empty database)")
}
// Show a preview of the schema (first 500 chars)
schemaPreview := schemaCheck
if len(schemaPreview) > 500 {
schemaPreview = schemaPreview[:500] + "..."
}
tarFile, err := t.FetchPath("/tmp/integration_test_db.sqlite3")
if err != nil {
return fmt.Errorf("failed to fetch database file: %w", err)
@ -740,7 +738,12 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
return fmt.Errorf("failed to read tar header: %w", err)
}
log.Printf("Found file in tar: %s (type: %d, size: %d)", header.Name, header.Typeflag, header.Size)
log.Printf(
"Found file in tar: %s (type: %d, size: %d)",
header.Name,
header.Typeflag,
header.Size,
)
// Extract the first regular file we find
if header.Typeflag == tar.TypeReg {
@ -756,11 +759,20 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
return fmt.Errorf("failed to copy database file: %w", err)
}
log.Printf("Extracted database file: %s (%d bytes written, header claimed %d bytes)", dbPath, written, header.Size)
log.Printf(
"Extracted database file: %s (%d bytes written, header claimed %d bytes)",
dbPath,
written,
header.Size,
)
// Check if we actually wrote something
if written == 0 {
return fmt.Errorf("database file is empty (size: %d, header size: %d)", written, header.Size)
return fmt.Errorf(
"database file is empty (size: %d, header size: %d)",
written,
header.Size,
)
}
return nil
@ -871,7 +883,15 @@ func (t *HeadscaleInContainer) WaitForRunning() error {
func (t *HeadscaleInContainer) CreateUser(
user string,
) (*v1.User, error) {
command := []string{"headscale", "users", "create", user, fmt.Sprintf("--email=%s@test.no", user), "--output", "json"}
command := []string{
"headscale",
"users",
"create",
user,
fmt.Sprintf("--email=%s@test.no", user),
"--output",
"json",
}
result, _, err := dockertestutil.ExecuteCommand(
t.container,
@ -1182,13 +1202,18 @@ func (t *HeadscaleInContainer) ApproveRoutes(id uint64, routes []netip.Prefix) (
[]string{},
)
if err != nil {
return nil, fmt.Errorf("failed to execute list node command: %w", err)
return nil, fmt.Errorf(
"failed to execute approve routes command (node %d, routes %v): %w",
id,
routes,
err,
)
}
var node *v1.Node
err = json.Unmarshal([]byte(result), &node)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal nodes: %w", err)
return nil, fmt.Errorf("failed to unmarshal node response: %q, error: %w", result, err)
}
return node, nil

View File

@ -310,7 +310,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
// Enable route on node 1
t.Logf("Enabling route on subnet router 1, no HA")
_, err = headscale.ApproveRoutes(
1,
MustFindNode(subRouter1.Hostname(), nodes).GetId(),
[]netip.Prefix{pref},
)
require.NoError(t, err)
@ -366,7 +366,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
// Enable route on node 2, now we will have a HA subnet router
t.Logf("Enabling route on subnet router 2, now HA, subnetrouter 1 is primary, 2 is standby")
_, err = headscale.ApproveRoutes(
2,
MustFindNode(subRouter2.Hostname(), nodes).GetId(),
[]netip.Prefix{pref},
)
require.NoError(t, err)
@ -422,7 +422,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
// be enabled.
t.Logf("Enabling route on subnet router 3, now HA, subnetrouter 1 is primary, 2 and 3 is standby")
_, err = headscale.ApproveRoutes(
3,
MustFindNode(subRouter3.Hostname(), nodes).GetId(),
[]netip.Prefix{pref},
)
require.NoError(t, err)
@ -639,7 +639,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
t.Logf("disabling route in subnet router r3 (%s)", subRouter3.Hostname())
t.Logf("expecting route to failover to r1 (%s), which is still available with r2", subRouter1.Hostname())
_, err = headscale.ApproveRoutes(nodes[2].GetId(), []netip.Prefix{})
_, err = headscale.ApproveRoutes(MustFindNode(subRouter3.Hostname(), nodes).GetId(), []netip.Prefix{})
time.Sleep(5 * time.Second)
@ -647,9 +647,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
require.NoError(t, err)
assert.Len(t, nodes, 6)
requireNodeRouteCount(t, nodes[0], 1, 1, 1)
requireNodeRouteCount(t, nodes[1], 1, 1, 0)
requireNodeRouteCount(t, nodes[2], 1, 0, 0)
requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 1)
requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 0)
requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
// Verify that the route is announced from subnet router 1
clientStatus, err = client.Status()
@ -684,7 +684,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
// Disable the route of subnet router 1, making it failover to 2
t.Logf("disabling route in subnet router r1 (%s)", subRouter1.Hostname())
t.Logf("expecting route to failover to r2 (%s)", subRouter2.Hostname())
_, err = headscale.ApproveRoutes(nodes[0].GetId(), []netip.Prefix{})
_, err = headscale.ApproveRoutes(MustFindNode(subRouter1.Hostname(), nodes).GetId(), []netip.Prefix{})
time.Sleep(5 * time.Second)
@ -692,9 +692,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
require.NoError(t, err)
assert.Len(t, nodes, 6)
requireNodeRouteCount(t, nodes[0], 1, 0, 0)
requireNodeRouteCount(t, nodes[1], 1, 1, 1)
requireNodeRouteCount(t, nodes[2], 1, 0, 0)
requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 0, 0)
requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1)
requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
// Verify that the route is announced from subnet router 1
clientStatus, err = client.Status()
@ -729,9 +729,10 @@ func TestHASubnetRouterFailover(t *testing.T) {
// enable the route of subnet router 1, no change expected
t.Logf("enabling route in subnet router 1 (%s)", subRouter1.Hostname())
t.Logf("both online, expecting r2 (%s) to still be primary (no flapping)", subRouter2.Hostname())
r1Node := MustFindNode(subRouter1.Hostname(), nodes)
_, err = headscale.ApproveRoutes(
nodes[0].GetId(),
util.MustStringsToPrefixes(nodes[0].GetAvailableRoutes()),
r1Node.GetId(),
util.MustStringsToPrefixes(r1Node.GetAvailableRoutes()),
)
time.Sleep(5 * time.Second)
@ -740,9 +741,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
require.NoError(t, err)
assert.Len(t, nodes, 6)
requireNodeRouteCount(t, nodes[0], 1, 1, 0)
requireNodeRouteCount(t, nodes[1], 1, 1, 1)
requireNodeRouteCount(t, nodes[2], 1, 0, 0)
requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 0)
requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1)
requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
// Verify that the route is announced from subnet router 1
clientStatus, err = client.Status()

View File

@ -223,7 +223,7 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) {
s.userToNetwork = userToNetwork
if spec.OIDCUsers != nil && len(spec.OIDCUsers) != 0 {
if len(spec.OIDCUsers) != 0 {
ttl := defaultAccessTTL
if spec.OIDCAccessTTL != 0 {
ttl = spec.OIDCAccessTTL

View File

@ -370,10 +370,12 @@ func TestSSHUserOnlyIsolation(t *testing.T) {
}
func doSSH(t *testing.T, client TailscaleClient, peer TailscaleClient) (string, string, error) {
t.Helper()
return doSSHWithRetry(t, client, peer, true)
}
func doSSHWithoutRetry(t *testing.T, client TailscaleClient, peer TailscaleClient) (string, string, error) {
t.Helper()
return doSSHWithRetry(t, client, peer, false)
}

View File

@ -319,6 +319,7 @@ func New(
dockertestutil.DockerRestartPolicy,
dockertestutil.DockerAllowLocalIPv6,
dockertestutil.DockerAllowNetworkAdministration,
dockertestutil.DockerMemoryLimit,
)
case "unstable":
tailscaleOptions.Repository = "tailscale/tailscale"
@ -329,6 +330,7 @@ func New(
dockertestutil.DockerRestartPolicy,
dockertestutil.DockerAllowLocalIPv6,
dockertestutil.DockerAllowNetworkAdministration,
dockertestutil.DockerMemoryLimit,
)
default:
tailscaleOptions.Repository = "tailscale/tailscale"
@ -339,6 +341,7 @@ func New(
dockertestutil.DockerRestartPolicy,
dockertestutil.DockerAllowLocalIPv6,
dockertestutil.DockerAllowNetworkAdministration,
dockertestutil.DockerMemoryLimit,
)
}

View File

@ -22,11 +22,11 @@ import (
const (
// derpPingTimeout defines the timeout for individual DERP ping operations
// Used in DERP connectivity tests to verify relay server communication
// Used in DERP connectivity tests to verify relay server communication.
derpPingTimeout = 2 * time.Second
// derpPingCount defines the number of ping attempts for DERP connectivity tests
// Higher count provides better reliability assessment of DERP connectivity
// Higher count provides better reliability assessment of DERP connectivity.
derpPingCount = 10
)
@ -321,7 +321,7 @@ func assertValidNetcheck(t *testing.T, client TailscaleClient) {
// before executing commands that depend on network state propagation.
//
// Timeout: 10 seconds with exponential backoff
// Use cases: DNS resolution, route propagation, policy updates
// Use cases: DNS resolution, route propagation, policy updates.
func assertCommandOutputContains(t *testing.T, c TailscaleClient, command []string, contains string) {
t.Helper()
@ -361,10 +361,10 @@ func isSelfClient(client TailscaleClient, addr string) bool {
}
func dockertestMaxWait() time.Duration {
wait := 120 * time.Second //nolint
wait := 300 * time.Second //nolint
if util.IsCI() {
wait = 300 * time.Second //nolint
wait = 600 * time.Second //nolint
}
return wait

View File

@ -6,7 +6,6 @@ import (
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"regexp"
@ -21,7 +20,7 @@ import (
const (
releasesURL = "https://api.github.com/repos/tailscale/tailscale/releases"
rawFileURL = "https://github.com/tailscale/tailscale/raw/refs/tags/%s/tailcfg/tailcfg.go"
outputFile = "../capver_generated.go"
outputFile = "../../hscontrol/capver/capver_generated.go"
)
type Release struct {
@ -105,7 +104,7 @@ func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion
sortedVersions := xmaps.Keys(versions)
sort.Strings(sortedVersions)
for _, version := range sortedVersions {
file.WriteString(fmt.Sprintf("\t\"%s\": %d,\n", version, versions[version]))
fmt.Fprintf(file, "\t\"%s\": %d,\n", version, versions[version])
}
file.WriteString("}\n")
@ -115,16 +114,13 @@ func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion
capVarToTailscaleVer := make(map[tailcfg.CapabilityVersion]string)
for _, v := range sortedVersions {
cap := versions[v]
log.Printf("cap for v: %d, %s", cap, v)
// If it is already set, skip and continue,
// we only want the first tailscale vsion per
// capability vsion.
if _, ok := capVarToTailscaleVer[cap]; ok {
log.Printf("Skipping %d, %s", cap, v)
continue
}
log.Printf("Storing %d, %s", cap, v)
capVarToTailscaleVer[cap] = v
}
@ -133,7 +129,7 @@ func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion
return capsSorted[i] < capsSorted[j]
})
for _, capVer := range capsSorted {
file.WriteString(fmt.Sprintf("\t%d:\t\t\"%s\",\n", capVer, capVarToTailscaleVer[capVer]))
fmt.Fprintf(file, "\t%d:\t\t\"%s\",\n", capVer, capVarToTailscaleVer[capVer])
}
file.WriteString("}\n")