mapper: produce map before poll (#2628)

This commit is contained in:
Kristoffer Dalby 2025-07-28 11:15:53 +02:00 committed by GitHub
parent b2a18830ed
commit a058bf3cd3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
70 changed files with 5771 additions and 2475 deletions

View File

@ -17,3 +17,7 @@ LICENSE
.vscode
*.sock
node_modules/
package-lock.json
package.json

55
.github/workflows/check-generated.yml vendored Normal file
View File

@ -0,0 +1,55 @@
name: Check Generated Files
on:
push:
branches:
- main
pull_request:
branches:
- main
concurrency:
group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
check-generated:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
fetch-depth: 2
- name: Get changed files
id: changed-files
uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2
with:
filters: |
files:
- '*.nix'
- 'go.*'
- '**/*.go'
- '**/*.proto'
- 'buf.gen.yaml'
- 'tools/**'
- uses: nixbuild/nix-quick-install-action@889f3180bb5f064ee9e3201428d04ae9e41d54ad # v31
if: steps.changed-files.outputs.files == 'true'
- uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3
if: steps.changed-files.outputs.files == 'true'
with:
primary-key: nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix', '**/flake.lock') }}
restore-prefixes-first-match: nix-${{ runner.os }}-${{ runner.arch }}
- name: Run make generate
if: steps.changed-files.outputs.files == 'true'
run: nix develop --command -- make generate
- name: Check for uncommitted changes
if: steps.changed-files.outputs.files == 'true'
run: |
if ! git diff --exit-code; then
echo "❌ Generated files are not up to date!"
echo "Please run 'make generate' and commit the changes."
exit 1
else
echo "✅ All generated files are up to date."
fi

View File

@ -77,7 +77,7 @@ jobs:
attempt_delay: 300000 # 5 min
attempt_limit: 2
command: |
nix develop --command -- hi run "^${{ inputs.test }}$" \
nix develop --command -- hi run --stats --ts-memory-limit=300 --hs-memory-limit=500 "^${{ inputs.test }}$" \
--timeout=120m \
${{ inputs.postgres_flag }}
- uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2

7
.gitignore vendored
View File

@ -1,6 +1,9 @@
ignored/
tailscale/
.vscode/
.claude/
*.prof
# Binaries for programs and plugins
*.exe
@ -46,3 +49,7 @@ integration_test/etc/config.dump.yaml
/site
__debug_bin
node_modules/
package-lock.json
package.json

View File

@ -2,6 +2,8 @@
## Next
**Minimum supported Tailscale client version: v1.64.0**
### Database integrity improvements
This release includes a significant database migration that addresses longstanding

395
CLAUDE.md Normal file
View File

@ -0,0 +1,395 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Overview
Headscale is an open-source implementation of the Tailscale control server written in Go. It provides self-hosted coordination for Tailscale networks (tailnets), managing node registration, IP allocation, policy enforcement, and DERP routing.
## Development Commands
### Quick Setup
```bash
# Recommended: Use Nix for dependency management
nix develop
# Full development workflow
make dev # runs fmt + lint + test + build
```
### Essential Commands
```bash
# Build headscale binary
make build
# Run tests
make test
go test ./... # All unit tests
go test -race ./... # With race detection
# Run specific integration test
go run ./cmd/hi run "TestName" --postgres
# Code formatting and linting
make fmt # Format all code (Go, docs, proto)
make lint # Lint all code (Go, proto)
make fmt-go # Format Go code only
make lint-go # Lint Go code only
# Protocol buffer generation (after modifying proto/)
make generate
# Clean build artifacts
make clean
```
### Integration Testing
```bash
# Use the hi (Headscale Integration) test runner
go run ./cmd/hi doctor # Check system requirements
go run ./cmd/hi run "TestPattern" # Run specific test
go run ./cmd/hi run "TestPattern" --postgres # With PostgreSQL backend
# Test artifacts are saved to control_logs/ with logs and debug data
```
## Project Structure & Architecture
### Top-Level Organization
```
headscale/
├── cmd/ # Command-line applications
│ ├── headscale/ # Main headscale server binary
│ └── hi/ # Headscale Integration test runner
├── hscontrol/ # Core control plane logic
├── integration/ # End-to-end Docker-based tests
├── proto/ # Protocol buffer definitions
├── gen/ # Generated code (protobuf)
├── docs/ # Documentation
└── packaging/ # Distribution packaging
```
### Core Packages (`hscontrol/`)
**Main Server (`hscontrol/`)**
- `app.go`: Application setup, dependency injection, server lifecycle
- `handlers.go`: HTTP/gRPC API endpoints for management operations
- `grpcv1.go`: gRPC service implementation for headscale API
- `poll.go`: **Critical** - Handles Tailscale MapRequest/MapResponse protocol
- `noise.go`: Noise protocol implementation for secure client communication
- `auth.go`: Authentication flows (web, OIDC, command-line)
- `oidc.go`: OpenID Connect integration for user authentication
**State Management (`hscontrol/state/`)**
- `state.go`: Central coordinator for all subsystems (database, policy, IP allocation, DERP)
- `node_store.go`: **Performance-critical** - In-memory cache with copy-on-write semantics
- Thread-safe operations with deadlock detection
- Coordinates between database persistence and real-time operations
**Database Layer (`hscontrol/db/`)**
- `db.go`: Database abstraction, GORM setup, migration management
- `node.go`: Node lifecycle, registration, expiration, IP assignment
- `users.go`: User management, namespace isolation
- `api_key.go`: API authentication tokens
- `preauth_keys.go`: Pre-authentication keys for automated node registration
- `ip.go`: IP address allocation and management
- `policy.go`: Policy storage and retrieval
- Schema migrations in `schema.sql` with extensive test data coverage
**Policy Engine (`hscontrol/policy/`)**
- `policy.go`: Core ACL evaluation logic, HuJSON parsing
- `v2/`: Next-generation policy system with improved filtering
- `matcher/`: ACL rule matching and evaluation engine
- Determines peer visibility, route approval, and network access rules
- Supports both file-based and database-stored policies
**Network Management (`hscontrol/`)**
- `derp/`: DERP (Designated Encrypted Relay for Packets) server implementation
- NAT traversal when direct connections fail
- Fallback relay for firewall-restricted environments
- `mapper/`: Converts internal Headscale state to Tailscale's wire protocol format
- `tail.go`: Tailscale-specific data structure generation
- `routes/`: Subnet route management and primary route selection
- `dns/`: DNS record management and MagicDNS implementation
**Utilities & Support (`hscontrol/`)**
- `types/`: Core data structures, configuration, validation
- `util/`: Helper functions for networking, DNS, key management
- `templates/`: Client configuration templates (Apple, Windows, etc.)
- `notifier/`: Event notification system for real-time updates
- `metrics.go`: Prometheus metrics collection
- `capver/`: Tailscale capability version management
### Key Subsystem Interactions
**Node Registration Flow**
1. **Client Connection**: `noise.go` handles secure protocol handshake
2. **Authentication**: `auth.go` validates credentials (web/OIDC/preauth)
3. **State Creation**: `state.go` coordinates IP allocation via `db/ip.go`
4. **Storage**: `db/node.go` persists node, `NodeStore` caches in memory
5. **Network Setup**: `mapper/` generates initial Tailscale network map
**Ongoing Operations**
1. **Poll Requests**: `poll.go` receives periodic client updates
2. **State Updates**: `NodeStore` maintains real-time node information
3. **Policy Application**: `policy/` evaluates ACL rules for peer relationships
4. **Map Distribution**: `mapper/` sends network topology to all affected clients
**Route Management**
1. **Advertisement**: Clients announce routes via `poll.go` Hostinfo updates
2. **Storage**: `db/` persists routes, `NodeStore` caches for performance
3. **Approval**: `policy/` auto-approves routes based on ACL rules
4. **Distribution**: `routes/` selects primary routes, `mapper/` distributes to peers
### Command-Line Tools (`cmd/`)
**Main Server (`cmd/headscale/`)**
- `headscale.go`: CLI parsing, configuration loading, server startup
- Supports daemon mode, CLI operations (user/node management), database operations
**Integration Test Runner (`cmd/hi/`)**
- `main.go`: Test execution framework with Docker orchestration
- `run.go`: Individual test execution with artifact collection
- `doctor.go`: System requirements validation
- `docker.go`: Container lifecycle management
- Essential for validating changes against real Tailscale clients
### Generated & External Code
**Protocol Buffers (`proto/``gen/`)**
- Defines gRPC API for headscale management operations
- Client libraries can generate from these definitions
- Run `make generate` after modifying `.proto` files
**Integration Testing (`integration/`)**
- `scenario.go`: Docker test environment setup
- `tailscale.go`: Tailscale client container management
- Individual test files for specific functionality areas
- Real end-to-end validation with network isolation
### Critical Performance Paths
**High-Frequency Operations**
1. **MapRequest Processing** (`poll.go`): Every 15-60 seconds per client
2. **NodeStore Reads** (`node_store.go`): Every operation requiring node data
3. **Policy Evaluation** (`policy/`): On every peer relationship calculation
4. **Route Lookups** (`routes/`): During network map generation
**Database Write Patterns**
- **Frequent**: Node heartbeats, endpoint updates, route changes
- **Moderate**: User operations, policy updates, API key management
- **Rare**: Schema migrations, bulk operations
### Configuration & Deployment
**Configuration** (`hscontrol/types/config.go`)**
- Database connection settings (SQLite/PostgreSQL)
- Network configuration (IP ranges, DNS settings)
- Policy mode (file vs database)
- DERP relay configuration
- OIDC provider settings
**Key Dependencies**
- **GORM**: Database ORM with migration support
- **Tailscale Libraries**: Core networking and protocol code
- **Zerolog**: Structured logging throughout the application
- **Buf**: Protocol buffer toolchain for code generation
### Development Workflow Integration
The architecture supports incremental development:
- **Unit Tests**: Focus on individual packages (`*_test.go` files)
- **Integration Tests**: Validate cross-component interactions
- **Database Tests**: Extensive migration and data integrity validation
- **Policy Tests**: ACL rule evaluation and edge cases
- **Performance Tests**: NodeStore and high-frequency operation validation
## Integration Test System
### Overview
Integration tests use Docker containers running real Tailscale clients against a Headscale server. Tests validate end-to-end functionality including routing, ACLs, node lifecycle, and network coordination.
### Running Integration Tests
**System Requirements**
```bash
# Check if your system is ready
go run ./cmd/hi doctor
```
This verifies Docker, Go, required images, and disk space.
**Test Execution Patterns**
```bash
# Run a single test (recommended for development)
go run ./cmd/hi run "TestSubnetRouterMultiNetwork"
# Run with PostgreSQL backend (for database-heavy tests)
go run ./cmd/hi run "TestExpireNode" --postgres
# Run multiple tests with pattern matching
go run ./cmd/hi run "TestSubnet*"
# Run all integration tests (CI/full validation)
go test ./integration -timeout 30m
```
**Test Categories & Timing**
- **Fast tests** (< 2 min): Basic functionality, CLI operations
- **Medium tests** (2-5 min): Route management, ACL validation
- **Slow tests** (5+ min): Node expiration, HA failover
- **Long-running tests** (10+ min): `TestNodeOnlineStatus` (12 min duration)
### Test Infrastructure
**Docker Setup**
- Headscale server container with configurable database backend
- Multiple Tailscale client containers with different versions
- Isolated networks per test scenario
- Automatic cleanup after test completion
**Test Artifacts**
All test runs save artifacts to `control_logs/TIMESTAMP-ID/`:
```
control_logs/20250713-213106-iajsux/
├── hs-testname-abc123.stderr.log # Headscale server logs
├── hs-testname-abc123.stdout.log
├── hs-testname-abc123.db # Database snapshot
├── hs-testname-abc123_metrics.txt # Prometheus metrics
├── hs-testname-abc123-mapresponses/ # Protocol debug data
├── ts-client-xyz789.stderr.log # Tailscale client logs
├── ts-client-xyz789.stdout.log
└── ts-client-xyz789_status.json # Client status dump
```
### Test Development Guidelines
**Timing Considerations**
Integration tests involve real network operations and Docker container lifecycle:
```go
// ❌ Wrong: Immediate assertions after async operations
client.Execute([]string{"tailscale", "set", "--advertise-routes=10.0.0.0/24"})
nodes, _ := headscale.ListNodes()
require.Len(t, nodes[0].GetAvailableRoutes(), 1) // May fail due to timing
// ✅ Correct: Wait for async operations to complete
client.Execute([]string{"tailscale", "set", "--advertise-routes=10.0.0.0/24"})
require.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err := headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, nodes[0].GetAvailableRoutes(), 1)
}, 10*time.Second, 100*time.Millisecond, "route should be advertised")
```
**Common Test Patterns**
- **Route Advertisement**: Use `EventuallyWithT` for route propagation
- **Node State Changes**: Wait for NodeStore synchronization
- **ACL Policy Changes**: Allow time for policy recalculation
- **Network Connectivity**: Use ping tests with retries
**Test Data Management**
```go
// Node identification: Don't assume array ordering
expectedRoutes := map[string]string{"1": "10.33.0.0/16"}
for _, node := range nodes {
nodeIDStr := fmt.Sprintf("%d", node.GetId())
if route, shouldHaveRoute := expectedRoutes[nodeIDStr]; shouldHaveRoute {
// Test the node that should have the route
}
}
```
### Troubleshooting Integration Tests
**Common Failure Patterns**
1. **Timing Issues**: Test assertions run before async operations complete
- **Solution**: Use `EventuallyWithT` with appropriate timeouts
- **Timeout Guidelines**: 3-5s for route operations, 10s for complex scenarios
2. **Infrastructure Problems**: Disk space, Docker issues, network conflicts
- **Check**: `go run ./cmd/hi doctor` for system health
- **Clean**: Remove old test containers and networks
3. **NodeStore Synchronization**: Tests expecting immediate data availability
- **Key Points**: Route advertisements must propagate through poll requests
- **Fix**: Wait for NodeStore updates after Hostinfo changes
4. **Database Backend Differences**: SQLite vs PostgreSQL behavior differences
- **Use**: `--postgres` flag for database-intensive tests
- **Note**: Some timing characteristics differ between backends
**Debugging Failed Tests**
1. **Check test artifacts** in `control_logs/` for detailed logs
2. **Examine MapResponse JSON** files for protocol-level debugging
3. **Review Headscale stderr logs** for server-side error messages
4. **Check Tailscale client status** for network-level issues
**Resource Management**
- Tests require significant disk space (each run ~100MB of logs)
- Docker containers are cleaned up automatically on success
- Failed tests may leave containers running - clean manually if needed
- Use `docker system prune` periodically to reclaim space
### Best Practices for Test Modifications
1. **Always test locally** before committing integration test changes
2. **Use appropriate timeouts** - too short causes flaky tests, too long slows CI
3. **Clean up properly** - ensure tests don't leave persistent state
4. **Handle both success and failure paths** in test scenarios
5. **Document timing requirements** for complex test scenarios
## NodeStore Implementation Details
**Key Insight from Recent Work**: The NodeStore is a critical performance optimization that caches node data in memory while ensuring consistency with the database. When working with route advertisements or node state changes:
1. **Timing Considerations**: Route advertisements need time to propagate from clients to server. Use `require.EventuallyWithT()` patterns in tests instead of immediate assertions.
2. **Synchronization Points**: NodeStore updates happen at specific points like `poll.go:420` after Hostinfo changes. Ensure these are maintained when modifying the polling logic.
3. **Peer Visibility**: The NodeStore's `peersFunc` determines which nodes are visible to each other. Policy-based filtering is separate from monitoring visibility - expired nodes should remain visible for debugging but marked as expired.
## Testing Guidelines
### Integration Test Patterns
```go
// Use EventuallyWithT for async operations
require.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err := headscale.ListNodes()
assert.NoError(c, err)
// Check expected state
}, 10*time.Second, 100*time.Millisecond, "description")
// Node route checking by actual node properties, not array position
var routeNode *v1.Node
for _, node := range nodes {
if nodeIDStr := fmt.Sprintf("%d", node.GetId()); expectedRoutes[nodeIDStr] != "" {
routeNode = node
break
}
}
```
### Running Problematic Tests
- Some tests require significant time (e.g., `TestNodeOnlineStatus` runs for 12 minutes)
- Infrastructure issues like disk space can cause test failures unrelated to code changes
- Use `--postgres` flag when testing database-heavy scenarios
## Important Notes
- **Dependencies**: Use `nix develop` for consistent toolchain (Go, buf, protobuf tools, linting)
- **Protocol Buffers**: Changes to `proto/` require `make generate` and should be committed separately
- **Code Style**: Enforced via golangci-lint with golines (width 88) and gofumpt formatting
- **Database**: Supports both SQLite (development) and PostgreSQL (production/testing)
- **Integration Tests**: Require Docker and can consume significant disk space
- **Performance**: NodeStore optimizations are critical for scale - be careful with changes to state management
## Debugging Integration Tests
Test artifacts are preserved in `control_logs/TIMESTAMP-ID/` including:
- Headscale server logs (stderr/stdout)
- Tailscale client logs and status
- Database dumps and network captures
- MapResponse JSON files for protocol debugging
When tests fail, check these artifacts first before assuming code issues.

View File

@ -87,10 +87,9 @@ lint-proto: check-deps $(PROTO_SOURCES)
# Code generation
.PHONY: generate
generate: check-deps $(PROTO_SOURCES)
@echo "Generating code from Protocol Buffers..."
rm -rf gen
buf generate proto
generate: check-deps
@echo "Generating code..."
go generate ./...
# Clean targets
.PHONY: clean

View File

@ -212,13 +212,10 @@ var listUsersCmd = &cobra.Command{
switch {
case id > 0:
request.Id = uint64(id)
break
case username != "":
request.Name = username
break
case email != "":
request.Email = email
break
}
response, err := client.ListUsers(ctx, request)

View File

@ -90,6 +90,32 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
log.Printf("Starting test: %s", config.TestPattern)
// Start stats collection for container resource monitoring (if enabled)
var statsCollector *StatsCollector
if config.Stats {
var err error
statsCollector, err = NewStatsCollector()
if err != nil {
if config.Verbose {
log.Printf("Warning: failed to create stats collector: %v", err)
}
statsCollector = nil
}
if statsCollector != nil {
defer statsCollector.Close()
// Start stats collection immediately - no need for complex retry logic
// The new implementation monitors Docker events and will catch containers as they start
if err := statsCollector.StartCollection(ctx, runID, config.Verbose); err != nil {
if config.Verbose {
log.Printf("Warning: failed to start stats collection: %v", err)
}
}
defer statsCollector.StopCollection()
}
}
exitCode, err := streamAndWait(ctx, cli, resp.ID)
// Ensure all containers have finished and logs are flushed before extracting artifacts
@ -105,6 +131,20 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
// Always list control files regardless of test outcome
listControlFiles(logsDir)
// Print stats summary and check memory limits if enabled
if config.Stats && statsCollector != nil {
violations := statsCollector.PrintSummaryAndCheckLimits(config.HSMemoryLimit, config.TSMemoryLimit)
if len(violations) > 0 {
log.Printf("MEMORY LIMIT VIOLATIONS DETECTED:")
log.Printf("=================================")
for _, violation := range violations {
log.Printf("Container %s exceeded memory limit: %.1f MB > %.1f MB",
violation.ContainerName, violation.MaxMemoryMB, violation.LimitMB)
}
return fmt.Errorf("test failed: %d container(s) exceeded memory limits", len(violations))
}
}
shouldCleanup := config.CleanAfter && (!config.KeepOnFailure || exitCode == 0)
if shouldCleanup {
if config.Verbose {
@ -379,10 +419,37 @@ func getDockerSocketPath() string {
return "/var/run/docker.sock"
}
// ensureImageAvailable pulls the specified Docker image to ensure it's available.
// checkImageAvailableLocally checks if the specified Docker image is available locally.
func checkImageAvailableLocally(ctx context.Context, cli *client.Client, imageName string) (bool, error) {
_, _, err := cli.ImageInspectWithRaw(ctx, imageName)
if err != nil {
if client.IsErrNotFound(err) {
return false, nil
}
return false, fmt.Errorf("failed to inspect image %s: %w", imageName, err)
}
return true, nil
}
// ensureImageAvailable checks if the image is available locally first, then pulls if needed.
func ensureImageAvailable(ctx context.Context, cli *client.Client, imageName string, verbose bool) error {
// First check if image is available locally
available, err := checkImageAvailableLocally(ctx, cli, imageName)
if err != nil {
return fmt.Errorf("failed to check local image availability: %w", err)
}
if available {
if verbose {
log.Printf("Image %s is available locally", imageName)
}
return nil
}
// Image not available locally, try to pull it
if verbose {
log.Printf("Pulling image %s...", imageName)
log.Printf("Image %s not found locally, pulling...", imageName)
}
reader, err := cli.ImagePull(ctx, imageName, image.PullOptions{})

View File

@ -190,7 +190,7 @@ func checkDockerSocket(ctx context.Context) DoctorResult {
}
}
// checkGolangImage verifies we can access the golang Docker image.
// checkGolangImage verifies the golang Docker image is available locally or can be pulled.
func checkGolangImage(ctx context.Context) DoctorResult {
cli, err := createDockerClient()
if err != nil {
@ -205,17 +205,40 @@ func checkGolangImage(ctx context.Context) DoctorResult {
goVersion := detectGoVersion()
imageName := "golang:" + goVersion
// Check if we can pull the image
// First check if image is available locally
available, err := checkImageAvailableLocally(ctx, cli, imageName)
if err != nil {
return DoctorResult{
Name: "Golang Image",
Status: "FAIL",
Message: fmt.Sprintf("Cannot check golang image %s: %v", imageName, err),
Suggestions: []string{
"Check Docker daemon status",
"Try: docker images | grep golang",
},
}
}
if available {
return DoctorResult{
Name: "Golang Image",
Status: "PASS",
Message: fmt.Sprintf("Golang image %s is available locally", imageName),
}
}
// Image not available locally, try to pull it
err = ensureImageAvailable(ctx, cli, imageName, false)
if err != nil {
return DoctorResult{
Name: "Golang Image",
Status: "FAIL",
Message: fmt.Sprintf("Cannot pull golang image %s: %v", imageName, err),
Message: fmt.Sprintf("Golang image %s not available locally and cannot pull: %v", imageName, err),
Suggestions: []string{
"Check internet connectivity",
"Verify Docker Hub access",
"Try: docker pull " + imageName,
"Or run tests offline if image was pulled previously",
},
}
}
@ -223,7 +246,7 @@ func checkGolangImage(ctx context.Context) DoctorResult {
return DoctorResult{
Name: "Golang Image",
Status: "PASS",
Message: fmt.Sprintf("Golang image %s is available", imageName),
Message: fmt.Sprintf("Golang image %s is now available", imageName),
}
}

View File

@ -24,6 +24,9 @@ type RunConfig struct {
KeepOnFailure bool `flag:"keep-on-failure,default=false,Keep containers on test failure"`
LogsDir string `flag:"logs-dir,default=control_logs,Control logs directory"`
Verbose bool `flag:"verbose,default=false,Verbose output"`
Stats bool `flag:"stats,default=false,Collect and display container resource usage statistics"`
HSMemoryLimit float64 `flag:"hs-memory-limit,default=0,Fail test if any Headscale container exceeds this memory limit in MB (0 = disabled)"`
TSMemoryLimit float64 `flag:"ts-memory-limit,default=0,Fail test if any Tailscale container exceeds this memory limit in MB (0 = disabled)"`
}
// runIntegrationTest executes the integration test workflow.

468
cmd/hi/stats.go Normal file
View File

@ -0,0 +1,468 @@
package main
import (
"context"
"encoding/json"
"fmt"
"log"
"sort"
"strings"
"sync"
"time"
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/events"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/client"
)
// ContainerStats represents statistics for a single container
type ContainerStats struct {
ContainerID string
ContainerName string
Stats []StatsSample
mutex sync.RWMutex
}
// StatsSample represents a single stats measurement
type StatsSample struct {
Timestamp time.Time
CPUUsage float64 // CPU usage percentage
MemoryMB float64 // Memory usage in MB
}
// StatsCollector manages collection of container statistics
type StatsCollector struct {
client *client.Client
containers map[string]*ContainerStats
stopChan chan struct{}
wg sync.WaitGroup
mutex sync.RWMutex
collectionStarted bool
}
// NewStatsCollector creates a new stats collector instance
func NewStatsCollector() (*StatsCollector, error) {
cli, err := createDockerClient()
if err != nil {
return nil, fmt.Errorf("failed to create Docker client: %w", err)
}
return &StatsCollector{
client: cli,
containers: make(map[string]*ContainerStats),
stopChan: make(chan struct{}),
}, nil
}
// StartCollection begins monitoring all containers and collecting stats for hs- and ts- containers with matching run ID
func (sc *StatsCollector) StartCollection(ctx context.Context, runID string, verbose bool) error {
sc.mutex.Lock()
defer sc.mutex.Unlock()
if sc.collectionStarted {
return fmt.Errorf("stats collection already started")
}
sc.collectionStarted = true
// Start monitoring existing containers
sc.wg.Add(1)
go sc.monitorExistingContainers(ctx, runID, verbose)
// Start Docker events monitoring for new containers
sc.wg.Add(1)
go sc.monitorDockerEvents(ctx, runID, verbose)
if verbose {
log.Printf("Started container monitoring for run ID %s", runID)
}
return nil
}
// StopCollection stops all stats collection
func (sc *StatsCollector) StopCollection() {
// Check if already stopped without holding lock
sc.mutex.RLock()
if !sc.collectionStarted {
sc.mutex.RUnlock()
return
}
sc.mutex.RUnlock()
// Signal stop to all goroutines
close(sc.stopChan)
// Wait for all goroutines to finish
sc.wg.Wait()
// Mark as stopped
sc.mutex.Lock()
sc.collectionStarted = false
sc.mutex.Unlock()
}
// monitorExistingContainers checks for existing containers that match our criteria
func (sc *StatsCollector) monitorExistingContainers(ctx context.Context, runID string, verbose bool) {
defer sc.wg.Done()
containers, err := sc.client.ContainerList(ctx, container.ListOptions{})
if err != nil {
if verbose {
log.Printf("Failed to list existing containers: %v", err)
}
return
}
for _, cont := range containers {
if sc.shouldMonitorContainer(cont, runID) {
sc.startStatsForContainer(ctx, cont.ID, cont.Names[0], verbose)
}
}
}
// monitorDockerEvents listens for container start events and begins monitoring relevant containers
func (sc *StatsCollector) monitorDockerEvents(ctx context.Context, runID string, verbose bool) {
defer sc.wg.Done()
filter := filters.NewArgs()
filter.Add("type", "container")
filter.Add("event", "start")
eventOptions := events.ListOptions{
Filters: filter,
}
events, errs := sc.client.Events(ctx, eventOptions)
for {
select {
case <-sc.stopChan:
return
case <-ctx.Done():
return
case event := <-events:
if event.Type == "container" && event.Action == "start" {
// Get container details
containerInfo, err := sc.client.ContainerInspect(ctx, event.ID)
if err != nil {
continue
}
// Convert to types.Container format for consistency
cont := types.Container{
ID: containerInfo.ID,
Names: []string{containerInfo.Name},
Labels: containerInfo.Config.Labels,
}
if sc.shouldMonitorContainer(cont, runID) {
sc.startStatsForContainer(ctx, cont.ID, cont.Names[0], verbose)
}
}
case err := <-errs:
if verbose {
log.Printf("Error in Docker events stream: %v", err)
}
return
}
}
}
// shouldMonitorContainer determines if a container should be monitored
func (sc *StatsCollector) shouldMonitorContainer(cont types.Container, runID string) bool {
// Check if it has the correct run ID label
if cont.Labels == nil || cont.Labels["hi.run-id"] != runID {
return false
}
// Check if it's an hs- or ts- container
for _, name := range cont.Names {
containerName := strings.TrimPrefix(name, "/")
if strings.HasPrefix(containerName, "hs-") || strings.HasPrefix(containerName, "ts-") {
return true
}
}
return false
}
// startStatsForContainer begins stats collection for a specific container
func (sc *StatsCollector) startStatsForContainer(ctx context.Context, containerID, containerName string, verbose bool) {
containerName = strings.TrimPrefix(containerName, "/")
sc.mutex.Lock()
// Check if we're already monitoring this container
if _, exists := sc.containers[containerID]; exists {
sc.mutex.Unlock()
return
}
sc.containers[containerID] = &ContainerStats{
ContainerID: containerID,
ContainerName: containerName,
Stats: make([]StatsSample, 0),
}
sc.mutex.Unlock()
if verbose {
log.Printf("Starting stats collection for container %s (%s)", containerName, containerID[:12])
}
sc.wg.Add(1)
go sc.collectStatsForContainer(ctx, containerID, verbose)
}
// collectStatsForContainer collects stats for a specific container using Docker API streaming
func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containerID string, verbose bool) {
defer sc.wg.Done()
// Use Docker API streaming stats - much more efficient than CLI
statsResponse, err := sc.client.ContainerStats(ctx, containerID, true)
if err != nil {
if verbose {
log.Printf("Failed to get stats stream for container %s: %v", containerID[:12], err)
}
return
}
defer statsResponse.Body.Close()
decoder := json.NewDecoder(statsResponse.Body)
var prevStats *container.Stats
for {
select {
case <-sc.stopChan:
return
case <-ctx.Done():
return
default:
var stats container.Stats
if err := decoder.Decode(&stats); err != nil {
// EOF is expected when container stops or stream ends
if err.Error() != "EOF" && verbose {
log.Printf("Failed to decode stats for container %s: %v", containerID[:12], err)
}
return
}
// Calculate CPU percentage (only if we have previous stats)
var cpuPercent float64
if prevStats != nil {
cpuPercent = calculateCPUPercent(prevStats, &stats)
}
// Calculate memory usage in MB
memoryMB := float64(stats.MemoryStats.Usage) / (1024 * 1024)
// Store the sample (skip first sample since CPU calculation needs previous stats)
if prevStats != nil {
// Get container stats reference without holding the main mutex
var containerStats *ContainerStats
var exists bool
sc.mutex.RLock()
containerStats, exists = sc.containers[containerID]
sc.mutex.RUnlock()
if exists && containerStats != nil {
containerStats.mutex.Lock()
containerStats.Stats = append(containerStats.Stats, StatsSample{
Timestamp: time.Now(),
CPUUsage: cpuPercent,
MemoryMB: memoryMB,
})
containerStats.mutex.Unlock()
}
}
// Save current stats for next iteration
prevStats = &stats
}
}
}
// calculateCPUPercent calculates CPU usage percentage from Docker stats
func calculateCPUPercent(prevStats, stats *container.Stats) float64 {
// CPU calculation based on Docker's implementation
cpuDelta := float64(stats.CPUStats.CPUUsage.TotalUsage) - float64(prevStats.CPUStats.CPUUsage.TotalUsage)
systemDelta := float64(stats.CPUStats.SystemUsage) - float64(prevStats.CPUStats.SystemUsage)
if systemDelta > 0 && cpuDelta >= 0 {
// Calculate CPU percentage: (container CPU delta / system CPU delta) * number of CPUs * 100
numCPUs := float64(len(stats.CPUStats.CPUUsage.PercpuUsage))
if numCPUs == 0 {
// Fallback: if PercpuUsage is not available, assume 1 CPU
numCPUs = 1.0
}
return (cpuDelta / systemDelta) * numCPUs * 100.0
}
return 0.0
}
// ContainerStatsSummary represents summary statistics for a container
type ContainerStatsSummary struct {
ContainerName string
SampleCount int
CPU StatsSummary
Memory StatsSummary
}
// MemoryViolation represents a container that exceeded the memory limit
type MemoryViolation struct {
ContainerName string
MaxMemoryMB float64
LimitMB float64
}
// StatsSummary represents min, max, and average for a metric
type StatsSummary struct {
Min float64
Max float64
Average float64
}
// GetSummary returns a summary of collected statistics
func (sc *StatsCollector) GetSummary() []ContainerStatsSummary {
// Take snapshot of container references without holding main lock long
sc.mutex.RLock()
containerRefs := make([]*ContainerStats, 0, len(sc.containers))
for _, containerStats := range sc.containers {
containerRefs = append(containerRefs, containerStats)
}
sc.mutex.RUnlock()
summaries := make([]ContainerStatsSummary, 0, len(containerRefs))
for _, containerStats := range containerRefs {
containerStats.mutex.RLock()
stats := make([]StatsSample, len(containerStats.Stats))
copy(stats, containerStats.Stats)
containerName := containerStats.ContainerName
containerStats.mutex.RUnlock()
if len(stats) == 0 {
continue
}
summary := ContainerStatsSummary{
ContainerName: containerName,
SampleCount: len(stats),
}
// Calculate CPU stats
cpuValues := make([]float64, len(stats))
memoryValues := make([]float64, len(stats))
for i, sample := range stats {
cpuValues[i] = sample.CPUUsage
memoryValues[i] = sample.MemoryMB
}
summary.CPU = calculateStatsSummary(cpuValues)
summary.Memory = calculateStatsSummary(memoryValues)
summaries = append(summaries, summary)
}
// Sort by container name for consistent output
sort.Slice(summaries, func(i, j int) bool {
return summaries[i].ContainerName < summaries[j].ContainerName
})
return summaries
}
// calculateStatsSummary calculates min, max, and average for a slice of values
func calculateStatsSummary(values []float64) StatsSummary {
if len(values) == 0 {
return StatsSummary{}
}
min := values[0]
max := values[0]
sum := 0.0
for _, value := range values {
if value < min {
min = value
}
if value > max {
max = value
}
sum += value
}
return StatsSummary{
Min: min,
Max: max,
Average: sum / float64(len(values)),
}
}
// PrintSummary prints the statistics summary to the console
func (sc *StatsCollector) PrintSummary() {
summaries := sc.GetSummary()
if len(summaries) == 0 {
log.Printf("No container statistics collected")
return
}
log.Printf("Container Resource Usage Summary:")
log.Printf("================================")
for _, summary := range summaries {
log.Printf("Container: %s (%d samples)", summary.ContainerName, summary.SampleCount)
log.Printf(" CPU Usage: Min: %6.2f%% Max: %6.2f%% Avg: %6.2f%%",
summary.CPU.Min, summary.CPU.Max, summary.CPU.Average)
log.Printf(" Memory Usage: Min: %6.1f MB Max: %6.1f MB Avg: %6.1f MB",
summary.Memory.Min, summary.Memory.Max, summary.Memory.Average)
log.Printf("")
}
}
// CheckMemoryLimits checks if any containers exceeded their memory limits
func (sc *StatsCollector) CheckMemoryLimits(hsLimitMB, tsLimitMB float64) []MemoryViolation {
if hsLimitMB <= 0 && tsLimitMB <= 0 {
return nil
}
summaries := sc.GetSummary()
var violations []MemoryViolation
for _, summary := range summaries {
var limitMB float64
if strings.HasPrefix(summary.ContainerName, "hs-") {
limitMB = hsLimitMB
} else if strings.HasPrefix(summary.ContainerName, "ts-") {
limitMB = tsLimitMB
} else {
continue // Skip containers that don't match our patterns
}
if limitMB > 0 && summary.Memory.Max > limitMB {
violations = append(violations, MemoryViolation{
ContainerName: summary.ContainerName,
MaxMemoryMB: summary.Memory.Max,
LimitMB: limitMB,
})
}
}
return violations
}
// PrintSummaryAndCheckLimits prints the statistics summary and returns memory violations if any
func (sc *StatsCollector) PrintSummaryAndCheckLimits(hsLimitMB, tsLimitMB float64) []MemoryViolation {
sc.PrintSummary()
return sc.CheckMemoryLimits(hsLimitMB, tsLimitMB)
}
// Close closes the stats collector and cleans up resources
func (sc *StatsCollector) Close() error {
sc.StopCollection()
return sc.client.Close()
}

View File

@ -19,7 +19,7 @@
overlay = _: prev: let
pkgs = nixpkgs.legacyPackages.${prev.system};
buildGo = pkgs.buildGo124Module;
vendorHash = "sha256-S2GnCg2dyfjIyi5gXhVEuRs5Bop2JAhZcnhg1fu4/Gg=";
vendorHash = "sha256-83L2NMyOwKCHWqcowStJ7Ze/U9CJYhzleDRLrJNhX2g=";
in {
headscale = buildGo {
pname = "headscale";

27
go.mod
View File

@ -23,7 +23,6 @@ require (
github.com/gorilla/mux v1.8.1
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.0
github.com/jagottsicher/termcolor v1.0.2
github.com/klauspost/compress v1.18.0
github.com/oauth2-proxy/mockoidc v0.0.0-20240214162133-caebfff84d25
github.com/ory/dockertest/v3 v3.12.0
github.com/philip-bui/grpc-zerolog v1.0.1
@ -43,11 +42,11 @@ require (
github.com/tailscale/tailsql v0.0.0-20250421235516-02f85f087b97
github.com/tcnksm/go-latest v0.0.0-20170313132115-e3007ae9052e
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba
golang.org/x/crypto v0.39.0
golang.org/x/crypto v0.40.0
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0
golang.org/x/net v0.41.0
golang.org/x/net v0.42.0
golang.org/x/oauth2 v0.30.0
golang.org/x/sync v0.15.0
golang.org/x/sync v0.16.0
google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822
google.golang.org/grpc v1.73.0
google.golang.org/protobuf v1.36.6
@ -55,7 +54,7 @@ require (
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/postgres v1.6.0
gorm.io/gorm v1.30.0
tailscale.com v1.84.2
tailscale.com v1.84.3
zgo.at/zcache/v2 v2.2.0
zombiezen.com/go/postgrestest v1.0.1
)
@ -81,7 +80,7 @@ require (
modernc.org/libc v1.62.1 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.10.0 // indirect
modernc.org/sqlite v1.37.0 // indirect
modernc.org/sqlite v1.37.0
)
require (
@ -166,6 +165,7 @@ require (
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/jsimonetti/rtnetlink v1.4.1 // indirect
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/lib/pq v1.10.9 // indirect
@ -231,14 +231,19 @@ require (
go.opentelemetry.io/otel/trace v1.36.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
go4.org/mem v0.0.0-20240501181205-ae6ca9944745 // indirect
golang.org/x/mod v0.25.0 // indirect
golang.org/x/sys v0.33.0 // indirect
golang.org/x/term v0.32.0 // indirect
golang.org/x/text v0.26.0 // indirect
golang.org/x/mod v0.26.0 // indirect
golang.org/x/sys v0.34.0 // indirect
golang.org/x/term v0.33.0 // indirect
golang.org/x/text v0.27.0 // indirect
golang.org/x/time v0.10.0 // indirect
golang.org/x/tools v0.33.0 // indirect
golang.org/x/tools v0.35.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 // indirect
gvisor.dev/gvisor v0.0.0-20250205023644-9414b50a5633 // indirect
)
tool (
golang.org/x/tools/cmd/stringer
tailscale.com/cmd/viewer
)

34
go.sum
View File

@ -555,8 +555,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 h1:R84qjqJb5nVJMxqWYb3np9L5ZsaDtB+a39EqjV0JSUM=
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0/go.mod h1:S9Xr4PYopiDyqSyp5NjCrhFrqg6A5zA2E/iPHPhqnS8=
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f h1:phY1HzDcf18Aq9A8KkmRtY9WvOFIxN8wgfvy6Zm1DV8=
@ -567,8 +567,8 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w=
golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
@ -577,8 +577,8 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs=
golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8=
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@ -587,8 +587,8 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8=
golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@ -615,8 +615,8 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210615171337-6886f2dfbf5b/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
@ -624,8 +624,8 @@ golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuX
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg=
golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ=
golang.org/x/term v0.33.0 h1:NuFncQrRcaRvVmgRkvM3j/F00gWIAlcmlB8ACEKmGIg=
golang.org/x/term v0.33.0/go.mod h1:s18+ql9tYWp1IfpV9DmCtQDDSRBUjKaw9M1eAv5UeF0=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
@ -633,8 +633,8 @@ golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M=
golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4=
golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@ -643,8 +643,8 @@ golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roY
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc=
golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=
golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0=
golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@ -714,6 +714,8 @@ software.sslmate.com/src/go-pkcs12 v0.4.0 h1:H2g08FrTvSFKUj+D309j1DPfk5APnIdAQAB
software.sslmate.com/src/go-pkcs12 v0.4.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=
tailscale.com v1.84.2 h1:v6aM4RWUgYiV52LRAx6ET+dlGnvO/5lnqPXb7/pMnR0=
tailscale.com v1.84.2/go.mod h1:6/S63NMAhmncYT/1zIPDJkvCuZwMw+JnUuOfSPNazpo=
tailscale.com v1.84.3 h1:Ur9LMedSgicwbqpy5xn7t49G8490/s6rqAJOk5Q5AYE=
tailscale.com v1.84.3/go.mod h1:6/S63NMAhmncYT/1zIPDJkvCuZwMw+JnUuOfSPNazpo=
zgo.at/zcache/v2 v2.2.0 h1:K29/IPjMniZfveYE+IRXfrl11tMzHkIPuyGrfVZ2fGo=
zgo.at/zcache/v2 v2.2.0/go.mod h1:gyCeoLVo01QjDZynjime8xUGHHMbsLiPyUTBpDGd4Gk=
zombiezen.com/go/postgrestest v1.0.1 h1:aXoADQAJmZDU3+xilYVut0pHhgc0sF8ZspPW9gFNwP4=

View File

@ -28,14 +28,15 @@ import (
derpServer "github.com/juanfont/headscale/hscontrol/derp/server"
"github.com/juanfont/headscale/hscontrol/dns"
"github.com/juanfont/headscale/hscontrol/mapper"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/state"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"github.com/juanfont/headscale/hscontrol/util"
zerolog "github.com/philip-bui/grpc-zerolog"
"github.com/pkg/profile"
zl "github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/sasha-s/go-deadlock"
"golang.org/x/crypto/acme"
"golang.org/x/crypto/acme/autocert"
"golang.org/x/sync/errgroup"
@ -64,6 +65,19 @@ var (
)
)
var (
debugDeadlock = envknob.Bool("HEADSCALE_DEBUG_DEADLOCK")
debugDeadlockTimeout = envknob.RegisterDuration("HEADSCALE_DEBUG_DEADLOCK_TIMEOUT")
)
func init() {
deadlock.Opts.Disable = !debugDeadlock
if debugDeadlock {
deadlock.Opts.DeadlockTimeout = debugDeadlockTimeout()
deadlock.Opts.PrintAllCurrentGoroutines = true
}
}
const (
AuthPrefix = "Bearer "
updateInterval = 5 * time.Second
@ -82,9 +96,8 @@ type Headscale struct {
// Things that generate changes
extraRecordMan *dns.ExtraRecordsMan
mapper *mapper.Mapper
nodeNotifier *notifier.Notifier
authProvider AuthProvider
mapBatcher mapper.Batcher
pollNetMapStreamWG sync.WaitGroup
}
@ -118,7 +131,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
cfg: cfg,
noisePrivateKey: noisePrivateKey,
pollNetMapStreamWG: sync.WaitGroup{},
nodeNotifier: notifier.NewNotifier(cfg),
state: s,
}
@ -136,12 +148,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
return
}
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "ephemeral-gc-policy", node.Hostname)
app.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
app.Change(policyChanged)
log.Debug().Uint64("node.id", ni.Uint64()).Msgf("deleted ephemeral node")
})
app.ephemeralGC = ephemeralGC
@ -153,10 +160,9 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
defer cancel()
oidcProvider, err := NewAuthProviderOIDC(
ctx,
&app,
cfg.ServerURL,
&cfg.OIDC,
app.state,
app.nodeNotifier,
)
if err != nil {
if cfg.OIDC.OnlyStartIfOIDCIsAvailable {
@ -262,16 +268,18 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
return
case <-expireTicker.C:
var update types.StateUpdate
var expiredNodeChanges []change.ChangeSet
var changed bool
lastExpiryCheck, update, changed = h.state.ExpireExpiredNodes(lastExpiryCheck)
lastExpiryCheck, expiredNodeChanges, changed = h.state.ExpireExpiredNodes(lastExpiryCheck)
if changed {
log.Trace().Interface("nodes", update.ChangePatches).Msgf("expiring nodes")
log.Trace().Interface("changes", expiredNodeChanges).Msgf("expiring nodes")
ctx := types.NotifyCtx(context.Background(), "expire-expired", "na")
h.nodeNotifier.NotifyAll(ctx, update)
// Send the changes directly since they're already in the new format
for _, nodeChange := range expiredNodeChanges {
h.Change(nodeChange)
}
}
case <-derpTickerChan:
@ -282,11 +290,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
derpMap.Regions[region.RegionID] = &region
}
ctx := types.NotifyCtx(context.Background(), "derpmap-update", "na")
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StateDERPUpdated,
DERPMap: derpMap,
})
h.Change(change.DERPSet)
case records, ok := <-extraRecordsUpdate:
if !ok {
@ -294,19 +298,16 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
}
h.cfg.TailcfgDNSConfig.ExtraRecords = records
ctx := types.NotifyCtx(context.Background(), "dns-extrarecord", "all")
// TODO(kradalby): We can probably do better than sending a full update here,
// but for now this will ensure that all of the nodes get the new records.
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
h.Change(change.ExtraRecordsSet)
}
}
}
func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
req interface{},
req any,
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
) (any, error) {
// Check if the request is coming from the on-server client.
// This is not secure, but it is to maintain maintainability
// with the "legacy" database-based client
@ -484,58 +485,6 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
return router
}
// // TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
// // Maybe we should attempt a new in memory state and not go via the DB?
// // Maybe this should be implemented as an event bus?
// // A bool is returned indicating if a full update was sent to all nodes
// func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error {
// users, err := db.ListUsers()
// if err != nil {
// return err
// }
// changed, err := polMan.SetUsers(users)
// if err != nil {
// return err
// }
// if changed {
// ctx := types.NotifyCtx(context.Background(), "acl-users-change", "all")
// notif.NotifyAll(ctx, types.UpdateFull())
// }
// return nil
// }
// // TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
// // Maybe we should attempt a new in memory state and not go via the DB?
// // Maybe this should be implemented as an event bus?
// // A bool is returned indicating if a full update was sent to all nodes
// func nodesChangedHook(
// db *db.HSDatabase,
// polMan policy.PolicyManager,
// notif *notifier.Notifier,
// ) (bool, error) {
// nodes, err := db.ListNodes()
// if err != nil {
// return false, err
// }
// filterChanged, err := polMan.SetNodes(nodes)
// if err != nil {
// return false, err
// }
// if filterChanged {
// ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all")
// notif.NotifyAll(ctx, types.UpdateFull())
// return true, nil
// }
// return false, nil
// }
// Serve launches the HTTP and gRPC server service Headscale and the API.
func (h *Headscale) Serve() error {
capver.CanOldCodeBeCleanedUp()
@ -562,8 +511,9 @@ func (h *Headscale) Serve() error {
Str("minimum_version", capver.TailscaleVersion(capver.MinSupportedCapabilityVersion)).
Msg("Clients with a lower minimum version will be rejected")
// Fetch an initial DERP Map before we start serving
h.mapper = mapper.NewMapper(h.state, h.cfg, h.nodeNotifier)
h.mapBatcher = mapper.NewBatcherAndMapper(h.cfg, h.state)
h.mapBatcher.Start()
defer h.mapBatcher.Close()
// TODO(kradalby): fix state part.
if h.cfg.DERP.ServerEnabled {
@ -838,8 +788,12 @@ func (h *Headscale) Serve() error {
log.Info().
Msg("ACL policy successfully reloaded, notifying nodes of change")
ctx := types.NotifyCtx(context.Background(), "acl-sighup", "na")
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
err = h.state.AutoApproveNodes()
if err != nil {
log.Error().Err(err).Msg("failed to approve routes after new policy")
}
h.Change(change.PolicySet)
}
default:
info := func(msg string) { log.Info().Msg(msg) }
@ -865,7 +819,6 @@ func (h *Headscale) Serve() error {
}
info("closing node notifier")
h.nodeNotifier.Close()
info("waiting for netmap stream to close")
h.pollNetMapStreamWG.Wait()
@ -1047,3 +1000,10 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
return &machineKey, nil
}
// Change is used to send changes to nodes.
// All change should be enqueued here and empty will be automatically
// ignored.
func (h *Headscale) Change(c change.ChangeSet) {
h.mapBatcher.AddWork(c)
}

View File

@ -10,6 +10,8 @@ import (
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
@ -32,6 +34,21 @@ func (h *Headscale) handleRegister(
}
if node != nil {
// If an existing node is trying to register with an auth key,
// we need to validate the auth key even for existing nodes
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
resp, err := h.handleRegisterWithAuthKey(regReq, machineKey)
if err != nil {
// Preserve HTTPError types so they can be handled properly by the HTTP layer
var httpErr HTTPError
if errors.As(err, &httpErr) {
return nil, httpErr
}
return nil, fmt.Errorf("handling register with auth key for existing node: %w", err)
}
return resp, nil
}
resp, err := h.handleExistingNode(node, regReq, machineKey)
if err != nil {
return nil, fmt.Errorf("handling existing node: %w", err)
@ -47,6 +64,11 @@ func (h *Headscale) handleRegister(
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
resp, err := h.handleRegisterWithAuthKey(regReq, machineKey)
if err != nil {
// Preserve HTTPError types so they can be handled properly by the HTTP layer
var httpErr HTTPError
if errors.As(err, &httpErr) {
return nil, httpErr
}
return nil, fmt.Errorf("handling register with auth key: %w", err)
}
@ -66,11 +88,13 @@ func (h *Headscale) handleExistingNode(
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
if node.MachineKey != machineKey {
return nil, NewHTTPError(http.StatusUnauthorized, "node exist with different machine key", nil)
}
expired := node.IsExpired()
if !expired && !regReq.Expiry.IsZero() {
requestExpiry := regReq.Expiry
@ -82,42 +106,26 @@ func (h *Headscale) handleExistingNode(
// If the request expiry is in the past, we consider it a logout.
if requestExpiry.Before(time.Now()) {
if node.IsEphemeral() {
policyChanged, err := h.state.DeleteNode(node)
c, err := h.state.DeleteNode(node)
if err != nil {
return nil, fmt.Errorf("deleting ephemeral node: %w", err)
}
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "auth-logout-ephemeral-policy", "na")
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
} else {
ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na")
h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerRemoved(node.ID))
}
h.Change(c)
return nil, nil
}
}
n, policyChanged, err := h.state.SetNodeExpiry(node.ID, requestExpiry)
_, c, err := h.state.SetNodeExpiry(node.ID, requestExpiry)
if err != nil {
return nil, fmt.Errorf("setting node expiry: %w", err)
}
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "auth-expiry-policy", "na")
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
} else {
ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na")
h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdateExpire(node.ID, requestExpiry), node.ID)
h.Change(c)
}
return nodeToRegisterResponse(n), nil
}
return nodeToRegisterResponse(node), nil
return nodeToRegisterResponse(node), nil
}
func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse {
@ -168,7 +176,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*tailcfg.RegisterResponse, error) {
node, changed, err := h.state.HandleNodeFromPreAuthKey(
node, changed, policyChanged, err := h.state.HandleNodeFromPreAuthKey(
regReq,
machineKey,
)
@ -184,6 +192,12 @@ func (h *Headscale) handleRegisterWithAuthKey(
return nil, err
}
// If node is nil, it means an ephemeral node was deleted during logout
if node == nil {
h.Change(changed)
return nil, nil
}
// This is a bit of a back and forth, but we have a bit of a chicken and egg
// dependency here.
// Because the way the policy manager works, we need to have the node
@ -195,23 +209,22 @@ func (h *Headscale) handleRegisterWithAuthKey(
// ensure we send an update.
// This works, but might be another good candidate for doing some sort of
// eventbus.
routesChanged := h.state.AutoApproveRoutes(node)
// TODO(kradalby): This needs to be ran as part of the batcher maybe?
// now since we dont update the node/pol here anymore
routeChange := h.state.AutoApproveRoutes(node)
if _, _, err := h.state.SaveNode(node); err != nil {
return nil, fmt.Errorf("saving auto approved routes to node: %w", err)
}
if routesChanged {
ctx := types.NotifyCtx(context.Background(), "node updated", node.Hostname)
h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(node.ID))
} else if changed {
ctx := types.NotifyCtx(context.Background(), "node created", node.Hostname)
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
} else {
// Existing node re-registering without route changes
// Still need to notify peers about the node being active again
// Use UpdateFull to ensure all peers get complete peer maps
ctx := types.NotifyCtx(context.Background(), "node re-registered", node.Hostname)
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
if routeChange && changed.Empty() {
changed = change.NodeAdded(node.ID)
}
h.Change(changed)
// If policy changed due to node registration, send a separate policy change
if policyChanged {
policyChange := change.PolicyChange()
h.Change(policyChange)
}
return &tailcfg.RegisterResponse{

View File

@ -1,5 +1,7 @@
package capver
//go:generate go run ../../tools/capver/main.go
import (
"slices"
"sort"
@ -10,7 +12,7 @@ import (
"tailscale.com/util/set"
)
const MinSupportedCapabilityVersion tailcfg.CapabilityVersion = 88
const MinSupportedCapabilityVersion tailcfg.CapabilityVersion = 90
// CanOldCodeBeCleanedUp is intended to be called on startup to see if
// there are old code that can ble cleaned up, entries should contain

View File

@ -1,14 +1,10 @@
package capver
// Generated DO NOT EDIT
//Generated DO NOT EDIT
import "tailscale.com/tailcfg"
var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{
"v1.60.0": 87,
"v1.60.1": 87,
"v1.62.0": 88,
"v1.62.1": 88,
"v1.64.0": 90,
"v1.64.1": 90,
"v1.64.2": 90,
@ -36,18 +32,21 @@ var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{
"v1.80.3": 113,
"v1.82.0": 115,
"v1.82.5": 115,
"v1.84.0": 116,
"v1.84.1": 116,
"v1.84.2": 116,
}
var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{
87: "v1.60.0",
88: "v1.62.0",
90: "v1.64.0",
95: "v1.66.0",
97: "v1.68.0",
102: "v1.70.0",
104: "v1.72.0",
106: "v1.74.0",
109: "v1.78.0",
113: "v1.80.0",
115: "v1.82.0",
90: "v1.64.0",
95: "v1.66.0",
97: "v1.68.0",
102: "v1.70.0",
104: "v1.72.0",
106: "v1.74.0",
109: "v1.78.0",
113: "v1.80.0",
115: "v1.82.0",
116: "v1.84.0",
}

View File

@ -13,11 +13,10 @@ func TestTailscaleLatestMajorMinor(t *testing.T) {
stripV bool
expected []string
}{
{3, false, []string{"v1.78", "v1.80", "v1.82"}},
{2, true, []string{"1.80", "1.82"}},
{3, false, []string{"v1.80", "v1.82", "v1.84"}},
{2, true, []string{"1.82", "1.84"}},
// Lazy way to see all supported versions
{10, true, []string{
"1.64",
"1.66",
"1.68",
"1.70",
@ -27,6 +26,7 @@ func TestTailscaleLatestMajorMinor(t *testing.T) {
"1.78",
"1.80",
"1.82",
"1.84",
}},
{0, false, nil},
}
@ -46,7 +46,6 @@ func TestCapVerMinimumTailscaleVersion(t *testing.T) {
input tailcfg.CapabilityVersion
expected string
}{
{88, "v1.62.0"},
{90, "v1.64.0"},
{95, "v1.66.0"},
{106, "v1.74.0"},

View File

@ -7,7 +7,6 @@ import (
"os/exec"
"path/filepath"
"slices"
"sort"
"strings"
"testing"
"time"
@ -362,8 +361,8 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) {
}
if diff := cmp.Diff(expectedKeys, keys, cmp.Comparer(func(a, b []string) bool {
sort.Sort(sort.StringSlice(a))
sort.Sort(sort.StringSlice(b))
slices.Sort(a)
slices.Sort(b)
return slices.Equal(a, b)
}), cmpopts.IgnoreFields(types.PreAuthKey{}, "User", "CreatedAt", "Reusable", "Ephemeral", "Used", "Expiration")); diff != "" {
t.Errorf("TestSQLiteMigrationAndDataValidation() pre-auth key tags migration mismatch (-want +got):\n%s", diff)

View File

@ -7,15 +7,19 @@ import (
"net/netip"
"slices"
"sort"
"strconv"
"sync"
"testing"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
)
const (
@ -39,9 +43,7 @@ var (
// If no peer IDs are given, all peers are returned.
// If at least one peer ID is given, only these peer nodes will be returned.
func (hsdb *HSDatabase) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
return ListPeers(rx, nodeID, peerIDs...)
})
return ListPeers(hsdb.DB, nodeID, peerIDs...)
}
// ListPeers returns peers of node, regardless of any Policy or if the node is expired.
@ -66,9 +68,7 @@ func ListPeers(tx *gorm.DB, nodeID types.NodeID, peerIDs ...types.NodeID) (types
// ListNodes queries the database for either all nodes if no parameters are given
// or for the given nodes if at least one node ID is given as parameter.
func (hsdb *HSDatabase) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
return ListNodes(rx, nodeIDs...)
})
return ListNodes(hsdb.DB, nodeIDs...)
}
// ListNodes queries the database for either all nodes if no parameters are given
@ -120,9 +120,7 @@ func getNode(tx *gorm.DB, uid types.UserID, name string) (*types.Node, error) {
}
func (hsdb *HSDatabase) GetNodeByID(id types.NodeID) (*types.Node, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
return GetNodeByID(rx, id)
})
return GetNodeByID(hsdb.DB, id)
}
// GetNodeByID finds a Node by ID and returns the Node struct.
@ -140,9 +138,7 @@ func GetNodeByID(tx *gorm.DB, id types.NodeID) (*types.Node, error) {
}
func (hsdb *HSDatabase) GetNodeByMachineKey(machineKey key.MachinePublic) (*types.Node, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
return GetNodeByMachineKey(rx, machineKey)
})
return GetNodeByMachineKey(hsdb.DB, machineKey)
}
// GetNodeByMachineKey finds a Node by its MachineKey and returns the Node struct.
@ -163,9 +159,7 @@ func GetNodeByMachineKey(
}
func (hsdb *HSDatabase) GetNodeByNodeKey(nodeKey key.NodePublic) (*types.Node, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
return GetNodeByNodeKey(rx, nodeKey)
})
return GetNodeByNodeKey(hsdb.DB, nodeKey)
}
// GetNodeByNodeKey finds a Node by its NodeKey and returns the Node struct.
@ -352,8 +346,8 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
registrationMethod string,
ipv4 *netip.Addr,
ipv6 *netip.Addr,
) (*types.Node, bool, error) {
var newNode bool
) (*types.Node, change.ChangeSet, error) {
var nodeChange change.ChangeSet
node, err := Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
if reg, ok := hsdb.regCache.Get(registrationID); ok {
if node, _ := GetNodeByNodeKey(tx, reg.Node.NodeKey); node == nil {
@ -405,7 +399,7 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
}
close(reg.Registered)
newNode = true
nodeChange = change.NodeAdded(node.ID)
return node, err
} else {
@ -415,6 +409,8 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
return nil, err
}
nodeChange = change.KeyExpiry(node.ID)
return node, nil
}
}
@ -422,7 +418,7 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
return nil, ErrNodeNotFoundRegistrationCache
})
return node, newNode, err
return node, nodeChange, err
}
func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {
@ -448,6 +444,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
if oldNode != nil && oldNode.UserID == node.UserID {
node.ID = oldNode.ID
node.GivenName = oldNode.GivenName
node.ApprovedRoutes = oldNode.ApprovedRoutes
ipv4 = oldNode.IPv4
ipv6 = oldNode.IPv6
}
@ -594,17 +591,18 @@ func ensureUniqueGivenName(
// containing the expired nodes, and a boolean indicating if any nodes were found.
func ExpireExpiredNodes(tx *gorm.DB,
lastCheck time.Time,
) (time.Time, types.StateUpdate, bool) {
) (time.Time, []change.ChangeSet, bool) {
// use the time of the start of the function to ensure we
// dont miss some nodes by returning it _after_ we have
// checked everything.
started := time.Now()
expired := make([]*tailcfg.PeerChange, 0)
var updates []change.ChangeSet
nodes, err := ListNodes(tx)
if err != nil {
return time.Unix(0, 0), types.StateUpdate{}, false
return time.Unix(0, 0), nil, false
}
for _, node := range nodes {
if node.IsExpired() && node.Expiry.After(lastCheck) {
@ -612,14 +610,15 @@ func ExpireExpiredNodes(tx *gorm.DB,
NodeID: tailcfg.NodeID(node.ID),
KeyExpiry: node.Expiry,
})
updates = append(updates, change.KeyExpiry(node.ID))
}
}
if len(expired) > 0 {
return started, types.UpdatePeerPatch(expired...), true
return started, updates, true
}
return started, types.StateUpdate{}, false
return started, nil, false
}
// EphemeralGarbageCollector is a garbage collector that will delete nodes after
@ -732,3 +731,114 @@ func (e *EphemeralGarbageCollector) Start() {
}
}
}
func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string) *types.Node {
if !testing.Testing() {
panic("CreateNodeForTest can only be called during tests")
}
if user == nil {
panic("CreateNodeForTest requires a valid user")
}
nodeName := "testnode"
if len(hostname) > 0 && hostname[0] != "" {
nodeName = hostname[0]
}
// Create a preauth key for the node
pak, err := hsdb.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
if err != nil {
panic(fmt.Sprintf("failed to create preauth key for test node: %v", err))
}
nodeKey := key.NewNode()
machineKey := key.NewMachine()
discoKey := key.NewDisco()
node := &types.Node{
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
DiscoKey: discoKey.Public(),
Hostname: nodeName,
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
}
err = hsdb.DB.Save(node).Error
if err != nil {
panic(fmt.Sprintf("failed to create test node: %v", err))
}
return node
}
func (hsdb *HSDatabase) CreateRegisteredNodeForTest(user *types.User, hostname ...string) *types.Node {
if !testing.Testing() {
panic("CreateRegisteredNodeForTest can only be called during tests")
}
node := hsdb.CreateNodeForTest(user, hostname...)
err := hsdb.DB.Transaction(func(tx *gorm.DB) error {
_, err := RegisterNode(tx, *node, nil, nil)
return err
})
if err != nil {
panic(fmt.Sprintf("failed to register test node: %v", err))
}
registeredNode, err := hsdb.GetNodeByID(node.ID)
if err != nil {
panic(fmt.Sprintf("failed to get registered test node: %v", err))
}
return registeredNode
}
func (hsdb *HSDatabase) CreateNodesForTest(user *types.User, count int, hostnamePrefix ...string) []*types.Node {
if !testing.Testing() {
panic("CreateNodesForTest can only be called during tests")
}
if user == nil {
panic("CreateNodesForTest requires a valid user")
}
prefix := "testnode"
if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" {
prefix = hostnamePrefix[0]
}
nodes := make([]*types.Node, count)
for i := range count {
hostname := prefix + "-" + strconv.Itoa(i)
nodes[i] = hsdb.CreateNodeForTest(user, hostname)
}
return nodes
}
func (hsdb *HSDatabase) CreateRegisteredNodesForTest(user *types.User, count int, hostnamePrefix ...string) []*types.Node {
if !testing.Testing() {
panic("CreateRegisteredNodesForTest can only be called during tests")
}
if user == nil {
panic("CreateRegisteredNodesForTest requires a valid user")
}
prefix := "testnode"
if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" {
prefix = hostnamePrefix[0]
}
nodes := make([]*types.Node, count)
for i := range count {
hostname := prefix + "-" + strconv.Itoa(i)
nodes[i] = hsdb.CreateRegisteredNodeForTest(user, hostname)
}
return nodes
}

View File

@ -6,7 +6,6 @@ import (
"math/big"
"net/netip"
"regexp"
"strconv"
"sync"
"testing"
"time"
@ -26,82 +25,36 @@ import (
)
func (s *Suite) TestGetNode(c *check.C) {
user, err := db.CreateUser(types.User{Name: "test"})
c.Assert(err, check.IsNil)
user := db.CreateUserForTest("test")
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.getNode(types.UserID(user.ID), "testnode")
_, err := db.getNode(types.UserID(user.ID), "testnode")
c.Assert(err, check.NotNil)
nodeKey := key.NewNode()
machineKey := key.NewMachine()
node := &types.Node{
ID: 0,
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
}
trx := db.DB.Save(node)
c.Assert(trx.Error, check.IsNil)
node := db.CreateNodeForTest(user, "testnode")
_, err = db.getNode(types.UserID(user.ID), "testnode")
c.Assert(err, check.IsNil)
c.Assert(node.Hostname, check.Equals, "testnode")
}
func (s *Suite) TestGetNodeByID(c *check.C) {
user, err := db.CreateUser(types.User{Name: "test"})
c.Assert(err, check.IsNil)
user := db.CreateUserForTest("test")
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNodeByID(0)
_, err := db.GetNodeByID(0)
c.Assert(err, check.NotNil)
nodeKey := key.NewNode()
machineKey := key.NewMachine()
node := db.CreateNodeForTest(user, "testnode")
node := types.Node{
ID: 0,
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "testnode",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
}
trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)
_, err = db.GetNodeByID(0)
retrievedNode, err := db.GetNodeByID(node.ID)
c.Assert(err, check.IsNil)
c.Assert(retrievedNode.Hostname, check.Equals, "testnode")
}
func (s *Suite) TestHardDeleteNode(c *check.C) {
user, err := db.CreateUser(types.User{Name: "test"})
c.Assert(err, check.IsNil)
user := db.CreateUserForTest("test")
node := db.CreateNodeForTest(user, "testnode3")
nodeKey := key.NewNode()
machineKey := key.NewMachine()
node := types.Node{
ID: 0,
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "testnode3",
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
}
trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)
err = db.DeleteNode(&node)
err := db.DeleteNode(node)
c.Assert(err, check.IsNil)
_, err = db.getNode(types.UserID(user.ID), "testnode3")
@ -109,42 +62,21 @@ func (s *Suite) TestHardDeleteNode(c *check.C) {
}
func (s *Suite) TestListPeers(c *check.C) {
user, err := db.CreateUser(types.User{Name: "test"})
c.Assert(err, check.IsNil)
user := db.CreateUserForTest("test")
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.GetNodeByID(0)
_, err := db.GetNodeByID(0)
c.Assert(err, check.NotNil)
for index := range 11 {
nodeKey := key.NewNode()
machineKey := key.NewMachine()
nodes := db.CreateNodesForTest(user, 11, "testnode")
node := types.Node{
ID: types.NodeID(index),
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "testnode" + strconv.Itoa(index),
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
}
trx := db.DB.Save(&node)
c.Assert(trx.Error, check.IsNil)
}
node0ByID, err := db.GetNodeByID(0)
firstNode := nodes[0]
peersOfFirstNode, err := db.ListPeers(firstNode.ID)
c.Assert(err, check.IsNil)
peersOfNode0, err := db.ListPeers(node0ByID.ID)
c.Assert(err, check.IsNil)
c.Assert(len(peersOfNode0), check.Equals, 9)
c.Assert(peersOfNode0[0].Hostname, check.Equals, "testnode2")
c.Assert(peersOfNode0[5].Hostname, check.Equals, "testnode7")
c.Assert(peersOfNode0[8].Hostname, check.Equals, "testnode10")
c.Assert(len(peersOfFirstNode), check.Equals, 10)
c.Assert(peersOfFirstNode[0].Hostname, check.Equals, "testnode-1")
c.Assert(peersOfFirstNode[5].Hostname, check.Equals, "testnode-6")
c.Assert(peersOfFirstNode[9].Hostname, check.Equals, "testnode-10")
}
func (s *Suite) TestExpireNode(c *check.C) {
@ -807,13 +739,13 @@ func TestListPeers(t *testing.T) {
// No parameter means no filter, should return all peers
nodes, err = db.ListPeers(1)
require.NoError(t, err)
assert.Len(t, nodes, 1)
assert.Equal(t, 1, len(nodes))
assert.Equal(t, "test2", nodes[0].Hostname)
// Empty node list should return all peers
nodes, err = db.ListPeers(1, types.NodeIDs{}...)
require.NoError(t, err)
assert.Len(t, nodes, 1)
assert.Equal(t, 1, len(nodes))
assert.Equal(t, "test2", nodes[0].Hostname)
// No match in IDs should return empty list and no error
@ -824,13 +756,13 @@ func TestListPeers(t *testing.T) {
// Partial match in IDs
nodes, err = db.ListPeers(1, types.NodeIDs{2, 3}...)
require.NoError(t, err)
assert.Len(t, nodes, 1)
assert.Equal(t, 1, len(nodes))
assert.Equal(t, "test2", nodes[0].Hostname)
// Several matched IDs, but node ID is still filtered out
nodes, err = db.ListPeers(1, types.NodeIDs{1, 2, 3}...)
require.NoError(t, err)
assert.Len(t, nodes, 1)
assert.Equal(t, 1, len(nodes))
assert.Equal(t, "test2", nodes[0].Hostname)
}
@ -892,14 +824,14 @@ func TestListNodes(t *testing.T) {
// No parameter means no filter, should return all nodes
nodes, err = db.ListNodes()
require.NoError(t, err)
assert.Len(t, nodes, 2)
assert.Equal(t, 2, len(nodes))
assert.Equal(t, "test1", nodes[0].Hostname)
assert.Equal(t, "test2", nodes[1].Hostname)
// Empty node list should return all nodes
nodes, err = db.ListNodes(types.NodeIDs{}...)
require.NoError(t, err)
assert.Len(t, nodes, 2)
assert.Equal(t, 2, len(nodes))
assert.Equal(t, "test1", nodes[0].Hostname)
assert.Equal(t, "test2", nodes[1].Hostname)
@ -911,13 +843,13 @@ func TestListNodes(t *testing.T) {
// Partial match in IDs
nodes, err = db.ListNodes(types.NodeIDs{2, 3}...)
require.NoError(t, err)
assert.Len(t, nodes, 1)
assert.Equal(t, 1, len(nodes))
assert.Equal(t, "test2", nodes[0].Hostname)
// Several matched IDs
nodes, err = db.ListNodes(types.NodeIDs{1, 2, 3}...)
require.NoError(t, err)
assert.Len(t, nodes, 2)
assert.Equal(t, 2, len(nodes))
assert.Equal(t, "test1", nodes[0].Hostname)
assert.Equal(t, "test2", nodes[1].Hostname)
}

View File

@ -109,9 +109,7 @@ func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, e
}
func (hsdb *HSDatabase) GetPreAuthKey(key string) (*types.PreAuthKey, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.PreAuthKey, error) {
return GetPreAuthKey(rx, key)
})
return GetPreAuthKey(hsdb.DB, key)
}
// GetPreAuthKey returns a PreAuthKey for a given key. The caller is responsible
@ -155,11 +153,8 @@ func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
if err := tx.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
return err
}
return nil
now := time.Now()
return tx.Model(&types.PreAuthKey{}).Where("id = ?", k.ID).Update("expiration", now).Error
}
func generateKey() (string, error) {

View File

@ -1,7 +1,7 @@
package db
import (
"sort"
"slices"
"testing"
"github.com/juanfont/headscale/hscontrol/types"
@ -57,7 +57,7 @@ func (*Suite) TestPreAuthKeyACLTags(c *check.C) {
listedPaks, err := db.ListPreAuthKeys(types.UserID(user.ID))
c.Assert(err, check.IsNil)
gotTags := listedPaks[0].Proto().GetAclTags()
sort.Sort(sort.StringSlice(gotTags))
slices.Sort(gotTags)
c.Assert(gotTags, check.DeepEquals, tags)
}

View File

@ -3,6 +3,8 @@ package db
import (
"errors"
"fmt"
"strconv"
"testing"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
@ -110,9 +112,7 @@ func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error {
}
func (hsdb *HSDatabase) GetUserByID(uid types.UserID) (*types.User, error) {
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
return GetUserByID(rx, uid)
})
return GetUserByID(hsdb.DB, uid)
}
func GetUserByID(tx *gorm.DB, uid types.UserID) (*types.User, error) {
@ -146,9 +146,7 @@ func GetUserByOIDCIdentifier(tx *gorm.DB, id string) (*types.User, error) {
}
func (hsdb *HSDatabase) ListUsers(where ...*types.User) ([]types.User, error) {
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) {
return ListUsers(rx, where...)
})
return ListUsers(hsdb.DB, where...)
}
// ListUsers gets all the existing users.
@ -217,3 +215,40 @@ func AssignNodeToUser(tx *gorm.DB, nodeID types.NodeID, uid types.UserID) error
return nil
}
func (hsdb *HSDatabase) CreateUserForTest(name ...string) *types.User {
if !testing.Testing() {
panic("CreateUserForTest can only be called during tests")
}
userName := "testuser"
if len(name) > 0 && name[0] != "" {
userName = name[0]
}
user, err := hsdb.CreateUser(types.User{Name: userName})
if err != nil {
panic(fmt.Sprintf("failed to create test user: %v", err))
}
return user
}
func (hsdb *HSDatabase) CreateUsersForTest(count int, namePrefix ...string) []*types.User {
if !testing.Testing() {
panic("CreateUsersForTest can only be called during tests")
}
prefix := "testuser"
if len(namePrefix) > 0 && namePrefix[0] != "" {
prefix = namePrefix[0]
}
users := make([]*types.User, count)
for i := range count {
name := prefix + "-" + strconv.Itoa(i)
users[i] = hsdb.CreateUserForTest(name)
}
return users
}

View File

@ -11,8 +11,7 @@ import (
)
func (s *Suite) TestCreateAndDestroyUser(c *check.C) {
user, err := db.CreateUser(types.User{Name: "test"})
c.Assert(err, check.IsNil)
user := db.CreateUserForTest("test")
c.Assert(user.Name, check.Equals, "test")
users, err := db.ListUsers()
@ -30,8 +29,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
err := db.DestroyUser(9998)
c.Assert(err, check.Equals, ErrUserNotFound)
user, err := db.CreateUser(types.User{Name: "test"})
c.Assert(err, check.IsNil)
user := db.CreateUserForTest("test")
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
c.Assert(err, check.IsNil)
@ -64,8 +62,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
}
func (s *Suite) TestRenameUser(c *check.C) {
userTest, err := db.CreateUser(types.User{Name: "test"})
c.Assert(err, check.IsNil)
userTest := db.CreateUserForTest("test")
c.Assert(userTest.Name, check.Equals, "test")
users, err := db.ListUsers()
@ -86,8 +83,7 @@ func (s *Suite) TestRenameUser(c *check.C) {
err = db.RenameUser(99988, "test")
c.Assert(err, check.Equals, ErrUserNotFound)
userTest2, err := db.CreateUser(types.User{Name: "test2"})
c.Assert(err, check.IsNil)
userTest2 := db.CreateUserForTest("test2")
c.Assert(userTest2.Name, check.Equals, "test2")
want := "UNIQUE constraint failed"
@ -98,11 +94,8 @@ func (s *Suite) TestRenameUser(c *check.C) {
}
func (s *Suite) TestSetMachineUser(c *check.C) {
oldUser, err := db.CreateUser(types.User{Name: "old"})
c.Assert(err, check.IsNil)
newUser, err := db.CreateUser(types.User{Name: "new"})
c.Assert(err, check.IsNil)
oldUser := db.CreateUserForTest("old")
newUser := db.CreateUserForTest("new")
pak, err := db.CreatePreAuthKey(types.UserID(oldUser.ID), false, false, nil, nil)
c.Assert(err, check.IsNil)

View File

@ -17,10 +17,6 @@ import (
func (h *Headscale) debugHTTPServer() *http.Server {
debugMux := http.NewServeMux()
debug := tsweb.Debugger(debugMux)
debug.Handle("notifier", "Connected nodes in notifier", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(h.nodeNotifier.String()))
}))
debug.Handle("config", "Current configuration", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
config, err := json.MarshalIndent(h.cfg, "", " ")
if err != nil {

View File

@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"io"
"maps"
"net/http"
"net/url"
"os"
@ -72,9 +73,7 @@ func mergeDERPMaps(derpMaps []*tailcfg.DERPMap) *tailcfg.DERPMap {
}
for _, derpMap := range derpMaps {
for id, region := range derpMap.Regions {
result.Regions[id] = region
}
maps.Copy(result.Regions, derpMap.Regions)
}
return &result

View File

@ -1,3 +1,5 @@
//go:generate buf generate --template ../buf.gen.yaml -o .. ../proto
// nolint
package hscontrol
@ -27,6 +29,7 @@ import (
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/state"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"github.com/juanfont/headscale/hscontrol/util"
)
@ -56,12 +59,14 @@ func (api headscaleV1APIServer) CreateUser(
return nil, status.Errorf(codes.Internal, "failed to create user: %s", err)
}
// Send policy update notifications if needed
c := change.UserAdded(types.UserID(user.ID))
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "grpc-user-created", user.Name)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
c.Change = change.Policy
}
api.h.Change(c)
return &v1.CreateUserResponse{User: user.Proto()}, nil
}
@ -81,8 +86,7 @@ func (api headscaleV1APIServer) RenameUser(
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "grpc-user-renamed", request.GetNewName())
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
api.h.Change(change.PolicyChange())
}
newUser, err := api.h.state.GetUserByName(request.GetNewName())
@ -107,6 +111,8 @@ func (api headscaleV1APIServer) DeleteUser(
return nil, err
}
api.h.Change(change.UserRemoved(types.UserID(user.ID)))
return &v1.DeleteUserResponse{}, nil
}
@ -246,7 +252,7 @@ func (api headscaleV1APIServer) RegisterNode(
return nil, fmt.Errorf("looking up user: %w", err)
}
node, _, err := api.h.state.HandleNodeFromAuthPath(
node, nodeChange, err := api.h.state.HandleNodeFromAuthPath(
registrationId,
types.UserID(user.ID),
nil,
@ -267,22 +273,13 @@ func (api headscaleV1APIServer) RegisterNode(
// ensure we send an update.
// This works, but might be another good candidate for doing some sort of
// eventbus.
routesChanged := api.h.state.AutoApproveRoutes(node)
_, policyChanged, err := api.h.state.SaveNode(node)
_ = api.h.state.AutoApproveRoutes(node)
_, _, err = api.h.state.SaveNode(node)
if err != nil {
return nil, fmt.Errorf("saving auto approved routes to node: %w", err)
}
// Send policy update notifications if needed (from SaveNode or route changes)
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "grpc-nodes-change", "all")
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
if routesChanged {
ctx = types.NotifyCtx(context.Background(), "web-node-login", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(node.ID))
}
api.h.Change(nodeChange)
return &v1.RegisterNodeResponse{Node: node.Proto()}, nil
}
@ -300,7 +297,7 @@ func (api headscaleV1APIServer) GetNode(
// Populate the online field based on
// currently connected nodes.
resp.Online = api.h.nodeNotifier.IsConnected(node.ID)
resp.Online = api.h.mapBatcher.IsConnected(node.ID)
return &v1.GetNodeResponse{Node: resp}, nil
}
@ -316,21 +313,14 @@ func (api headscaleV1APIServer) SetTags(
}
}
node, policyChanged, err := api.h.state.SetNodeTags(types.NodeID(request.GetNodeId()), request.GetTags())
node, nodeChange, err := api.h.state.SetNodeTags(types.NodeID(request.GetNodeId()), request.GetTags())
if err != nil {
return &v1.SetTagsResponse{
Node: nil,
}, status.Error(codes.InvalidArgument, err.Error())
}
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "grpc-node-tags", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
ctx = types.NotifyCtx(ctx, "cli-settags", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
api.h.Change(nodeChange)
log.Trace().
Str("node", node.Hostname).
@ -362,23 +352,19 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
tsaddr.SortPrefixes(routes)
routes = slices.Compact(routes)
node, policyChanged, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), routes)
node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), routes)
if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "grpc-routes-approved", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
routeChange := api.h.state.SetNodeRoutes(node.ID, node.SubnetRoutes()...)
if api.h.state.SetNodeRoutes(node.ID, node.SubnetRoutes()...) {
ctx := types.NotifyCtx(ctx, "poll-primary-change", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
} else {
ctx = types.NotifyCtx(ctx, "cli-approveroutes", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
// Always propagate node changes from SetApprovedRoutes
api.h.Change(nodeChange)
// If routes changed, propagate those changes too
if !routeChange.Empty() {
api.h.Change(routeChange)
}
proto := node.Proto()
@ -409,19 +395,12 @@ func (api headscaleV1APIServer) DeleteNode(
return nil, err
}
policyChanged, err := api.h.state.DeleteNode(node)
nodeChange, err := api.h.state.DeleteNode(node)
if err != nil {
return nil, err
}
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "grpc-node-deleted", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
ctx = types.NotifyCtx(ctx, "cli-deletenode", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerRemoved(node.ID))
api.h.Change(nodeChange)
return &v1.DeleteNodeResponse{}, nil
}
@ -432,25 +411,13 @@ func (api headscaleV1APIServer) ExpireNode(
) (*v1.ExpireNodeResponse, error) {
now := time.Now()
node, policyChanged, err := api.h.state.SetNodeExpiry(types.NodeID(request.GetNodeId()), now)
node, nodeChange, err := api.h.state.SetNodeExpiry(types.NodeID(request.GetNodeId()), now)
if err != nil {
return nil, err
}
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "grpc-node-expired", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
ctx = types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname)
api.h.nodeNotifier.NotifyByNodeID(
ctx,
types.UpdateSelf(node.ID),
node.ID)
ctx = types.NotifyCtx(ctx, "cli-expirenode-peers", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdateExpire(node.ID, now), node.ID)
// TODO(kradalby): Ensure that both the selfupdate and peer updates are sent
api.h.Change(nodeChange)
log.Trace().
Str("node", node.Hostname).
@ -464,22 +431,13 @@ func (api headscaleV1APIServer) RenameNode(
ctx context.Context,
request *v1.RenameNodeRequest,
) (*v1.RenameNodeResponse, error) {
node, policyChanged, err := api.h.state.RenameNode(types.NodeID(request.GetNodeId()), request.GetNewName())
node, nodeChange, err := api.h.state.RenameNode(types.NodeID(request.GetNodeId()), request.GetNewName())
if err != nil {
return nil, err
}
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "grpc-node-renamed", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
ctx = types.NotifyCtx(ctx, "cli-renamenode-self", node.Hostname)
api.h.nodeNotifier.NotifyByNodeID(ctx, types.UpdateSelf(node.ID), node.ID)
ctx = types.NotifyCtx(ctx, "cli-renamenode-peers", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
// TODO(kradalby): investigate if we need selfupdate
api.h.Change(nodeChange)
log.Trace().
Str("node", node.Hostname).
@ -498,7 +456,7 @@ func (api headscaleV1APIServer) ListNodes(
// probably be done once.
// TODO(kradalby): This should be done in one tx.
isLikelyConnected := api.h.nodeNotifier.LikelyConnectedMap()
IsConnected := api.h.mapBatcher.ConnectedMap()
if request.GetUser() != "" {
user, err := api.h.state.GetUserByName(request.GetUser())
if err != nil {
@ -510,7 +468,7 @@ func (api headscaleV1APIServer) ListNodes(
return nil, err
}
response := nodesToProto(api.h.state, isLikelyConnected, nodes)
response := nodesToProto(api.h.state, IsConnected, nodes)
return &v1.ListNodesResponse{Nodes: response}, nil
}
@ -523,18 +481,18 @@ func (api headscaleV1APIServer) ListNodes(
return nodes[i].ID < nodes[j].ID
})
response := nodesToProto(api.h.state, isLikelyConnected, nodes)
response := nodesToProto(api.h.state, IsConnected, nodes)
return &v1.ListNodesResponse{Nodes: response}, nil
}
func nodesToProto(state *state.State, isLikelyConnected *xsync.MapOf[types.NodeID, bool], nodes types.Nodes) []*v1.Node {
func nodesToProto(state *state.State, IsConnected *xsync.MapOf[types.NodeID, bool], nodes types.Nodes) []*v1.Node {
response := make([]*v1.Node, len(nodes))
for index, node := range nodes {
resp := node.Proto()
// Populate the online field based on
// currently connected nodes.
if val, ok := isLikelyConnected.Load(node.ID); ok && val {
if val, ok := IsConnected.Load(node.ID); ok && val {
resp.Online = true
}
@ -556,24 +514,14 @@ func (api headscaleV1APIServer) MoveNode(
ctx context.Context,
request *v1.MoveNodeRequest,
) (*v1.MoveNodeResponse, error) {
node, policyChanged, err := api.h.state.AssignNodeToUser(types.NodeID(request.GetNodeId()), types.UserID(request.GetUser()))
node, nodeChange, err := api.h.state.AssignNodeToUser(types.NodeID(request.GetNodeId()), types.UserID(request.GetUser()))
if err != nil {
return nil, err
}
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "grpc-node-moved", node.Hostname)
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
ctx = types.NotifyCtx(ctx, "cli-movenode-self", node.Hostname)
api.h.nodeNotifier.NotifyByNodeID(
ctx,
types.UpdateSelf(node.ID),
node.ID)
ctx = types.NotifyCtx(ctx, "cli-movenode", node.Hostname)
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
// TODO(kradalby): Ensure the policy is also sent
// TODO(kradalby): ensure that both the selfupdate and peer updates are sent
api.h.Change(nodeChange)
return &v1.MoveNodeResponse{Node: node.Proto()}, nil
}
@ -754,8 +702,7 @@ func (api headscaleV1APIServer) SetPolicy(
return nil, err
}
ctx := types.NotifyCtx(context.Background(), "acl-update", "na")
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
api.h.Change(change.PolicyChange())
}
response := &v1.SetPolicyResponse{

155
hscontrol/mapper/batcher.go Normal file
View File

@ -0,0 +1,155 @@
package mapper
import (
"fmt"
"time"
"github.com/juanfont/headscale/hscontrol/state"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"github.com/puzpuzpuz/xsync/v4"
"tailscale.com/tailcfg"
"tailscale.com/types/ptr"
)
type batcherFunc func(cfg *types.Config, state *state.State) Batcher
// Batcher defines the common interface for all batcher implementations.
type Batcher interface {
Start()
Close()
AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error
RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool)
IsConnected(id types.NodeID) bool
ConnectedMap() *xsync.Map[types.NodeID, bool]
AddWork(c change.ChangeSet)
MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error)
}
func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeBatcher {
return &LockFreeBatcher{
mapper: mapper,
workers: workers,
tick: time.NewTicker(batchTime),
// The size of this channel is arbitrary chosen, the sizing should be revisited.
workCh: make(chan work, workers*200),
nodes: xsync.NewMap[types.NodeID, *nodeConn](),
connected: xsync.NewMap[types.NodeID, *time.Time](),
pendingChanges: xsync.NewMap[types.NodeID, []change.ChangeSet](),
}
}
// NewBatcherAndMapper creates a Batcher implementation.
func NewBatcherAndMapper(cfg *types.Config, state *state.State) Batcher {
m := newMapper(cfg, state)
b := NewBatcher(cfg.Tuning.BatchChangeDelay, cfg.Tuning.BatcherWorkers, m)
m.batcher = b
return b
}
// nodeConnection interface for different connection implementations.
type nodeConnection interface {
nodeID() types.NodeID
version() tailcfg.CapabilityVersion
send(data *tailcfg.MapResponse) error
}
// generateMapResponse generates a [tailcfg.MapResponse] for the given NodeID that is based on the provided [change.ChangeSet].
func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion, mapper *mapper, c change.ChangeSet) (*tailcfg.MapResponse, error) {
if c.Empty() {
return nil, nil
}
// Validate inputs before processing
if nodeID == 0 {
return nil, fmt.Errorf("invalid nodeID: %d", nodeID)
}
if mapper == nil {
return nil, fmt.Errorf("mapper is nil for nodeID %d", nodeID)
}
var mapResp *tailcfg.MapResponse
var err error
switch c.Change {
case change.DERP:
mapResp, err = mapper.derpMapResponse(nodeID)
case change.NodeCameOnline, change.NodeWentOffline:
if c.IsSubnetRouter {
// TODO(kradalby): This can potentially be a peer update of the old and new subnet router.
mapResp, err = mapper.fullMapResponse(nodeID, version)
} else {
mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{
{
NodeID: c.NodeID.NodeID(),
Online: ptr.To(c.Change == change.NodeCameOnline),
},
})
}
case change.NodeNewOrUpdate:
mapResp, err = mapper.fullMapResponse(nodeID, version)
case change.NodeRemove:
mapResp, err = mapper.peerRemovedResponse(nodeID, c.NodeID)
default:
// The following will always hit this:
// change.Full, change.Policy
mapResp, err = mapper.fullMapResponse(nodeID, version)
}
if err != nil {
return nil, fmt.Errorf("generating map response for nodeID %d: %w", nodeID, err)
}
// TODO(kradalby): Is this necessary?
// Validate the generated map response - only check for nil response
// Note: mapResp.Node can be nil for peer updates, which is valid
if mapResp == nil && c.Change != change.DERP && c.Change != change.NodeRemove {
return nil, fmt.Errorf("generated nil map response for nodeID %d change %s", nodeID, c.Change.String())
}
return mapResp, nil
}
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.ChangeSet].
func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) error {
if nc == nil {
return fmt.Errorf("nodeConnection is nil")
}
nodeID := nc.nodeID()
data, err := generateMapResponse(nodeID, nc.version(), mapper, c)
if err != nil {
return fmt.Errorf("generating map response for node %d: %w", nodeID, err)
}
if data == nil {
// No data to send is valid for some change types
return nil
}
// Send the map response
if err := nc.send(data); err != nil {
return fmt.Errorf("sending map response to node %d: %w", nodeID, err)
}
return nil
}
// workResult represents the result of processing a change.
type workResult struct {
mapResponse *tailcfg.MapResponse
err error
}
// work represents a unit of work to be processed by workers.
type work struct {
c change.ChangeSet
nodeID types.NodeID
resultCh chan<- workResult // optional channel for synchronous operations
}

View File

@ -0,0 +1,491 @@
package mapper
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"github.com/puzpuzpuz/xsync/v4"
"github.com/rs/zerolog/log"
"tailscale.com/tailcfg"
"tailscale.com/types/ptr"
)
// LockFreeBatcher uses atomic operations and concurrent maps to eliminate mutex contention.
type LockFreeBatcher struct {
tick *time.Ticker
mapper *mapper
workers int
// Lock-free concurrent maps
nodes *xsync.Map[types.NodeID, *nodeConn]
connected *xsync.Map[types.NodeID, *time.Time]
// Work queue channel
workCh chan work
ctx context.Context
cancel context.CancelFunc
// Batching state
pendingChanges *xsync.Map[types.NodeID, []change.ChangeSet]
batchMutex sync.RWMutex
// Metrics
totalNodes atomic.Int64
totalUpdates atomic.Int64
workQueuedCount atomic.Int64
workProcessed atomic.Int64
workErrors atomic.Int64
}
// AddNode registers a new node connection with the batcher and sends an initial map response.
// It creates or updates the node's connection data, validates the initial map generation,
// and notifies other nodes that this node has come online.
// TODO(kradalby): See if we can move the isRouter argument somewhere else.
func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error {
// First validate that we can generate initial map before doing anything else
fullSelfChange := change.FullSelf(id)
// TODO(kradalby): This should not be generated here, but rather in MapResponseFromChange.
// This currently means that the goroutine for the node connection will do the processing
// which means that we might have uncontrolled concurrency.
// When we use MapResponseFromChange, it will be processed by the same worker pool, causing
// it to be processed in a more controlled manner.
initialMap, err := generateMapResponse(id, version, b.mapper, fullSelfChange)
if err != nil {
return fmt.Errorf("failed to generate initial map for node %d: %w", id, err)
}
// Only after validation succeeds, create or update node connection
newConn := newNodeConn(id, c, version, b.mapper)
var conn *nodeConn
if existing, loaded := b.nodes.LoadOrStore(id, newConn); loaded {
// Update existing connection
existing.updateConnection(c, version)
conn = existing
} else {
b.totalNodes.Add(1)
conn = newConn
}
// Mark as connected only after validation succeeds
b.connected.Store(id, nil) // nil = connected
log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node connected to batcher")
// Send the validated initial map
if initialMap != nil {
if err := conn.send(initialMap); err != nil {
// Clean up the connection state on send failure
b.nodes.Delete(id)
b.connected.Delete(id)
return fmt.Errorf("failed to send initial map to node %d: %w", id, err)
}
// Notify other nodes that this node came online
b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeCameOnline, IsSubnetRouter: isRouter})
}
return nil
}
// RemoveNode disconnects a node from the batcher, marking it as offline and cleaning up its state.
// It validates the connection channel matches the current one, closes the connection,
// and notifies other nodes that this node has gone offline.
func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool) {
// Check if this is the current connection and mark it as closed
if existing, ok := b.nodes.Load(id); ok {
if !existing.matchesChannel(c) {
log.Debug().Uint64("node.id", id.Uint64()).Msg("RemoveNode called for non-current connection, ignoring")
return // Not the current connection, not an error
}
// Mark the connection as closed to prevent further sends
if connData := existing.connData.Load(); connData != nil {
connData.closed.Store(true)
}
}
log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node disconnected from batcher, marking as offline")
// Remove node and mark disconnected atomically
b.nodes.Delete(id)
b.connected.Store(id, ptr.To(time.Now()))
b.totalNodes.Add(-1)
// Notify other nodes that this node went offline
b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeWentOffline, IsSubnetRouter: isRouter})
}
// AddWork queues a change to be processed by the batcher.
// Critical changes are processed immediately, while others are batched for efficiency.
func (b *LockFreeBatcher) AddWork(c change.ChangeSet) {
b.addWork(c)
}
func (b *LockFreeBatcher) Start() {
b.ctx, b.cancel = context.WithCancel(context.Background())
go b.doWork()
}
func (b *LockFreeBatcher) Close() {
if b.cancel != nil {
b.cancel()
}
close(b.workCh)
}
func (b *LockFreeBatcher) doWork() {
log.Debug().Msg("batcher doWork loop started")
defer log.Debug().Msg("batcher doWork loop stopped")
for i := range b.workers {
go b.worker(i + 1)
}
for {
select {
case <-b.tick.C:
// Process batched changes
b.processBatchedChanges()
case <-b.ctx.Done():
return
}
}
}
func (b *LockFreeBatcher) worker(workerID int) {
log.Debug().Int("workerID", workerID).Msg("batcher worker started")
defer log.Debug().Int("workerID", workerID).Msg("batcher worker stopped")
for {
select {
case w, ok := <-b.workCh:
if !ok {
return
}
startTime := time.Now()
b.workProcessed.Add(1)
// If the resultCh is set, it means that this is a work request
// where there is a blocking function waiting for the map that
// is being generated.
// This is used for synchronous map generation.
if w.resultCh != nil {
var result workResult
if nc, exists := b.nodes.Load(w.nodeID); exists {
result.mapResponse, result.err = generateMapResponse(nc.nodeID(), nc.version(), b.mapper, w.c)
if result.err != nil {
b.workErrors.Add(1)
log.Error().Err(result.err).
Int("workerID", workerID).
Uint64("node.id", w.nodeID.Uint64()).
Str("change", w.c.Change.String()).
Msg("failed to generate map response for synchronous work")
}
} else {
result.err = fmt.Errorf("node %d not found", w.nodeID)
b.workErrors.Add(1)
log.Error().Err(result.err).
Int("workerID", workerID).
Uint64("node.id", w.nodeID.Uint64()).
Msg("node not found for synchronous work")
}
// Send result
select {
case w.resultCh <- result:
case <-b.ctx.Done():
return
}
duration := time.Since(startTime)
if duration > 100*time.Millisecond {
log.Warn().
Int("workerID", workerID).
Uint64("node.id", w.nodeID.Uint64()).
Str("change", w.c.Change.String()).
Dur("duration", duration).
Msg("slow synchronous work processing")
}
continue
}
// If resultCh is nil, this is an asynchronous work request
// that should be processed and sent to the node instead of
// returned to the caller.
if nc, exists := b.nodes.Load(w.nodeID); exists {
// Check if this connection is still active before processing
if connData := nc.connData.Load(); connData != nil && connData.closed.Load() {
log.Debug().
Int("workerID", workerID).
Uint64("node.id", w.nodeID.Uint64()).
Str("change", w.c.Change.String()).
Msg("skipping work for closed connection")
continue
}
err := nc.change(w.c)
if err != nil {
b.workErrors.Add(1)
log.Error().Err(err).
Int("workerID", workerID).
Uint64("node.id", w.c.NodeID.Uint64()).
Str("change", w.c.Change.String()).
Msg("failed to apply change")
}
} else {
log.Debug().
Int("workerID", workerID).
Uint64("node.id", w.nodeID.Uint64()).
Str("change", w.c.Change.String()).
Msg("node not found for asynchronous work - node may have disconnected")
}
duration := time.Since(startTime)
if duration > 100*time.Millisecond {
log.Warn().
Int("workerID", workerID).
Uint64("node.id", w.nodeID.Uint64()).
Str("change", w.c.Change.String()).
Dur("duration", duration).
Msg("slow asynchronous work processing")
}
case <-b.ctx.Done():
return
}
}
}
func (b *LockFreeBatcher) addWork(c change.ChangeSet) {
// For critical changes that need immediate processing, send directly
if b.shouldProcessImmediately(c) {
if c.SelfUpdateOnly {
b.queueWork(work{c: c, nodeID: c.NodeID, resultCh: nil})
return
}
b.nodes.Range(func(nodeID types.NodeID, _ *nodeConn) bool {
if c.NodeID == nodeID && !c.AlsoSelf() {
return true
}
b.queueWork(work{c: c, nodeID: nodeID, resultCh: nil})
return true
})
return
}
// For non-critical changes, add to batch
b.addToBatch(c)
}
// queueWork safely queues work
func (b *LockFreeBatcher) queueWork(w work) {
b.workQueuedCount.Add(1)
select {
case b.workCh <- w:
// Successfully queued
case <-b.ctx.Done():
// Batcher is shutting down
return
}
}
// shouldProcessImmediately determines if a change should bypass batching
func (b *LockFreeBatcher) shouldProcessImmediately(c change.ChangeSet) bool {
// Process these changes immediately to avoid delaying critical functionality
switch c.Change {
case change.Full, change.NodeRemove, change.NodeCameOnline, change.NodeWentOffline, change.Policy:
return true
default:
return false
}
}
// addToBatch adds a change to the pending batch
func (b *LockFreeBatcher) addToBatch(c change.ChangeSet) {
b.batchMutex.Lock()
defer b.batchMutex.Unlock()
if c.SelfUpdateOnly {
changes, _ := b.pendingChanges.LoadOrStore(c.NodeID, []change.ChangeSet{})
changes = append(changes, c)
b.pendingChanges.Store(c.NodeID, changes)
return
}
b.nodes.Range(func(nodeID types.NodeID, _ *nodeConn) bool {
if c.NodeID == nodeID && !c.AlsoSelf() {
return true
}
changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.ChangeSet{})
changes = append(changes, c)
b.pendingChanges.Store(nodeID, changes)
return true
})
}
// processBatchedChanges processes all pending batched changes
func (b *LockFreeBatcher) processBatchedChanges() {
b.batchMutex.Lock()
defer b.batchMutex.Unlock()
if b.pendingChanges == nil {
return
}
// Process all pending changes
b.pendingChanges.Range(func(nodeID types.NodeID, changes []change.ChangeSet) bool {
if len(changes) == 0 {
return true
}
// Send all batched changes for this node
for _, c := range changes {
b.queueWork(work{c: c, nodeID: nodeID, resultCh: nil})
}
// Clear the pending changes for this node
b.pendingChanges.Delete(nodeID)
return true
})
}
// IsConnected is lock-free read.
func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool {
if val, ok := b.connected.Load(id); ok {
// nil means connected
return val == nil
}
return false
}
// ConnectedMap returns a lock-free map of all connected nodes.
func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
ret := xsync.NewMap[types.NodeID, bool]()
b.connected.Range(func(id types.NodeID, val *time.Time) bool {
// nil means connected
ret.Store(id, val == nil)
return true
})
return ret
}
// MapResponseFromChange queues work to generate a map response and waits for the result.
// This allows synchronous map generation using the same worker pool.
func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error) {
resultCh := make(chan workResult, 1)
// Queue the work with a result channel using the safe queueing method
b.queueWork(work{c: c, nodeID: id, resultCh: resultCh})
// Wait for the result
select {
case result := <-resultCh:
return result.mapResponse, result.err
case <-b.ctx.Done():
return nil, fmt.Errorf("batcher shutting down while generating map response for node %d", id)
}
}
// connectionData holds the channel and connection parameters.
type connectionData struct {
c chan<- *tailcfg.MapResponse
version tailcfg.CapabilityVersion
closed atomic.Bool // Track if this connection has been closed
}
// nodeConn described the node connection and its associated data.
type nodeConn struct {
id types.NodeID
mapper *mapper
// Atomic pointer to connection data - allows lock-free updates
connData atomic.Pointer[connectionData]
updateCount atomic.Int64
}
func newNodeConn(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion, mapper *mapper) *nodeConn {
nc := &nodeConn{
id: id,
mapper: mapper,
}
// Initialize connection data
data := &connectionData{
c: c,
version: version,
}
nc.connData.Store(data)
return nc
}
// updateConnection atomically updates connection parameters.
func (nc *nodeConn) updateConnection(c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) {
newData := &connectionData{
c: c,
version: version,
}
nc.connData.Store(newData)
}
// matchesChannel checks if the given channel matches current connection.
func (nc *nodeConn) matchesChannel(c chan<- *tailcfg.MapResponse) bool {
data := nc.connData.Load()
if data == nil {
return false
}
// Compare channel pointers directly
return data.c == c
}
// compressAndVersion atomically reads connection settings.
func (nc *nodeConn) version() tailcfg.CapabilityVersion {
data := nc.connData.Load()
if data == nil {
return 0
}
return data.version
}
func (nc *nodeConn) nodeID() types.NodeID {
return nc.id
}
func (nc *nodeConn) change(c change.ChangeSet) error {
return handleNodeChange(nc, nc.mapper, c)
}
// send sends data to the node's channel.
// The node will pick it up and send it to the HTTP handler.
func (nc *nodeConn) send(data *tailcfg.MapResponse) error {
connData := nc.connData.Load()
if connData == nil {
return fmt.Errorf("node %d: no connection data", nc.id)
}
// Check if connection has been closed
if connData.closed.Load() {
return fmt.Errorf("node %d: connection closed", nc.id)
}
// TODO(kradalby): We might need some sort of timeout here if the client is not reading
// the channel. That might mean that we are sending to a node that has gone offline, but
// the channel is still open.
connData.c <- data
nc.updateCount.Add(1)
return nil
}

File diff suppressed because it is too large Load Diff

259
hscontrol/mapper/builder.go Normal file
View File

@ -0,0 +1,259 @@
package mapper
import (
"net/netip"
"sort"
"time"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types"
"tailscale.com/tailcfg"
"tailscale.com/types/views"
"tailscale.com/util/multierr"
)
// MapResponseBuilder provides a fluent interface for building tailcfg.MapResponse
type MapResponseBuilder struct {
resp *tailcfg.MapResponse
mapper *mapper
nodeID types.NodeID
capVer tailcfg.CapabilityVersion
errs []error
}
// NewMapResponseBuilder creates a new builder with basic fields set
func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder {
now := time.Now()
return &MapResponseBuilder{
resp: &tailcfg.MapResponse{
KeepAlive: false,
ControlTime: &now,
},
mapper: m,
nodeID: nodeID,
errs: nil,
}
}
// addError adds an error to the builder's error list
func (b *MapResponseBuilder) addError(err error) {
if err != nil {
b.errs = append(b.errs, err)
}
}
// hasErrors returns true if the builder has accumulated any errors
func (b *MapResponseBuilder) hasErrors() bool {
return len(b.errs) > 0
}
// WithCapabilityVersion sets the capability version for the response
func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVersion) *MapResponseBuilder {
b.capVer = capVer
return b
}
// WithSelfNode adds the requesting node to the response
func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
node, err := b.mapper.state.GetNodeByID(b.nodeID)
if err != nil {
b.addError(err)
return b
}
_, matchers := b.mapper.state.Filter()
tailnode, err := tailNode(
node.View(), b.capVer, b.mapper.state,
func(id types.NodeID) []netip.Prefix {
return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers)
},
b.mapper.cfg)
if err != nil {
b.addError(err)
return b
}
b.resp.Node = tailnode
return b
}
// WithDERPMap adds the DERP map to the response
func (b *MapResponseBuilder) WithDERPMap() *MapResponseBuilder {
b.resp.DERPMap = b.mapper.state.DERPMap()
return b
}
// WithDomain adds the domain configuration
func (b *MapResponseBuilder) WithDomain() *MapResponseBuilder {
b.resp.Domain = b.mapper.cfg.Domain()
return b
}
// WithCollectServicesDisabled sets the collect services flag to false
func (b *MapResponseBuilder) WithCollectServicesDisabled() *MapResponseBuilder {
b.resp.CollectServices.Set(false)
return b
}
// WithDebugConfig adds debug configuration
// It disables log tailing if the mapper's LogTail is not enabled
func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
b.resp.Debug = &tailcfg.Debug{
DisableLogTail: !b.mapper.cfg.LogTail.Enabled,
}
return b
}
// WithSSHPolicy adds SSH policy configuration for the requesting node
func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
node, err := b.mapper.state.GetNodeByID(b.nodeID)
if err != nil {
b.addError(err)
return b
}
sshPolicy, err := b.mapper.state.SSHPolicy(node.View())
if err != nil {
b.addError(err)
return b
}
b.resp.SSHPolicy = sshPolicy
return b
}
// WithDNSConfig adds DNS configuration for the requesting node
func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
node, err := b.mapper.state.GetNodeByID(b.nodeID)
if err != nil {
b.addError(err)
return b
}
b.resp.DNSConfig = generateDNSConfig(b.mapper.cfg, node)
return b
}
// WithUserProfiles adds user profiles for the requesting node and given peers
func (b *MapResponseBuilder) WithUserProfiles(peers types.Nodes) *MapResponseBuilder {
node, err := b.mapper.state.GetNodeByID(b.nodeID)
if err != nil {
b.addError(err)
return b
}
b.resp.UserProfiles = generateUserProfiles(node, peers)
return b
}
// WithPacketFilters adds packet filter rules based on policy
func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
node, err := b.mapper.state.GetNodeByID(b.nodeID)
if err != nil {
b.addError(err)
return b
}
filter, _ := b.mapper.state.Filter()
// CapVer 81: 2023-11-17: MapResponse.PacketFilters (incremental packet filter updates)
// Currently, we do not send incremental package filters, however using the
// new PacketFilters field and "base" allows us to send a full update when we
// have to send an empty list, avoiding the hack in the else block.
b.resp.PacketFilters = map[string][]tailcfg.FilterRule{
"base": policy.ReduceFilterRules(node.View(), filter),
}
return b
}
// WithPeers adds full peer list with policy filtering (for full map response)
func (b *MapResponseBuilder) WithPeers(peers types.Nodes) *MapResponseBuilder {
tailPeers, err := b.buildTailPeers(peers)
if err != nil {
b.addError(err)
return b
}
b.resp.Peers = tailPeers
return b
}
// WithPeerChanges adds changed peers with policy filtering (for incremental updates)
func (b *MapResponseBuilder) WithPeerChanges(peers types.Nodes) *MapResponseBuilder {
tailPeers, err := b.buildTailPeers(peers)
if err != nil {
b.addError(err)
return b
}
b.resp.PeersChanged = tailPeers
return b
}
// buildTailPeers converts types.Nodes to []tailcfg.Node with policy filtering and sorting
func (b *MapResponseBuilder) buildTailPeers(peers types.Nodes) ([]*tailcfg.Node, error) {
node, err := b.mapper.state.GetNodeByID(b.nodeID)
if err != nil {
return nil, err
}
filter, matchers := b.mapper.state.Filter()
// If there are filter rules present, see if there are any nodes that cannot
// access each-other at all and remove them from the peers.
var changedViews views.Slice[types.NodeView]
if len(filter) > 0 {
changedViews = policy.ReduceNodes(node.View(), peers.ViewSlice(), matchers)
} else {
changedViews = peers.ViewSlice()
}
tailPeers, err := tailNodes(
changedViews, b.capVer, b.mapper.state,
func(id types.NodeID) []netip.Prefix {
return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers)
},
b.mapper.cfg)
if err != nil {
return nil, err
}
// Peers is always returned sorted by Node.ID.
sort.SliceStable(tailPeers, func(x, y int) bool {
return tailPeers[x].ID < tailPeers[y].ID
})
return tailPeers, nil
}
// WithPeerChangedPatch adds peer change patches
func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange) *MapResponseBuilder {
b.resp.PeersChangedPatch = changes
return b
}
// WithPeersRemoved adds removed peer IDs
func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder {
var tailscaleIDs []tailcfg.NodeID
for _, id := range removedIDs {
tailscaleIDs = append(tailscaleIDs, id.NodeID())
}
b.resp.PeersRemoved = tailscaleIDs
return b
}
// Build finalizes the response and returns marshaled bytes
func (b *MapResponseBuilder) Build(messages ...string) (*tailcfg.MapResponse, error) {
if len(b.errs) > 0 {
return nil, multierr.New(b.errs...)
}
if debugDumpMapResponsePath != "" {
writeDebugMapResponse(b.resp, b.nodeID)
}
return b.resp, nil
}

View File

@ -0,0 +1,347 @@
package mapper
import (
"testing"
"time"
"github.com/juanfont/headscale/hscontrol/state"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tailscale.com/tailcfg"
)
func TestMapResponseBuilder_Basic(t *testing.T) {
cfg := &types.Config{
BaseDomain: "example.com",
LogTail: types.LogTailConfig{
Enabled: true,
},
}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
builder := m.NewMapResponseBuilder(nodeID)
// Test basic builder creation
assert.NotNil(t, builder)
assert.Equal(t, nodeID, builder.nodeID)
assert.NotNil(t, builder.resp)
assert.False(t, builder.resp.KeepAlive)
assert.NotNil(t, builder.resp.ControlTime)
assert.WithinDuration(t, time.Now(), *builder.resp.ControlTime, time.Second)
}
func TestMapResponseBuilder_WithCapabilityVersion(t *testing.T) {
cfg := &types.Config{}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
capVer := tailcfg.CapabilityVersion(42)
builder := m.NewMapResponseBuilder(nodeID).
WithCapabilityVersion(capVer)
assert.Equal(t, capVer, builder.capVer)
assert.False(t, builder.hasErrors())
}
func TestMapResponseBuilder_WithDomain(t *testing.T) {
domain := "test.example.com"
cfg := &types.Config{
ServerURL: "https://test.example.com",
BaseDomain: domain,
}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
builder := m.NewMapResponseBuilder(nodeID).
WithDomain()
assert.Equal(t, domain, builder.resp.Domain)
assert.False(t, builder.hasErrors())
}
func TestMapResponseBuilder_WithCollectServicesDisabled(t *testing.T) {
cfg := &types.Config{}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
builder := m.NewMapResponseBuilder(nodeID).
WithCollectServicesDisabled()
value, isSet := builder.resp.CollectServices.Get()
assert.True(t, isSet)
assert.False(t, value)
assert.False(t, builder.hasErrors())
}
func TestMapResponseBuilder_WithDebugConfig(t *testing.T) {
tests := []struct {
name string
logTailEnabled bool
expected bool
}{
{
name: "LogTail enabled",
logTailEnabled: true,
expected: false, // DisableLogTail should be false when LogTail is enabled
},
{
name: "LogTail disabled",
logTailEnabled: false,
expected: true, // DisableLogTail should be true when LogTail is disabled
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := &types.Config{
LogTail: types.LogTailConfig{
Enabled: tt.logTailEnabled,
},
}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
builder := m.NewMapResponseBuilder(nodeID).
WithDebugConfig()
require.NotNil(t, builder.resp.Debug)
assert.Equal(t, tt.expected, builder.resp.Debug.DisableLogTail)
assert.False(t, builder.hasErrors())
})
}
}
func TestMapResponseBuilder_WithPeerChangedPatch(t *testing.T) {
cfg := &types.Config{}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
changes := []*tailcfg.PeerChange{
{
NodeID: 123,
DERPRegion: 1,
},
{
NodeID: 456,
DERPRegion: 2,
},
}
builder := m.NewMapResponseBuilder(nodeID).
WithPeerChangedPatch(changes)
assert.Equal(t, changes, builder.resp.PeersChangedPatch)
assert.False(t, builder.hasErrors())
}
func TestMapResponseBuilder_WithPeersRemoved(t *testing.T) {
cfg := &types.Config{}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
removedID1 := types.NodeID(123)
removedID2 := types.NodeID(456)
builder := m.NewMapResponseBuilder(nodeID).
WithPeersRemoved(removedID1, removedID2)
expected := []tailcfg.NodeID{
removedID1.NodeID(),
removedID2.NodeID(),
}
assert.Equal(t, expected, builder.resp.PeersRemoved)
assert.False(t, builder.hasErrors())
}
func TestMapResponseBuilder_ErrorHandling(t *testing.T) {
cfg := &types.Config{}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
// Simulate an error in the builder
builder := m.NewMapResponseBuilder(nodeID)
builder.addError(assert.AnError)
// All subsequent calls should continue to work and accumulate errors
result := builder.
WithDomain().
WithCollectServicesDisabled().
WithDebugConfig()
assert.True(t, result.hasErrors())
assert.Len(t, result.errs, 1)
assert.Equal(t, assert.AnError, result.errs[0])
// Build should return the error
data, err := result.Build("none")
assert.Nil(t, data)
assert.Error(t, err)
}
func TestMapResponseBuilder_ChainedCalls(t *testing.T) {
domain := "chained.example.com"
cfg := &types.Config{
ServerURL: "https://chained.example.com",
BaseDomain: domain,
LogTail: types.LogTailConfig{
Enabled: false,
},
}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
capVer := tailcfg.CapabilityVersion(99)
builder := m.NewMapResponseBuilder(nodeID).
WithCapabilityVersion(capVer).
WithDomain().
WithCollectServicesDisabled().
WithDebugConfig()
// Verify all fields are set correctly
assert.Equal(t, capVer, builder.capVer)
assert.Equal(t, domain, builder.resp.Domain)
value, isSet := builder.resp.CollectServices.Get()
assert.True(t, isSet)
assert.False(t, value)
assert.NotNil(t, builder.resp.Debug)
assert.True(t, builder.resp.Debug.DisableLogTail)
assert.False(t, builder.hasErrors())
}
func TestMapResponseBuilder_MultipleWithPeersRemoved(t *testing.T) {
cfg := &types.Config{}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
removedID1 := types.NodeID(100)
removedID2 := types.NodeID(200)
// Test calling WithPeersRemoved multiple times
builder := m.NewMapResponseBuilder(nodeID).
WithPeersRemoved(removedID1).
WithPeersRemoved(removedID2)
// Second call should overwrite the first
expected := []tailcfg.NodeID{removedID2.NodeID()}
assert.Equal(t, expected, builder.resp.PeersRemoved)
assert.False(t, builder.hasErrors())
}
func TestMapResponseBuilder_EmptyPeerChangedPatch(t *testing.T) {
cfg := &types.Config{}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
builder := m.NewMapResponseBuilder(nodeID).
WithPeerChangedPatch([]*tailcfg.PeerChange{})
assert.Empty(t, builder.resp.PeersChangedPatch)
assert.False(t, builder.hasErrors())
}
func TestMapResponseBuilder_NilPeerChangedPatch(t *testing.T) {
cfg := &types.Config{}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
builder := m.NewMapResponseBuilder(nodeID).
WithPeerChangedPatch(nil)
assert.Nil(t, builder.resp.PeersChangedPatch)
assert.False(t, builder.hasErrors())
}
func TestMapResponseBuilder_MultipleErrors(t *testing.T) {
cfg := &types.Config{}
mockState := &state.State{}
m := &mapper{
cfg: cfg,
state: mockState,
}
nodeID := types.NodeID(1)
// Create a builder and add multiple errors
builder := m.NewMapResponseBuilder(nodeID)
builder.addError(assert.AnError)
builder.addError(assert.AnError)
builder.addError(nil) // This should be ignored
// All subsequent calls should continue to work
result := builder.
WithDomain().
WithCollectServicesDisabled()
assert.True(t, result.hasErrors())
assert.Len(t, result.errs, 2) // nil error should be ignored
// Build should return a multierr
data, err := result.Build("none")
assert.Nil(t, data)
assert.Error(t, err)
// The error should contain information about multiple errors
assert.Contains(t, err.Error(), "multiple errors")
}

View File

@ -1,7 +1,6 @@
package mapper
import (
"encoding/binary"
"encoding/json"
"fmt"
"io/fs"
@ -10,31 +9,21 @@ import (
"os"
"path"
"slices"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/state"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/klauspost/compress/zstd"
"github.com/rs/zerolog/log"
"tailscale.com/envknob"
"tailscale.com/smallzstd"
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
"tailscale.com/types/views"
)
const (
nextDNSDoHPrefix = "https://dns.nextdns.io"
reservedResponseHeaderSize = 4
mapperIDLength = 8
debugMapResponsePerm = 0o755
nextDNSDoHPrefix = "https://dns.nextdns.io"
mapperIDLength = 8
debugMapResponsePerm = 0o755
)
var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH")
@ -50,15 +39,13 @@ var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_
// - Create a "minifier" that removes info not needed for the node
// - some sort of batching, wait for 5 or 60 seconds before sending
type Mapper struct {
type mapper struct {
// Configuration
state *state.State
cfg *types.Config
notif *notifier.Notifier
state *state.State
cfg *types.Config
batcher Batcher
uid string
created time.Time
seq uint64
}
type patch struct {
@ -66,41 +53,31 @@ type patch struct {
change *tailcfg.PeerChange
}
func NewMapper(
state *state.State,
func newMapper(
cfg *types.Config,
notif *notifier.Notifier,
) *Mapper {
uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
state *state.State,
) *mapper {
// uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
return &Mapper{
return &mapper{
state: state,
cfg: cfg,
notif: notif,
uid: uid,
created: time.Now(),
seq: 0,
}
}
func (m *Mapper) String() string {
return fmt.Sprintf("Mapper: { seq: %d, uid: %s, created: %s }", m.seq, m.uid, m.created)
}
func generateUserProfiles(
node types.NodeView,
peers views.Slice[types.NodeView],
node *types.Node,
peers types.Nodes,
) []tailcfg.UserProfile {
userMap := make(map[uint]*types.User)
ids := make([]uint, 0, peers.Len()+1)
user := node.User()
userMap[user.ID] = &user
ids = append(ids, user.ID)
for _, peer := range peers.All() {
peerUser := peer.User()
userMap[peerUser.ID] = &peerUser
ids = append(ids, peerUser.ID)
ids := make([]uint, 0, len(userMap))
userMap[node.User.ID] = &node.User
ids = append(ids, node.User.ID)
for _, peer := range peers {
userMap[peer.User.ID] = &peer.User
ids = append(ids, peer.User.ID)
}
slices.Sort(ids)
@ -117,7 +94,7 @@ func generateUserProfiles(
func generateDNSConfig(
cfg *types.Config,
node types.NodeView,
node *types.Node,
) *tailcfg.DNSConfig {
if cfg.TailcfgDNSConfig == nil {
return nil
@ -137,17 +114,16 @@ func generateDNSConfig(
//
// This will produce a resolver like:
// `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1`
func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
for _, resolver := range resolvers {
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {
attrs := url.Values{
"device_name": []string{node.Hostname()},
"device_model": []string{node.Hostinfo().OS()},
"device_name": []string{node.Hostname},
"device_model": []string{node.Hostinfo.OS},
}
nodeIPs := node.IPs()
if len(nodeIPs) > 0 {
attrs.Add("device_ip", nodeIPs[0].String())
if len(node.IPs()) > 0 {
attrs.Add("device_ip", node.IPs()[0].String())
}
resolver.Addr = fmt.Sprintf("%s?%s", resolver.Addr, attrs.Encode())
@ -155,434 +131,151 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
}
}
// fullMapResponse creates a complete MapResponse for a node.
// It is a separate function to make testing easier.
func (m *Mapper) fullMapResponse(
node types.NodeView,
peers views.Slice[types.NodeView],
// fullMapResponse returns a MapResponse for the given node.
func (m *mapper) fullMapResponse(
nodeID types.NodeID,
capVer tailcfg.CapabilityVersion,
messages ...string,
) (*tailcfg.MapResponse, error) {
resp, err := m.baseWithConfigMapResponse(node, capVer)
peers, err := m.listPeers(nodeID)
if err != nil {
return nil, err
}
err = appendPeerChanges(
resp,
true, // full change
m.state,
node,
capVer,
peers,
m.cfg,
)
if err != nil {
return nil, err
}
return resp, nil
return m.NewMapResponseBuilder(nodeID).
WithCapabilityVersion(capVer).
WithSelfNode().
WithDERPMap().
WithDomain().
WithCollectServicesDisabled().
WithDebugConfig().
WithSSHPolicy().
WithDNSConfig().
WithUserProfiles(peers).
WithPacketFilters().
WithPeers(peers).
Build(messages...)
}
// FullMapResponse returns a MapResponse for the given node.
func (m *Mapper) FullMapResponse(
mapRequest tailcfg.MapRequest,
node types.NodeView,
messages ...string,
) ([]byte, error) {
peers, err := m.ListPeers(node.ID())
if err != nil {
return nil, err
}
resp, err := m.fullMapResponse(node, peers.ViewSlice(), mapRequest.Version)
if err != nil {
return nil, err
}
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
}
// ReadOnlyMapResponse returns a MapResponse for the given node.
// Lite means that the peers has been omitted, this is intended
// to be used to answer MapRequests with OmitPeers set to true.
func (m *Mapper) ReadOnlyMapResponse(
mapRequest tailcfg.MapRequest,
node types.NodeView,
messages ...string,
) ([]byte, error) {
resp, err := m.baseWithConfigMapResponse(node, mapRequest.Version)
if err != nil {
return nil, err
}
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
}
func (m *Mapper) KeepAliveResponse(
mapRequest tailcfg.MapRequest,
node types.NodeView,
) ([]byte, error) {
resp := m.baseMapResponse()
resp.KeepAlive = true
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
}
func (m *Mapper) DERPMapResponse(
mapRequest tailcfg.MapRequest,
node types.NodeView,
derpMap *tailcfg.DERPMap,
) ([]byte, error) {
resp := m.baseMapResponse()
resp.DERPMap = derpMap
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
}
func (m *Mapper) PeerChangedResponse(
mapRequest tailcfg.MapRequest,
node types.NodeView,
changed map[types.NodeID]bool,
patches []*tailcfg.PeerChange,
messages ...string,
) ([]byte, error) {
var err error
resp := m.baseMapResponse()
var removedIDs []tailcfg.NodeID
var changedIDs []types.NodeID
for nodeID, nodeChanged := range changed {
if nodeChanged {
if nodeID != node.ID() {
changedIDs = append(changedIDs, nodeID)
}
} else {
removedIDs = append(removedIDs, nodeID.NodeID())
}
}
changedNodes := types.Nodes{}
if len(changedIDs) > 0 {
changedNodes, err = m.ListNodes(changedIDs...)
if err != nil {
return nil, err
}
}
err = appendPeerChanges(
&resp,
false, // partial change
m.state,
node,
mapRequest.Version,
changedNodes.ViewSlice(),
m.cfg,
)
if err != nil {
return nil, err
}
resp.PeersRemoved = removedIDs
// Sending patches as a part of a PeersChanged response
// is technically not suppose to be done, but they are
// applied after the PeersChanged. The patch list
// should _only_ contain Nodes that are not in the
// PeersChanged or PeersRemoved list and the caller
// should filter them out.
//
// From tailcfg docs:
// These are applied after Peers* above, but in practice the
// control server should only send these on their own, without
// the Peers* fields also set.
if patches != nil {
resp.PeersChangedPatch = patches
}
_, matchers := m.state.Filter()
// Add the node itself, it might have changed, and particularly
// if there are no patches or changes, this is a self update.
tailnode, err := tailNode(
node, mapRequest.Version, m.state,
func(id types.NodeID) []netip.Prefix {
return policy.ReduceRoutes(node, m.state.GetNodePrimaryRoutes(id), matchers)
},
m.cfg)
if err != nil {
return nil, err
}
resp.Node = tailnode
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress, messages...)
func (m *mapper) derpMapResponse(
nodeID types.NodeID,
) (*tailcfg.MapResponse, error) {
return m.NewMapResponseBuilder(nodeID).
WithDERPMap().
Build()
}
// PeerChangedPatchResponse creates a patch MapResponse with
// incoming update from a state change.
func (m *Mapper) PeerChangedPatchResponse(
mapRequest tailcfg.MapRequest,
node types.NodeView,
func (m *mapper) peerChangedPatchResponse(
nodeID types.NodeID,
changed []*tailcfg.PeerChange,
) ([]byte, error) {
resp := m.baseMapResponse()
resp.PeersChangedPatch = changed
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
}
func (m *Mapper) marshalMapResponse(
mapRequest tailcfg.MapRequest,
resp *tailcfg.MapResponse,
node types.NodeView,
compression string,
messages ...string,
) ([]byte, error) {
atomic.AddUint64(&m.seq, 1)
jsonBody, err := json.Marshal(resp)
if err != nil {
return nil, fmt.Errorf("marshalling map response: %w", err)
}
if debugDumpMapResponsePath != "" {
data := map[string]any{
"Messages": messages,
"MapRequest": mapRequest,
"MapResponse": resp,
}
responseType := "keepalive"
switch {
case resp.Peers != nil && len(resp.Peers) > 0:
responseType = "full"
case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil && resp.DERPMap == nil && !resp.KeepAlive:
responseType = "self"
case resp.PeersChanged != nil && len(resp.PeersChanged) > 0:
responseType = "changed"
case resp.PeersChangedPatch != nil && len(resp.PeersChangedPatch) > 0:
responseType = "patch"
case resp.PeersRemoved != nil && len(resp.PeersRemoved) > 0:
responseType = "removed"
}
body, err := json.MarshalIndent(data, "", " ")
if err != nil {
return nil, fmt.Errorf("marshalling map response: %w", err)
}
perms := fs.FileMode(debugMapResponsePerm)
mPath := path.Join(debugDumpMapResponsePath, node.Hostname())
err = os.MkdirAll(mPath, perms)
if err != nil {
panic(err)
}
now := time.Now().Format("2006-01-02T15-04-05.999999999")
mapResponsePath := path.Join(
mPath,
fmt.Sprintf("%s-%s-%d-%s.json", now, m.uid, atomic.LoadUint64(&m.seq), responseType),
)
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
err = os.WriteFile(mapResponsePath, body, perms)
if err != nil {
panic(err)
}
}
var respBody []byte
if compression == util.ZstdCompression {
respBody = zstdEncode(jsonBody)
} else {
respBody = jsonBody
}
data := make([]byte, reservedResponseHeaderSize)
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
data = append(data, respBody...)
return data, nil
}
func zstdEncode(in []byte) []byte {
encoder, ok := zstdEncoderPool.Get().(*zstd.Encoder)
if !ok {
panic("invalid type in sync pool")
}
out := encoder.EncodeAll(in, nil)
_ = encoder.Close()
zstdEncoderPool.Put(encoder)
return out
}
var zstdEncoderPool = &sync.Pool{
New: func() any {
encoder, err := smallzstd.NewEncoder(
nil,
zstd.WithEncoderLevel(zstd.SpeedFastest))
if err != nil {
panic(err)
}
return encoder
},
}
// baseMapResponse returns a tailcfg.MapResponse with
// KeepAlive false and ControlTime set to now.
func (m *Mapper) baseMapResponse() tailcfg.MapResponse {
now := time.Now()
resp := tailcfg.MapResponse{
KeepAlive: false,
ControlTime: &now,
// TODO(kradalby): Implement PingRequest?
}
return resp
}
// baseWithConfigMapResponse returns a tailcfg.MapResponse struct
// with the basic configuration from headscale set.
// It is used in for bigger updates, such as full and lite, not
// incremental.
func (m *Mapper) baseWithConfigMapResponse(
node types.NodeView,
capVer tailcfg.CapabilityVersion,
) (*tailcfg.MapResponse, error) {
resp := m.baseMapResponse()
return m.NewMapResponseBuilder(nodeID).
WithPeerChangedPatch(changed).
Build()
}
_, matchers := m.state.Filter()
tailnode, err := tailNode(
node, capVer, m.state,
func(id types.NodeID) []netip.Prefix {
return policy.ReduceRoutes(node, m.state.GetNodePrimaryRoutes(id), matchers)
},
m.cfg)
// peerChangeResponse returns a MapResponse with changed or added nodes.
func (m *mapper) peerChangeResponse(
nodeID types.NodeID,
capVer tailcfg.CapabilityVersion,
changedNodeID types.NodeID,
) (*tailcfg.MapResponse, error) {
peers, err := m.listPeers(nodeID, changedNodeID)
if err != nil {
return nil, err
}
resp.Node = tailnode
resp.DERPMap = m.state.DERPMap()
resp.Domain = m.cfg.Domain()
// Do not instruct clients to collect services we do not
// support or do anything with them
resp.CollectServices = "false"
resp.KeepAlive = false
resp.Debug = &tailcfg.Debug{
DisableLogTail: !m.cfg.LogTail.Enabled,
}
return &resp, nil
return m.NewMapResponseBuilder(nodeID).
WithCapabilityVersion(capVer).
WithSelfNode().
WithUserProfiles(peers).
WithPeerChanges(peers).
Build()
}
// ListPeers returns peers of node, regardless of any Policy or if the node is expired.
// peerRemovedResponse creates a MapResponse indicating that a peer has been removed.
func (m *mapper) peerRemovedResponse(
nodeID types.NodeID,
removedNodeID types.NodeID,
) (*tailcfg.MapResponse, error) {
return m.NewMapResponseBuilder(nodeID).
WithPeersRemoved(removedNodeID).
Build()
}
func writeDebugMapResponse(
resp *tailcfg.MapResponse,
nodeID types.NodeID,
messages ...string,
) {
data := map[string]any{
"Messages": messages,
"MapResponse": resp,
}
responseType := "keepalive"
switch {
case len(resp.Peers) > 0:
responseType = "full"
case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil && resp.DERPMap == nil && !resp.KeepAlive:
responseType = "self"
case len(resp.PeersChanged) > 0:
responseType = "changed"
case len(resp.PeersChangedPatch) > 0:
responseType = "patch"
case len(resp.PeersRemoved) > 0:
responseType = "removed"
}
body, err := json.MarshalIndent(data, "", " ")
if err != nil {
panic(err)
}
perms := fs.FileMode(debugMapResponsePerm)
mPath := path.Join(debugDumpMapResponsePath, nodeID.String())
err = os.MkdirAll(mPath, perms)
if err != nil {
panic(err)
}
now := time.Now().Format("2006-01-02T15-04-05.999999999")
mapResponsePath := path.Join(
mPath,
fmt.Sprintf("%s-%s.json", now, responseType),
)
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
err = os.WriteFile(mapResponsePath, body, perms)
if err != nil {
panic(err)
}
}
// listPeers returns peers of node, regardless of any Policy or if the node is expired.
// If no peer IDs are given, all peers are returned.
// If at least one peer ID is given, only these peer nodes will be returned.
func (m *Mapper) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
func (m *mapper) listPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
peers, err := m.state.ListPeers(nodeID, peerIDs...)
if err != nil {
return nil, err
}
// TODO(kradalby): Add back online via batcher. This was removed
// to avoid a circular dependency between the mapper and the notification.
for _, peer := range peers {
online := m.notif.IsLikelyConnected(peer.ID)
online := m.batcher.IsConnected(peer.ID)
peer.IsOnline = &online
}
return peers, nil
}
// ListNodes queries the database for either all nodes if no parameters are given
// or for the given nodes if at least one node ID is given as parameter.
func (m *Mapper) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
nodes, err := m.state.ListNodes(nodeIDs...)
if err != nil {
return nil, err
}
for _, node := range nodes {
online := m.notif.IsLikelyConnected(node.ID)
node.IsOnline = &online
}
return nodes, nil
}
// routeFilterFunc is a function that takes a node ID and returns a list of
// netip.Prefixes that are allowed for that node. It is used to filter routes
// from the primary route manager to the node.
type routeFilterFunc func(id types.NodeID) []netip.Prefix
// appendPeerChanges mutates a tailcfg.MapResponse with all the
// necessary changes when peers have changed.
func appendPeerChanges(
resp *tailcfg.MapResponse,
fullChange bool,
state *state.State,
node types.NodeView,
capVer tailcfg.CapabilityVersion,
changed views.Slice[types.NodeView],
cfg *types.Config,
) error {
filter, matchers := state.Filter()
sshPolicy, err := state.SSHPolicy(node)
if err != nil {
return err
}
// If there are filter rules present, see if there are any nodes that cannot
// access each-other at all and remove them from the peers.
var reducedChanged views.Slice[types.NodeView]
if len(filter) > 0 {
reducedChanged = policy.ReduceNodes(node, changed, matchers)
} else {
reducedChanged = changed
}
profiles := generateUserProfiles(node, reducedChanged)
dnsConfig := generateDNSConfig(cfg, node)
tailPeers, err := tailNodes(
reducedChanged, capVer, state,
func(id types.NodeID) []netip.Prefix {
return policy.ReduceRoutes(node, state.GetNodePrimaryRoutes(id), matchers)
},
cfg)
if err != nil {
return err
}
// Peers is always returned sorted by Node.ID.
sort.SliceStable(tailPeers, func(x, y int) bool {
return tailPeers[x].ID < tailPeers[y].ID
})
if fullChange {
resp.Peers = tailPeers
} else {
resp.PeersChanged = tailPeers
}
resp.DNSConfig = dnsConfig
resp.UserProfiles = profiles
resp.SSHPolicy = sshPolicy
// CapVer 81: 2023-11-17: MapResponse.PacketFilters (incremental packet filter updates)
// Currently, we do not send incremental package filters, however using the
// new PacketFilters field and "base" allows us to send a full update when we
// have to send an empty list, avoiding the hack in the else block.
resp.PacketFilters = map[string][]tailcfg.FilterRule{
"base": policy.ReduceFilterRules(node, filter),
}
return nil
}

View File

@ -3,6 +3,7 @@ package mapper
import (
"fmt"
"net/netip"
"slices"
"testing"
"github.com/google/go-cmp/cmp"
@ -70,7 +71,7 @@ func TestDNSConfigMapResponse(t *testing.T) {
&types.Config{
TailcfgDNSConfig: &dnsConfigOrig,
},
nodeInShared1.View(),
nodeInShared1,
)
if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" {
@ -126,11 +127,8 @@ func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (typ
// Filter peers by the provided IDs
var filtered types.Nodes
for _, peer := range m.peers {
for _, id := range peerIDs {
if peer.ID == id {
filtered = append(filtered, peer)
break
}
if slices.Contains(peerIDs, peer.ID) {
filtered = append(filtered, peer)
}
}
@ -152,11 +150,8 @@ func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
// Filter nodes by the provided IDs
var filtered types.Nodes
for _, node := range m.nodes {
for _, id := range nodeIDs {
if node.ID == id {
filtered = append(filtered, node)
break
}
if slices.Contains(nodeIDs, node.ID) {
filtered = append(filtered, node)
}
}

47
hscontrol/mapper/utils.go Normal file
View File

@ -0,0 +1,47 @@
package mapper
import "tailscale.com/tailcfg"
// mergePatch takes the current patch and a newer patch
// and override any field that has changed.
func mergePatch(currPatch, newPatch *tailcfg.PeerChange) {
if newPatch.DERPRegion != 0 {
currPatch.DERPRegion = newPatch.DERPRegion
}
if newPatch.Cap != 0 {
currPatch.Cap = newPatch.Cap
}
if newPatch.CapMap != nil {
currPatch.CapMap = newPatch.CapMap
}
if newPatch.Endpoints != nil {
currPatch.Endpoints = newPatch.Endpoints
}
if newPatch.Key != nil {
currPatch.Key = newPatch.Key
}
if newPatch.KeySignature != nil {
currPatch.KeySignature = newPatch.KeySignature
}
if newPatch.DiscoKey != nil {
currPatch.DiscoKey = newPatch.DiscoKey
}
if newPatch.Online != nil {
currPatch.Online = newPatch.Online
}
if newPatch.LastSeen != nil {
currPatch.LastSeen = newPatch.LastSeen
}
if newPatch.KeyExpiry != nil {
currPatch.KeyExpiry = newPatch.KeyExpiry
}
}

View File

@ -221,7 +221,7 @@ func (ns *noiseServer) NoisePollNetMapHandler(
ns.nodeKey = nv.NodeKey()
sess := ns.headscale.newMapSession(req.Context(), mapRequest, writer, nv)
sess := ns.headscale.newMapSession(req.Context(), mapRequest, writer, nv.AsStruct())
sess.tracef("a node sending a MapRequest with Noise protocol")
if !sess.isStreaming() {
sess.serve()
@ -279,28 +279,33 @@ func (ns *noiseServer) NoiseRegistrationHandler(
return
}
respBody, err := json.Marshal(registerResponse)
if err != nil {
httpError(writer, err)
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
if err := json.NewEncoder(writer).Encode(registerResponse); err != nil {
log.Error().Err(err).Msg("NoiseRegistrationHandler: failed to encode RegisterResponse")
return
}
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
writer.Write(respBody)
// Ensure response is flushed to client
if flusher, ok := writer.(http.Flusher); ok {
flusher.Flush()
}
}
// getAndValidateNode retrieves the node from the database using the NodeKey
// and validates that it matches the MachineKey from the Noise session.
func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types.NodeView, error) {
nv, err := ns.headscale.state.GetNodeViewByNodeKey(mapRequest.NodeKey)
node, err := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node not found", nil)
}
return types.NodeView{}, err
return types.NodeView{}, NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("lookup node: %s", err), nil)
}
nv := node.View()
// Validate that the MachineKey in the Noise session matches the one associated with the NodeKey.
if ns.machineKey != nv.MachineKey() {
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node key in request does not match the one associated with this machine key", nil)

View File

@ -1,68 +0,0 @@
package notifier
import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"tailscale.com/envknob"
)
const prometheusNamespace = "headscale"
var debugHighCardinalityMetrics = envknob.Bool("HEADSCALE_DEBUG_HIGH_CARDINALITY_METRICS")
var notifierUpdateSent *prometheus.CounterVec
func init() {
if debugHighCardinalityMetrics {
notifierUpdateSent = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace,
Name: "notifier_update_sent_total",
Help: "total count of update sent on nodes channel",
}, []string{"status", "type", "trigger", "id"})
} else {
notifierUpdateSent = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace,
Name: "notifier_update_sent_total",
Help: "total count of update sent on nodes channel",
}, []string{"status", "type", "trigger"})
}
}
var (
notifierWaitersForLock = promauto.NewGaugeVec(prometheus.GaugeOpts{
Namespace: prometheusNamespace,
Name: "notifier_waiters_for_lock",
Help: "gauge of waiters for the notifier lock",
}, []string{"type", "action"})
notifierWaitForLock = promauto.NewHistogramVec(prometheus.HistogramOpts{
Namespace: prometheusNamespace,
Name: "notifier_wait_for_lock_seconds",
Help: "histogram of time spent waiting for the notifier lock",
Buckets: []float64{0.001, 0.01, 0.1, 0.3, 0.5, 1, 3, 5, 10},
}, []string{"action"})
notifierUpdateReceived = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace,
Name: "notifier_update_received_total",
Help: "total count of updates received by notifier",
}, []string{"type", "trigger"})
notifierNodeUpdateChans = promauto.NewGauge(prometheus.GaugeOpts{
Namespace: prometheusNamespace,
Name: "notifier_open_channels_total",
Help: "total count open channels in notifier",
})
notifierBatcherWaitersForLock = promauto.NewGaugeVec(prometheus.GaugeOpts{
Namespace: prometheusNamespace,
Name: "notifier_batcher_waiters_for_lock",
Help: "gauge of waiters for the notifier batcher lock",
}, []string{"type", "action"})
notifierBatcherChanges = promauto.NewGaugeVec(prometheus.GaugeOpts{
Namespace: prometheusNamespace,
Name: "notifier_batcher_changes_pending",
Help: "gauge of full changes pending in the notifier batcher",
}, []string{})
notifierBatcherPatches = promauto.NewGaugeVec(prometheus.GaugeOpts{
Namespace: prometheusNamespace,
Name: "notifier_batcher_patches_pending",
Help: "gauge of patches pending in the notifier batcher",
}, []string{})
)

View File

@ -1,488 +0,0 @@
package notifier
import (
"context"
"fmt"
"sort"
"strings"
"sync"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/puzpuzpuz/xsync/v4"
"github.com/rs/zerolog/log"
"github.com/sasha-s/go-deadlock"
"tailscale.com/envknob"
"tailscale.com/tailcfg"
"tailscale.com/util/set"
)
var (
debugDeadlock = envknob.Bool("HEADSCALE_DEBUG_DEADLOCK")
debugDeadlockTimeout = envknob.RegisterDuration("HEADSCALE_DEBUG_DEADLOCK_TIMEOUT")
)
func init() {
deadlock.Opts.Disable = !debugDeadlock
if debugDeadlock {
deadlock.Opts.DeadlockTimeout = debugDeadlockTimeout()
deadlock.Opts.PrintAllCurrentGoroutines = true
}
}
type Notifier struct {
l deadlock.Mutex
nodes map[types.NodeID]chan<- types.StateUpdate
connected *xsync.MapOf[types.NodeID, bool]
b *batcher
cfg *types.Config
closed bool
}
func NewNotifier(cfg *types.Config) *Notifier {
n := &Notifier{
nodes: make(map[types.NodeID]chan<- types.StateUpdate),
connected: xsync.NewMapOf[types.NodeID, bool](),
cfg: cfg,
closed: false,
}
b := newBatcher(cfg.Tuning.BatchChangeDelay, n)
n.b = b
go b.doWork()
return n
}
// Close stops the batcher and closes all channels.
func (n *Notifier) Close() {
notifierWaitersForLock.WithLabelValues("lock", "close").Inc()
n.l.Lock()
defer n.l.Unlock()
notifierWaitersForLock.WithLabelValues("lock", "close").Dec()
n.closed = true
n.b.close()
// Close channels safely using the helper method
for nodeID, c := range n.nodes {
n.safeCloseChannel(nodeID, c)
}
// Clear node map after closing channels
n.nodes = make(map[types.NodeID]chan<- types.StateUpdate)
}
// safeCloseChannel closes a channel and panic recovers if already closed.
func (n *Notifier) safeCloseChannel(nodeID types.NodeID, c chan<- types.StateUpdate) {
defer func() {
if r := recover(); r != nil {
log.Error().
Uint64("node.id", nodeID.Uint64()).
Any("recover", r).
Msg("recovered from panic when closing channel in Close()")
}
}()
close(c)
}
func (n *Notifier) tracef(nID types.NodeID, msg string, args ...any) {
log.Trace().
Uint64("node.id", nID.Uint64()).
Int("open_chans", len(n.nodes)).Msgf(msg, args...)
}
func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) {
start := time.Now()
notifierWaitersForLock.WithLabelValues("lock", "add").Inc()
n.l.Lock()
defer n.l.Unlock()
notifierWaitersForLock.WithLabelValues("lock", "add").Dec()
notifierWaitForLock.WithLabelValues("add").Observe(time.Since(start).Seconds())
if n.closed {
return
}
// If a channel exists, it means the node has opened a new
// connection. Close the old channel and replace it.
if curr, ok := n.nodes[nodeID]; ok {
n.tracef(nodeID, "channel present, closing and replacing")
// Use the safeCloseChannel helper in a goroutine to avoid deadlocks
// if/when someone is waiting to send on this channel
go func(ch chan<- types.StateUpdate) {
n.safeCloseChannel(nodeID, ch)
}(curr)
}
n.nodes[nodeID] = c
n.connected.Store(nodeID, true)
n.tracef(nodeID, "added new channel")
notifierNodeUpdateChans.Inc()
}
// RemoveNode removes a node and a given channel from the notifier.
// It checks that the channel is the same as currently being updated
// and ignores the removal if it is not.
// RemoveNode reports if the node/chan was removed.
func (n *Notifier) RemoveNode(nodeID types.NodeID, c chan<- types.StateUpdate) bool {
start := time.Now()
notifierWaitersForLock.WithLabelValues("lock", "remove").Inc()
n.l.Lock()
defer n.l.Unlock()
notifierWaitersForLock.WithLabelValues("lock", "remove").Dec()
notifierWaitForLock.WithLabelValues("remove").Observe(time.Since(start).Seconds())
if n.closed {
return true
}
if len(n.nodes) == 0 {
return true
}
// If the channel exist, but it does not belong
// to the caller, ignore.
if curr, ok := n.nodes[nodeID]; ok {
if curr != c {
n.tracef(nodeID, "channel has been replaced, not removing")
return false
}
}
delete(n.nodes, nodeID)
n.connected.Store(nodeID, false)
n.tracef(nodeID, "removed channel")
notifierNodeUpdateChans.Dec()
return true
}
// IsConnected reports if a node is connected to headscale and has a
// poll session open.
func (n *Notifier) IsConnected(nodeID types.NodeID) bool {
notifierWaitersForLock.WithLabelValues("lock", "conncheck").Inc()
n.l.Lock()
defer n.l.Unlock()
notifierWaitersForLock.WithLabelValues("lock", "conncheck").Dec()
if val, ok := n.connected.Load(nodeID); ok {
return val
}
return false
}
// IsLikelyConnected reports if a node is connected to headscale and has a
// poll session open, but doesn't lock, so might be wrong.
func (n *Notifier) IsLikelyConnected(nodeID types.NodeID) bool {
if val, ok := n.connected.Load(nodeID); ok {
return val
}
return false
}
// LikelyConnectedMap returns a thread safe map of connected nodes.
func (n *Notifier) LikelyConnectedMap() *xsync.MapOf[types.NodeID, bool] {
return n.connected
}
func (n *Notifier) NotifyAll(ctx context.Context, update types.StateUpdate) {
n.NotifyWithIgnore(ctx, update)
}
func (n *Notifier) NotifyWithIgnore(
ctx context.Context,
update types.StateUpdate,
ignoreNodeIDs ...types.NodeID,
) {
if n.closed {
return
}
notifierUpdateReceived.WithLabelValues(update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
n.b.addOrPassthrough(update)
}
func (n *Notifier) NotifyByNodeID(
ctx context.Context,
update types.StateUpdate,
nodeID types.NodeID,
) {
start := time.Now()
notifierWaitersForLock.WithLabelValues("lock", "notify").Inc()
n.l.Lock()
defer n.l.Unlock()
notifierWaitersForLock.WithLabelValues("lock", "notify").Dec()
notifierWaitForLock.WithLabelValues("notify").Observe(time.Since(start).Seconds())
if n.closed {
return
}
if c, ok := n.nodes[nodeID]; ok {
select {
case <-ctx.Done():
log.Error().
Err(ctx.Err()).
Uint64("node.id", nodeID.Uint64()).
Any("origin", types.NotifyOriginKey.Value(ctx)).
Any("origin-hostname", types.NotifyHostnameKey.Value(ctx)).
Msgf("update not sent, context cancelled")
if debugHighCardinalityMetrics {
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), types.NotifyOriginKey.Value(ctx), nodeID.String()).Inc()
} else {
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
}
return
case c <- update:
n.tracef(nodeID, "update successfully sent on chan, origin: %s, origin-hostname: %s", ctx.Value("origin"), ctx.Value("hostname"))
if debugHighCardinalityMetrics {
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), types.NotifyOriginKey.Value(ctx), nodeID.String()).Inc()
} else {
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
}
}
}
}
func (n *Notifier) sendAll(update types.StateUpdate) {
start := time.Now()
notifierWaitersForLock.WithLabelValues("lock", "send-all").Inc()
n.l.Lock()
defer n.l.Unlock()
notifierWaitersForLock.WithLabelValues("lock", "send-all").Dec()
notifierWaitForLock.WithLabelValues("send-all").Observe(time.Since(start).Seconds())
if n.closed {
return
}
for id, c := range n.nodes {
// Whenever an update is sent to all nodes, there is a chance that the node
// has disconnected and the goroutine that was supposed to consume the update
// has shut down the channel and is waiting for the lock held here in RemoveNode.
// This means that there is potential for a deadlock which would stop all updates
// going out to clients. This timeout prevents that from happening by moving on to the
// next node if the context is cancelled. After sendAll releases the lock, the add/remove
// call will succeed and the update will go to the correct nodes on the next call.
ctx, cancel := context.WithTimeout(context.Background(), n.cfg.Tuning.NotifierSendTimeout)
defer cancel()
select {
case <-ctx.Done():
log.Error().
Err(ctx.Err()).
Uint64("node.id", id.Uint64()).
Msgf("update not sent, context cancelled")
if debugHighCardinalityMetrics {
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), "send-all", id.String()).Inc()
} else {
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), "send-all").Inc()
}
return
case c <- update:
if debugHighCardinalityMetrics {
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), "send-all", id.String()).Inc()
} else {
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), "send-all").Inc()
}
}
}
}
func (n *Notifier) String() string {
notifierWaitersForLock.WithLabelValues("lock", "string").Inc()
n.l.Lock()
defer n.l.Unlock()
notifierWaitersForLock.WithLabelValues("lock", "string").Dec()
var b strings.Builder
fmt.Fprintf(&b, "chans (%d):\n", len(n.nodes))
var keys []types.NodeID
n.connected.Range(func(key types.NodeID, value bool) bool {
keys = append(keys, key)
return true
})
sort.Slice(keys, func(i, j int) bool {
return keys[i] < keys[j]
})
for _, key := range keys {
fmt.Fprintf(&b, "\t%d: %p\n", key, n.nodes[key])
}
b.WriteString("\n")
fmt.Fprintf(&b, "connected (%d):\n", len(n.nodes))
for _, key := range keys {
val, _ := n.connected.Load(key)
fmt.Fprintf(&b, "\t%d: %t\n", key, val)
}
return b.String()
}
type batcher struct {
tick *time.Ticker
mu sync.Mutex
cancelCh chan struct{}
changedNodeIDs set.Slice[types.NodeID]
nodesChanged bool
patches map[types.NodeID]tailcfg.PeerChange
patchesChanged bool
n *Notifier
}
func newBatcher(batchTime time.Duration, n *Notifier) *batcher {
return &batcher{
tick: time.NewTicker(batchTime),
cancelCh: make(chan struct{}),
patches: make(map[types.NodeID]tailcfg.PeerChange),
n: n,
}
}
func (b *batcher) close() {
b.cancelCh <- struct{}{}
}
// addOrPassthrough adds the update to the batcher, if it is not a
// type that is currently batched, it will be sent immediately.
func (b *batcher) addOrPassthrough(update types.StateUpdate) {
notifierBatcherWaitersForLock.WithLabelValues("lock", "add").Inc()
b.mu.Lock()
defer b.mu.Unlock()
notifierBatcherWaitersForLock.WithLabelValues("lock", "add").Dec()
switch update.Type {
case types.StatePeerChanged:
b.changedNodeIDs.Add(update.ChangeNodes...)
b.nodesChanged = true
notifierBatcherChanges.WithLabelValues().Set(float64(b.changedNodeIDs.Len()))
case types.StatePeerChangedPatch:
for _, newPatch := range update.ChangePatches {
if curr, ok := b.patches[types.NodeID(newPatch.NodeID)]; ok {
overwritePatch(&curr, newPatch)
b.patches[types.NodeID(newPatch.NodeID)] = curr
} else {
b.patches[types.NodeID(newPatch.NodeID)] = *newPatch
}
}
b.patchesChanged = true
notifierBatcherPatches.WithLabelValues().Set(float64(len(b.patches)))
default:
b.n.sendAll(update)
}
}
// flush sends all the accumulated patches to all
// nodes in the notifier.
func (b *batcher) flush() {
notifierBatcherWaitersForLock.WithLabelValues("lock", "flush").Inc()
b.mu.Lock()
defer b.mu.Unlock()
notifierBatcherWaitersForLock.WithLabelValues("lock", "flush").Dec()
if b.nodesChanged || b.patchesChanged {
var patches []*tailcfg.PeerChange
// If a node is getting a full update from a change
// node update, then the patch can be dropped.
for nodeID, patch := range b.patches {
if b.changedNodeIDs.Contains(nodeID) {
delete(b.patches, nodeID)
} else {
patches = append(patches, &patch)
}
}
changedNodes := b.changedNodeIDs.Slice().AsSlice()
sort.Slice(changedNodes, func(i, j int) bool {
return changedNodes[i] < changedNodes[j]
})
if b.changedNodeIDs.Slice().Len() > 0 {
update := types.UpdatePeerChanged(changedNodes...)
b.n.sendAll(update)
}
if len(patches) > 0 {
patchUpdate := types.UpdatePeerPatch(patches...)
b.n.sendAll(patchUpdate)
}
b.changedNodeIDs = set.Slice[types.NodeID]{}
notifierBatcherChanges.WithLabelValues().Set(0)
b.nodesChanged = false
b.patches = make(map[types.NodeID]tailcfg.PeerChange, len(b.patches))
notifierBatcherPatches.WithLabelValues().Set(0)
b.patchesChanged = false
}
}
func (b *batcher) doWork() {
for {
select {
case <-b.cancelCh:
return
case <-b.tick.C:
b.flush()
}
}
}
// overwritePatch takes the current patch and a newer patch
// and override any field that has changed.
func overwritePatch(currPatch, newPatch *tailcfg.PeerChange) {
if newPatch.DERPRegion != 0 {
currPatch.DERPRegion = newPatch.DERPRegion
}
if newPatch.Cap != 0 {
currPatch.Cap = newPatch.Cap
}
if newPatch.CapMap != nil {
currPatch.CapMap = newPatch.CapMap
}
if newPatch.Endpoints != nil {
currPatch.Endpoints = newPatch.Endpoints
}
if newPatch.Key != nil {
currPatch.Key = newPatch.Key
}
if newPatch.KeySignature != nil {
currPatch.KeySignature = newPatch.KeySignature
}
if newPatch.DiscoKey != nil {
currPatch.DiscoKey = newPatch.DiscoKey
}
if newPatch.Online != nil {
currPatch.Online = newPatch.Online
}
if newPatch.LastSeen != nil {
currPatch.LastSeen = newPatch.LastSeen
}
if newPatch.KeyExpiry != nil {
currPatch.KeyExpiry = newPatch.KeyExpiry
}
}

View File

@ -1,342 +0,0 @@
package notifier
import (
"fmt"
"math/rand"
"net/netip"
"slices"
"sort"
"sync"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"tailscale.com/tailcfg"
)
func TestBatcher(t *testing.T) {
tests := []struct {
name string
updates []types.StateUpdate
want []types.StateUpdate
}{
{
name: "full-passthrough",
updates: []types.StateUpdate{
{
Type: types.StateFullUpdate,
},
},
want: []types.StateUpdate{
{
Type: types.StateFullUpdate,
},
},
},
{
name: "derp-passthrough",
updates: []types.StateUpdate{
{
Type: types.StateDERPUpdated,
},
},
want: []types.StateUpdate{
{
Type: types.StateDERPUpdated,
},
},
},
{
name: "single-node-update",
updates: []types.StateUpdate{
{
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{
2,
},
},
},
want: []types.StateUpdate{
{
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{
2,
},
},
},
},
{
name: "merge-node-update",
updates: []types.StateUpdate{
{
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{
2, 4,
},
},
{
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{
2, 3,
},
},
},
want: []types.StateUpdate{
{
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{
2, 3, 4,
},
},
},
},
{
name: "single-patch-update",
updates: []types.StateUpdate{
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 2,
DERPRegion: 5,
},
},
},
},
want: []types.StateUpdate{
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 2,
DERPRegion: 5,
},
},
},
},
},
{
name: "merge-patch-to-same-node-update",
updates: []types.StateUpdate{
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 2,
DERPRegion: 5,
},
},
},
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 2,
DERPRegion: 6,
},
},
},
},
want: []types.StateUpdate{
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 2,
DERPRegion: 6,
},
},
},
},
},
{
name: "merge-patch-to-multiple-node-update",
updates: []types.StateUpdate{
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 3,
Endpoints: []netip.AddrPort{
netip.MustParseAddrPort("1.1.1.1:9090"),
},
},
},
},
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 3,
Endpoints: []netip.AddrPort{
netip.MustParseAddrPort("1.1.1.1:9090"),
netip.MustParseAddrPort("2.2.2.2:8080"),
},
},
},
},
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 4,
DERPRegion: 6,
},
},
},
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 4,
Cap: tailcfg.CapabilityVersion(54),
},
},
},
},
want: []types.StateUpdate{
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 3,
Endpoints: []netip.AddrPort{
netip.MustParseAddrPort("1.1.1.1:9090"),
netip.MustParseAddrPort("2.2.2.2:8080"),
},
},
{
NodeID: 4,
DERPRegion: 6,
Cap: tailcfg.CapabilityVersion(54),
},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
n := NewNotifier(&types.Config{
Tuning: types.Tuning{
// We will call flush manually for the tests,
// so do not run the worker.
BatchChangeDelay: time.Hour,
// Since we do not load the config, we won't get the
// default, so set it manually so we dont time out
// and have flakes.
NotifierSendTimeout: time.Second,
},
})
ch := make(chan types.StateUpdate, 30)
defer close(ch)
n.AddNode(1, ch)
defer n.RemoveNode(1, ch)
for _, u := range tt.updates {
n.NotifyAll(t.Context(), u)
}
n.b.flush()
var got []types.StateUpdate
for len(ch) > 0 {
out := <-ch
got = append(got, out)
}
// Make the inner order stable for comparison.
for _, u := range got {
slices.Sort(u.ChangeNodes)
sort.Slice(u.ChangePatches, func(i, j int) bool {
return u.ChangePatches[i].NodeID < u.ChangePatches[j].NodeID
})
}
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
t.Errorf("batcher() unexpected result (-want +got):\n%s", diff)
}
})
}
}
// TestIsLikelyConnectedRaceCondition tests for a race condition in IsLikelyConnected
// Multiple goroutines calling AddNode and RemoveNode cause panics when trying to
// close a channel that was already closed, which can happen when a node changes
// network transport quickly (eg mobile->wifi) and reconnects whilst also disconnecting.
func TestIsLikelyConnectedRaceCondition(t *testing.T) {
// mock config for the notifier
cfg := &types.Config{
Tuning: types.Tuning{
NotifierSendTimeout: 1 * time.Second,
BatchChangeDelay: 1 * time.Second,
NodeMapSessionBufferedChanSize: 30,
},
}
notifier := NewNotifier(cfg)
defer notifier.Close()
nodeID := types.NodeID(1)
updateChan := make(chan types.StateUpdate, 10)
var wg sync.WaitGroup
// Number of goroutines to spawn for concurrent access
concurrentAccessors := 100
iterations := 100
// Add node to notifier
notifier.AddNode(nodeID, updateChan)
// Track errors
errChan := make(chan string, concurrentAccessors*iterations)
// Start goroutines to cause a race
wg.Add(concurrentAccessors)
for i := range concurrentAccessors {
go func(routineID int) {
defer wg.Done()
for range iterations {
// Simulate race by having some goroutines check IsLikelyConnected
// while others add/remove the node
switch routineID % 3 {
case 0:
// This goroutine checks connection status
isConnected := notifier.IsLikelyConnected(nodeID)
if isConnected != true && isConnected != false {
errChan <- fmt.Sprintf("Invalid connection status: %v", isConnected)
}
case 1:
// This goroutine removes the node
notifier.RemoveNode(nodeID, updateChan)
default:
// This goroutine adds the node back
notifier.AddNode(nodeID, updateChan)
}
// Small random delay to increase chance of races
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
}
}(i)
}
wg.Wait()
close(errChan)
// Collate errors
var errors []string
for err := range errChan {
errors = append(errors, err)
}
if len(errors) > 0 {
t.Errorf("Detected %d race condition errors: %v", len(errors), errors)
}
}

View File

@ -16,9 +16,8 @@ import (
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/mux"
"github.com/juanfont/headscale/hscontrol/db"
"github.com/juanfont/headscale/hscontrol/notifier"
"github.com/juanfont/headscale/hscontrol/state"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"golang.org/x/oauth2"
@ -56,11 +55,10 @@ type RegistrationInfo struct {
}
type AuthProviderOIDC struct {
h *Headscale
serverURL string
cfg *types.OIDCConfig
state *state.State
registrationCache *zcache.Cache[string, RegistrationInfo]
notifier *notifier.Notifier
oidcProvider *oidc.Provider
oauth2Config *oauth2.Config
@ -68,10 +66,9 @@ type AuthProviderOIDC struct {
func NewAuthProviderOIDC(
ctx context.Context,
h *Headscale,
serverURL string,
cfg *types.OIDCConfig,
state *state.State,
notif *notifier.Notifier,
) (*AuthProviderOIDC, error) {
var err error
// grab oidc config if it hasn't been already
@ -94,11 +91,10 @@ func NewAuthProviderOIDC(
)
return &AuthProviderOIDC{
h: h,
serverURL: serverURL,
cfg: cfg,
state: state,
registrationCache: registrationCache,
notifier: notif,
oidcProvider: oidcProvider,
oauth2Config: oauth2Config,
@ -318,8 +314,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "oidc-user-created", user.Name)
a.notifier.NotifyAll(ctx, types.UpdateFull())
a.h.Change(change.PolicyChange())
}
// TODO(kradalby): Is this comment right?
@ -360,8 +355,6 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
// Neither node nor machine key was found in the state cache meaning
// that we could not reauth nor register the node.
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
return
}
func extractCodeAndStateParamFromRequest(
@ -490,12 +483,14 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
var err error
var newUser bool
var policyChanged bool
user, err = a.state.GetUserByOIDCIdentifier(claims.Identifier())
user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier())
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
return nil, false, fmt.Errorf("creating or updating user: %w", err)
}
// if the user is still not found, create a new empty user.
// TODO(kradalby): This might cause us to not have an ID below which
// is a problem.
if user == nil {
newUser = true
user = &types.User{}
@ -504,12 +499,12 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
user.FromClaim(claims)
if newUser {
user, policyChanged, err = a.state.CreateUser(*user)
user, policyChanged, err = a.h.state.CreateUser(*user)
if err != nil {
return nil, false, fmt.Errorf("creating user: %w", err)
}
} else {
_, policyChanged, err = a.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error {
_, policyChanged, err = a.h.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error {
*u = *user
return nil
})
@ -526,7 +521,7 @@ func (a *AuthProviderOIDC) handleRegistration(
registrationID types.RegistrationID,
expiry time.Time,
) (bool, error) {
node, newNode, err := a.state.HandleNodeFromAuthPath(
node, nodeChange, err := a.h.state.HandleNodeFromAuthPath(
registrationID,
types.UserID(user.ID),
&expiry,
@ -547,31 +542,20 @@ func (a *AuthProviderOIDC) handleRegistration(
// ensure we send an update.
// This works, but might be another good candidate for doing some sort of
// eventbus.
routesChanged := a.state.AutoApproveRoutes(node)
_, policyChanged, err := a.state.SaveNode(node)
_ = a.h.state.AutoApproveRoutes(node)
_, policyChange, err := a.h.state.SaveNode(node)
if err != nil {
return false, fmt.Errorf("saving auto approved routes to node: %w", err)
}
// Send policy update notifications if needed (from SaveNode or route changes)
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "oidc-nodes-change", "all")
a.notifier.NotifyAll(ctx, types.UpdateFull())
// Policy updates are full and take precedence over node changes.
if !policyChange.Empty() {
a.h.Change(policyChange)
} else {
a.h.Change(nodeChange)
}
if routesChanged {
ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname)
a.notifier.NotifyByNodeID(
ctx,
types.UpdateSelf(node.ID),
node.ID,
)
ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname)
a.notifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
}
return newNode, nil
return !nodeChange.Empty(), nil
}
// TODO(kradalby):

View File

@ -113,6 +113,17 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf
}
}
}
// Also check approved subnet routes - nodes should have access
// to subnets they're approved to route traffic for.
subnetRoutes := node.SubnetRoutes()
for _, subnetRoute := range subnetRoutes {
if expanded.OverlapsPrefix(subnetRoute) {
dests = append(dests, dest)
continue DEST_LOOP
}
}
}
if len(dests) > 0 {
@ -142,16 +153,23 @@ func AutoApproveRoutes(pm PolicyManager, node *types.Node) bool {
newApproved = append(newApproved, route)
}
}
if newApproved != nil {
newApproved = append(newApproved, node.ApprovedRoutes...)
tsaddr.SortPrefixes(newApproved)
newApproved = slices.Compact(newApproved)
newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool {
// Only modify ApprovedRoutes if we have new routes to approve.
// This prevents clearing existing approved routes when nodes
// temporarily don't have announced routes during policy changes.
if len(newApproved) > 0 {
combined := append(newApproved, node.ApprovedRoutes...)
tsaddr.SortPrefixes(combined)
combined = slices.Compact(combined)
combined = lo.Filter(combined, func(route netip.Prefix, index int) bool {
return route.IsValid()
})
node.ApprovedRoutes = newApproved
return true
// Only update if the routes actually changed
if !slices.Equal(node.ApprovedRoutes, combined) {
node.ApprovedRoutes = combined
return true
}
}
return false

View File

@ -56,10 +56,13 @@ func (pol *Policy) compileFilterRules(
}
if ips == nil {
log.Debug().Msgf("destination resolved to nil ips: %v", dest)
continue
}
for _, pref := range ips.Prefixes() {
prefixes := ips.Prefixes()
for _, pref := range prefixes {
for _, port := range dest.Ports {
pr := tailcfg.NetPortRange{
IP: pref.String(),
@ -103,6 +106,8 @@ func (pol *Policy) compileSSHPolicy(
return nil, nil
}
log.Trace().Msgf("compiling SSH policy for node %q", node.Hostname())
var rules []*tailcfg.SSHRule
for index, rule := range pol.SSHs {
@ -137,7 +142,8 @@ func (pol *Policy) compileSSHPolicy(
var principals []*tailcfg.SSHPrincipal
srcIPs, err := rule.Sources.Resolve(pol, users, nodes)
if err != nil {
log.Trace().Err(err).Msgf("resolving source ips")
log.Trace().Err(err).Msgf("SSH policy compilation failed resolving source ips for rule %+v", rule)
continue // Skip this rule if we can't resolve sources
}
for addr := range util.IPSetAddrIter(srcIPs) {

View File

@ -70,7 +70,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
// TODO(kradalby): This could potentially be optimized by only clearing the
// policies for nodes that have changed. Particularly if the only difference is
// that nodes has been added or removed.
defer clear(pm.sshPolicyMap)
clear(pm.sshPolicyMap)
filter, err := pm.pol.compileFilterRules(pm.users, pm.nodes)
if err != nil {

View File

@ -1730,7 +1730,7 @@ func (u SSHUser) MarshalJSON() ([]byte, error) {
// In addition to unmarshalling, it will also validate the policy.
// This is the only entrypoint of reading a policy from a file or other source.
func unmarshalPolicy(b []byte) (*Policy, error) {
if b == nil || len(b) == 0 {
if len(b) == 0 {
return nil, nil
}

View File

@ -2,20 +2,20 @@ package hscontrol
import (
"context"
"encoding/binary"
"encoding/json"
"fmt"
"math/rand/v2"
"net/http"
"net/netip"
"slices"
"time"
"github.com/juanfont/headscale/hscontrol/mapper"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"github.com/sasha-s/go-deadlock"
xslices "golang.org/x/exp/slices"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/util/zstdframe"
)
const (
@ -31,18 +31,17 @@ type mapSession struct {
req tailcfg.MapRequest
ctx context.Context
capVer tailcfg.CapabilityVersion
mapper *mapper.Mapper
cancelChMu deadlock.Mutex
ch chan types.StateUpdate
ch chan *tailcfg.MapResponse
cancelCh chan struct{}
cancelChOpen bool
keepAlive time.Duration
keepAliveTicker *time.Ticker
node types.NodeView
node *types.Node
w http.ResponseWriter
warnf func(string, ...any)
@ -55,18 +54,9 @@ func (h *Headscale) newMapSession(
ctx context.Context,
req tailcfg.MapRequest,
w http.ResponseWriter,
nv types.NodeView,
node *types.Node,
) *mapSession {
warnf, infof, tracef, errf := logPollFuncView(req, nv)
var updateChan chan types.StateUpdate
if req.Stream {
// Use a buffered channel in case a node is not fully ready
// to receive a message to make sure we dont block the entire
// notifier.
updateChan = make(chan types.StateUpdate, h.cfg.Tuning.NodeMapSessionBufferedChanSize)
updateChan <- types.UpdateFull()
}
warnf, infof, tracef, errf := logPollFunc(req, node)
ka := keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond)
@ -75,11 +65,10 @@ func (h *Headscale) newMapSession(
ctx: ctx,
req: req,
w: w,
node: nv,
node: node,
capVer: req.Version,
mapper: h.mapper,
ch: updateChan,
ch: make(chan *tailcfg.MapResponse, h.cfg.Tuning.NodeMapSessionBufferedChanSize),
cancelCh: make(chan struct{}),
cancelChOpen: true,
@ -95,15 +84,11 @@ func (h *Headscale) newMapSession(
}
func (m *mapSession) isStreaming() bool {
return m.req.Stream && !m.req.ReadOnly
return m.req.Stream
}
func (m *mapSession) isEndpointUpdate() bool {
return !m.req.Stream && !m.req.ReadOnly && m.req.OmitPeers
}
func (m *mapSession) isReadOnlyUpdate() bool {
return !m.req.Stream && m.req.OmitPeers && m.req.ReadOnly
return !m.req.Stream && m.req.OmitPeers
}
func (m *mapSession) resetKeepAlive() {
@ -112,25 +97,22 @@ func (m *mapSession) resetKeepAlive() {
func (m *mapSession) beforeServeLongPoll() {
if m.node.IsEphemeral() {
m.h.ephemeralGC.Cancel(m.node.ID())
m.h.ephemeralGC.Cancel(m.node.ID)
}
}
func (m *mapSession) afterServeLongPoll() {
if m.node.IsEphemeral() {
m.h.ephemeralGC.Schedule(m.node.ID(), m.h.cfg.EphemeralNodeInactivityTimeout)
m.h.ephemeralGC.Schedule(m.node.ID, m.h.cfg.EphemeralNodeInactivityTimeout)
}
}
// serve handles non-streaming requests.
func (m *mapSession) serve() {
// TODO(kradalby): A set todos to harden:
// - func to tell the stream to die, readonly -> false, !stream && omitpeers -> false, true
// This is the mechanism where the node gives us information about its
// current configuration.
//
// If OmitPeers is true, Stream is false, and ReadOnly is false,
// If OmitPeers is true and Stream is false
// then the server will let clients update their endpoints without
// breaking existing long-polling (Stream == true) connections.
// In this case, the server can omit the entire response; the client
@ -138,26 +120,18 @@ func (m *mapSession) serve() {
//
// This is what Tailscale calls a Lite update, the client ignores
// the response and just wants a 200.
// !req.stream && !req.ReadOnly && req.OmitPeers
//
// TODO(kradalby): remove ReadOnly when we only support capVer 68+
// !req.stream && req.OmitPeers
if m.isEndpointUpdate() {
m.handleEndpointUpdate()
c, err := m.h.state.UpdateNodeFromMapRequest(m.node, m.req)
if err != nil {
httpError(m.w, err)
return
}
return
}
m.h.Change(c)
// ReadOnly is whether the client just wants to fetch the
// MapResponse, without updating their Endpoints. The
// Endpoints field will be ignored and LastSeen will not be
// updated and peers will not be notified of changes.
//
// The intended use is for clients to discover the DERP map at
// start-up before their first real endpoint update.
if m.isReadOnlyUpdate() {
m.handleReadOnlyRequest()
return
m.w.WriteHeader(http.StatusOK)
mapResponseEndpointUpdates.WithLabelValues("ok").Inc()
}
}
@ -175,23 +149,15 @@ func (m *mapSession) serveLongPoll() {
close(m.cancelCh)
m.cancelChMu.Unlock()
// only update node status if the node channel was removed.
// in principal, it will be removed, but the client rapidly
// reconnects, the channel might be of another connection.
// In that case, it is not closed and the node is still online.
if m.h.nodeNotifier.RemoveNode(m.node.ID(), m.ch) {
// TODO(kradalby): This can likely be made more effective, but likely most
// nodes has access to the same routes, so it might not be a big deal.
change, err := m.h.state.Disconnect(m.node.ID())
if err != nil {
m.errf(err, "Failed to disconnect node %s", m.node.Hostname())
}
if change {
ctx := types.NotifyCtx(context.Background(), "poll-primary-change", m.node.Hostname())
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
// TODO(kradalby): This can likely be made more effective, but likely most
// nodes has access to the same routes, so it might not be a big deal.
disconnectChange, err := m.h.state.Disconnect(m.node)
if err != nil {
m.errf(err, "Failed to disconnect node %s", m.node.Hostname)
}
m.h.Change(disconnectChange)
m.h.mapBatcher.RemoveNode(m.node.ID, m.ch, m.node.IsSubnetRouter())
m.afterServeLongPoll()
m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch)
@ -201,21 +167,30 @@ func (m *mapSession) serveLongPoll() {
m.h.pollNetMapStreamWG.Add(1)
defer m.h.pollNetMapStreamWG.Done()
m.h.state.Connect(m.node.ID())
// Upgrade the writer to a ResponseController
rc := http.NewResponseController(m.w)
// Longpolling will break if there is a write timeout,
// so it needs to be disabled.
rc.SetWriteDeadline(time.Time{})
ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname()))
ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname))
defer cancel()
m.keepAliveTicker = time.NewTicker(m.keepAlive)
m.h.nodeNotifier.AddNode(m.node.ID(), m.ch)
// Add node to batcher BEFORE sending Connect change to prevent race condition
// where the change is sent before the node is in the batcher's node map
if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.node.IsSubnetRouter(), m.capVer); err != nil {
m.errf(err, "failed to add node to batcher")
// Send empty response to client to fail fast for invalid/non-existent nodes
select {
case m.ch <- &tailcfg.MapResponse{}:
default:
// Channel might be closed
}
return
}
// Now send the Connect change - the batcher handles NodeCameOnline internally
// but we still need to update routes and other state-level changes
connectChange := m.h.state.Connect(m.node)
if !connectChange.Empty() && connectChange.Change != change.NodeCameOnline {
m.h.Change(connectChange)
}
m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch)
@ -236,290 +211,94 @@ func (m *mapSession) serveLongPoll() {
// Consume updates sent to node
case update, ok := <-m.ch:
m.tracef("received update from channel, ok: %t", ok)
if !ok {
m.tracef("update channel closed, streaming session is likely being replaced")
return
}
// If the node has been removed from headscale, close the stream
if slices.Contains(update.Removed, m.node.ID()) {
m.tracef("node removed, closing stream")
if err := m.writeMap(update); err != nil {
m.errf(err, "cannot write update to client")
return
}
m.tracef("received stream update: %s %s", update.Type.String(), update.Message)
mapResponseUpdateReceived.WithLabelValues(update.Type.String()).Inc()
var data []byte
var err error
var lastMessage string
// Ensure the node view is updated, for example, there
// might have been a hostinfo update in a sidechannel
// which contains data needed to generate a map response.
m.node, err = m.h.state.GetNodeViewByID(m.node.ID())
if err != nil {
m.errf(err, "Could not get machine from db")
return
}
updateType := "full"
switch update.Type {
case types.StateFullUpdate:
m.tracef("Sending Full MapResponse")
data, err = m.mapper.FullMapResponse(m.req, m.node, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming()))
case types.StatePeerChanged:
changed := make(map[types.NodeID]bool, len(update.ChangeNodes))
for _, nodeID := range update.ChangeNodes {
changed[nodeID] = true
}
lastMessage = update.Message
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage)
updateType = "change"
case types.StatePeerChangedPatch:
m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage))
data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches)
updateType = "patch"
case types.StatePeerRemoved:
changed := make(map[types.NodeID]bool, len(update.Removed))
for _, nodeID := range update.Removed {
changed[nodeID] = false
}
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage)
updateType = "remove"
case types.StateSelfUpdate:
lastMessage = update.Message
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
// create the map so an empty (self) update is sent
data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, lastMessage)
updateType = "remove"
case types.StateDERPUpdated:
m.tracef("Sending DERPUpdate MapResponse")
data, err = m.mapper.DERPMapResponse(m.req, m.node, m.h.state.DERPMap())
updateType = "derp"
}
if err != nil {
m.errf(err, "Could not get the create map update")
return
}
// Only send update if there is change
if data != nil {
startWrite := time.Now()
_, err = m.w.Write(data)
if err != nil {
mapResponseSent.WithLabelValues("error", updateType).Inc()
m.errf(err, "could not write the map response(%s), for mapSession: %p", update.Type.String(), m)
return
}
err = rc.Flush()
if err != nil {
mapResponseSent.WithLabelValues("error", updateType).Inc()
m.errf(err, "flushing the map response to client, for mapSession: %p", m)
return
}
log.Trace().Str("node", m.node.Hostname()).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey().String()).Msg("finished writing mapresp to node")
if debugHighCardinalityMetrics {
mapResponseLastSentSeconds.WithLabelValues(updateType, m.node.ID().String()).Set(float64(time.Now().Unix()))
}
mapResponseSent.WithLabelValues("ok", updateType).Inc()
m.tracef("update sent")
m.resetKeepAlive()
}
m.tracef("update sent")
m.resetKeepAlive()
case <-m.keepAliveTicker.C:
data, err := m.mapper.KeepAliveResponse(m.req, m.node)
if err != nil {
m.errf(err, "Error generating the keep alive msg")
mapResponseSent.WithLabelValues("error", "keepalive").Inc()
return
}
_, err = m.w.Write(data)
if err != nil {
m.errf(err, "Cannot write keep alive message")
mapResponseSent.WithLabelValues("error", "keepalive").Inc()
return
}
err = rc.Flush()
if err != nil {
m.errf(err, "flushing keep alive to client, for mapSession: %p", m)
mapResponseSent.WithLabelValues("error", "keepalive").Inc()
if err := m.writeMap(&keepAlive); err != nil {
m.errf(err, "cannot write keep alive")
return
}
if debugHighCardinalityMetrics {
mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID().String()).Set(float64(time.Now().Unix()))
mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID.String()).Set(float64(time.Now().Unix()))
}
mapResponseSent.WithLabelValues("ok", "keepalive").Inc()
}
}
}
func (m *mapSession) handleEndpointUpdate() {
m.tracef("received endpoint update")
// Get fresh node state from database for accurate route calculations
node, err := m.h.state.GetNodeByID(m.node.ID())
// writeMap writes the map response to the client.
// It handles compression if requested and any headers that need to be set.
// It also handles flushing the response if the ResponseWriter
// implements http.Flusher.
func (m *mapSession) writeMap(msg *tailcfg.MapResponse) error {
jsonBody, err := json.Marshal(msg)
if err != nil {
m.errf(err, "Failed to get fresh node from database for endpoint update")
http.Error(m.w, "", http.StatusInternalServerError)
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
return
return fmt.Errorf("marshalling map response: %w", err)
}
change := m.node.PeerChangeFromMapRequest(m.req)
online := m.h.nodeNotifier.IsLikelyConnected(m.node.ID())
change.Online = &online
node.ApplyPeerChange(&change)
sendUpdate, routesChanged := hostInfoChanged(node.Hostinfo, m.req.Hostinfo)
// The node might not set NetInfo if it has not changed and if
// the full HostInfo object is overwritten, the information is lost.
// If there is no NetInfo, keep the previous one.
// From 1.66 the client only sends it if changed:
// https://github.com/tailscale/tailscale/commit/e1011f138737286ecf5123ff887a7a5800d129a2
// TODO(kradalby): evaluate if we need better comparing of hostinfo
// before we take the changes.
if m.req.Hostinfo.NetInfo == nil && node.Hostinfo != nil {
m.req.Hostinfo.NetInfo = node.Hostinfo.NetInfo
}
node.Hostinfo = m.req.Hostinfo
logTracePeerChange(node.Hostname, sendUpdate, &change)
// If there is no changes and nothing to save,
// return early.
if peerChangeEmpty(change) && !sendUpdate {
mapResponseEndpointUpdates.WithLabelValues("noop").Inc()
return
if m.req.Compress == util.ZstdCompression {
jsonBody = zstdframe.AppendEncode(nil, jsonBody, zstdframe.FastestCompression)
}
// Auto approve any routes that have been defined in policy as
// auto approved. Check if this actually changed the node.
routesAutoApproved := m.h.state.AutoApproveRoutes(node)
data := make([]byte, reservedResponseHeaderSize)
binary.LittleEndian.PutUint32(data, uint32(len(jsonBody)))
data = append(data, jsonBody...)
// Always update routes for connected nodes to handle reconnection scenarios
// where routes need to be restored to the primary routes system
routesToSet := node.SubnetRoutes()
startWrite := time.Now()
if m.h.state.SetNodeRoutes(node.ID, routesToSet...) {
ctx := types.NotifyCtx(m.ctx, "poll-primary-change", node.Hostname)
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
} else if routesChanged {
// Only send peer changed notification if routes actually changed
ctx := types.NotifyCtx(m.ctx, "cli-approveroutes", node.Hostname)
m.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
// TODO(kradalby): I am not sure if we need this?
// Send an update to the node itself with to ensure it
// has an updated packetfilter allowing the new route
// if it is defined in the ACL.
ctx = types.NotifyCtx(m.ctx, "poll-nodeupdate-self-hostinfochange", node.Hostname)
m.h.nodeNotifier.NotifyByNodeID(
ctx,
types.UpdateSelf(node.ID),
node.ID)
_, err = m.w.Write(data)
if err != nil {
return err
}
// If routes were auto-approved, we need to save the node to persist the changes
if routesAutoApproved {
if _, _, err := m.h.state.SaveNode(node); err != nil {
m.errf(err, "Failed to save auto-approved routes to node")
http.Error(m.w, "", http.StatusInternalServerError)
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
return
if m.isStreaming() {
if f, ok := m.w.(http.Flusher); ok {
f.Flush()
} else {
m.errf(nil, "ResponseWriter does not implement http.Flusher, cannot flush")
}
}
// Check if there has been a change to Hostname and update them
// in the database. Then send a Changed update
// (containing the whole node object) to peers to inform about
// the hostname change.
node.ApplyHostnameFromHostInfo(m.req.Hostinfo)
log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey.String()).Msg("finished writing mapresp to node")
_, policyChanged, err := m.h.state.SaveNode(node)
if err != nil {
m.errf(err, "Failed to persist/update node in the database")
http.Error(m.w, "", http.StatusInternalServerError)
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
return
}
// Send policy update notifications if needed
if policyChanged {
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-policy", node.Hostname)
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
}
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-patch", node.Hostname)
m.h.nodeNotifier.NotifyWithIgnore(
ctx,
types.UpdatePeerChanged(node.ID),
node.ID,
)
m.w.WriteHeader(http.StatusOK)
mapResponseEndpointUpdates.WithLabelValues("ok").Inc()
return nil
}
func (m *mapSession) handleReadOnlyRequest() {
m.tracef("Client asked for a lite update, responding without peers")
mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node)
if err != nil {
m.errf(err, "Failed to create MapResponse")
http.Error(m.w, "", http.StatusInternalServerError)
mapResponseReadOnly.WithLabelValues("error").Inc()
return
}
m.w.Header().Set("Content-Type", "application/json; charset=utf-8")
m.w.WriteHeader(http.StatusOK)
_, err = m.w.Write(mapResp)
if err != nil {
m.errf(err, "Failed to write response")
mapResponseReadOnly.WithLabelValues("error").Inc()
return
}
m.w.WriteHeader(http.StatusOK)
mapResponseReadOnly.WithLabelValues("ok").Inc()
var keepAlive = tailcfg.MapResponse{
KeepAlive: true,
}
func logTracePeerChange(hostname string, hostinfoChange bool, change *tailcfg.PeerChange) {
trace := log.Trace().Uint64("node.id", uint64(change.NodeID)).Str("hostname", hostname)
func logTracePeerChange(hostname string, hostinfoChange bool, peerChange *tailcfg.PeerChange) {
trace := log.Trace().Uint64("node.id", uint64(peerChange.NodeID)).Str("hostname", hostname)
if change.Key != nil {
trace = trace.Str("node_key", change.Key.ShortString())
if peerChange.Key != nil {
trace = trace.Str("node_key", peerChange.Key.ShortString())
}
if change.DiscoKey != nil {
trace = trace.Str("disco_key", change.DiscoKey.ShortString())
if peerChange.DiscoKey != nil {
trace = trace.Str("disco_key", peerChange.DiscoKey.ShortString())
}
if change.Online != nil {
trace = trace.Bool("online", *change.Online)
if peerChange.Online != nil {
trace = trace.Bool("online", *peerChange.Online)
}
if change.Endpoints != nil {
eps := make([]string, len(change.Endpoints))
for idx, ep := range change.Endpoints {
if peerChange.Endpoints != nil {
eps := make([]string, len(peerChange.Endpoints))
for idx, ep := range peerChange.Endpoints {
eps[idx] = ep.String()
}
@ -530,21 +309,11 @@ func logTracePeerChange(hostname string, hostinfoChange bool, change *tailcfg.Pe
trace = trace.Bool("hostinfo_changed", hostinfoChange)
}
if change.DERPRegion != 0 {
trace = trace.Int("derp_region", change.DERPRegion)
if peerChange.DERPRegion != 0 {
trace = trace.Int("derp_region", peerChange.DERPRegion)
}
trace.Time("last_seen", *change.LastSeen).Msg("PeerChange received")
}
func peerChangeEmpty(chng tailcfg.PeerChange) bool {
return chng.Key == nil &&
chng.DiscoKey == nil &&
chng.Online == nil &&
chng.Endpoints == nil &&
chng.DERPRegion == 0 &&
chng.LastSeen == nil &&
chng.KeyExpiry == nil
trace.Time("last_seen", *peerChange.LastSeen).Msg("PeerChange received")
}
func logPollFunc(
@ -554,7 +323,6 @@ func logPollFunc(
return func(msg string, a ...any) {
log.Warn().
Caller().
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", node.ID.Uint64()).
@ -564,7 +332,6 @@ func logPollFunc(
func(msg string, a ...any) {
log.Info().
Caller().
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", node.ID.Uint64()).
@ -574,7 +341,6 @@ func logPollFunc(
func(msg string, a ...any) {
log.Trace().
Caller().
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", node.ID.Uint64()).
@ -584,7 +350,6 @@ func logPollFunc(
func(err error, msg string, a ...any) {
log.Error().
Caller().
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", node.ID.Uint64()).
@ -593,91 +358,3 @@ func logPollFunc(
Msgf(msg, a...)
}
}
func logPollFuncView(
mapRequest tailcfg.MapRequest,
nodeView types.NodeView,
) (func(string, ...any), func(string, ...any), func(string, ...any), func(error, string, ...any)) {
return func(msg string, a ...any) {
log.Warn().
Caller().
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", nodeView.ID().Uint64()).
Str("node", nodeView.Hostname()).
Msgf(msg, a...)
},
func(msg string, a ...any) {
log.Info().
Caller().
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", nodeView.ID().Uint64()).
Str("node", nodeView.Hostname()).
Msgf(msg, a...)
},
func(msg string, a ...any) {
log.Trace().
Caller().
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", nodeView.ID().Uint64()).
Str("node", nodeView.Hostname()).
Msgf(msg, a...)
},
func(err error, msg string, a ...any) {
log.Error().
Caller().
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", nodeView.ID().Uint64()).
Str("node", nodeView.Hostname()).
Err(err).
Msgf(msg, a...)
}
}
// hostInfoChanged reports if hostInfo has changed in two ways,
// - first bool reports if an update needs to be sent to nodes
// - second reports if there has been changes to routes
// the caller can then use this info to save and update nodes
// and routes as needed.
func hostInfoChanged(old, new *tailcfg.Hostinfo) (bool, bool) {
if old.Equal(new) {
return false, false
}
if old == nil && new != nil {
return true, true
}
// Routes
oldRoutes := make([]netip.Prefix, 0)
if old != nil {
oldRoutes = old.RoutableIPs
}
newRoutes := new.RoutableIPs
tsaddr.SortPrefixes(oldRoutes)
tsaddr.SortPrefixes(newRoutes)
if !xslices.Equal(oldRoutes, newRoutes) {
return true, true
}
// Services is mostly useful for discovery and not critical,
// except for peerapi, which is how nodes talk to each other.
// If peerapi was not part of the initial mapresponse, we
// need to make sure its sent out later as it is needed for
// Taildrop.
// TODO(kradalby): Length comparison is a bit naive, replace.
if len(old.Services) != len(new.Services) {
return true, false
}
return false, false
}

View File

@ -17,10 +17,13 @@ import (
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/routes"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"github.com/sasha-s/go-deadlock"
xslices "golang.org/x/exp/slices"
"gorm.io/gorm"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
@ -46,12 +49,6 @@ type State struct {
// cfg holds the current Headscale configuration
cfg *types.Config
// in-memory data, protected by mu
// nodes contains the current set of registered nodes
nodes types.Nodes
// users contains the current set of users/namespaces
users types.Users
// subsystem keeping state
// db provides persistent storage and database operations
db *hsdb.HSDatabase
@ -113,9 +110,6 @@ func NewState(cfg *types.Config) (*State, error) {
return &State{
cfg: cfg,
nodes: nodes,
users: users,
db: db,
ipAlloc: ipAlloc,
// TODO(kradalby): Update DERPMap
@ -215,6 +209,7 @@ func (s *State) CreateUser(user types.User) (*types.User, bool, error) {
s.mu.Lock()
defer s.mu.Unlock()
if err := s.db.DB.Save(&user).Error; err != nil {
return nil, false, fmt.Errorf("creating user: %w", err)
}
@ -226,6 +221,18 @@ func (s *State) CreateUser(user types.User) (*types.User, bool, error) {
return &user, false, fmt.Errorf("failed to update policy manager after user creation: %w", err)
}
// Even if the policy manager doesn't detect a filter change, SSH policies
// might now be resolvable when they weren't before. If there are existing
// nodes, we should send a policy change to ensure they get updated SSH policies.
if !policyChanged {
nodes, err := s.ListNodes()
if err == nil && len(nodes) > 0 {
policyChanged = true
}
}
log.Info().Str("user", user.Name).Bool("policyChanged", policyChanged).Msg("User created, policy manager updated")
// TODO(kradalby): implement the user in-memory cache
return &user, policyChanged, nil
@ -329,7 +336,7 @@ func (s *State) CreateNode(node *types.Node) (*types.Node, bool, error) {
}
// updateNodeTx performs a database transaction to update a node and refresh the policy manager.
func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) error) (*types.Node, bool, error) {
func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) error) (*types.Node, change.ChangeSet, error) {
s.mu.Lock()
defer s.mu.Unlock()
@ -350,72 +357,100 @@ func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) err
return node, nil
})
if err != nil {
return nil, false, err
return nil, change.EmptySet, err
}
// Check if policy manager needs updating
policyChanged, err := s.updatePolicyManagerNodes()
if err != nil {
return node, false, fmt.Errorf("failed to update policy manager after node update: %w", err)
return node, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err)
}
// TODO(kradalby): implement the node in-memory cache
return node, policyChanged, nil
var c change.ChangeSet
if policyChanged {
c = change.PolicyChange()
} else {
// Basic node change without specific details since this is a generic update
c = change.NodeAdded(node.ID)
}
return node, c, nil
}
// SaveNode persists an existing node to the database and updates the policy manager.
func (s *State) SaveNode(node *types.Node) (*types.Node, bool, error) {
func (s *State) SaveNode(node *types.Node) (*types.Node, change.ChangeSet, error) {
s.mu.Lock()
defer s.mu.Unlock()
if err := s.db.DB.Save(node).Error; err != nil {
return nil, false, fmt.Errorf("saving node: %w", err)
return nil, change.EmptySet, fmt.Errorf("saving node: %w", err)
}
// Check if policy manager needs updating
policyChanged, err := s.updatePolicyManagerNodes()
if err != nil {
return node, false, fmt.Errorf("failed to update policy manager after node save: %w", err)
return node, change.EmptySet, fmt.Errorf("failed to update policy manager after node save: %w", err)
}
// TODO(kradalby): implement the node in-memory cache
return node, policyChanged, nil
if policyChanged {
return node, change.PolicyChange(), nil
}
return node, change.EmptySet, nil
}
// DeleteNode permanently removes a node and cleans up associated resources.
// Returns whether policies changed and any error. This operation is irreversible.
func (s *State) DeleteNode(node *types.Node) (bool, error) {
func (s *State) DeleteNode(node *types.Node) (change.ChangeSet, error) {
err := s.db.DeleteNode(node)
if err != nil {
return false, err
return change.EmptySet, err
}
c := change.NodeRemoved(node.ID)
// Check if policy manager needs updating after node deletion
policyChanged, err := s.updatePolicyManagerNodes()
if err != nil {
return false, fmt.Errorf("failed to update policy manager after node deletion: %w", err)
return change.EmptySet, fmt.Errorf("failed to update policy manager after node deletion: %w", err)
}
return policyChanged, nil
if policyChanged {
c = change.PolicyChange()
}
return c, nil
}
func (s *State) Connect(id types.NodeID) {
func (s *State) Connect(node *types.Node) change.ChangeSet {
c := change.NodeOnline(node.ID)
routeChange := s.primaryRoutes.SetRoutes(node.ID, node.SubnetRoutes()...)
if routeChange {
c = change.NodeAdded(node.ID)
}
return c
}
func (s *State) Disconnect(id types.NodeID) (bool, error) {
// TODO(kradalby): This node should update the in memory state
_, polChanged, err := s.SetLastSeen(id, time.Now())
func (s *State) Disconnect(node *types.Node) (change.ChangeSet, error) {
c := change.NodeOffline(node.ID)
_, _, err := s.SetLastSeen(node.ID, time.Now())
if err != nil {
return false, fmt.Errorf("disconnecting node: %w", err)
return c, fmt.Errorf("disconnecting node: %w", err)
}
changed := s.primaryRoutes.SetRoutes(id)
if routeChange := s.primaryRoutes.SetRoutes(node.ID); routeChange {
c = change.PolicyChange()
}
// TODO(kradalby): the returned change should be more nuanced allowing us to
// send more directed updates.
return changed || polChanged, nil
// TODO(kradalby): This node should update the in memory state
return c, nil
}
// GetNodeByID retrieves a node by ID.
@ -475,45 +510,93 @@ func (s *State) ListEphemeralNodes() (types.Nodes, error) {
}
// SetNodeExpiry updates the expiration time for a node.
func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (*types.Node, bool, error) {
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (*types.Node, change.ChangeSet, error) {
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
return hsdb.NodeSetExpiry(tx, nodeID, expiry)
})
if err != nil {
return nil, change.EmptySet, fmt.Errorf("setting node expiry: %w", err)
}
if !c.IsFull() {
c = change.KeyExpiry(nodeID)
}
return n, c, nil
}
// SetNodeTags assigns tags to a node for use in access control policies.
func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (*types.Node, bool, error) {
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (*types.Node, change.ChangeSet, error) {
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
return hsdb.SetTags(tx, nodeID, tags)
})
if err != nil {
return nil, change.EmptySet, fmt.Errorf("setting node tags: %w", err)
}
if !c.IsFull() {
c = change.NodeAdded(nodeID)
}
return n, c, nil
}
// SetApprovedRoutes sets the network routes that a node is approved to advertise.
func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (*types.Node, bool, error) {
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (*types.Node, change.ChangeSet, error) {
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
return hsdb.SetApprovedRoutes(tx, nodeID, routes)
})
if err != nil {
return nil, change.EmptySet, fmt.Errorf("setting approved routes: %w", err)
}
// Update primary routes after changing approved routes
routeChange := s.primaryRoutes.SetRoutes(nodeID, n.SubnetRoutes()...)
if routeChange || !c.IsFull() {
c = change.PolicyChange()
}
return n, c, nil
}
// RenameNode changes the display name of a node.
func (s *State) RenameNode(nodeID types.NodeID, newName string) (*types.Node, bool, error) {
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
func (s *State) RenameNode(nodeID types.NodeID, newName string) (*types.Node, change.ChangeSet, error) {
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
return hsdb.RenameNode(tx, nodeID, newName)
})
if err != nil {
return nil, change.EmptySet, fmt.Errorf("renaming node: %w", err)
}
if !c.IsFull() {
c = change.NodeAdded(nodeID)
}
return n, c, nil
}
// SetLastSeen updates when a node was last seen, used for connectivity monitoring.
func (s *State) SetLastSeen(nodeID types.NodeID, lastSeen time.Time) (*types.Node, bool, error) {
func (s *State) SetLastSeen(nodeID types.NodeID, lastSeen time.Time) (*types.Node, change.ChangeSet, error) {
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
return hsdb.SetLastSeen(tx, nodeID, lastSeen)
})
}
// AssignNodeToUser transfers a node to a different user.
func (s *State) AssignNodeToUser(nodeID types.NodeID, userID types.UserID) (*types.Node, bool, error) {
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
func (s *State) AssignNodeToUser(nodeID types.NodeID, userID types.UserID) (*types.Node, change.ChangeSet, error) {
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
return hsdb.AssignNodeToUser(tx, nodeID, userID)
})
if err != nil {
return nil, change.EmptySet, fmt.Errorf("assigning node to user: %w", err)
}
if !c.IsFull() {
c = change.NodeAdded(nodeID)
}
return n, c, nil
}
// BackfillNodeIPs assigns IP addresses to nodes that don't have them.
@ -523,7 +606,7 @@ func (s *State) BackfillNodeIPs() ([]string, error) {
// ExpireExpiredNodes finds and processes expired nodes since the last check.
// Returns next check time, state update with expired nodes, and whether any were found.
func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, types.StateUpdate, bool) {
func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.ChangeSet, bool) {
return hsdb.ExpireExpiredNodes(s.db.DB, lastCheck)
}
@ -568,8 +651,14 @@ func (s *State) SetPolicyInDB(data string) (*types.Policy, error) {
}
// SetNodeRoutes sets the primary routes for a node.
func (s *State) SetNodeRoutes(nodeID types.NodeID, routes ...netip.Prefix) bool {
return s.primaryRoutes.SetRoutes(nodeID, routes...)
func (s *State) SetNodeRoutes(nodeID types.NodeID, routes ...netip.Prefix) change.ChangeSet {
if s.primaryRoutes.SetRoutes(nodeID, routes...) {
// Route changes affect packet filters for all nodes, so trigger a policy change
// to ensure filters are regenerated across the entire network
return change.PolicyChange()
}
return change.EmptySet
}
// GetNodePrimaryRoutes returns the primary routes for a node.
@ -653,10 +742,10 @@ func (s *State) HandleNodeFromAuthPath(
userID types.UserID,
expiry *time.Time,
registrationMethod string,
) (*types.Node, bool, error) {
) (*types.Node, change.ChangeSet, error) {
ipv4, ipv6, err := s.ipAlloc.Next()
if err != nil {
return nil, false, err
return nil, change.EmptySet, err
}
return s.db.HandleNodeFromAuthPath(
@ -672,12 +761,15 @@ func (s *State) HandleNodeFromAuthPath(
func (s *State) HandleNodeFromPreAuthKey(
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (*types.Node, bool, error) {
) (*types.Node, change.ChangeSet, bool, error) {
pak, err := s.GetPreAuthKey(regReq.Auth.AuthKey)
if err != nil {
return nil, change.EmptySet, false, err
}
err = pak.Validate()
if err != nil {
return nil, false, err
return nil, change.EmptySet, false, err
}
nodeToRegister := types.Node{
@ -698,22 +790,13 @@ func (s *State) HandleNodeFromPreAuthKey(
AuthKeyID: &pak.ID,
}
// For auth key registration, ensure we don't keep an expired node
// This is especially important for re-registration after logout
if !regReq.Expiry.IsZero() && regReq.Expiry.After(time.Now()) {
if !regReq.Expiry.IsZero() {
nodeToRegister.Expiry = &regReq.Expiry
} else if !regReq.Expiry.IsZero() {
// If client is sending an expired time (e.g., after logout),
// don't set expiry so the node won't be considered expired
log.Debug().
Time("requested_expiry", regReq.Expiry).
Str("node", regReq.Hostinfo.Hostname).
Msg("Ignoring expired expiry time from auth key registration")
}
ipv4, ipv6, err := s.ipAlloc.Next()
if err != nil {
return nil, false, fmt.Errorf("allocating IPs: %w", err)
return nil, change.EmptySet, false, fmt.Errorf("allocating IPs: %w", err)
}
node, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
@ -735,18 +818,38 @@ func (s *State) HandleNodeFromPreAuthKey(
return node, nil
})
if err != nil {
return nil, false, fmt.Errorf("writing node to database: %w", err)
return nil, change.EmptySet, false, fmt.Errorf("writing node to database: %w", err)
}
// Check if this is a logout request for an ephemeral node
if !regReq.Expiry.IsZero() && regReq.Expiry.Before(time.Now()) && pak.Ephemeral {
// This is a logout request for an ephemeral node, delete it immediately
c, err := s.DeleteNode(node)
if err != nil {
return nil, change.EmptySet, false, fmt.Errorf("deleting ephemeral node during logout: %w", err)
}
return nil, c, false, nil
}
// Check if policy manager needs updating
// This is necessary because we just created a new node.
// We need to ensure that the policy manager is aware of this new node.
policyChanged, err := s.updatePolicyManagerNodes()
// Also update users to ensure all users are known when evaluating policies.
usersChanged, err := s.updatePolicyManagerUsers()
if err != nil {
return nil, false, fmt.Errorf("failed to update policy manager after node registration: %w", err)
return nil, change.EmptySet, false, fmt.Errorf("failed to update policy manager users after node registration: %w", err)
}
return node, policyChanged, nil
nodesChanged, err := s.updatePolicyManagerNodes()
if err != nil {
return nil, change.EmptySet, false, fmt.Errorf("failed to update policy manager nodes after node registration: %w", err)
}
policyChanged := usersChanged || nodesChanged
c := change.NodeAdded(node.ID)
return node, c, policyChanged, nil
}
// AllocateNextIPs allocates the next available IPv4 and IPv6 addresses.
@ -766,11 +869,15 @@ func (s *State) updatePolicyManagerUsers() (bool, error) {
return false, fmt.Errorf("listing users for policy update: %w", err)
}
log.Debug().Int("userCount", len(users)).Msg("Updating policy manager with users")
changed, err := s.polMan.SetUsers(users)
if err != nil {
return false, fmt.Errorf("updating policy manager users: %w", err)
}
log.Debug().Bool("changed", changed).Msg("Policy manager users updated")
return changed, nil
}
@ -835,3 +942,125 @@ func (s *State) autoApproveNodes() error {
return nil
}
// TODO(kradalby): This should just take the node ID?
func (s *State) UpdateNodeFromMapRequest(node *types.Node, req tailcfg.MapRequest) (change.ChangeSet, error) {
// TODO(kradalby): This is essentially a patch update that could be sent directly to nodes,
// which means we could shortcut the whole change thing if there are no other important updates.
peerChange := node.PeerChangeFromMapRequest(req)
node.ApplyPeerChange(&peerChange)
sendUpdate, routesChanged := hostInfoChanged(node.Hostinfo, req.Hostinfo)
// The node might not set NetInfo if it has not changed and if
// the full HostInfo object is overwritten, the information is lost.
// If there is no NetInfo, keep the previous one.
// From 1.66 the client only sends it if changed:
// https://github.com/tailscale/tailscale/commit/e1011f138737286ecf5123ff887a7a5800d129a2
// TODO(kradalby): evaluate if we need better comparing of hostinfo
// before we take the changes.
if req.Hostinfo.NetInfo == nil && node.Hostinfo != nil {
req.Hostinfo.NetInfo = node.Hostinfo.NetInfo
}
node.Hostinfo = req.Hostinfo
// If there is no changes and nothing to save,
// return early.
if peerChangeEmpty(peerChange) && !sendUpdate {
// mapResponseEndpointUpdates.WithLabelValues("noop").Inc()
return change.EmptySet, nil
}
c := change.EmptySet
// Check if the Hostinfo of the node has changed.
// If it has changed, check if there has been a change to
// the routable IPs of the host and update them in
// the database. Then send a Changed update
// (containing the whole node object) to peers to inform about
// the route change.
// If the hostinfo has changed, but not the routes, just update
// hostinfo and let the function continue.
if routesChanged {
// Auto approve any routes that have been defined in policy as
// auto approved. Check if this actually changed the node.
_ = s.AutoApproveRoutes(node)
// Update the routes of the given node in the route manager to
// see if an update needs to be sent.
c = s.SetNodeRoutes(node.ID, node.SubnetRoutes()...)
}
// Check if there has been a change to Hostname and update them
// in the database. Then send a Changed update
// (containing the whole node object) to peers to inform about
// the hostname change.
node.ApplyHostnameFromHostInfo(req.Hostinfo)
_, policyChange, err := s.SaveNode(node)
if err != nil {
return change.EmptySet, err
}
if policyChange.IsFull() {
c = policyChange
}
if c.Empty() {
c = change.NodeAdded(node.ID)
}
return c, nil
}
// hostInfoChanged reports if hostInfo has changed in two ways,
// - first bool reports if an update needs to be sent to nodes
// - second reports if there has been changes to routes
// the caller can then use this info to save and update nodes
// and routes as needed.
func hostInfoChanged(old, new *tailcfg.Hostinfo) (bool, bool) {
if old.Equal(new) {
return false, false
}
if old == nil && new != nil {
return true, true
}
// Routes
oldRoutes := make([]netip.Prefix, 0)
if old != nil {
oldRoutes = old.RoutableIPs
}
newRoutes := new.RoutableIPs
tsaddr.SortPrefixes(oldRoutes)
tsaddr.SortPrefixes(newRoutes)
if !xslices.Equal(oldRoutes, newRoutes) {
return true, true
}
// Services is mostly useful for discovery and not critical,
// except for peerapi, which is how nodes talk to each other.
// If peerapi was not part of the initial mapresponse, we
// need to make sure its sent out later as it is needed for
// Taildrop.
// TODO(kradalby): Length comparison is a bit naive, replace.
if len(old.Services) != len(new.Services) {
return true, false
}
return false, false
}
func peerChangeEmpty(peerChange tailcfg.PeerChange) bool {
return peerChange.Key == nil &&
peerChange.DiscoKey == nil &&
peerChange.Online == nil &&
peerChange.Endpoints == nil &&
peerChange.DERPRegion == 0 &&
peerChange.LastSeen == nil &&
peerChange.KeyExpiry == nil
}

View File

@ -0,0 +1,183 @@
//go:generate go tool stringer -type=Change
package change
import (
"errors"
"github.com/juanfont/headscale/hscontrol/types"
)
type (
NodeID = types.NodeID
UserID = types.UserID
)
type Change int
const (
ChangeUnknown Change = 0
// Deprecated: Use specific change instead
// Full is a legacy change to ensure places where we
// have not yet determined the specific update, can send.
Full Change = 9
// Server changes.
Policy Change = 11
DERP Change = 12
ExtraRecords Change = 13
// Node changes.
NodeCameOnline Change = 21
NodeWentOffline Change = 22
NodeRemove Change = 23
NodeKeyExpiry Change = 24
NodeNewOrUpdate Change = 25
// User changes.
UserNewOrUpdate Change = 51
UserRemove Change = 52
)
// AlsoSelf reports whether this change should also be sent to the node itself.
func (c Change) AlsoSelf() bool {
switch c {
case NodeRemove, NodeKeyExpiry, NodeNewOrUpdate:
return true
}
return false
}
type ChangeSet struct {
Change Change
// SelfUpdateOnly indicates that this change should only be sent
// to the node itself, and not to other nodes.
// This is used for changes that are not relevant to other nodes.
// NodeID must be set if this is true.
SelfUpdateOnly bool
// NodeID if set, is the ID of the node that is being changed.
// It must be set if this is a node change.
NodeID types.NodeID
// UserID if set, is the ID of the user that is being changed.
// It must be set if this is a user change.
UserID types.UserID
// IsSubnetRouter indicates whether the node is a subnet router.
IsSubnetRouter bool
}
func (c *ChangeSet) Validate() error {
if c.Change >= NodeCameOnline || c.Change <= NodeNewOrUpdate {
if c.NodeID == 0 {
return errors.New("ChangeSet.NodeID must be set for node updates")
}
}
if c.Change >= UserNewOrUpdate || c.Change <= UserRemove {
if c.UserID == 0 {
return errors.New("ChangeSet.UserID must be set for user updates")
}
}
return nil
}
// Empty reports whether the ChangeSet is empty, meaning it does not
// represent any change.
func (c ChangeSet) Empty() bool {
return c.Change == ChangeUnknown && c.NodeID == 0 && c.UserID == 0
}
// IsFull reports whether the ChangeSet represents a full update.
func (c ChangeSet) IsFull() bool {
return c.Change == Full || c.Change == Policy
}
func (c ChangeSet) AlsoSelf() bool {
// If NodeID is 0, it means this ChangeSet is not related to a specific node,
// so we consider it as a change that should be sent to all nodes.
if c.NodeID == 0 {
return true
}
return c.Change.AlsoSelf() || c.SelfUpdateOnly
}
var (
EmptySet = ChangeSet{Change: ChangeUnknown}
FullSet = ChangeSet{Change: Full}
DERPSet = ChangeSet{Change: DERP}
PolicySet = ChangeSet{Change: Policy}
ExtraRecordsSet = ChangeSet{Change: ExtraRecords}
)
func FullSelf(id types.NodeID) ChangeSet {
return ChangeSet{
Change: Full,
SelfUpdateOnly: true,
NodeID: id,
}
}
func NodeAdded(id types.NodeID) ChangeSet {
return ChangeSet{
Change: NodeNewOrUpdate,
NodeID: id,
}
}
func NodeRemoved(id types.NodeID) ChangeSet {
return ChangeSet{
Change: NodeRemove,
NodeID: id,
}
}
func NodeOnline(id types.NodeID) ChangeSet {
return ChangeSet{
Change: NodeCameOnline,
NodeID: id,
}
}
func NodeOffline(id types.NodeID) ChangeSet {
return ChangeSet{
Change: NodeWentOffline,
NodeID: id,
}
}
func KeyExpiry(id types.NodeID) ChangeSet {
return ChangeSet{
Change: NodeKeyExpiry,
NodeID: id,
}
}
func UserAdded(id types.UserID) ChangeSet {
return ChangeSet{
Change: UserNewOrUpdate,
UserID: id,
}
}
func UserRemoved(id types.UserID) ChangeSet {
return ChangeSet{
Change: UserRemove,
UserID: id,
}
}
func PolicyChange() ChangeSet {
return ChangeSet{
Change: Policy,
}
}
func DERPChange() ChangeSet {
return ChangeSet{
Change: DERP,
}
}

View File

@ -0,0 +1,57 @@
// Code generated by "stringer -type=Change"; DO NOT EDIT.
package change
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[ChangeUnknown-0]
_ = x[Full-9]
_ = x[Policy-11]
_ = x[DERP-12]
_ = x[ExtraRecords-13]
_ = x[NodeCameOnline-21]
_ = x[NodeWentOffline-22]
_ = x[NodeRemove-23]
_ = x[NodeKeyExpiry-24]
_ = x[NodeNewOrUpdate-25]
_ = x[UserNewOrUpdate-51]
_ = x[UserRemove-52]
}
const (
_Change_name_0 = "ChangeUnknown"
_Change_name_1 = "Full"
_Change_name_2 = "PolicyDERPExtraRecords"
_Change_name_3 = "NodeCameOnlineNodeWentOfflineNodeRemoveNodeKeyExpiryNodeNewOrUpdate"
_Change_name_4 = "UserNewOrUpdateUserRemove"
)
var (
_Change_index_2 = [...]uint8{0, 6, 10, 22}
_Change_index_3 = [...]uint8{0, 14, 29, 39, 52, 67}
_Change_index_4 = [...]uint8{0, 15, 25}
)
func (i Change) String() string {
switch {
case i == 0:
return _Change_name_0
case i == 9:
return _Change_name_1
case 11 <= i && i <= 13:
i -= 11
return _Change_name_2[_Change_index_2[i]:_Change_index_2[i+1]]
case 21 <= i && i <= 25:
i -= 21
return _Change_name_3[_Change_index_3[i]:_Change_index_3[i+1]]
case 51 <= i && i <= 52:
i -= 51
return _Change_name_4[_Change_index_4[i]:_Change_index_4[i+1]]
default:
return "Change(" + strconv.FormatInt(int64(i), 10) + ")"
}
}

View File

@ -1,16 +1,16 @@
//go:generate go run tailscale.com/cmd/viewer --type=User,Node,PreAuthKey
//go:generate go tool viewer --type=User,Node,PreAuthKey
package types
//go:generate go run tailscale.com/cmd/viewer --type=User,Node,PreAuthKey
import (
"context"
"errors"
"fmt"
"runtime"
"time"
"github.com/juanfont/headscale/hscontrol/util"
"tailscale.com/tailcfg"
"tailscale.com/util/ctxkey"
)
const (
@ -150,18 +150,6 @@ func UpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate {
}
}
var (
NotifyOriginKey = ctxkey.New("notify.origin", "")
NotifyHostnameKey = ctxkey.New("notify.hostname", "")
)
func NotifyCtx(ctx context.Context, origin, hostname string) context.Context {
ctx2, _ := context.WithTimeout(ctx, 3*time.Second)
ctx2 = NotifyOriginKey.WithValue(ctx2, origin)
ctx2 = NotifyHostnameKey.WithValue(ctx2, hostname)
return ctx2
}
const RegistrationIDLength = 24
type RegistrationID string
@ -199,3 +187,20 @@ type RegisterNode struct {
Node Node
Registered chan *Node
}
// DefaultBatcherWorkers returns the default number of batcher workers.
// Default to 3/4 of CPU cores, minimum 1, no maximum.
func DefaultBatcherWorkers() int {
return DefaultBatcherWorkersFor(runtime.NumCPU())
}
// DefaultBatcherWorkersFor returns the default number of batcher workers for a given CPU count.
// Default to 3/4 of CPU cores, minimum 1, no maximum.
func DefaultBatcherWorkersFor(cpuCount int) int {
defaultWorkers := (cpuCount * 3) / 4
if defaultWorkers < 1 {
defaultWorkers = 1
}
return defaultWorkers
}

View File

@ -0,0 +1,36 @@
package types
import (
"testing"
)
func TestDefaultBatcherWorkersFor(t *testing.T) {
tests := []struct {
cpuCount int
expected int
}{
{1, 1}, // (1*3)/4 = 0, should be minimum 1
{2, 1}, // (2*3)/4 = 1
{4, 3}, // (4*3)/4 = 3
{8, 6}, // (8*3)/4 = 6
{12, 9}, // (12*3)/4 = 9
{16, 12}, // (16*3)/4 = 12
{20, 15}, // (20*3)/4 = 15
{24, 18}, // (24*3)/4 = 18
}
for _, test := range tests {
result := DefaultBatcherWorkersFor(test.cpuCount)
if result != test.expected {
t.Errorf("DefaultBatcherWorkersFor(%d) = %d, expected %d", test.cpuCount, result, test.expected)
}
}
}
func TestDefaultBatcherWorkers(t *testing.T) {
// Just verify it returns a valid value (>= 1)
result := DefaultBatcherWorkers()
if result < 1 {
t.Errorf("DefaultBatcherWorkers() = %d, expected value >= 1", result)
}
}

View File

@ -234,6 +234,7 @@ type Tuning struct {
NotifierSendTimeout time.Duration
BatchChangeDelay time.Duration
NodeMapSessionBufferedChanSize int
BatcherWorkers int
}
func validatePKCEMethod(method string) error {
@ -991,6 +992,12 @@ func LoadServerConfig() (*Config, error) {
NodeMapSessionBufferedChanSize: viper.GetInt(
"tuning.node_mapsession_buffered_chan_size",
),
BatcherWorkers: func() int {
if workers := viper.GetInt("tuning.batcher_workers"); workers > 0 {
return workers
}
return DefaultBatcherWorkers()
}(),
},
}, nil
}

View File

@ -431,6 +431,11 @@ func (node *Node) SubnetRoutes() []netip.Prefix {
return routes
}
// IsSubnetRouter reports if the node has any subnet routes.
func (node *Node) IsSubnetRouter() bool {
return len(node.SubnetRoutes()) > 0
}
func (node *Node) String() string {
return node.Hostname
}
@ -669,6 +674,13 @@ func (v NodeView) SubnetRoutes() []netip.Prefix {
return v.ж.SubnetRoutes()
}
func (v NodeView) IsSubnetRouter() bool {
if !v.Valid() {
return false
}
return v.ж.IsSubnetRouter()
}
func (v NodeView) AppendToIPSet(build *netipx.IPSetBuilder) {
if !v.Valid() {
return

View File

@ -1,17 +1,16 @@
package types
import (
"fmt"
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/rs/zerolog/log"
"google.golang.org/protobuf/types/known/timestamppb"
)
type PAKError string
func (e PAKError) Error() string { return string(e) }
func (e PAKError) Unwrap() error { return fmt.Errorf("preauth key error: %w", e) }
// PreAuthKey describes a pre-authorization key usable in a particular user.
type PreAuthKey struct {
@ -60,6 +59,21 @@ func (pak *PreAuthKey) Validate() error {
if pak == nil {
return PAKError("invalid authkey")
}
log.Debug().
Str("key", pak.Key).
Bool("hasExpiration", pak.Expiration != nil).
Time("expiration", func() time.Time {
if pak.Expiration != nil {
return *pak.Expiration
}
return time.Time{}
}()).
Time("now", time.Now()).
Bool("reusable", pak.Reusable).
Bool("used", pak.Used).
Msg("PreAuthKey.Validate: checking key")
if pak.Expiration != nil && pak.Expiration.Before(time.Now()) {
return PAKError("authkey expired")
}

View File

@ -5,6 +5,8 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"tailscale.com/util/dnsname"
"tailscale.com/util/must"
)
func TestCheckForFQDNRules(t *testing.T) {
@ -102,59 +104,16 @@ func TestConvertWithFQDNRules(t *testing.T) {
func TestMagicDNSRootDomains100(t *testing.T) {
domains := GenerateIPv4DNSRootDomain(netip.MustParsePrefix("100.64.0.0/10"))
found := false
for _, domain := range domains {
if domain == "64.100.in-addr.arpa." {
found = true
break
}
}
assert.True(t, found)
found = false
for _, domain := range domains {
if domain == "100.100.in-addr.arpa." {
found = true
break
}
}
assert.True(t, found)
found = false
for _, domain := range domains {
if domain == "127.100.in-addr.arpa." {
found = true
break
}
}
assert.True(t, found)
assert.Contains(t, domains, must.Get(dnsname.ToFQDN("64.100.in-addr.arpa.")))
assert.Contains(t, domains, must.Get(dnsname.ToFQDN("100.100.in-addr.arpa.")))
assert.Contains(t, domains, must.Get(dnsname.ToFQDN("127.100.in-addr.arpa.")))
}
func TestMagicDNSRootDomains172(t *testing.T) {
domains := GenerateIPv4DNSRootDomain(netip.MustParsePrefix("172.16.0.0/16"))
found := false
for _, domain := range domains {
if domain == "0.16.172.in-addr.arpa." {
found = true
break
}
}
assert.True(t, found)
found = false
for _, domain := range domains {
if domain == "255.16.172.in-addr.arpa." {
found = true
break
}
}
assert.True(t, found)
assert.Contains(t, domains, must.Get(dnsname.ToFQDN("0.16.172.in-addr.arpa.")))
assert.Contains(t, domains, must.Get(dnsname.ToFQDN("255.16.172.in-addr.arpa.")))
}
// Happens when netmask is a multiple of 4 bits (sounds likely).

View File

@ -143,7 +143,7 @@ func ParseTraceroute(output string) (Traceroute, error) {
// Parse latencies
for j := 5; j <= 7; j++ {
if matches[j] != "" {
if j < len(matches) && matches[j] != "" {
ms, err := strconv.ParseFloat(matches[j], 64)
if err != nil {
return Traceroute{}, fmt.Errorf("parsing latency: %w", err)

View File

@ -88,7 +88,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
var err error
listNodes, err = headscale.ListNodes()
assert.NoError(ct, err)
assert.Equal(ct, nodeCountBeforeLogout, len(listNodes), "Node count should match before logout count")
assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match before logout count")
}, 20*time.Second, 1*time.Second)
for _, node := range listNodes {
@ -123,7 +123,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
var err error
listNodes, err = headscale.ListNodes()
assert.NoError(ct, err)
assert.Equal(ct, nodeCountBeforeLogout, len(listNodes), "Node count should match after HTTPS reconnection")
assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match after HTTPS reconnection")
}, 30*time.Second, 2*time.Second)
for _, node := range listNodes {
@ -161,7 +161,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
}
listNodes, err = headscale.ListNodes()
require.Equal(t, nodeCountBeforeLogout, len(listNodes))
require.Len(t, listNodes, nodeCountBeforeLogout)
for _, node := range listNodes {
assertLastSeenSet(t, node)
}
@ -355,7 +355,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
"--user",
strconv.FormatUint(userMap[userName].GetId(), 10),
"expire",
key.Key,
key.GetKey(),
})
assertNoErr(t, err)

View File

@ -604,7 +604,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
status, err := client.Status()
assert.NoError(ct, err)
assert.NotContains(ct, []string{"Starting", "Running"}, status.BackendState,
assert.NotContains(ct, []string{"Starting", "Running"}, status.BackendState,
"Expected node to be logged out, backend state: %s", status.BackendState)
}, 30*time.Second, 2*time.Second)

View File

@ -147,3 +147,9 @@ func DockerAllowNetworkAdministration(config *docker.HostConfig) {
config.CapAdd = append(config.CapAdd, "NET_ADMIN")
config.Privileged = true
}
// DockerMemoryLimit sets memory limit and disables OOM kill for containers.
func DockerMemoryLimit(config *docker.HostConfig) {
config.Memory = 2 * 1024 * 1024 * 1024 // 2GB in bytes
config.OOMKillDisable = true
}

View File

@ -145,9 +145,9 @@ func derpServerScenario(
assert.NoError(ct, err, "Failed to get status for client %s", client.Hostname())
for _, health := range status.Health {
assert.NotContains(ct, health, "could not connect to any relay server",
assert.NotContains(ct, health, "could not connect to any relay server",
"Client %s should be connected to DERP relay", client.Hostname())
assert.NotContains(ct, health, "could not connect to the 'Headscale Embedded DERP' relay server.",
assert.NotContains(ct, health, "could not connect to the 'Headscale Embedded DERP' relay server.",
"Client %s should be connected to Headscale Embedded DERP", client.Hostname())
}
}, 30*time.Second, 2*time.Second)
@ -166,9 +166,9 @@ func derpServerScenario(
assert.NoError(ct, err, "Failed to get status for client %s", client.Hostname())
for _, health := range status.Health {
assert.NotContains(ct, health, "could not connect to any relay server",
assert.NotContains(ct, health, "could not connect to any relay server",
"Client %s should be connected to DERP relay after first run", client.Hostname())
assert.NotContains(ct, health, "could not connect to the 'Headscale Embedded DERP' relay server.",
assert.NotContains(ct, health, "could not connect to the 'Headscale Embedded DERP' relay server.",
"Client %s should be connected to Headscale Embedded DERP after first run", client.Hostname())
}
}, 30*time.Second, 2*time.Second)
@ -191,9 +191,9 @@ func derpServerScenario(
assert.NoError(ct, err, "Failed to get status for client %s", client.Hostname())
for _, health := range status.Health {
assert.NotContains(ct, health, "could not connect to any relay server",
assert.NotContains(ct, health, "could not connect to any relay server",
"Client %s should be connected to DERP relay after second run", client.Hostname())
assert.NotContains(ct, health, "could not connect to the 'Headscale Embedded DERP' relay server.",
assert.NotContains(ct, health, "could not connect to the 'Headscale Embedded DERP' relay server.",
"Client %s should be connected to Headscale Embedded DERP after second run", client.Hostname())
}
}, 30*time.Second, 2*time.Second)

View File

@ -883,6 +883,10 @@ func TestNodeOnlineStatus(t *testing.T) {
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
status, err := client.Status()
assert.NoError(ct, err)
if status == nil {
assert.Fail(ct, "status is nil")
return
}
for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey]
@ -984,16 +988,11 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
}
// Wait for sync and successful pings after nodes come back up
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
err = scenario.WaitForTailscaleSync()
assert.NoError(ct, err)
success := pingAllHelper(t, allClients, allAddrs)
assert.Greater(ct, success, 0, "Nodes should be able to ping after coming back up")
}, 30*time.Second, 2*time.Second)
err = scenario.WaitForTailscaleSync()
assert.NoError(t, err)
success := pingAllHelper(t, allClients, allAddrs)
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
assert.Equalf(t, len(allClients)*len(allIps), success, "%d successful pings out of %d", success, len(allClients)*len(allIps))
}
}

View File

@ -260,7 +260,9 @@ func WithDERPConfig(derpMap tailcfg.DERPMap) Option {
func WithTuning(batchTimeout time.Duration, mapSessionChanSize int) Option {
return func(hsic *HeadscaleInContainer) {
hsic.env["HEADSCALE_TUNING_BATCH_CHANGE_DELAY"] = batchTimeout.String()
hsic.env["HEADSCALE_TUNING_NODE_MAPSESSION_BUFFERED_CHAN_SIZE"] = strconv.Itoa(mapSessionChanSize)
hsic.env["HEADSCALE_TUNING_NODE_MAPSESSION_BUFFERED_CHAN_SIZE"] = strconv.Itoa(
mapSessionChanSize,
)
}
}
@ -279,10 +281,16 @@ func WithDebugPort(port int) Option {
// buildEntrypoint builds the container entrypoint command based on configuration.
func (hsic *HeadscaleInContainer) buildEntrypoint() []string {
debugCmd := fmt.Sprintf("/go/bin/dlv --listen=0.0.0.0:%d --headless=true --api-version=2 --accept-multiclient --allow-non-terminal-interactive=true exec /go/bin/headscale --continue -- serve", hsic.debugPort)
entrypoint := fmt.Sprintf("/bin/sleep 3 ; update-ca-certificates ; %s ; /bin/sleep 30", debugCmd)
debugCmd := fmt.Sprintf(
"/go/bin/dlv --listen=0.0.0.0:%d --headless=true --api-version=2 --accept-multiclient --allow-non-terminal-interactive=true exec /go/bin/headscale --continue -- serve",
hsic.debugPort,
)
entrypoint := fmt.Sprintf(
"/bin/sleep 3 ; update-ca-certificates ; %s ; /bin/sleep 30",
debugCmd,
)
return []string{"/bin/bash", "-c", entrypoint}
}
@ -447,8 +455,12 @@ func New(
log.Printf("Created %s container\n", hsic.hostname)
hsic.container = container
log.Printf("Debug ports for %s: delve=%s, metrics/pprof=49090\n", hsic.hostname, hsic.GetHostDebugPort())
log.Printf(
"Debug ports for %s: delve=%s, metrics/pprof=49090\n",
hsic.hostname,
hsic.GetHostDebugPort(),
)
// Write the CA certificates to the container
for i, cert := range hsic.caCerts {
@ -684,14 +696,6 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
return nil
}
// First, let's see what files are actually in /tmp
tmpListing, err := t.Execute([]string{"ls", "-la", "/tmp/"})
if err != nil {
log.Printf("Warning: could not list /tmp directory: %v", err)
} else {
log.Printf("Contents of /tmp in container %s:\n%s", t.hostname, tmpListing)
}
// Also check for any .sqlite files
sqliteFiles, err := t.Execute([]string{"find", "/tmp", "-name", "*.sqlite*", "-type", "f"})
if err != nil {
@ -718,12 +722,6 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
return errors.New("database file exists but has no schema (empty database)")
}
// Show a preview of the schema (first 500 chars)
schemaPreview := schemaCheck
if len(schemaPreview) > 500 {
schemaPreview = schemaPreview[:500] + "..."
}
tarFile, err := t.FetchPath("/tmp/integration_test_db.sqlite3")
if err != nil {
return fmt.Errorf("failed to fetch database file: %w", err)
@ -740,7 +738,12 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
return fmt.Errorf("failed to read tar header: %w", err)
}
log.Printf("Found file in tar: %s (type: %d, size: %d)", header.Name, header.Typeflag, header.Size)
log.Printf(
"Found file in tar: %s (type: %d, size: %d)",
header.Name,
header.Typeflag,
header.Size,
)
// Extract the first regular file we find
if header.Typeflag == tar.TypeReg {
@ -756,11 +759,20 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
return fmt.Errorf("failed to copy database file: %w", err)
}
log.Printf("Extracted database file: %s (%d bytes written, header claimed %d bytes)", dbPath, written, header.Size)
log.Printf(
"Extracted database file: %s (%d bytes written, header claimed %d bytes)",
dbPath,
written,
header.Size,
)
// Check if we actually wrote something
if written == 0 {
return fmt.Errorf("database file is empty (size: %d, header size: %d)", written, header.Size)
return fmt.Errorf(
"database file is empty (size: %d, header size: %d)",
written,
header.Size,
)
}
return nil
@ -871,7 +883,15 @@ func (t *HeadscaleInContainer) WaitForRunning() error {
func (t *HeadscaleInContainer) CreateUser(
user string,
) (*v1.User, error) {
command := []string{"headscale", "users", "create", user, fmt.Sprintf("--email=%s@test.no", user), "--output", "json"}
command := []string{
"headscale",
"users",
"create",
user,
fmt.Sprintf("--email=%s@test.no", user),
"--output",
"json",
}
result, _, err := dockertestutil.ExecuteCommand(
t.container,
@ -1182,13 +1202,18 @@ func (t *HeadscaleInContainer) ApproveRoutes(id uint64, routes []netip.Prefix) (
[]string{},
)
if err != nil {
return nil, fmt.Errorf("failed to execute list node command: %w", err)
return nil, fmt.Errorf(
"failed to execute approve routes command (node %d, routes %v): %w",
id,
routes,
err,
)
}
var node *v1.Node
err = json.Unmarshal([]byte(result), &node)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal nodes: %w", err)
return nil, fmt.Errorf("failed to unmarshal node response: %q, error: %w", result, err)
}
return node, nil

View File

@ -310,7 +310,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
// Enable route on node 1
t.Logf("Enabling route on subnet router 1, no HA")
_, err = headscale.ApproveRoutes(
1,
MustFindNode(subRouter1.Hostname(), nodes).GetId(),
[]netip.Prefix{pref},
)
require.NoError(t, err)
@ -366,7 +366,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
// Enable route on node 2, now we will have a HA subnet router
t.Logf("Enabling route on subnet router 2, now HA, subnetrouter 1 is primary, 2 is standby")
_, err = headscale.ApproveRoutes(
2,
MustFindNode(subRouter2.Hostname(), nodes).GetId(),
[]netip.Prefix{pref},
)
require.NoError(t, err)
@ -422,7 +422,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
// be enabled.
t.Logf("Enabling route on subnet router 3, now HA, subnetrouter 1 is primary, 2 and 3 is standby")
_, err = headscale.ApproveRoutes(
3,
MustFindNode(subRouter3.Hostname(), nodes).GetId(),
[]netip.Prefix{pref},
)
require.NoError(t, err)
@ -639,7 +639,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
t.Logf("disabling route in subnet router r3 (%s)", subRouter3.Hostname())
t.Logf("expecting route to failover to r1 (%s), which is still available with r2", subRouter1.Hostname())
_, err = headscale.ApproveRoutes(nodes[2].GetId(), []netip.Prefix{})
_, err = headscale.ApproveRoutes(MustFindNode(subRouter3.Hostname(), nodes).GetId(), []netip.Prefix{})
time.Sleep(5 * time.Second)
@ -647,9 +647,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
require.NoError(t, err)
assert.Len(t, nodes, 6)
requireNodeRouteCount(t, nodes[0], 1, 1, 1)
requireNodeRouteCount(t, nodes[1], 1, 1, 0)
requireNodeRouteCount(t, nodes[2], 1, 0, 0)
requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 1)
requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 0)
requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
// Verify that the route is announced from subnet router 1
clientStatus, err = client.Status()
@ -684,7 +684,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
// Disable the route of subnet router 1, making it failover to 2
t.Logf("disabling route in subnet router r1 (%s)", subRouter1.Hostname())
t.Logf("expecting route to failover to r2 (%s)", subRouter2.Hostname())
_, err = headscale.ApproveRoutes(nodes[0].GetId(), []netip.Prefix{})
_, err = headscale.ApproveRoutes(MustFindNode(subRouter1.Hostname(), nodes).GetId(), []netip.Prefix{})
time.Sleep(5 * time.Second)
@ -692,9 +692,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
require.NoError(t, err)
assert.Len(t, nodes, 6)
requireNodeRouteCount(t, nodes[0], 1, 0, 0)
requireNodeRouteCount(t, nodes[1], 1, 1, 1)
requireNodeRouteCount(t, nodes[2], 1, 0, 0)
requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 0, 0)
requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1)
requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
// Verify that the route is announced from subnet router 1
clientStatus, err = client.Status()
@ -729,9 +729,10 @@ func TestHASubnetRouterFailover(t *testing.T) {
// enable the route of subnet router 1, no change expected
t.Logf("enabling route in subnet router 1 (%s)", subRouter1.Hostname())
t.Logf("both online, expecting r2 (%s) to still be primary (no flapping)", subRouter2.Hostname())
r1Node := MustFindNode(subRouter1.Hostname(), nodes)
_, err = headscale.ApproveRoutes(
nodes[0].GetId(),
util.MustStringsToPrefixes(nodes[0].GetAvailableRoutes()),
r1Node.GetId(),
util.MustStringsToPrefixes(r1Node.GetAvailableRoutes()),
)
time.Sleep(5 * time.Second)
@ -740,9 +741,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
require.NoError(t, err)
assert.Len(t, nodes, 6)
requireNodeRouteCount(t, nodes[0], 1, 1, 0)
requireNodeRouteCount(t, nodes[1], 1, 1, 1)
requireNodeRouteCount(t, nodes[2], 1, 0, 0)
requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 0)
requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1)
requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
// Verify that the route is announced from subnet router 1
clientStatus, err = client.Status()

View File

@ -223,7 +223,7 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) {
s.userToNetwork = userToNetwork
if spec.OIDCUsers != nil && len(spec.OIDCUsers) != 0 {
if len(spec.OIDCUsers) != 0 {
ttl := defaultAccessTTL
if spec.OIDCAccessTTL != 0 {
ttl = spec.OIDCAccessTTL

View File

@ -370,10 +370,12 @@ func TestSSHUserOnlyIsolation(t *testing.T) {
}
func doSSH(t *testing.T, client TailscaleClient, peer TailscaleClient) (string, string, error) {
t.Helper()
return doSSHWithRetry(t, client, peer, true)
}
func doSSHWithoutRetry(t *testing.T, client TailscaleClient, peer TailscaleClient) (string, string, error) {
t.Helper()
return doSSHWithRetry(t, client, peer, false)
}

View File

@ -319,6 +319,7 @@ func New(
dockertestutil.DockerRestartPolicy,
dockertestutil.DockerAllowLocalIPv6,
dockertestutil.DockerAllowNetworkAdministration,
dockertestutil.DockerMemoryLimit,
)
case "unstable":
tailscaleOptions.Repository = "tailscale/tailscale"
@ -329,6 +330,7 @@ func New(
dockertestutil.DockerRestartPolicy,
dockertestutil.DockerAllowLocalIPv6,
dockertestutil.DockerAllowNetworkAdministration,
dockertestutil.DockerMemoryLimit,
)
default:
tailscaleOptions.Repository = "tailscale/tailscale"
@ -339,6 +341,7 @@ func New(
dockertestutil.DockerRestartPolicy,
dockertestutil.DockerAllowLocalIPv6,
dockertestutil.DockerAllowNetworkAdministration,
dockertestutil.DockerMemoryLimit,
)
}

View File

@ -22,11 +22,11 @@ import (
const (
// derpPingTimeout defines the timeout for individual DERP ping operations
// Used in DERP connectivity tests to verify relay server communication
// Used in DERP connectivity tests to verify relay server communication.
derpPingTimeout = 2 * time.Second
// derpPingCount defines the number of ping attempts for DERP connectivity tests
// Higher count provides better reliability assessment of DERP connectivity
// Higher count provides better reliability assessment of DERP connectivity.
derpPingCount = 10
)
@ -317,11 +317,11 @@ func assertValidNetcheck(t *testing.T, client TailscaleClient) {
// assertCommandOutputContains executes a command with exponential backoff retry until the output
// contains the expected string or timeout is reached (10 seconds).
// This implements eventual consistency patterns and should be used instead of time.Sleep
// This implements eventual consistency patterns and should be used instead of time.Sleep
// before executing commands that depend on network state propagation.
//
// Timeout: 10 seconds with exponential backoff
// Use cases: DNS resolution, route propagation, policy updates
// Use cases: DNS resolution, route propagation, policy updates.
func assertCommandOutputContains(t *testing.T, c TailscaleClient, command []string, contains string) {
t.Helper()
@ -361,10 +361,10 @@ func isSelfClient(client TailscaleClient, addr string) bool {
}
func dockertestMaxWait() time.Duration {
wait := 120 * time.Second //nolint
wait := 300 * time.Second //nolint
if util.IsCI() {
wait = 300 * time.Second //nolint
wait = 600 * time.Second //nolint
}
return wait

View File

@ -6,7 +6,6 @@ import (
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"regexp"
@ -21,7 +20,7 @@ import (
const (
releasesURL = "https://api.github.com/repos/tailscale/tailscale/releases"
rawFileURL = "https://github.com/tailscale/tailscale/raw/refs/tags/%s/tailcfg/tailcfg.go"
outputFile = "../capver_generated.go"
outputFile = "../../hscontrol/capver/capver_generated.go"
)
type Release struct {
@ -105,7 +104,7 @@ func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion
sortedVersions := xmaps.Keys(versions)
sort.Strings(sortedVersions)
for _, version := range sortedVersions {
file.WriteString(fmt.Sprintf("\t\"%s\": %d,\n", version, versions[version]))
fmt.Fprintf(file, "\t\"%s\": %d,\n", version, versions[version])
}
file.WriteString("}\n")
@ -115,16 +114,13 @@ func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion
capVarToTailscaleVer := make(map[tailcfg.CapabilityVersion]string)
for _, v := range sortedVersions {
cap := versions[v]
log.Printf("cap for v: %d, %s", cap, v)
// If it is already set, skip and continue,
// we only want the first tailscale vsion per
// capability vsion.
if _, ok := capVarToTailscaleVer[cap]; ok {
log.Printf("Skipping %d, %s", cap, v)
continue
}
log.Printf("Storing %d, %s", cap, v)
capVarToTailscaleVer[cap] = v
}
@ -133,7 +129,7 @@ func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion
return capsSorted[i] < capsSorted[j]
})
for _, capVer := range capsSorted {
file.WriteString(fmt.Sprintf("\t%d:\t\t\"%s\",\n", capVer, capVarToTailscaleVer[capVer]))
fmt.Fprintf(file, "\t%d:\t\t\"%s\",\n", capVer, capVarToTailscaleVer[capVer])
}
file.WriteString("}\n")