mirror of
https://github.com/juanfont/headscale.git
synced 2025-07-29 13:33:44 +00:00
mapper: produce map before poll (#2628)
This commit is contained in:
parent
b2a18830ed
commit
a058bf3cd3
@ -17,3 +17,7 @@ LICENSE
|
||||
.vscode
|
||||
|
||||
*.sock
|
||||
|
||||
node_modules/
|
||||
package-lock.json
|
||||
package.json
|
||||
|
55
.github/workflows/check-generated.yml
vendored
Normal file
55
.github/workflows/check-generated.yml
vendored
Normal file
@ -0,0 +1,55 @@
|
||||
name: Check Generated Files
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
check-generated:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: Get changed files
|
||||
id: changed-files
|
||||
uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2
|
||||
with:
|
||||
filters: |
|
||||
files:
|
||||
- '*.nix'
|
||||
- 'go.*'
|
||||
- '**/*.go'
|
||||
- '**/*.proto'
|
||||
- 'buf.gen.yaml'
|
||||
- 'tools/**'
|
||||
- uses: nixbuild/nix-quick-install-action@889f3180bb5f064ee9e3201428d04ae9e41d54ad # v31
|
||||
if: steps.changed-files.outputs.files == 'true'
|
||||
- uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3
|
||||
if: steps.changed-files.outputs.files == 'true'
|
||||
with:
|
||||
primary-key: nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix', '**/flake.lock') }}
|
||||
restore-prefixes-first-match: nix-${{ runner.os }}-${{ runner.arch }}
|
||||
|
||||
- name: Run make generate
|
||||
if: steps.changed-files.outputs.files == 'true'
|
||||
run: nix develop --command -- make generate
|
||||
|
||||
- name: Check for uncommitted changes
|
||||
if: steps.changed-files.outputs.files == 'true'
|
||||
run: |
|
||||
if ! git diff --exit-code; then
|
||||
echo "❌ Generated files are not up to date!"
|
||||
echo "Please run 'make generate' and commit the changes."
|
||||
exit 1
|
||||
else
|
||||
echo "✅ All generated files are up to date."
|
||||
fi
|
@ -77,7 +77,7 @@ jobs:
|
||||
attempt_delay: 300000 # 5 min
|
||||
attempt_limit: 2
|
||||
command: |
|
||||
nix develop --command -- hi run "^${{ inputs.test }}$" \
|
||||
nix develop --command -- hi run --stats --ts-memory-limit=300 --hs-memory-limit=500 "^${{ inputs.test }}$" \
|
||||
--timeout=120m \
|
||||
${{ inputs.postgres_flag }}
|
||||
- uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||
|
7
.gitignore
vendored
7
.gitignore
vendored
@ -1,6 +1,9 @@
|
||||
ignored/
|
||||
tailscale/
|
||||
.vscode/
|
||||
.claude/
|
||||
|
||||
*.prof
|
||||
|
||||
# Binaries for programs and plugins
|
||||
*.exe
|
||||
@ -46,3 +49,7 @@ integration_test/etc/config.dump.yaml
|
||||
/site
|
||||
|
||||
__debug_bin
|
||||
|
||||
node_modules/
|
||||
package-lock.json
|
||||
package.json
|
||||
|
@ -2,6 +2,8 @@
|
||||
|
||||
## Next
|
||||
|
||||
**Minimum supported Tailscale client version: v1.64.0**
|
||||
|
||||
### Database integrity improvements
|
||||
|
||||
This release includes a significant database migration that addresses longstanding
|
||||
|
395
CLAUDE.md
Normal file
395
CLAUDE.md
Normal file
@ -0,0 +1,395 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Overview
|
||||
|
||||
Headscale is an open-source implementation of the Tailscale control server written in Go. It provides self-hosted coordination for Tailscale networks (tailnets), managing node registration, IP allocation, policy enforcement, and DERP routing.
|
||||
|
||||
## Development Commands
|
||||
|
||||
### Quick Setup
|
||||
```bash
|
||||
# Recommended: Use Nix for dependency management
|
||||
nix develop
|
||||
|
||||
# Full development workflow
|
||||
make dev # runs fmt + lint + test + build
|
||||
```
|
||||
|
||||
### Essential Commands
|
||||
```bash
|
||||
# Build headscale binary
|
||||
make build
|
||||
|
||||
# Run tests
|
||||
make test
|
||||
go test ./... # All unit tests
|
||||
go test -race ./... # With race detection
|
||||
|
||||
# Run specific integration test
|
||||
go run ./cmd/hi run "TestName" --postgres
|
||||
|
||||
# Code formatting and linting
|
||||
make fmt # Format all code (Go, docs, proto)
|
||||
make lint # Lint all code (Go, proto)
|
||||
make fmt-go # Format Go code only
|
||||
make lint-go # Lint Go code only
|
||||
|
||||
# Protocol buffer generation (after modifying proto/)
|
||||
make generate
|
||||
|
||||
# Clean build artifacts
|
||||
make clean
|
||||
```
|
||||
|
||||
### Integration Testing
|
||||
```bash
|
||||
# Use the hi (Headscale Integration) test runner
|
||||
go run ./cmd/hi doctor # Check system requirements
|
||||
go run ./cmd/hi run "TestPattern" # Run specific test
|
||||
go run ./cmd/hi run "TestPattern" --postgres # With PostgreSQL backend
|
||||
|
||||
# Test artifacts are saved to control_logs/ with logs and debug data
|
||||
```
|
||||
|
||||
## Project Structure & Architecture
|
||||
|
||||
### Top-Level Organization
|
||||
|
||||
```
|
||||
headscale/
|
||||
├── cmd/ # Command-line applications
|
||||
│ ├── headscale/ # Main headscale server binary
|
||||
│ └── hi/ # Headscale Integration test runner
|
||||
├── hscontrol/ # Core control plane logic
|
||||
├── integration/ # End-to-end Docker-based tests
|
||||
├── proto/ # Protocol buffer definitions
|
||||
├── gen/ # Generated code (protobuf)
|
||||
├── docs/ # Documentation
|
||||
└── packaging/ # Distribution packaging
|
||||
```
|
||||
|
||||
### Core Packages (`hscontrol/`)
|
||||
|
||||
**Main Server (`hscontrol/`)**
|
||||
- `app.go`: Application setup, dependency injection, server lifecycle
|
||||
- `handlers.go`: HTTP/gRPC API endpoints for management operations
|
||||
- `grpcv1.go`: gRPC service implementation for headscale API
|
||||
- `poll.go`: **Critical** - Handles Tailscale MapRequest/MapResponse protocol
|
||||
- `noise.go`: Noise protocol implementation for secure client communication
|
||||
- `auth.go`: Authentication flows (web, OIDC, command-line)
|
||||
- `oidc.go`: OpenID Connect integration for user authentication
|
||||
|
||||
**State Management (`hscontrol/state/`)**
|
||||
- `state.go`: Central coordinator for all subsystems (database, policy, IP allocation, DERP)
|
||||
- `node_store.go`: **Performance-critical** - In-memory cache with copy-on-write semantics
|
||||
- Thread-safe operations with deadlock detection
|
||||
- Coordinates between database persistence and real-time operations
|
||||
|
||||
**Database Layer (`hscontrol/db/`)**
|
||||
- `db.go`: Database abstraction, GORM setup, migration management
|
||||
- `node.go`: Node lifecycle, registration, expiration, IP assignment
|
||||
- `users.go`: User management, namespace isolation
|
||||
- `api_key.go`: API authentication tokens
|
||||
- `preauth_keys.go`: Pre-authentication keys for automated node registration
|
||||
- `ip.go`: IP address allocation and management
|
||||
- `policy.go`: Policy storage and retrieval
|
||||
- Schema migrations in `schema.sql` with extensive test data coverage
|
||||
|
||||
**Policy Engine (`hscontrol/policy/`)**
|
||||
- `policy.go`: Core ACL evaluation logic, HuJSON parsing
|
||||
- `v2/`: Next-generation policy system with improved filtering
|
||||
- `matcher/`: ACL rule matching and evaluation engine
|
||||
- Determines peer visibility, route approval, and network access rules
|
||||
- Supports both file-based and database-stored policies
|
||||
|
||||
**Network Management (`hscontrol/`)**
|
||||
- `derp/`: DERP (Designated Encrypted Relay for Packets) server implementation
|
||||
- NAT traversal when direct connections fail
|
||||
- Fallback relay for firewall-restricted environments
|
||||
- `mapper/`: Converts internal Headscale state to Tailscale's wire protocol format
|
||||
- `tail.go`: Tailscale-specific data structure generation
|
||||
- `routes/`: Subnet route management and primary route selection
|
||||
- `dns/`: DNS record management and MagicDNS implementation
|
||||
|
||||
**Utilities & Support (`hscontrol/`)**
|
||||
- `types/`: Core data structures, configuration, validation
|
||||
- `util/`: Helper functions for networking, DNS, key management
|
||||
- `templates/`: Client configuration templates (Apple, Windows, etc.)
|
||||
- `notifier/`: Event notification system for real-time updates
|
||||
- `metrics.go`: Prometheus metrics collection
|
||||
- `capver/`: Tailscale capability version management
|
||||
|
||||
### Key Subsystem Interactions
|
||||
|
||||
**Node Registration Flow**
|
||||
1. **Client Connection**: `noise.go` handles secure protocol handshake
|
||||
2. **Authentication**: `auth.go` validates credentials (web/OIDC/preauth)
|
||||
3. **State Creation**: `state.go` coordinates IP allocation via `db/ip.go`
|
||||
4. **Storage**: `db/node.go` persists node, `NodeStore` caches in memory
|
||||
5. **Network Setup**: `mapper/` generates initial Tailscale network map
|
||||
|
||||
**Ongoing Operations**
|
||||
1. **Poll Requests**: `poll.go` receives periodic client updates
|
||||
2. **State Updates**: `NodeStore` maintains real-time node information
|
||||
3. **Policy Application**: `policy/` evaluates ACL rules for peer relationships
|
||||
4. **Map Distribution**: `mapper/` sends network topology to all affected clients
|
||||
|
||||
**Route Management**
|
||||
1. **Advertisement**: Clients announce routes via `poll.go` Hostinfo updates
|
||||
2. **Storage**: `db/` persists routes, `NodeStore` caches for performance
|
||||
3. **Approval**: `policy/` auto-approves routes based on ACL rules
|
||||
4. **Distribution**: `routes/` selects primary routes, `mapper/` distributes to peers
|
||||
|
||||
### Command-Line Tools (`cmd/`)
|
||||
|
||||
**Main Server (`cmd/headscale/`)**
|
||||
- `headscale.go`: CLI parsing, configuration loading, server startup
|
||||
- Supports daemon mode, CLI operations (user/node management), database operations
|
||||
|
||||
**Integration Test Runner (`cmd/hi/`)**
|
||||
- `main.go`: Test execution framework with Docker orchestration
|
||||
- `run.go`: Individual test execution with artifact collection
|
||||
- `doctor.go`: System requirements validation
|
||||
- `docker.go`: Container lifecycle management
|
||||
- Essential for validating changes against real Tailscale clients
|
||||
|
||||
### Generated & External Code
|
||||
|
||||
**Protocol Buffers (`proto/` → `gen/`)**
|
||||
- Defines gRPC API for headscale management operations
|
||||
- Client libraries can generate from these definitions
|
||||
- Run `make generate` after modifying `.proto` files
|
||||
|
||||
**Integration Testing (`integration/`)**
|
||||
- `scenario.go`: Docker test environment setup
|
||||
- `tailscale.go`: Tailscale client container management
|
||||
- Individual test files for specific functionality areas
|
||||
- Real end-to-end validation with network isolation
|
||||
|
||||
### Critical Performance Paths
|
||||
|
||||
**High-Frequency Operations**
|
||||
1. **MapRequest Processing** (`poll.go`): Every 15-60 seconds per client
|
||||
2. **NodeStore Reads** (`node_store.go`): Every operation requiring node data
|
||||
3. **Policy Evaluation** (`policy/`): On every peer relationship calculation
|
||||
4. **Route Lookups** (`routes/`): During network map generation
|
||||
|
||||
**Database Write Patterns**
|
||||
- **Frequent**: Node heartbeats, endpoint updates, route changes
|
||||
- **Moderate**: User operations, policy updates, API key management
|
||||
- **Rare**: Schema migrations, bulk operations
|
||||
|
||||
### Configuration & Deployment
|
||||
|
||||
**Configuration** (`hscontrol/types/config.go`)**
|
||||
- Database connection settings (SQLite/PostgreSQL)
|
||||
- Network configuration (IP ranges, DNS settings)
|
||||
- Policy mode (file vs database)
|
||||
- DERP relay configuration
|
||||
- OIDC provider settings
|
||||
|
||||
**Key Dependencies**
|
||||
- **GORM**: Database ORM with migration support
|
||||
- **Tailscale Libraries**: Core networking and protocol code
|
||||
- **Zerolog**: Structured logging throughout the application
|
||||
- **Buf**: Protocol buffer toolchain for code generation
|
||||
|
||||
### Development Workflow Integration
|
||||
|
||||
The architecture supports incremental development:
|
||||
- **Unit Tests**: Focus on individual packages (`*_test.go` files)
|
||||
- **Integration Tests**: Validate cross-component interactions
|
||||
- **Database Tests**: Extensive migration and data integrity validation
|
||||
- **Policy Tests**: ACL rule evaluation and edge cases
|
||||
- **Performance Tests**: NodeStore and high-frequency operation validation
|
||||
|
||||
## Integration Test System
|
||||
|
||||
### Overview
|
||||
Integration tests use Docker containers running real Tailscale clients against a Headscale server. Tests validate end-to-end functionality including routing, ACLs, node lifecycle, and network coordination.
|
||||
|
||||
### Running Integration Tests
|
||||
|
||||
**System Requirements**
|
||||
```bash
|
||||
# Check if your system is ready
|
||||
go run ./cmd/hi doctor
|
||||
```
|
||||
This verifies Docker, Go, required images, and disk space.
|
||||
|
||||
**Test Execution Patterns**
|
||||
```bash
|
||||
# Run a single test (recommended for development)
|
||||
go run ./cmd/hi run "TestSubnetRouterMultiNetwork"
|
||||
|
||||
# Run with PostgreSQL backend (for database-heavy tests)
|
||||
go run ./cmd/hi run "TestExpireNode" --postgres
|
||||
|
||||
# Run multiple tests with pattern matching
|
||||
go run ./cmd/hi run "TestSubnet*"
|
||||
|
||||
# Run all integration tests (CI/full validation)
|
||||
go test ./integration -timeout 30m
|
||||
```
|
||||
|
||||
**Test Categories & Timing**
|
||||
- **Fast tests** (< 2 min): Basic functionality, CLI operations
|
||||
- **Medium tests** (2-5 min): Route management, ACL validation
|
||||
- **Slow tests** (5+ min): Node expiration, HA failover
|
||||
- **Long-running tests** (10+ min): `TestNodeOnlineStatus` (12 min duration)
|
||||
|
||||
### Test Infrastructure
|
||||
|
||||
**Docker Setup**
|
||||
- Headscale server container with configurable database backend
|
||||
- Multiple Tailscale client containers with different versions
|
||||
- Isolated networks per test scenario
|
||||
- Automatic cleanup after test completion
|
||||
|
||||
**Test Artifacts**
|
||||
All test runs save artifacts to `control_logs/TIMESTAMP-ID/`:
|
||||
```
|
||||
control_logs/20250713-213106-iajsux/
|
||||
├── hs-testname-abc123.stderr.log # Headscale server logs
|
||||
├── hs-testname-abc123.stdout.log
|
||||
├── hs-testname-abc123.db # Database snapshot
|
||||
├── hs-testname-abc123_metrics.txt # Prometheus metrics
|
||||
├── hs-testname-abc123-mapresponses/ # Protocol debug data
|
||||
├── ts-client-xyz789.stderr.log # Tailscale client logs
|
||||
├── ts-client-xyz789.stdout.log
|
||||
└── ts-client-xyz789_status.json # Client status dump
|
||||
```
|
||||
|
||||
### Test Development Guidelines
|
||||
|
||||
**Timing Considerations**
|
||||
Integration tests involve real network operations and Docker container lifecycle:
|
||||
|
||||
```go
|
||||
// ❌ Wrong: Immediate assertions after async operations
|
||||
client.Execute([]string{"tailscale", "set", "--advertise-routes=10.0.0.0/24"})
|
||||
nodes, _ := headscale.ListNodes()
|
||||
require.Len(t, nodes[0].GetAvailableRoutes(), 1) // May fail due to timing
|
||||
|
||||
// ✅ Correct: Wait for async operations to complete
|
||||
client.Execute([]string{"tailscale", "set", "--advertise-routes=10.0.0.0/24"})
|
||||
require.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
nodes, err := headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, nodes[0].GetAvailableRoutes(), 1)
|
||||
}, 10*time.Second, 100*time.Millisecond, "route should be advertised")
|
||||
```
|
||||
|
||||
**Common Test Patterns**
|
||||
- **Route Advertisement**: Use `EventuallyWithT` for route propagation
|
||||
- **Node State Changes**: Wait for NodeStore synchronization
|
||||
- **ACL Policy Changes**: Allow time for policy recalculation
|
||||
- **Network Connectivity**: Use ping tests with retries
|
||||
|
||||
**Test Data Management**
|
||||
```go
|
||||
// Node identification: Don't assume array ordering
|
||||
expectedRoutes := map[string]string{"1": "10.33.0.0/16"}
|
||||
for _, node := range nodes {
|
||||
nodeIDStr := fmt.Sprintf("%d", node.GetId())
|
||||
if route, shouldHaveRoute := expectedRoutes[nodeIDStr]; shouldHaveRoute {
|
||||
// Test the node that should have the route
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Troubleshooting Integration Tests
|
||||
|
||||
**Common Failure Patterns**
|
||||
1. **Timing Issues**: Test assertions run before async operations complete
|
||||
- **Solution**: Use `EventuallyWithT` with appropriate timeouts
|
||||
- **Timeout Guidelines**: 3-5s for route operations, 10s for complex scenarios
|
||||
|
||||
2. **Infrastructure Problems**: Disk space, Docker issues, network conflicts
|
||||
- **Check**: `go run ./cmd/hi doctor` for system health
|
||||
- **Clean**: Remove old test containers and networks
|
||||
|
||||
3. **NodeStore Synchronization**: Tests expecting immediate data availability
|
||||
- **Key Points**: Route advertisements must propagate through poll requests
|
||||
- **Fix**: Wait for NodeStore updates after Hostinfo changes
|
||||
|
||||
4. **Database Backend Differences**: SQLite vs PostgreSQL behavior differences
|
||||
- **Use**: `--postgres` flag for database-intensive tests
|
||||
- **Note**: Some timing characteristics differ between backends
|
||||
|
||||
**Debugging Failed Tests**
|
||||
1. **Check test artifacts** in `control_logs/` for detailed logs
|
||||
2. **Examine MapResponse JSON** files for protocol-level debugging
|
||||
3. **Review Headscale stderr logs** for server-side error messages
|
||||
4. **Check Tailscale client status** for network-level issues
|
||||
|
||||
**Resource Management**
|
||||
- Tests require significant disk space (each run ~100MB of logs)
|
||||
- Docker containers are cleaned up automatically on success
|
||||
- Failed tests may leave containers running - clean manually if needed
|
||||
- Use `docker system prune` periodically to reclaim space
|
||||
|
||||
### Best Practices for Test Modifications
|
||||
|
||||
1. **Always test locally** before committing integration test changes
|
||||
2. **Use appropriate timeouts** - too short causes flaky tests, too long slows CI
|
||||
3. **Clean up properly** - ensure tests don't leave persistent state
|
||||
4. **Handle both success and failure paths** in test scenarios
|
||||
5. **Document timing requirements** for complex test scenarios
|
||||
|
||||
## NodeStore Implementation Details
|
||||
|
||||
**Key Insight from Recent Work**: The NodeStore is a critical performance optimization that caches node data in memory while ensuring consistency with the database. When working with route advertisements or node state changes:
|
||||
|
||||
1. **Timing Considerations**: Route advertisements need time to propagate from clients to server. Use `require.EventuallyWithT()` patterns in tests instead of immediate assertions.
|
||||
|
||||
2. **Synchronization Points**: NodeStore updates happen at specific points like `poll.go:420` after Hostinfo changes. Ensure these are maintained when modifying the polling logic.
|
||||
|
||||
3. **Peer Visibility**: The NodeStore's `peersFunc` determines which nodes are visible to each other. Policy-based filtering is separate from monitoring visibility - expired nodes should remain visible for debugging but marked as expired.
|
||||
|
||||
## Testing Guidelines
|
||||
|
||||
### Integration Test Patterns
|
||||
```go
|
||||
// Use EventuallyWithT for async operations
|
||||
require.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
nodes, err := headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
// Check expected state
|
||||
}, 10*time.Second, 100*time.Millisecond, "description")
|
||||
|
||||
// Node route checking by actual node properties, not array position
|
||||
var routeNode *v1.Node
|
||||
for _, node := range nodes {
|
||||
if nodeIDStr := fmt.Sprintf("%d", node.GetId()); expectedRoutes[nodeIDStr] != "" {
|
||||
routeNode = node
|
||||
break
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Running Problematic Tests
|
||||
- Some tests require significant time (e.g., `TestNodeOnlineStatus` runs for 12 minutes)
|
||||
- Infrastructure issues like disk space can cause test failures unrelated to code changes
|
||||
- Use `--postgres` flag when testing database-heavy scenarios
|
||||
|
||||
## Important Notes
|
||||
|
||||
- **Dependencies**: Use `nix develop` for consistent toolchain (Go, buf, protobuf tools, linting)
|
||||
- **Protocol Buffers**: Changes to `proto/` require `make generate` and should be committed separately
|
||||
- **Code Style**: Enforced via golangci-lint with golines (width 88) and gofumpt formatting
|
||||
- **Database**: Supports both SQLite (development) and PostgreSQL (production/testing)
|
||||
- **Integration Tests**: Require Docker and can consume significant disk space
|
||||
- **Performance**: NodeStore optimizations are critical for scale - be careful with changes to state management
|
||||
|
||||
## Debugging Integration Tests
|
||||
|
||||
Test artifacts are preserved in `control_logs/TIMESTAMP-ID/` including:
|
||||
- Headscale server logs (stderr/stdout)
|
||||
- Tailscale client logs and status
|
||||
- Database dumps and network captures
|
||||
- MapResponse JSON files for protocol debugging
|
||||
|
||||
When tests fail, check these artifacts first before assuming code issues.
|
7
Makefile
7
Makefile
@ -87,10 +87,9 @@ lint-proto: check-deps $(PROTO_SOURCES)
|
||||
|
||||
# Code generation
|
||||
.PHONY: generate
|
||||
generate: check-deps $(PROTO_SOURCES)
|
||||
@echo "Generating code from Protocol Buffers..."
|
||||
rm -rf gen
|
||||
buf generate proto
|
||||
generate: check-deps
|
||||
@echo "Generating code..."
|
||||
go generate ./...
|
||||
|
||||
# Clean targets
|
||||
.PHONY: clean
|
||||
|
@ -212,13 +212,10 @@ var listUsersCmd = &cobra.Command{
|
||||
switch {
|
||||
case id > 0:
|
||||
request.Id = uint64(id)
|
||||
break
|
||||
case username != "":
|
||||
request.Name = username
|
||||
break
|
||||
case email != "":
|
||||
request.Email = email
|
||||
break
|
||||
}
|
||||
|
||||
response, err := client.ListUsers(ctx, request)
|
||||
|
@ -90,6 +90,32 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
||||
|
||||
log.Printf("Starting test: %s", config.TestPattern)
|
||||
|
||||
// Start stats collection for container resource monitoring (if enabled)
|
||||
var statsCollector *StatsCollector
|
||||
if config.Stats {
|
||||
var err error
|
||||
statsCollector, err = NewStatsCollector()
|
||||
if err != nil {
|
||||
if config.Verbose {
|
||||
log.Printf("Warning: failed to create stats collector: %v", err)
|
||||
}
|
||||
statsCollector = nil
|
||||
}
|
||||
|
||||
if statsCollector != nil {
|
||||
defer statsCollector.Close()
|
||||
|
||||
// Start stats collection immediately - no need for complex retry logic
|
||||
// The new implementation monitors Docker events and will catch containers as they start
|
||||
if err := statsCollector.StartCollection(ctx, runID, config.Verbose); err != nil {
|
||||
if config.Verbose {
|
||||
log.Printf("Warning: failed to start stats collection: %v", err)
|
||||
}
|
||||
}
|
||||
defer statsCollector.StopCollection()
|
||||
}
|
||||
}
|
||||
|
||||
exitCode, err := streamAndWait(ctx, cli, resp.ID)
|
||||
|
||||
// Ensure all containers have finished and logs are flushed before extracting artifacts
|
||||
@ -105,6 +131,20 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
||||
// Always list control files regardless of test outcome
|
||||
listControlFiles(logsDir)
|
||||
|
||||
// Print stats summary and check memory limits if enabled
|
||||
if config.Stats && statsCollector != nil {
|
||||
violations := statsCollector.PrintSummaryAndCheckLimits(config.HSMemoryLimit, config.TSMemoryLimit)
|
||||
if len(violations) > 0 {
|
||||
log.Printf("MEMORY LIMIT VIOLATIONS DETECTED:")
|
||||
log.Printf("=================================")
|
||||
for _, violation := range violations {
|
||||
log.Printf("Container %s exceeded memory limit: %.1f MB > %.1f MB",
|
||||
violation.ContainerName, violation.MaxMemoryMB, violation.LimitMB)
|
||||
}
|
||||
return fmt.Errorf("test failed: %d container(s) exceeded memory limits", len(violations))
|
||||
}
|
||||
}
|
||||
|
||||
shouldCleanup := config.CleanAfter && (!config.KeepOnFailure || exitCode == 0)
|
||||
if shouldCleanup {
|
||||
if config.Verbose {
|
||||
@ -379,10 +419,37 @@ func getDockerSocketPath() string {
|
||||
return "/var/run/docker.sock"
|
||||
}
|
||||
|
||||
// ensureImageAvailable pulls the specified Docker image to ensure it's available.
|
||||
// checkImageAvailableLocally checks if the specified Docker image is available locally.
|
||||
func checkImageAvailableLocally(ctx context.Context, cli *client.Client, imageName string) (bool, error) {
|
||||
_, _, err := cli.ImageInspectWithRaw(ctx, imageName)
|
||||
if err != nil {
|
||||
if client.IsErrNotFound(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("failed to inspect image %s: %w", imageName, err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// ensureImageAvailable checks if the image is available locally first, then pulls if needed.
|
||||
func ensureImageAvailable(ctx context.Context, cli *client.Client, imageName string, verbose bool) error {
|
||||
// First check if image is available locally
|
||||
available, err := checkImageAvailableLocally(ctx, cli, imageName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check local image availability: %w", err)
|
||||
}
|
||||
|
||||
if available {
|
||||
if verbose {
|
||||
log.Printf("Image %s is available locally", imageName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Image not available locally, try to pull it
|
||||
if verbose {
|
||||
log.Printf("Pulling image %s...", imageName)
|
||||
log.Printf("Image %s not found locally, pulling...", imageName)
|
||||
}
|
||||
|
||||
reader, err := cli.ImagePull(ctx, imageName, image.PullOptions{})
|
||||
|
@ -190,7 +190,7 @@ func checkDockerSocket(ctx context.Context) DoctorResult {
|
||||
}
|
||||
}
|
||||
|
||||
// checkGolangImage verifies we can access the golang Docker image.
|
||||
// checkGolangImage verifies the golang Docker image is available locally or can be pulled.
|
||||
func checkGolangImage(ctx context.Context) DoctorResult {
|
||||
cli, err := createDockerClient()
|
||||
if err != nil {
|
||||
@ -205,17 +205,40 @@ func checkGolangImage(ctx context.Context) DoctorResult {
|
||||
goVersion := detectGoVersion()
|
||||
imageName := "golang:" + goVersion
|
||||
|
||||
// Check if we can pull the image
|
||||
// First check if image is available locally
|
||||
available, err := checkImageAvailableLocally(ctx, cli, imageName)
|
||||
if err != nil {
|
||||
return DoctorResult{
|
||||
Name: "Golang Image",
|
||||
Status: "FAIL",
|
||||
Message: fmt.Sprintf("Cannot check golang image %s: %v", imageName, err),
|
||||
Suggestions: []string{
|
||||
"Check Docker daemon status",
|
||||
"Try: docker images | grep golang",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if available {
|
||||
return DoctorResult{
|
||||
Name: "Golang Image",
|
||||
Status: "PASS",
|
||||
Message: fmt.Sprintf("Golang image %s is available locally", imageName),
|
||||
}
|
||||
}
|
||||
|
||||
// Image not available locally, try to pull it
|
||||
err = ensureImageAvailable(ctx, cli, imageName, false)
|
||||
if err != nil {
|
||||
return DoctorResult{
|
||||
Name: "Golang Image",
|
||||
Status: "FAIL",
|
||||
Message: fmt.Sprintf("Cannot pull golang image %s: %v", imageName, err),
|
||||
Message: fmt.Sprintf("Golang image %s not available locally and cannot pull: %v", imageName, err),
|
||||
Suggestions: []string{
|
||||
"Check internet connectivity",
|
||||
"Verify Docker Hub access",
|
||||
"Try: docker pull " + imageName,
|
||||
"Or run tests offline if image was pulled previously",
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -223,7 +246,7 @@ func checkGolangImage(ctx context.Context) DoctorResult {
|
||||
return DoctorResult{
|
||||
Name: "Golang Image",
|
||||
Status: "PASS",
|
||||
Message: fmt.Sprintf("Golang image %s is available", imageName),
|
||||
Message: fmt.Sprintf("Golang image %s is now available", imageName),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -24,6 +24,9 @@ type RunConfig struct {
|
||||
KeepOnFailure bool `flag:"keep-on-failure,default=false,Keep containers on test failure"`
|
||||
LogsDir string `flag:"logs-dir,default=control_logs,Control logs directory"`
|
||||
Verbose bool `flag:"verbose,default=false,Verbose output"`
|
||||
Stats bool `flag:"stats,default=false,Collect and display container resource usage statistics"`
|
||||
HSMemoryLimit float64 `flag:"hs-memory-limit,default=0,Fail test if any Headscale container exceeds this memory limit in MB (0 = disabled)"`
|
||||
TSMemoryLimit float64 `flag:"ts-memory-limit,default=0,Fail test if any Tailscale container exceeds this memory limit in MB (0 = disabled)"`
|
||||
}
|
||||
|
||||
// runIntegrationTest executes the integration test workflow.
|
||||
|
468
cmd/hi/stats.go
Normal file
468
cmd/hi/stats.go
Normal file
@ -0,0 +1,468 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/docker/docker/api/types"
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/api/types/events"
|
||||
"github.com/docker/docker/api/types/filters"
|
||||
"github.com/docker/docker/client"
|
||||
)
|
||||
|
||||
// ContainerStats represents statistics for a single container
|
||||
type ContainerStats struct {
|
||||
ContainerID string
|
||||
ContainerName string
|
||||
Stats []StatsSample
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// StatsSample represents a single stats measurement
|
||||
type StatsSample struct {
|
||||
Timestamp time.Time
|
||||
CPUUsage float64 // CPU usage percentage
|
||||
MemoryMB float64 // Memory usage in MB
|
||||
}
|
||||
|
||||
// StatsCollector manages collection of container statistics
|
||||
type StatsCollector struct {
|
||||
client *client.Client
|
||||
containers map[string]*ContainerStats
|
||||
stopChan chan struct{}
|
||||
wg sync.WaitGroup
|
||||
mutex sync.RWMutex
|
||||
collectionStarted bool
|
||||
}
|
||||
|
||||
// NewStatsCollector creates a new stats collector instance
|
||||
func NewStatsCollector() (*StatsCollector, error) {
|
||||
cli, err := createDockerClient()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Docker client: %w", err)
|
||||
}
|
||||
|
||||
return &StatsCollector{
|
||||
client: cli,
|
||||
containers: make(map[string]*ContainerStats),
|
||||
stopChan: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// StartCollection begins monitoring all containers and collecting stats for hs- and ts- containers with matching run ID
|
||||
func (sc *StatsCollector) StartCollection(ctx context.Context, runID string, verbose bool) error {
|
||||
sc.mutex.Lock()
|
||||
defer sc.mutex.Unlock()
|
||||
|
||||
if sc.collectionStarted {
|
||||
return fmt.Errorf("stats collection already started")
|
||||
}
|
||||
|
||||
sc.collectionStarted = true
|
||||
|
||||
// Start monitoring existing containers
|
||||
sc.wg.Add(1)
|
||||
go sc.monitorExistingContainers(ctx, runID, verbose)
|
||||
|
||||
// Start Docker events monitoring for new containers
|
||||
sc.wg.Add(1)
|
||||
go sc.monitorDockerEvents(ctx, runID, verbose)
|
||||
|
||||
if verbose {
|
||||
log.Printf("Started container monitoring for run ID %s", runID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StopCollection stops all stats collection
|
||||
func (sc *StatsCollector) StopCollection() {
|
||||
// Check if already stopped without holding lock
|
||||
sc.mutex.RLock()
|
||||
if !sc.collectionStarted {
|
||||
sc.mutex.RUnlock()
|
||||
return
|
||||
}
|
||||
sc.mutex.RUnlock()
|
||||
|
||||
// Signal stop to all goroutines
|
||||
close(sc.stopChan)
|
||||
|
||||
// Wait for all goroutines to finish
|
||||
sc.wg.Wait()
|
||||
|
||||
// Mark as stopped
|
||||
sc.mutex.Lock()
|
||||
sc.collectionStarted = false
|
||||
sc.mutex.Unlock()
|
||||
}
|
||||
|
||||
// monitorExistingContainers checks for existing containers that match our criteria
|
||||
func (sc *StatsCollector) monitorExistingContainers(ctx context.Context, runID string, verbose bool) {
|
||||
defer sc.wg.Done()
|
||||
|
||||
containers, err := sc.client.ContainerList(ctx, container.ListOptions{})
|
||||
if err != nil {
|
||||
if verbose {
|
||||
log.Printf("Failed to list existing containers: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for _, cont := range containers {
|
||||
if sc.shouldMonitorContainer(cont, runID) {
|
||||
sc.startStatsForContainer(ctx, cont.ID, cont.Names[0], verbose)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// monitorDockerEvents listens for container start events and begins monitoring relevant containers
|
||||
func (sc *StatsCollector) monitorDockerEvents(ctx context.Context, runID string, verbose bool) {
|
||||
defer sc.wg.Done()
|
||||
|
||||
filter := filters.NewArgs()
|
||||
filter.Add("type", "container")
|
||||
filter.Add("event", "start")
|
||||
|
||||
eventOptions := events.ListOptions{
|
||||
Filters: filter,
|
||||
}
|
||||
|
||||
events, errs := sc.client.Events(ctx, eventOptions)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-sc.stopChan:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case event := <-events:
|
||||
if event.Type == "container" && event.Action == "start" {
|
||||
// Get container details
|
||||
containerInfo, err := sc.client.ContainerInspect(ctx, event.ID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Convert to types.Container format for consistency
|
||||
cont := types.Container{
|
||||
ID: containerInfo.ID,
|
||||
Names: []string{containerInfo.Name},
|
||||
Labels: containerInfo.Config.Labels,
|
||||
}
|
||||
|
||||
if sc.shouldMonitorContainer(cont, runID) {
|
||||
sc.startStatsForContainer(ctx, cont.ID, cont.Names[0], verbose)
|
||||
}
|
||||
}
|
||||
case err := <-errs:
|
||||
if verbose {
|
||||
log.Printf("Error in Docker events stream: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// shouldMonitorContainer determines if a container should be monitored
|
||||
func (sc *StatsCollector) shouldMonitorContainer(cont types.Container, runID string) bool {
|
||||
// Check if it has the correct run ID label
|
||||
if cont.Labels == nil || cont.Labels["hi.run-id"] != runID {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if it's an hs- or ts- container
|
||||
for _, name := range cont.Names {
|
||||
containerName := strings.TrimPrefix(name, "/")
|
||||
if strings.HasPrefix(containerName, "hs-") || strings.HasPrefix(containerName, "ts-") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// startStatsForContainer begins stats collection for a specific container
|
||||
func (sc *StatsCollector) startStatsForContainer(ctx context.Context, containerID, containerName string, verbose bool) {
|
||||
containerName = strings.TrimPrefix(containerName, "/")
|
||||
|
||||
sc.mutex.Lock()
|
||||
// Check if we're already monitoring this container
|
||||
if _, exists := sc.containers[containerID]; exists {
|
||||
sc.mutex.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
sc.containers[containerID] = &ContainerStats{
|
||||
ContainerID: containerID,
|
||||
ContainerName: containerName,
|
||||
Stats: make([]StatsSample, 0),
|
||||
}
|
||||
sc.mutex.Unlock()
|
||||
|
||||
if verbose {
|
||||
log.Printf("Starting stats collection for container %s (%s)", containerName, containerID[:12])
|
||||
}
|
||||
|
||||
sc.wg.Add(1)
|
||||
go sc.collectStatsForContainer(ctx, containerID, verbose)
|
||||
}
|
||||
|
||||
// collectStatsForContainer collects stats for a specific container using Docker API streaming
|
||||
func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containerID string, verbose bool) {
|
||||
defer sc.wg.Done()
|
||||
|
||||
// Use Docker API streaming stats - much more efficient than CLI
|
||||
statsResponse, err := sc.client.ContainerStats(ctx, containerID, true)
|
||||
if err != nil {
|
||||
if verbose {
|
||||
log.Printf("Failed to get stats stream for container %s: %v", containerID[:12], err)
|
||||
}
|
||||
return
|
||||
}
|
||||
defer statsResponse.Body.Close()
|
||||
|
||||
decoder := json.NewDecoder(statsResponse.Body)
|
||||
var prevStats *container.Stats
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-sc.stopChan:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
var stats container.Stats
|
||||
if err := decoder.Decode(&stats); err != nil {
|
||||
// EOF is expected when container stops or stream ends
|
||||
if err.Error() != "EOF" && verbose {
|
||||
log.Printf("Failed to decode stats for container %s: %v", containerID[:12], err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Calculate CPU percentage (only if we have previous stats)
|
||||
var cpuPercent float64
|
||||
if prevStats != nil {
|
||||
cpuPercent = calculateCPUPercent(prevStats, &stats)
|
||||
}
|
||||
|
||||
// Calculate memory usage in MB
|
||||
memoryMB := float64(stats.MemoryStats.Usage) / (1024 * 1024)
|
||||
|
||||
// Store the sample (skip first sample since CPU calculation needs previous stats)
|
||||
if prevStats != nil {
|
||||
// Get container stats reference without holding the main mutex
|
||||
var containerStats *ContainerStats
|
||||
var exists bool
|
||||
|
||||
sc.mutex.RLock()
|
||||
containerStats, exists = sc.containers[containerID]
|
||||
sc.mutex.RUnlock()
|
||||
|
||||
if exists && containerStats != nil {
|
||||
containerStats.mutex.Lock()
|
||||
containerStats.Stats = append(containerStats.Stats, StatsSample{
|
||||
Timestamp: time.Now(),
|
||||
CPUUsage: cpuPercent,
|
||||
MemoryMB: memoryMB,
|
||||
})
|
||||
containerStats.mutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Save current stats for next iteration
|
||||
prevStats = &stats
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// calculateCPUPercent calculates CPU usage percentage from Docker stats
|
||||
func calculateCPUPercent(prevStats, stats *container.Stats) float64 {
|
||||
// CPU calculation based on Docker's implementation
|
||||
cpuDelta := float64(stats.CPUStats.CPUUsage.TotalUsage) - float64(prevStats.CPUStats.CPUUsage.TotalUsage)
|
||||
systemDelta := float64(stats.CPUStats.SystemUsage) - float64(prevStats.CPUStats.SystemUsage)
|
||||
|
||||
if systemDelta > 0 && cpuDelta >= 0 {
|
||||
// Calculate CPU percentage: (container CPU delta / system CPU delta) * number of CPUs * 100
|
||||
numCPUs := float64(len(stats.CPUStats.CPUUsage.PercpuUsage))
|
||||
if numCPUs == 0 {
|
||||
// Fallback: if PercpuUsage is not available, assume 1 CPU
|
||||
numCPUs = 1.0
|
||||
}
|
||||
return (cpuDelta / systemDelta) * numCPUs * 100.0
|
||||
}
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// ContainerStatsSummary represents summary statistics for a container
|
||||
type ContainerStatsSummary struct {
|
||||
ContainerName string
|
||||
SampleCount int
|
||||
CPU StatsSummary
|
||||
Memory StatsSummary
|
||||
}
|
||||
|
||||
// MemoryViolation represents a container that exceeded the memory limit
|
||||
type MemoryViolation struct {
|
||||
ContainerName string
|
||||
MaxMemoryMB float64
|
||||
LimitMB float64
|
||||
}
|
||||
|
||||
// StatsSummary represents min, max, and average for a metric
|
||||
type StatsSummary struct {
|
||||
Min float64
|
||||
Max float64
|
||||
Average float64
|
||||
}
|
||||
|
||||
// GetSummary returns a summary of collected statistics
|
||||
func (sc *StatsCollector) GetSummary() []ContainerStatsSummary {
|
||||
// Take snapshot of container references without holding main lock long
|
||||
sc.mutex.RLock()
|
||||
containerRefs := make([]*ContainerStats, 0, len(sc.containers))
|
||||
for _, containerStats := range sc.containers {
|
||||
containerRefs = append(containerRefs, containerStats)
|
||||
}
|
||||
sc.mutex.RUnlock()
|
||||
|
||||
summaries := make([]ContainerStatsSummary, 0, len(containerRefs))
|
||||
|
||||
for _, containerStats := range containerRefs {
|
||||
containerStats.mutex.RLock()
|
||||
stats := make([]StatsSample, len(containerStats.Stats))
|
||||
copy(stats, containerStats.Stats)
|
||||
containerName := containerStats.ContainerName
|
||||
containerStats.mutex.RUnlock()
|
||||
|
||||
if len(stats) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
summary := ContainerStatsSummary{
|
||||
ContainerName: containerName,
|
||||
SampleCount: len(stats),
|
||||
}
|
||||
|
||||
// Calculate CPU stats
|
||||
cpuValues := make([]float64, len(stats))
|
||||
memoryValues := make([]float64, len(stats))
|
||||
|
||||
for i, sample := range stats {
|
||||
cpuValues[i] = sample.CPUUsage
|
||||
memoryValues[i] = sample.MemoryMB
|
||||
}
|
||||
|
||||
summary.CPU = calculateStatsSummary(cpuValues)
|
||||
summary.Memory = calculateStatsSummary(memoryValues)
|
||||
|
||||
summaries = append(summaries, summary)
|
||||
}
|
||||
|
||||
// Sort by container name for consistent output
|
||||
sort.Slice(summaries, func(i, j int) bool {
|
||||
return summaries[i].ContainerName < summaries[j].ContainerName
|
||||
})
|
||||
|
||||
return summaries
|
||||
}
|
||||
|
||||
// calculateStatsSummary calculates min, max, and average for a slice of values
|
||||
func calculateStatsSummary(values []float64) StatsSummary {
|
||||
if len(values) == 0 {
|
||||
return StatsSummary{}
|
||||
}
|
||||
|
||||
min := values[0]
|
||||
max := values[0]
|
||||
sum := 0.0
|
||||
|
||||
for _, value := range values {
|
||||
if value < min {
|
||||
min = value
|
||||
}
|
||||
if value > max {
|
||||
max = value
|
||||
}
|
||||
sum += value
|
||||
}
|
||||
|
||||
return StatsSummary{
|
||||
Min: min,
|
||||
Max: max,
|
||||
Average: sum / float64(len(values)),
|
||||
}
|
||||
}
|
||||
|
||||
// PrintSummary prints the statistics summary to the console
|
||||
func (sc *StatsCollector) PrintSummary() {
|
||||
summaries := sc.GetSummary()
|
||||
|
||||
if len(summaries) == 0 {
|
||||
log.Printf("No container statistics collected")
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("Container Resource Usage Summary:")
|
||||
log.Printf("================================")
|
||||
|
||||
for _, summary := range summaries {
|
||||
log.Printf("Container: %s (%d samples)", summary.ContainerName, summary.SampleCount)
|
||||
log.Printf(" CPU Usage: Min: %6.2f%% Max: %6.2f%% Avg: %6.2f%%",
|
||||
summary.CPU.Min, summary.CPU.Max, summary.CPU.Average)
|
||||
log.Printf(" Memory Usage: Min: %6.1f MB Max: %6.1f MB Avg: %6.1f MB",
|
||||
summary.Memory.Min, summary.Memory.Max, summary.Memory.Average)
|
||||
log.Printf("")
|
||||
}
|
||||
}
|
||||
|
||||
// CheckMemoryLimits checks if any containers exceeded their memory limits
|
||||
func (sc *StatsCollector) CheckMemoryLimits(hsLimitMB, tsLimitMB float64) []MemoryViolation {
|
||||
if hsLimitMB <= 0 && tsLimitMB <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
summaries := sc.GetSummary()
|
||||
var violations []MemoryViolation
|
||||
|
||||
for _, summary := range summaries {
|
||||
var limitMB float64
|
||||
if strings.HasPrefix(summary.ContainerName, "hs-") {
|
||||
limitMB = hsLimitMB
|
||||
} else if strings.HasPrefix(summary.ContainerName, "ts-") {
|
||||
limitMB = tsLimitMB
|
||||
} else {
|
||||
continue // Skip containers that don't match our patterns
|
||||
}
|
||||
|
||||
if limitMB > 0 && summary.Memory.Max > limitMB {
|
||||
violations = append(violations, MemoryViolation{
|
||||
ContainerName: summary.ContainerName,
|
||||
MaxMemoryMB: summary.Memory.Max,
|
||||
LimitMB: limitMB,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return violations
|
||||
}
|
||||
|
||||
// PrintSummaryAndCheckLimits prints the statistics summary and returns memory violations if any
|
||||
func (sc *StatsCollector) PrintSummaryAndCheckLimits(hsLimitMB, tsLimitMB float64) []MemoryViolation {
|
||||
sc.PrintSummary()
|
||||
return sc.CheckMemoryLimits(hsLimitMB, tsLimitMB)
|
||||
}
|
||||
|
||||
// Close closes the stats collector and cleans up resources
|
||||
func (sc *StatsCollector) Close() error {
|
||||
sc.StopCollection()
|
||||
return sc.client.Close()
|
||||
}
|
@ -19,7 +19,7 @@
|
||||
overlay = _: prev: let
|
||||
pkgs = nixpkgs.legacyPackages.${prev.system};
|
||||
buildGo = pkgs.buildGo124Module;
|
||||
vendorHash = "sha256-S2GnCg2dyfjIyi5gXhVEuRs5Bop2JAhZcnhg1fu4/Gg=";
|
||||
vendorHash = "sha256-83L2NMyOwKCHWqcowStJ7Ze/U9CJYhzleDRLrJNhX2g=";
|
||||
in {
|
||||
headscale = buildGo {
|
||||
pname = "headscale";
|
||||
|
27
go.mod
27
go.mod
@ -23,7 +23,6 @@ require (
|
||||
github.com/gorilla/mux v1.8.1
|
||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.0
|
||||
github.com/jagottsicher/termcolor v1.0.2
|
||||
github.com/klauspost/compress v1.18.0
|
||||
github.com/oauth2-proxy/mockoidc v0.0.0-20240214162133-caebfff84d25
|
||||
github.com/ory/dockertest/v3 v3.12.0
|
||||
github.com/philip-bui/grpc-zerolog v1.0.1
|
||||
@ -43,11 +42,11 @@ require (
|
||||
github.com/tailscale/tailsql v0.0.0-20250421235516-02f85f087b97
|
||||
github.com/tcnksm/go-latest v0.0.0-20170313132115-e3007ae9052e
|
||||
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba
|
||||
golang.org/x/crypto v0.39.0
|
||||
golang.org/x/crypto v0.40.0
|
||||
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0
|
||||
golang.org/x/net v0.41.0
|
||||
golang.org/x/net v0.42.0
|
||||
golang.org/x/oauth2 v0.30.0
|
||||
golang.org/x/sync v0.15.0
|
||||
golang.org/x/sync v0.16.0
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822
|
||||
google.golang.org/grpc v1.73.0
|
||||
google.golang.org/protobuf v1.36.6
|
||||
@ -55,7 +54,7 @@ require (
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gorm.io/driver/postgres v1.6.0
|
||||
gorm.io/gorm v1.30.0
|
||||
tailscale.com v1.84.2
|
||||
tailscale.com v1.84.3
|
||||
zgo.at/zcache/v2 v2.2.0
|
||||
zombiezen.com/go/postgrestest v1.0.1
|
||||
)
|
||||
@ -81,7 +80,7 @@ require (
|
||||
modernc.org/libc v1.62.1 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.10.0 // indirect
|
||||
modernc.org/sqlite v1.37.0 // indirect
|
||||
modernc.org/sqlite v1.37.0
|
||||
)
|
||||
|
||||
require (
|
||||
@ -166,6 +165,7 @@ require (
|
||||
github.com/jmespath/go-jmespath v0.4.0 // indirect
|
||||
github.com/jsimonetti/rtnetlink v1.4.1 // indirect
|
||||
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
|
||||
github.com/klauspost/compress v1.18.0 // indirect
|
||||
github.com/kr/pretty v0.3.1 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/lib/pq v1.10.9 // indirect
|
||||
@ -231,14 +231,19 @@ require (
|
||||
go.opentelemetry.io/otel/trace v1.36.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
go4.org/mem v0.0.0-20240501181205-ae6ca9944745 // indirect
|
||||
golang.org/x/mod v0.25.0 // indirect
|
||||
golang.org/x/sys v0.33.0 // indirect
|
||||
golang.org/x/term v0.32.0 // indirect
|
||||
golang.org/x/text v0.26.0 // indirect
|
||||
golang.org/x/mod v0.26.0 // indirect
|
||||
golang.org/x/sys v0.34.0 // indirect
|
||||
golang.org/x/term v0.33.0 // indirect
|
||||
golang.org/x/text v0.27.0 // indirect
|
||||
golang.org/x/time v0.10.0 // indirect
|
||||
golang.org/x/tools v0.33.0 // indirect
|
||||
golang.org/x/tools v0.35.0 // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 // indirect
|
||||
gvisor.dev/gvisor v0.0.0-20250205023644-9414b50a5633 // indirect
|
||||
)
|
||||
|
||||
tool (
|
||||
golang.org/x/tools/cmd/stringer
|
||||
tailscale.com/cmd/viewer
|
||||
)
|
||||
|
34
go.sum
34
go.sum
@ -555,8 +555,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
|
||||
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
|
||||
golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
|
||||
golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
|
||||
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 h1:R84qjqJb5nVJMxqWYb3np9L5ZsaDtB+a39EqjV0JSUM=
|
||||
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0/go.mod h1:S9Xr4PYopiDyqSyp5NjCrhFrqg6A5zA2E/iPHPhqnS8=
|
||||
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f h1:phY1HzDcf18Aq9A8KkmRtY9WvOFIxN8wgfvy6Zm1DV8=
|
||||
@ -567,8 +567,8 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w=
|
||||
golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
|
||||
golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
|
||||
golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
@ -577,8 +577,8 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
|
||||
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
|
||||
golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs=
|
||||
golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8=
|
||||
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
||||
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@ -587,8 +587,8 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8=
|
||||
golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@ -615,8 +615,8 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
|
||||
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
|
||||
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210615171337-6886f2dfbf5b/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
@ -624,8 +624,8 @@ golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuX
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
||||
golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg=
|
||||
golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ=
|
||||
golang.org/x/term v0.33.0 h1:NuFncQrRcaRvVmgRkvM3j/F00gWIAlcmlB8ACEKmGIg=
|
||||
golang.org/x/term v0.33.0/go.mod h1:s18+ql9tYWp1IfpV9DmCtQDDSRBUjKaw9M1eAv5UeF0=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
@ -633,8 +633,8 @@ golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M=
|
||||
golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
|
||||
golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
|
||||
golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
|
||||
golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4=
|
||||
golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
@ -643,8 +643,8 @@ golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roY
|
||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc=
|
||||
golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=
|
||||
golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0=
|
||||
golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
@ -714,6 +714,8 @@ software.sslmate.com/src/go-pkcs12 v0.4.0 h1:H2g08FrTvSFKUj+D309j1DPfk5APnIdAQAB
|
||||
software.sslmate.com/src/go-pkcs12 v0.4.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=
|
||||
tailscale.com v1.84.2 h1:v6aM4RWUgYiV52LRAx6ET+dlGnvO/5lnqPXb7/pMnR0=
|
||||
tailscale.com v1.84.2/go.mod h1:6/S63NMAhmncYT/1zIPDJkvCuZwMw+JnUuOfSPNazpo=
|
||||
tailscale.com v1.84.3 h1:Ur9LMedSgicwbqpy5xn7t49G8490/s6rqAJOk5Q5AYE=
|
||||
tailscale.com v1.84.3/go.mod h1:6/S63NMAhmncYT/1zIPDJkvCuZwMw+JnUuOfSPNazpo=
|
||||
zgo.at/zcache/v2 v2.2.0 h1:K29/IPjMniZfveYE+IRXfrl11tMzHkIPuyGrfVZ2fGo=
|
||||
zgo.at/zcache/v2 v2.2.0/go.mod h1:gyCeoLVo01QjDZynjime8xUGHHMbsLiPyUTBpDGd4Gk=
|
||||
zombiezen.com/go/postgrestest v1.0.1 h1:aXoADQAJmZDU3+xilYVut0pHhgc0sF8ZspPW9gFNwP4=
|
||||
|
130
hscontrol/app.go
130
hscontrol/app.go
@ -28,14 +28,15 @@ import (
|
||||
derpServer "github.com/juanfont/headscale/hscontrol/derp/server"
|
||||
"github.com/juanfont/headscale/hscontrol/dns"
|
||||
"github.com/juanfont/headscale/hscontrol/mapper"
|
||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||
"github.com/juanfont/headscale/hscontrol/state"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
zerolog "github.com/philip-bui/grpc-zerolog"
|
||||
"github.com/pkg/profile"
|
||||
zl "github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/sasha-s/go-deadlock"
|
||||
"golang.org/x/crypto/acme"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
"golang.org/x/sync/errgroup"
|
||||
@ -64,6 +65,19 @@ var (
|
||||
)
|
||||
)
|
||||
|
||||
var (
|
||||
debugDeadlock = envknob.Bool("HEADSCALE_DEBUG_DEADLOCK")
|
||||
debugDeadlockTimeout = envknob.RegisterDuration("HEADSCALE_DEBUG_DEADLOCK_TIMEOUT")
|
||||
)
|
||||
|
||||
func init() {
|
||||
deadlock.Opts.Disable = !debugDeadlock
|
||||
if debugDeadlock {
|
||||
deadlock.Opts.DeadlockTimeout = debugDeadlockTimeout()
|
||||
deadlock.Opts.PrintAllCurrentGoroutines = true
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
AuthPrefix = "Bearer "
|
||||
updateInterval = 5 * time.Second
|
||||
@ -82,9 +96,8 @@ type Headscale struct {
|
||||
|
||||
// Things that generate changes
|
||||
extraRecordMan *dns.ExtraRecordsMan
|
||||
mapper *mapper.Mapper
|
||||
nodeNotifier *notifier.Notifier
|
||||
authProvider AuthProvider
|
||||
mapBatcher mapper.Batcher
|
||||
|
||||
pollNetMapStreamWG sync.WaitGroup
|
||||
}
|
||||
@ -118,7 +131,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||
cfg: cfg,
|
||||
noisePrivateKey: noisePrivateKey,
|
||||
pollNetMapStreamWG: sync.WaitGroup{},
|
||||
nodeNotifier: notifier.NewNotifier(cfg),
|
||||
state: s,
|
||||
}
|
||||
|
||||
@ -136,12 +148,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||
return
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "ephemeral-gc-policy", node.Hostname)
|
||||
app.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
}
|
||||
|
||||
app.Change(policyChanged)
|
||||
log.Debug().Uint64("node.id", ni.Uint64()).Msgf("deleted ephemeral node")
|
||||
})
|
||||
app.ephemeralGC = ephemeralGC
|
||||
@ -153,10 +160,9 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||
defer cancel()
|
||||
oidcProvider, err := NewAuthProviderOIDC(
|
||||
ctx,
|
||||
&app,
|
||||
cfg.ServerURL,
|
||||
&cfg.OIDC,
|
||||
app.state,
|
||||
app.nodeNotifier,
|
||||
)
|
||||
if err != nil {
|
||||
if cfg.OIDC.OnlyStartIfOIDCIsAvailable {
|
||||
@ -262,16 +268,18 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
||||
return
|
||||
|
||||
case <-expireTicker.C:
|
||||
var update types.StateUpdate
|
||||
var expiredNodeChanges []change.ChangeSet
|
||||
var changed bool
|
||||
|
||||
lastExpiryCheck, update, changed = h.state.ExpireExpiredNodes(lastExpiryCheck)
|
||||
lastExpiryCheck, expiredNodeChanges, changed = h.state.ExpireExpiredNodes(lastExpiryCheck)
|
||||
|
||||
if changed {
|
||||
log.Trace().Interface("nodes", update.ChangePatches).Msgf("expiring nodes")
|
||||
log.Trace().Interface("changes", expiredNodeChanges).Msgf("expiring nodes")
|
||||
|
||||
ctx := types.NotifyCtx(context.Background(), "expire-expired", "na")
|
||||
h.nodeNotifier.NotifyAll(ctx, update)
|
||||
// Send the changes directly since they're already in the new format
|
||||
for _, nodeChange := range expiredNodeChanges {
|
||||
h.Change(nodeChange)
|
||||
}
|
||||
}
|
||||
|
||||
case <-derpTickerChan:
|
||||
@ -282,11 +290,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
||||
derpMap.Regions[region.RegionID] = ®ion
|
||||
}
|
||||
|
||||
ctx := types.NotifyCtx(context.Background(), "derpmap-update", "na")
|
||||
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
|
||||
Type: types.StateDERPUpdated,
|
||||
DERPMap: derpMap,
|
||||
})
|
||||
h.Change(change.DERPSet)
|
||||
|
||||
case records, ok := <-extraRecordsUpdate:
|
||||
if !ok {
|
||||
@ -294,19 +298,16 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
||||
}
|
||||
h.cfg.TailcfgDNSConfig.ExtraRecords = records
|
||||
|
||||
ctx := types.NotifyCtx(context.Background(), "dns-extrarecord", "all")
|
||||
// TODO(kradalby): We can probably do better than sending a full update here,
|
||||
// but for now this will ensure that all of the nodes get the new records.
|
||||
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
h.Change(change.ExtraRecordsSet)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
|
||||
req interface{},
|
||||
req any,
|
||||
info *grpc.UnaryServerInfo,
|
||||
handler grpc.UnaryHandler,
|
||||
) (interface{}, error) {
|
||||
) (any, error) {
|
||||
// Check if the request is coming from the on-server client.
|
||||
// This is not secure, but it is to maintain maintainability
|
||||
// with the "legacy" database-based client
|
||||
@ -484,58 +485,6 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
|
||||
return router
|
||||
}
|
||||
|
||||
// // TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
|
||||
// // Maybe we should attempt a new in memory state and not go via the DB?
|
||||
// // Maybe this should be implemented as an event bus?
|
||||
// // A bool is returned indicating if a full update was sent to all nodes
|
||||
// func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error {
|
||||
// users, err := db.ListUsers()
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
// changed, err := polMan.SetUsers(users)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
// if changed {
|
||||
// ctx := types.NotifyCtx(context.Background(), "acl-users-change", "all")
|
||||
// notif.NotifyAll(ctx, types.UpdateFull())
|
||||
// }
|
||||
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// // TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
|
||||
// // Maybe we should attempt a new in memory state and not go via the DB?
|
||||
// // Maybe this should be implemented as an event bus?
|
||||
// // A bool is returned indicating if a full update was sent to all nodes
|
||||
// func nodesChangedHook(
|
||||
// db *db.HSDatabase,
|
||||
// polMan policy.PolicyManager,
|
||||
// notif *notifier.Notifier,
|
||||
// ) (bool, error) {
|
||||
// nodes, err := db.ListNodes()
|
||||
// if err != nil {
|
||||
// return false, err
|
||||
// }
|
||||
|
||||
// filterChanged, err := polMan.SetNodes(nodes)
|
||||
// if err != nil {
|
||||
// return false, err
|
||||
// }
|
||||
|
||||
// if filterChanged {
|
||||
// ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all")
|
||||
// notif.NotifyAll(ctx, types.UpdateFull())
|
||||
|
||||
// return true, nil
|
||||
// }
|
||||
|
||||
// return false, nil
|
||||
// }
|
||||
|
||||
// Serve launches the HTTP and gRPC server service Headscale and the API.
|
||||
func (h *Headscale) Serve() error {
|
||||
capver.CanOldCodeBeCleanedUp()
|
||||
@ -562,8 +511,9 @@ func (h *Headscale) Serve() error {
|
||||
Str("minimum_version", capver.TailscaleVersion(capver.MinSupportedCapabilityVersion)).
|
||||
Msg("Clients with a lower minimum version will be rejected")
|
||||
|
||||
// Fetch an initial DERP Map before we start serving
|
||||
h.mapper = mapper.NewMapper(h.state, h.cfg, h.nodeNotifier)
|
||||
h.mapBatcher = mapper.NewBatcherAndMapper(h.cfg, h.state)
|
||||
h.mapBatcher.Start()
|
||||
defer h.mapBatcher.Close()
|
||||
|
||||
// TODO(kradalby): fix state part.
|
||||
if h.cfg.DERP.ServerEnabled {
|
||||
@ -838,8 +788,12 @@ func (h *Headscale) Serve() error {
|
||||
log.Info().
|
||||
Msg("ACL policy successfully reloaded, notifying nodes of change")
|
||||
|
||||
ctx := types.NotifyCtx(context.Background(), "acl-sighup", "na")
|
||||
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
err = h.state.AutoApproveNodes()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to approve routes after new policy")
|
||||
}
|
||||
|
||||
h.Change(change.PolicySet)
|
||||
}
|
||||
default:
|
||||
info := func(msg string) { log.Info().Msg(msg) }
|
||||
@ -865,7 +819,6 @@ func (h *Headscale) Serve() error {
|
||||
}
|
||||
|
||||
info("closing node notifier")
|
||||
h.nodeNotifier.Close()
|
||||
|
||||
info("waiting for netmap stream to close")
|
||||
h.pollNetMapStreamWG.Wait()
|
||||
@ -1047,3 +1000,10 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
|
||||
|
||||
return &machineKey, nil
|
||||
}
|
||||
|
||||
// Change is used to send changes to nodes.
|
||||
// All change should be enqueued here and empty will be automatically
|
||||
// ignored.
|
||||
func (h *Headscale) Change(c change.ChangeSet) {
|
||||
h.mapBatcher.AddWork(c)
|
||||
}
|
||||
|
@ -10,6 +10,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
@ -32,6 +34,21 @@ func (h *Headscale) handleRegister(
|
||||
}
|
||||
|
||||
if node != nil {
|
||||
// If an existing node is trying to register with an auth key,
|
||||
// we need to validate the auth key even for existing nodes
|
||||
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
|
||||
resp, err := h.handleRegisterWithAuthKey(regReq, machineKey)
|
||||
if err != nil {
|
||||
// Preserve HTTPError types so they can be handled properly by the HTTP layer
|
||||
var httpErr HTTPError
|
||||
if errors.As(err, &httpErr) {
|
||||
return nil, httpErr
|
||||
}
|
||||
return nil, fmt.Errorf("handling register with auth key for existing node: %w", err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
resp, err := h.handleExistingNode(node, regReq, machineKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("handling existing node: %w", err)
|
||||
@ -47,6 +64,11 @@ func (h *Headscale) handleRegister(
|
||||
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
|
||||
resp, err := h.handleRegisterWithAuthKey(regReq, machineKey)
|
||||
if err != nil {
|
||||
// Preserve HTTPError types so they can be handled properly by the HTTP layer
|
||||
var httpErr HTTPError
|
||||
if errors.As(err, &httpErr) {
|
||||
return nil, httpErr
|
||||
}
|
||||
return nil, fmt.Errorf("handling register with auth key: %w", err)
|
||||
}
|
||||
|
||||
@ -66,11 +88,13 @@ func (h *Headscale) handleExistingNode(
|
||||
regReq tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) (*tailcfg.RegisterResponse, error) {
|
||||
|
||||
if node.MachineKey != machineKey {
|
||||
return nil, NewHTTPError(http.StatusUnauthorized, "node exist with different machine key", nil)
|
||||
}
|
||||
|
||||
expired := node.IsExpired()
|
||||
|
||||
if !expired && !regReq.Expiry.IsZero() {
|
||||
requestExpiry := regReq.Expiry
|
||||
|
||||
@ -82,42 +106,26 @@ func (h *Headscale) handleExistingNode(
|
||||
// If the request expiry is in the past, we consider it a logout.
|
||||
if requestExpiry.Before(time.Now()) {
|
||||
if node.IsEphemeral() {
|
||||
policyChanged, err := h.state.DeleteNode(node)
|
||||
c, err := h.state.DeleteNode(node)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("deleting ephemeral node: %w", err)
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "auth-logout-ephemeral-policy", "na")
|
||||
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
} else {
|
||||
ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na")
|
||||
h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerRemoved(node.ID))
|
||||
}
|
||||
h.Change(c)
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
n, policyChanged, err := h.state.SetNodeExpiry(node.ID, requestExpiry)
|
||||
_, c, err := h.state.SetNodeExpiry(node.ID, requestExpiry)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("setting node expiry: %w", err)
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "auth-expiry-policy", "na")
|
||||
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
} else {
|
||||
ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na")
|
||||
h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdateExpire(node.ID, requestExpiry), node.ID)
|
||||
h.Change(c)
|
||||
}
|
||||
|
||||
return nodeToRegisterResponse(n), nil
|
||||
}
|
||||
|
||||
return nodeToRegisterResponse(node), nil
|
||||
return nodeToRegisterResponse(node), nil
|
||||
}
|
||||
|
||||
func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse {
|
||||
@ -168,7 +176,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
||||
regReq tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) (*tailcfg.RegisterResponse, error) {
|
||||
node, changed, err := h.state.HandleNodeFromPreAuthKey(
|
||||
node, changed, policyChanged, err := h.state.HandleNodeFromPreAuthKey(
|
||||
regReq,
|
||||
machineKey,
|
||||
)
|
||||
@ -184,6 +192,12 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If node is nil, it means an ephemeral node was deleted during logout
|
||||
if node == nil {
|
||||
h.Change(changed)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// This is a bit of a back and forth, but we have a bit of a chicken and egg
|
||||
// dependency here.
|
||||
// Because the way the policy manager works, we need to have the node
|
||||
@ -195,23 +209,22 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
||||
// ensure we send an update.
|
||||
// This works, but might be another good candidate for doing some sort of
|
||||
// eventbus.
|
||||
routesChanged := h.state.AutoApproveRoutes(node)
|
||||
// TODO(kradalby): This needs to be ran as part of the batcher maybe?
|
||||
// now since we dont update the node/pol here anymore
|
||||
routeChange := h.state.AutoApproveRoutes(node)
|
||||
if _, _, err := h.state.SaveNode(node); err != nil {
|
||||
return nil, fmt.Errorf("saving auto approved routes to node: %w", err)
|
||||
}
|
||||
|
||||
if routesChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "node updated", node.Hostname)
|
||||
h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(node.ID))
|
||||
} else if changed {
|
||||
ctx := types.NotifyCtx(context.Background(), "node created", node.Hostname)
|
||||
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
} else {
|
||||
// Existing node re-registering without route changes
|
||||
// Still need to notify peers about the node being active again
|
||||
// Use UpdateFull to ensure all peers get complete peer maps
|
||||
ctx := types.NotifyCtx(context.Background(), "node re-registered", node.Hostname)
|
||||
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
if routeChange && changed.Empty() {
|
||||
changed = change.NodeAdded(node.ID)
|
||||
}
|
||||
h.Change(changed)
|
||||
|
||||
// If policy changed due to node registration, send a separate policy change
|
||||
if policyChanged {
|
||||
policyChange := change.PolicyChange()
|
||||
h.Change(policyChange)
|
||||
}
|
||||
|
||||
return &tailcfg.RegisterResponse{
|
||||
|
@ -1,5 +1,7 @@
|
||||
package capver
|
||||
|
||||
//go:generate go run ../../tools/capver/main.go
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"sort"
|
||||
@ -10,7 +12,7 @@ import (
|
||||
"tailscale.com/util/set"
|
||||
)
|
||||
|
||||
const MinSupportedCapabilityVersion tailcfg.CapabilityVersion = 88
|
||||
const MinSupportedCapabilityVersion tailcfg.CapabilityVersion = 90
|
||||
|
||||
// CanOldCodeBeCleanedUp is intended to be called on startup to see if
|
||||
// there are old code that can ble cleaned up, entries should contain
|
||||
|
@ -1,14 +1,10 @@
|
||||
package capver
|
||||
|
||||
// Generated DO NOT EDIT
|
||||
//Generated DO NOT EDIT
|
||||
|
||||
import "tailscale.com/tailcfg"
|
||||
|
||||
var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{
|
||||
"v1.60.0": 87,
|
||||
"v1.60.1": 87,
|
||||
"v1.62.0": 88,
|
||||
"v1.62.1": 88,
|
||||
"v1.64.0": 90,
|
||||
"v1.64.1": 90,
|
||||
"v1.64.2": 90,
|
||||
@ -36,18 +32,21 @@ var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{
|
||||
"v1.80.3": 113,
|
||||
"v1.82.0": 115,
|
||||
"v1.82.5": 115,
|
||||
"v1.84.0": 116,
|
||||
"v1.84.1": 116,
|
||||
"v1.84.2": 116,
|
||||
}
|
||||
|
||||
|
||||
var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{
|
||||
87: "v1.60.0",
|
||||
88: "v1.62.0",
|
||||
90: "v1.64.0",
|
||||
95: "v1.66.0",
|
||||
97: "v1.68.0",
|
||||
102: "v1.70.0",
|
||||
104: "v1.72.0",
|
||||
106: "v1.74.0",
|
||||
109: "v1.78.0",
|
||||
113: "v1.80.0",
|
||||
115: "v1.82.0",
|
||||
90: "v1.64.0",
|
||||
95: "v1.66.0",
|
||||
97: "v1.68.0",
|
||||
102: "v1.70.0",
|
||||
104: "v1.72.0",
|
||||
106: "v1.74.0",
|
||||
109: "v1.78.0",
|
||||
113: "v1.80.0",
|
||||
115: "v1.82.0",
|
||||
116: "v1.84.0",
|
||||
}
|
||||
|
@ -13,11 +13,10 @@ func TestTailscaleLatestMajorMinor(t *testing.T) {
|
||||
stripV bool
|
||||
expected []string
|
||||
}{
|
||||
{3, false, []string{"v1.78", "v1.80", "v1.82"}},
|
||||
{2, true, []string{"1.80", "1.82"}},
|
||||
{3, false, []string{"v1.80", "v1.82", "v1.84"}},
|
||||
{2, true, []string{"1.82", "1.84"}},
|
||||
// Lazy way to see all supported versions
|
||||
{10, true, []string{
|
||||
"1.64",
|
||||
"1.66",
|
||||
"1.68",
|
||||
"1.70",
|
||||
@ -27,6 +26,7 @@ func TestTailscaleLatestMajorMinor(t *testing.T) {
|
||||
"1.78",
|
||||
"1.80",
|
||||
"1.82",
|
||||
"1.84",
|
||||
}},
|
||||
{0, false, nil},
|
||||
}
|
||||
@ -46,7 +46,6 @@ func TestCapVerMinimumTailscaleVersion(t *testing.T) {
|
||||
input tailcfg.CapabilityVersion
|
||||
expected string
|
||||
}{
|
||||
{88, "v1.62.0"},
|
||||
{90, "v1.64.0"},
|
||||
{95, "v1.66.0"},
|
||||
{106, "v1.74.0"},
|
||||
|
@ -7,7 +7,6 @@ import (
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@ -362,8 +361,8 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) {
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(expectedKeys, keys, cmp.Comparer(func(a, b []string) bool {
|
||||
sort.Sort(sort.StringSlice(a))
|
||||
sort.Sort(sort.StringSlice(b))
|
||||
slices.Sort(a)
|
||||
slices.Sort(b)
|
||||
return slices.Equal(a, b)
|
||||
}), cmpopts.IgnoreFields(types.PreAuthKey{}, "User", "CreatedAt", "Reusable", "Ephemeral", "Used", "Expiration")); diff != "" {
|
||||
t.Errorf("TestSQLiteMigrationAndDataValidation() pre-auth key tags migration mismatch (-want +got):\n%s", diff)
|
||||
|
@ -7,15 +7,19 @@ import (
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -39,9 +43,7 @@ var (
|
||||
// If no peer IDs are given, all peers are returned.
|
||||
// If at least one peer ID is given, only these peer nodes will be returned.
|
||||
func (hsdb *HSDatabase) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||||
return ListPeers(rx, nodeID, peerIDs...)
|
||||
})
|
||||
return ListPeers(hsdb.DB, nodeID, peerIDs...)
|
||||
}
|
||||
|
||||
// ListPeers returns peers of node, regardless of any Policy or if the node is expired.
|
||||
@ -66,9 +68,7 @@ func ListPeers(tx *gorm.DB, nodeID types.NodeID, peerIDs ...types.NodeID) (types
|
||||
// ListNodes queries the database for either all nodes if no parameters are given
|
||||
// or for the given nodes if at least one node ID is given as parameter.
|
||||
func (hsdb *HSDatabase) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
||||
return ListNodes(rx, nodeIDs...)
|
||||
})
|
||||
return ListNodes(hsdb.DB, nodeIDs...)
|
||||
}
|
||||
|
||||
// ListNodes queries the database for either all nodes if no parameters are given
|
||||
@ -120,9 +120,7 @@ func getNode(tx *gorm.DB, uid types.UserID, name string) (*types.Node, error) {
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetNodeByID(id types.NodeID) (*types.Node, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
|
||||
return GetNodeByID(rx, id)
|
||||
})
|
||||
return GetNodeByID(hsdb.DB, id)
|
||||
}
|
||||
|
||||
// GetNodeByID finds a Node by ID and returns the Node struct.
|
||||
@ -140,9 +138,7 @@ func GetNodeByID(tx *gorm.DB, id types.NodeID) (*types.Node, error) {
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetNodeByMachineKey(machineKey key.MachinePublic) (*types.Node, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
|
||||
return GetNodeByMachineKey(rx, machineKey)
|
||||
})
|
||||
return GetNodeByMachineKey(hsdb.DB, machineKey)
|
||||
}
|
||||
|
||||
// GetNodeByMachineKey finds a Node by its MachineKey and returns the Node struct.
|
||||
@ -163,9 +159,7 @@ func GetNodeByMachineKey(
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetNodeByNodeKey(nodeKey key.NodePublic) (*types.Node, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
|
||||
return GetNodeByNodeKey(rx, nodeKey)
|
||||
})
|
||||
return GetNodeByNodeKey(hsdb.DB, nodeKey)
|
||||
}
|
||||
|
||||
// GetNodeByNodeKey finds a Node by its NodeKey and returns the Node struct.
|
||||
@ -352,8 +346,8 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
|
||||
registrationMethod string,
|
||||
ipv4 *netip.Addr,
|
||||
ipv6 *netip.Addr,
|
||||
) (*types.Node, bool, error) {
|
||||
var newNode bool
|
||||
) (*types.Node, change.ChangeSet, error) {
|
||||
var nodeChange change.ChangeSet
|
||||
node, err := Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
if reg, ok := hsdb.regCache.Get(registrationID); ok {
|
||||
if node, _ := GetNodeByNodeKey(tx, reg.Node.NodeKey); node == nil {
|
||||
@ -405,7 +399,7 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
|
||||
}
|
||||
close(reg.Registered)
|
||||
|
||||
newNode = true
|
||||
nodeChange = change.NodeAdded(node.ID)
|
||||
|
||||
return node, err
|
||||
} else {
|
||||
@ -415,6 +409,8 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nodeChange = change.KeyExpiry(node.ID)
|
||||
|
||||
return node, nil
|
||||
}
|
||||
}
|
||||
@ -422,7 +418,7 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
|
||||
return nil, ErrNodeNotFoundRegistrationCache
|
||||
})
|
||||
|
||||
return node, newNode, err
|
||||
return node, nodeChange, err
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {
|
||||
@ -448,6 +444,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
|
||||
if oldNode != nil && oldNode.UserID == node.UserID {
|
||||
node.ID = oldNode.ID
|
||||
node.GivenName = oldNode.GivenName
|
||||
node.ApprovedRoutes = oldNode.ApprovedRoutes
|
||||
ipv4 = oldNode.IPv4
|
||||
ipv6 = oldNode.IPv6
|
||||
}
|
||||
@ -594,17 +591,18 @@ func ensureUniqueGivenName(
|
||||
// containing the expired nodes, and a boolean indicating if any nodes were found.
|
||||
func ExpireExpiredNodes(tx *gorm.DB,
|
||||
lastCheck time.Time,
|
||||
) (time.Time, types.StateUpdate, bool) {
|
||||
) (time.Time, []change.ChangeSet, bool) {
|
||||
// use the time of the start of the function to ensure we
|
||||
// dont miss some nodes by returning it _after_ we have
|
||||
// checked everything.
|
||||
started := time.Now()
|
||||
|
||||
expired := make([]*tailcfg.PeerChange, 0)
|
||||
var updates []change.ChangeSet
|
||||
|
||||
nodes, err := ListNodes(tx)
|
||||
if err != nil {
|
||||
return time.Unix(0, 0), types.StateUpdate{}, false
|
||||
return time.Unix(0, 0), nil, false
|
||||
}
|
||||
for _, node := range nodes {
|
||||
if node.IsExpired() && node.Expiry.After(lastCheck) {
|
||||
@ -612,14 +610,15 @@ func ExpireExpiredNodes(tx *gorm.DB,
|
||||
NodeID: tailcfg.NodeID(node.ID),
|
||||
KeyExpiry: node.Expiry,
|
||||
})
|
||||
updates = append(updates, change.KeyExpiry(node.ID))
|
||||
}
|
||||
}
|
||||
|
||||
if len(expired) > 0 {
|
||||
return started, types.UpdatePeerPatch(expired...), true
|
||||
return started, updates, true
|
||||
}
|
||||
|
||||
return started, types.StateUpdate{}, false
|
||||
return started, nil, false
|
||||
}
|
||||
|
||||
// EphemeralGarbageCollector is a garbage collector that will delete nodes after
|
||||
@ -732,3 +731,114 @@ func (e *EphemeralGarbageCollector) Start() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string) *types.Node {
|
||||
if !testing.Testing() {
|
||||
panic("CreateNodeForTest can only be called during tests")
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
panic("CreateNodeForTest requires a valid user")
|
||||
}
|
||||
|
||||
nodeName := "testnode"
|
||||
if len(hostname) > 0 && hostname[0] != "" {
|
||||
nodeName = hostname[0]
|
||||
}
|
||||
|
||||
// Create a preauth key for the node
|
||||
pak, err := hsdb.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to create preauth key for test node: %v", err))
|
||||
}
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
machineKey := key.NewMachine()
|
||||
discoKey := key.NewDisco()
|
||||
|
||||
node := &types.Node{
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
DiscoKey: discoKey.Public(),
|
||||
Hostname: nodeName,
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: ptr.To(pak.ID),
|
||||
}
|
||||
|
||||
err = hsdb.DB.Save(node).Error
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to create test node: %v", err))
|
||||
}
|
||||
|
||||
return node
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) CreateRegisteredNodeForTest(user *types.User, hostname ...string) *types.Node {
|
||||
if !testing.Testing() {
|
||||
panic("CreateRegisteredNodeForTest can only be called during tests")
|
||||
}
|
||||
|
||||
node := hsdb.CreateNodeForTest(user, hostname...)
|
||||
|
||||
err := hsdb.DB.Transaction(func(tx *gorm.DB) error {
|
||||
_, err := RegisterNode(tx, *node, nil, nil)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to register test node: %v", err))
|
||||
}
|
||||
|
||||
registeredNode, err := hsdb.GetNodeByID(node.ID)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to get registered test node: %v", err))
|
||||
}
|
||||
|
||||
return registeredNode
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) CreateNodesForTest(user *types.User, count int, hostnamePrefix ...string) []*types.Node {
|
||||
if !testing.Testing() {
|
||||
panic("CreateNodesForTest can only be called during tests")
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
panic("CreateNodesForTest requires a valid user")
|
||||
}
|
||||
|
||||
prefix := "testnode"
|
||||
if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" {
|
||||
prefix = hostnamePrefix[0]
|
||||
}
|
||||
|
||||
nodes := make([]*types.Node, count)
|
||||
for i := range count {
|
||||
hostname := prefix + "-" + strconv.Itoa(i)
|
||||
nodes[i] = hsdb.CreateNodeForTest(user, hostname)
|
||||
}
|
||||
|
||||
return nodes
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) CreateRegisteredNodesForTest(user *types.User, count int, hostnamePrefix ...string) []*types.Node {
|
||||
if !testing.Testing() {
|
||||
panic("CreateRegisteredNodesForTest can only be called during tests")
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
panic("CreateRegisteredNodesForTest requires a valid user")
|
||||
}
|
||||
|
||||
prefix := "testnode"
|
||||
if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" {
|
||||
prefix = hostnamePrefix[0]
|
||||
}
|
||||
|
||||
nodes := make([]*types.Node, count)
|
||||
for i := range count {
|
||||
hostname := prefix + "-" + strconv.Itoa(i)
|
||||
nodes[i] = hsdb.CreateRegisteredNodeForTest(user, hostname)
|
||||
}
|
||||
|
||||
return nodes
|
||||
}
|
||||
|
@ -6,7 +6,6 @@ import (
|
||||
"math/big"
|
||||
"net/netip"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@ -26,82 +25,36 @@ import (
|
||||
)
|
||||
|
||||
func (s *Suite) TestGetNode(c *check.C) {
|
||||
user, err := db.CreateUser(types.User{Name: "test"})
|
||||
c.Assert(err, check.IsNil)
|
||||
user := db.CreateUserForTest("test")
|
||||
|
||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.getNode(types.UserID(user.ID), "testnode")
|
||||
_, err := db.getNode(types.UserID(user.ID), "testnode")
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
machineKey := key.NewMachine()
|
||||
|
||||
node := &types.Node{
|
||||
ID: 0,
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
Hostname: "testnode",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: ptr.To(pak.ID),
|
||||
}
|
||||
trx := db.DB.Save(node)
|
||||
c.Assert(trx.Error, check.IsNil)
|
||||
node := db.CreateNodeForTest(user, "testnode")
|
||||
|
||||
_, err = db.getNode(types.UserID(user.ID), "testnode")
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(node.Hostname, check.Equals, "testnode")
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetNodeByID(c *check.C) {
|
||||
user, err := db.CreateUser(types.User{Name: "test"})
|
||||
c.Assert(err, check.IsNil)
|
||||
user := db.CreateUserForTest("test")
|
||||
|
||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetNodeByID(0)
|
||||
_, err := db.GetNodeByID(0)
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
machineKey := key.NewMachine()
|
||||
node := db.CreateNodeForTest(user, "testnode")
|
||||
|
||||
node := types.Node{
|
||||
ID: 0,
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
Hostname: "testnode",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: ptr.To(pak.ID),
|
||||
}
|
||||
trx := db.DB.Save(&node)
|
||||
c.Assert(trx.Error, check.IsNil)
|
||||
|
||||
_, err = db.GetNodeByID(0)
|
||||
retrievedNode, err := db.GetNodeByID(node.ID)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(retrievedNode.Hostname, check.Equals, "testnode")
|
||||
}
|
||||
|
||||
func (s *Suite) TestHardDeleteNode(c *check.C) {
|
||||
user, err := db.CreateUser(types.User{Name: "test"})
|
||||
c.Assert(err, check.IsNil)
|
||||
user := db.CreateUserForTest("test")
|
||||
node := db.CreateNodeForTest(user, "testnode3")
|
||||
|
||||
nodeKey := key.NewNode()
|
||||
machineKey := key.NewMachine()
|
||||
|
||||
node := types.Node{
|
||||
ID: 0,
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
Hostname: "testnode3",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
}
|
||||
trx := db.DB.Save(&node)
|
||||
c.Assert(trx.Error, check.IsNil)
|
||||
|
||||
err = db.DeleteNode(&node)
|
||||
err := db.DeleteNode(node)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.getNode(types.UserID(user.ID), "testnode3")
|
||||
@ -109,42 +62,21 @@ func (s *Suite) TestHardDeleteNode(c *check.C) {
|
||||
}
|
||||
|
||||
func (s *Suite) TestListPeers(c *check.C) {
|
||||
user, err := db.CreateUser(types.User{Name: "test"})
|
||||
c.Assert(err, check.IsNil)
|
||||
user := db.CreateUserForTest("test")
|
||||
|
||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
_, err = db.GetNodeByID(0)
|
||||
_, err := db.GetNodeByID(0)
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
for index := range 11 {
|
||||
nodeKey := key.NewNode()
|
||||
machineKey := key.NewMachine()
|
||||
nodes := db.CreateNodesForTest(user, 11, "testnode")
|
||||
|
||||
node := types.Node{
|
||||
ID: types.NodeID(index),
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
Hostname: "testnode" + strconv.Itoa(index),
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: ptr.To(pak.ID),
|
||||
}
|
||||
trx := db.DB.Save(&node)
|
||||
c.Assert(trx.Error, check.IsNil)
|
||||
}
|
||||
|
||||
node0ByID, err := db.GetNodeByID(0)
|
||||
firstNode := nodes[0]
|
||||
peersOfFirstNode, err := db.ListPeers(firstNode.ID)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
peersOfNode0, err := db.ListPeers(node0ByID.ID)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
c.Assert(len(peersOfNode0), check.Equals, 9)
|
||||
c.Assert(peersOfNode0[0].Hostname, check.Equals, "testnode2")
|
||||
c.Assert(peersOfNode0[5].Hostname, check.Equals, "testnode7")
|
||||
c.Assert(peersOfNode0[8].Hostname, check.Equals, "testnode10")
|
||||
c.Assert(len(peersOfFirstNode), check.Equals, 10)
|
||||
c.Assert(peersOfFirstNode[0].Hostname, check.Equals, "testnode-1")
|
||||
c.Assert(peersOfFirstNode[5].Hostname, check.Equals, "testnode-6")
|
||||
c.Assert(peersOfFirstNode[9].Hostname, check.Equals, "testnode-10")
|
||||
}
|
||||
|
||||
func (s *Suite) TestExpireNode(c *check.C) {
|
||||
@ -807,13 +739,13 @@ func TestListPeers(t *testing.T) {
|
||||
// No parameter means no filter, should return all peers
|
||||
nodes, err = db.ListPeers(1)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.Equal(t, 1, len(nodes))
|
||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||
|
||||
// Empty node list should return all peers
|
||||
nodes, err = db.ListPeers(1, types.NodeIDs{}...)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.Equal(t, 1, len(nodes))
|
||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||
|
||||
// No match in IDs should return empty list and no error
|
||||
@ -824,13 +756,13 @@ func TestListPeers(t *testing.T) {
|
||||
// Partial match in IDs
|
||||
nodes, err = db.ListPeers(1, types.NodeIDs{2, 3}...)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.Equal(t, 1, len(nodes))
|
||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||
|
||||
// Several matched IDs, but node ID is still filtered out
|
||||
nodes, err = db.ListPeers(1, types.NodeIDs{1, 2, 3}...)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.Equal(t, 1, len(nodes))
|
||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||
}
|
||||
|
||||
@ -892,14 +824,14 @@ func TestListNodes(t *testing.T) {
|
||||
// No parameter means no filter, should return all nodes
|
||||
nodes, err = db.ListNodes()
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 2)
|
||||
assert.Equal(t, 2, len(nodes))
|
||||
assert.Equal(t, "test1", nodes[0].Hostname)
|
||||
assert.Equal(t, "test2", nodes[1].Hostname)
|
||||
|
||||
// Empty node list should return all nodes
|
||||
nodes, err = db.ListNodes(types.NodeIDs{}...)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 2)
|
||||
assert.Equal(t, 2, len(nodes))
|
||||
assert.Equal(t, "test1", nodes[0].Hostname)
|
||||
assert.Equal(t, "test2", nodes[1].Hostname)
|
||||
|
||||
@ -911,13 +843,13 @@ func TestListNodes(t *testing.T) {
|
||||
// Partial match in IDs
|
||||
nodes, err = db.ListNodes(types.NodeIDs{2, 3}...)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 1)
|
||||
assert.Equal(t, 1, len(nodes))
|
||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||
|
||||
// Several matched IDs
|
||||
nodes, err = db.ListNodes(types.NodeIDs{1, 2, 3}...)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 2)
|
||||
assert.Equal(t, 2, len(nodes))
|
||||
assert.Equal(t, "test1", nodes[0].Hostname)
|
||||
assert.Equal(t, "test2", nodes[1].Hostname)
|
||||
}
|
||||
|
@ -109,9 +109,7 @@ func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, e
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetPreAuthKey(key string) (*types.PreAuthKey, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.PreAuthKey, error) {
|
||||
return GetPreAuthKey(rx, key)
|
||||
})
|
||||
return GetPreAuthKey(hsdb.DB, key)
|
||||
}
|
||||
|
||||
// GetPreAuthKey returns a PreAuthKey for a given key. The caller is responsible
|
||||
@ -155,11 +153,8 @@ func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
|
||||
|
||||
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
|
||||
func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
|
||||
if err := tx.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
now := time.Now()
|
||||
return tx.Model(&types.PreAuthKey{}).Where("id = ?", k.ID).Update("expiration", now).Error
|
||||
}
|
||||
|
||||
func generateKey() (string, error) {
|
||||
|
@ -1,7 +1,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
@ -57,7 +57,7 @@ func (*Suite) TestPreAuthKeyACLTags(c *check.C) {
|
||||
listedPaks, err := db.ListPreAuthKeys(types.UserID(user.ID))
|
||||
c.Assert(err, check.IsNil)
|
||||
gotTags := listedPaks[0].Proto().GetAclTags()
|
||||
sort.Sort(sort.StringSlice(gotTags))
|
||||
slices.Sort(gotTags)
|
||||
c.Assert(gotTags, check.DeepEquals, tags)
|
||||
}
|
||||
|
||||
|
@ -3,6 +3,8 @@ package db
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
@ -110,9 +112,7 @@ func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error {
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) GetUserByID(uid types.UserID) (*types.User, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
|
||||
return GetUserByID(rx, uid)
|
||||
})
|
||||
return GetUserByID(hsdb.DB, uid)
|
||||
}
|
||||
|
||||
func GetUserByID(tx *gorm.DB, uid types.UserID) (*types.User, error) {
|
||||
@ -146,9 +146,7 @@ func GetUserByOIDCIdentifier(tx *gorm.DB, id string) (*types.User, error) {
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) ListUsers(where ...*types.User) ([]types.User, error) {
|
||||
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) {
|
||||
return ListUsers(rx, where...)
|
||||
})
|
||||
return ListUsers(hsdb.DB, where...)
|
||||
}
|
||||
|
||||
// ListUsers gets all the existing users.
|
||||
@ -217,3 +215,40 @@ func AssignNodeToUser(tx *gorm.DB, nodeID types.NodeID, uid types.UserID) error
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) CreateUserForTest(name ...string) *types.User {
|
||||
if !testing.Testing() {
|
||||
panic("CreateUserForTest can only be called during tests")
|
||||
}
|
||||
|
||||
userName := "testuser"
|
||||
if len(name) > 0 && name[0] != "" {
|
||||
userName = name[0]
|
||||
}
|
||||
|
||||
user, err := hsdb.CreateUser(types.User{Name: userName})
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to create test user: %v", err))
|
||||
}
|
||||
|
||||
return user
|
||||
}
|
||||
|
||||
func (hsdb *HSDatabase) CreateUsersForTest(count int, namePrefix ...string) []*types.User {
|
||||
if !testing.Testing() {
|
||||
panic("CreateUsersForTest can only be called during tests")
|
||||
}
|
||||
|
||||
prefix := "testuser"
|
||||
if len(namePrefix) > 0 && namePrefix[0] != "" {
|
||||
prefix = namePrefix[0]
|
||||
}
|
||||
|
||||
users := make([]*types.User, count)
|
||||
for i := range count {
|
||||
name := prefix + "-" + strconv.Itoa(i)
|
||||
users[i] = hsdb.CreateUserForTest(name)
|
||||
}
|
||||
|
||||
return users
|
||||
}
|
||||
|
@ -11,8 +11,7 @@ import (
|
||||
)
|
||||
|
||||
func (s *Suite) TestCreateAndDestroyUser(c *check.C) {
|
||||
user, err := db.CreateUser(types.User{Name: "test"})
|
||||
c.Assert(err, check.IsNil)
|
||||
user := db.CreateUserForTest("test")
|
||||
c.Assert(user.Name, check.Equals, "test")
|
||||
|
||||
users, err := db.ListUsers()
|
||||
@ -30,8 +29,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
|
||||
err := db.DestroyUser(9998)
|
||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||
|
||||
user, err := db.CreateUser(types.User{Name: "test"})
|
||||
c.Assert(err, check.IsNil)
|
||||
user := db.CreateUserForTest("test")
|
||||
|
||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
@ -64,8 +62,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
|
||||
}
|
||||
|
||||
func (s *Suite) TestRenameUser(c *check.C) {
|
||||
userTest, err := db.CreateUser(types.User{Name: "test"})
|
||||
c.Assert(err, check.IsNil)
|
||||
userTest := db.CreateUserForTest("test")
|
||||
c.Assert(userTest.Name, check.Equals, "test")
|
||||
|
||||
users, err := db.ListUsers()
|
||||
@ -86,8 +83,7 @@ func (s *Suite) TestRenameUser(c *check.C) {
|
||||
err = db.RenameUser(99988, "test")
|
||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||
|
||||
userTest2, err := db.CreateUser(types.User{Name: "test2"})
|
||||
c.Assert(err, check.IsNil)
|
||||
userTest2 := db.CreateUserForTest("test2")
|
||||
c.Assert(userTest2.Name, check.Equals, "test2")
|
||||
|
||||
want := "UNIQUE constraint failed"
|
||||
@ -98,11 +94,8 @@ func (s *Suite) TestRenameUser(c *check.C) {
|
||||
}
|
||||
|
||||
func (s *Suite) TestSetMachineUser(c *check.C) {
|
||||
oldUser, err := db.CreateUser(types.User{Name: "old"})
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
newUser, err := db.CreateUser(types.User{Name: "new"})
|
||||
c.Assert(err, check.IsNil)
|
||||
oldUser := db.CreateUserForTest("old")
|
||||
newUser := db.CreateUserForTest("new")
|
||||
|
||||
pak, err := db.CreatePreAuthKey(types.UserID(oldUser.ID), false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
@ -17,10 +17,6 @@ import (
|
||||
func (h *Headscale) debugHTTPServer() *http.Server {
|
||||
debugMux := http.NewServeMux()
|
||||
debug := tsweb.Debugger(debugMux)
|
||||
debug.Handle("notifier", "Connected nodes in notifier", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(h.nodeNotifier.String()))
|
||||
}))
|
||||
debug.Handle("config", "Current configuration", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
config, err := json.MarshalIndent(h.cfg, "", " ")
|
||||
if err != nil {
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"maps"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
@ -72,9 +73,7 @@ func mergeDERPMaps(derpMaps []*tailcfg.DERPMap) *tailcfg.DERPMap {
|
||||
}
|
||||
|
||||
for _, derpMap := range derpMaps {
|
||||
for id, region := range derpMap.Regions {
|
||||
result.Regions[id] = region
|
||||
}
|
||||
maps.Copy(result.Regions, derpMap.Regions)
|
||||
}
|
||||
|
||||
return &result
|
||||
|
@ -1,3 +1,5 @@
|
||||
//go:generate buf generate --template ../buf.gen.yaml -o .. ../proto
|
||||
|
||||
// nolint
|
||||
package hscontrol
|
||||
|
||||
@ -27,6 +29,7 @@ import (
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/state"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
)
|
||||
|
||||
@ -56,12 +59,14 @@ func (api headscaleV1APIServer) CreateUser(
|
||||
return nil, status.Errorf(codes.Internal, "failed to create user: %s", err)
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
|
||||
c := change.UserAdded(types.UserID(user.ID))
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "grpc-user-created", user.Name)
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
c.Change = change.Policy
|
||||
}
|
||||
|
||||
api.h.Change(c)
|
||||
|
||||
return &v1.CreateUserResponse{User: user.Proto()}, nil
|
||||
}
|
||||
|
||||
@ -81,8 +86,7 @@ func (api headscaleV1APIServer) RenameUser(
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "grpc-user-renamed", request.GetNewName())
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
api.h.Change(change.PolicyChange())
|
||||
}
|
||||
|
||||
newUser, err := api.h.state.GetUserByName(request.GetNewName())
|
||||
@ -107,6 +111,8 @@ func (api headscaleV1APIServer) DeleteUser(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
api.h.Change(change.UserRemoved(types.UserID(user.ID)))
|
||||
|
||||
return &v1.DeleteUserResponse{}, nil
|
||||
}
|
||||
|
||||
@ -246,7 +252,7 @@ func (api headscaleV1APIServer) RegisterNode(
|
||||
return nil, fmt.Errorf("looking up user: %w", err)
|
||||
}
|
||||
|
||||
node, _, err := api.h.state.HandleNodeFromAuthPath(
|
||||
node, nodeChange, err := api.h.state.HandleNodeFromAuthPath(
|
||||
registrationId,
|
||||
types.UserID(user.ID),
|
||||
nil,
|
||||
@ -267,22 +273,13 @@ func (api headscaleV1APIServer) RegisterNode(
|
||||
// ensure we send an update.
|
||||
// This works, but might be another good candidate for doing some sort of
|
||||
// eventbus.
|
||||
routesChanged := api.h.state.AutoApproveRoutes(node)
|
||||
_, policyChanged, err := api.h.state.SaveNode(node)
|
||||
_ = api.h.state.AutoApproveRoutes(node)
|
||||
_, _, err = api.h.state.SaveNode(node)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("saving auto approved routes to node: %w", err)
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed (from SaveNode or route changes)
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "grpc-nodes-change", "all")
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
}
|
||||
|
||||
if routesChanged {
|
||||
ctx = types.NotifyCtx(context.Background(), "web-node-login", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(node.ID))
|
||||
}
|
||||
api.h.Change(nodeChange)
|
||||
|
||||
return &v1.RegisterNodeResponse{Node: node.Proto()}, nil
|
||||
}
|
||||
@ -300,7 +297,7 @@ func (api headscaleV1APIServer) GetNode(
|
||||
|
||||
// Populate the online field based on
|
||||
// currently connected nodes.
|
||||
resp.Online = api.h.nodeNotifier.IsConnected(node.ID)
|
||||
resp.Online = api.h.mapBatcher.IsConnected(node.ID)
|
||||
|
||||
return &v1.GetNodeResponse{Node: resp}, nil
|
||||
}
|
||||
@ -316,21 +313,14 @@ func (api headscaleV1APIServer) SetTags(
|
||||
}
|
||||
}
|
||||
|
||||
node, policyChanged, err := api.h.state.SetNodeTags(types.NodeID(request.GetNodeId()), request.GetTags())
|
||||
node, nodeChange, err := api.h.state.SetNodeTags(types.NodeID(request.GetNodeId()), request.GetTags())
|
||||
if err != nil {
|
||||
return &v1.SetTagsResponse{
|
||||
Node: nil,
|
||||
}, status.Error(codes.InvalidArgument, err.Error())
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "grpc-node-tags", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
}
|
||||
|
||||
ctx = types.NotifyCtx(ctx, "cli-settags", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
|
||||
api.h.Change(nodeChange)
|
||||
|
||||
log.Trace().
|
||||
Str("node", node.Hostname).
|
||||
@ -362,23 +352,19 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
|
||||
tsaddr.SortPrefixes(routes)
|
||||
routes = slices.Compact(routes)
|
||||
|
||||
node, policyChanged, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), routes)
|
||||
node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), routes)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.InvalidArgument, err.Error())
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "grpc-routes-approved", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
}
|
||||
routeChange := api.h.state.SetNodeRoutes(node.ID, node.SubnetRoutes()...)
|
||||
|
||||
if api.h.state.SetNodeRoutes(node.ID, node.SubnetRoutes()...) {
|
||||
ctx := types.NotifyCtx(ctx, "poll-primary-change", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
} else {
|
||||
ctx = types.NotifyCtx(ctx, "cli-approveroutes", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
|
||||
// Always propagate node changes from SetApprovedRoutes
|
||||
api.h.Change(nodeChange)
|
||||
|
||||
// If routes changed, propagate those changes too
|
||||
if !routeChange.Empty() {
|
||||
api.h.Change(routeChange)
|
||||
}
|
||||
|
||||
proto := node.Proto()
|
||||
@ -409,19 +395,12 @@ func (api headscaleV1APIServer) DeleteNode(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
policyChanged, err := api.h.state.DeleteNode(node)
|
||||
nodeChange, err := api.h.state.DeleteNode(node)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "grpc-node-deleted", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
}
|
||||
|
||||
ctx = types.NotifyCtx(ctx, "cli-deletenode", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerRemoved(node.ID))
|
||||
api.h.Change(nodeChange)
|
||||
|
||||
return &v1.DeleteNodeResponse{}, nil
|
||||
}
|
||||
@ -432,25 +411,13 @@ func (api headscaleV1APIServer) ExpireNode(
|
||||
) (*v1.ExpireNodeResponse, error) {
|
||||
now := time.Now()
|
||||
|
||||
node, policyChanged, err := api.h.state.SetNodeExpiry(types.NodeID(request.GetNodeId()), now)
|
||||
node, nodeChange, err := api.h.state.SetNodeExpiry(types.NodeID(request.GetNodeId()), now)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "grpc-node-expired", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
}
|
||||
|
||||
ctx = types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyByNodeID(
|
||||
ctx,
|
||||
types.UpdateSelf(node.ID),
|
||||
node.ID)
|
||||
|
||||
ctx = types.NotifyCtx(ctx, "cli-expirenode-peers", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdateExpire(node.ID, now), node.ID)
|
||||
// TODO(kradalby): Ensure that both the selfupdate and peer updates are sent
|
||||
api.h.Change(nodeChange)
|
||||
|
||||
log.Trace().
|
||||
Str("node", node.Hostname).
|
||||
@ -464,22 +431,13 @@ func (api headscaleV1APIServer) RenameNode(
|
||||
ctx context.Context,
|
||||
request *v1.RenameNodeRequest,
|
||||
) (*v1.RenameNodeResponse, error) {
|
||||
node, policyChanged, err := api.h.state.RenameNode(types.NodeID(request.GetNodeId()), request.GetNewName())
|
||||
node, nodeChange, err := api.h.state.RenameNode(types.NodeID(request.GetNodeId()), request.GetNewName())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "grpc-node-renamed", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
}
|
||||
|
||||
ctx = types.NotifyCtx(ctx, "cli-renamenode-self", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyByNodeID(ctx, types.UpdateSelf(node.ID), node.ID)
|
||||
|
||||
ctx = types.NotifyCtx(ctx, "cli-renamenode-peers", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
|
||||
// TODO(kradalby): investigate if we need selfupdate
|
||||
api.h.Change(nodeChange)
|
||||
|
||||
log.Trace().
|
||||
Str("node", node.Hostname).
|
||||
@ -498,7 +456,7 @@ func (api headscaleV1APIServer) ListNodes(
|
||||
// probably be done once.
|
||||
// TODO(kradalby): This should be done in one tx.
|
||||
|
||||
isLikelyConnected := api.h.nodeNotifier.LikelyConnectedMap()
|
||||
IsConnected := api.h.mapBatcher.ConnectedMap()
|
||||
if request.GetUser() != "" {
|
||||
user, err := api.h.state.GetUserByName(request.GetUser())
|
||||
if err != nil {
|
||||
@ -510,7 +468,7 @@ func (api headscaleV1APIServer) ListNodes(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response := nodesToProto(api.h.state, isLikelyConnected, nodes)
|
||||
response := nodesToProto(api.h.state, IsConnected, nodes)
|
||||
return &v1.ListNodesResponse{Nodes: response}, nil
|
||||
}
|
||||
|
||||
@ -523,18 +481,18 @@ func (api headscaleV1APIServer) ListNodes(
|
||||
return nodes[i].ID < nodes[j].ID
|
||||
})
|
||||
|
||||
response := nodesToProto(api.h.state, isLikelyConnected, nodes)
|
||||
response := nodesToProto(api.h.state, IsConnected, nodes)
|
||||
return &v1.ListNodesResponse{Nodes: response}, nil
|
||||
}
|
||||
|
||||
func nodesToProto(state *state.State, isLikelyConnected *xsync.MapOf[types.NodeID, bool], nodes types.Nodes) []*v1.Node {
|
||||
func nodesToProto(state *state.State, IsConnected *xsync.MapOf[types.NodeID, bool], nodes types.Nodes) []*v1.Node {
|
||||
response := make([]*v1.Node, len(nodes))
|
||||
for index, node := range nodes {
|
||||
resp := node.Proto()
|
||||
|
||||
// Populate the online field based on
|
||||
// currently connected nodes.
|
||||
if val, ok := isLikelyConnected.Load(node.ID); ok && val {
|
||||
if val, ok := IsConnected.Load(node.ID); ok && val {
|
||||
resp.Online = true
|
||||
}
|
||||
|
||||
@ -556,24 +514,14 @@ func (api headscaleV1APIServer) MoveNode(
|
||||
ctx context.Context,
|
||||
request *v1.MoveNodeRequest,
|
||||
) (*v1.MoveNodeResponse, error) {
|
||||
node, policyChanged, err := api.h.state.AssignNodeToUser(types.NodeID(request.GetNodeId()), types.UserID(request.GetUser()))
|
||||
node, nodeChange, err := api.h.state.AssignNodeToUser(types.NodeID(request.GetNodeId()), types.UserID(request.GetUser()))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "grpc-node-moved", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
}
|
||||
|
||||
ctx = types.NotifyCtx(ctx, "cli-movenode-self", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyByNodeID(
|
||||
ctx,
|
||||
types.UpdateSelf(node.ID),
|
||||
node.ID)
|
||||
ctx = types.NotifyCtx(ctx, "cli-movenode", node.Hostname)
|
||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
|
||||
// TODO(kradalby): Ensure the policy is also sent
|
||||
// TODO(kradalby): ensure that both the selfupdate and peer updates are sent
|
||||
api.h.Change(nodeChange)
|
||||
|
||||
return &v1.MoveNodeResponse{Node: node.Proto()}, nil
|
||||
}
|
||||
@ -754,8 +702,7 @@ func (api headscaleV1APIServer) SetPolicy(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx := types.NotifyCtx(context.Background(), "acl-update", "na")
|
||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
api.h.Change(change.PolicyChange())
|
||||
}
|
||||
|
||||
response := &v1.SetPolicyResponse{
|
||||
|
155
hscontrol/mapper/batcher.go
Normal file
155
hscontrol/mapper/batcher.go
Normal file
@ -0,0 +1,155 @@
|
||||
package mapper
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/state"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
type batcherFunc func(cfg *types.Config, state *state.State) Batcher
|
||||
|
||||
// Batcher defines the common interface for all batcher implementations.
|
||||
type Batcher interface {
|
||||
Start()
|
||||
Close()
|
||||
AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error
|
||||
RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool)
|
||||
IsConnected(id types.NodeID) bool
|
||||
ConnectedMap() *xsync.Map[types.NodeID, bool]
|
||||
AddWork(c change.ChangeSet)
|
||||
MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error)
|
||||
}
|
||||
|
||||
func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeBatcher {
|
||||
return &LockFreeBatcher{
|
||||
mapper: mapper,
|
||||
workers: workers,
|
||||
tick: time.NewTicker(batchTime),
|
||||
|
||||
// The size of this channel is arbitrary chosen, the sizing should be revisited.
|
||||
workCh: make(chan work, workers*200),
|
||||
nodes: xsync.NewMap[types.NodeID, *nodeConn](),
|
||||
connected: xsync.NewMap[types.NodeID, *time.Time](),
|
||||
pendingChanges: xsync.NewMap[types.NodeID, []change.ChangeSet](),
|
||||
}
|
||||
}
|
||||
|
||||
// NewBatcherAndMapper creates a Batcher implementation.
|
||||
func NewBatcherAndMapper(cfg *types.Config, state *state.State) Batcher {
|
||||
m := newMapper(cfg, state)
|
||||
b := NewBatcher(cfg.Tuning.BatchChangeDelay, cfg.Tuning.BatcherWorkers, m)
|
||||
m.batcher = b
|
||||
return b
|
||||
}
|
||||
|
||||
// nodeConnection interface for different connection implementations.
|
||||
type nodeConnection interface {
|
||||
nodeID() types.NodeID
|
||||
version() tailcfg.CapabilityVersion
|
||||
send(data *tailcfg.MapResponse) error
|
||||
}
|
||||
|
||||
// generateMapResponse generates a [tailcfg.MapResponse] for the given NodeID that is based on the provided [change.ChangeSet].
|
||||
func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion, mapper *mapper, c change.ChangeSet) (*tailcfg.MapResponse, error) {
|
||||
if c.Empty() {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Validate inputs before processing
|
||||
if nodeID == 0 {
|
||||
return nil, fmt.Errorf("invalid nodeID: %d", nodeID)
|
||||
}
|
||||
|
||||
if mapper == nil {
|
||||
return nil, fmt.Errorf("mapper is nil for nodeID %d", nodeID)
|
||||
}
|
||||
|
||||
var mapResp *tailcfg.MapResponse
|
||||
var err error
|
||||
|
||||
switch c.Change {
|
||||
case change.DERP:
|
||||
mapResp, err = mapper.derpMapResponse(nodeID)
|
||||
|
||||
case change.NodeCameOnline, change.NodeWentOffline:
|
||||
if c.IsSubnetRouter {
|
||||
// TODO(kradalby): This can potentially be a peer update of the old and new subnet router.
|
||||
mapResp, err = mapper.fullMapResponse(nodeID, version)
|
||||
} else {
|
||||
mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: c.NodeID.NodeID(),
|
||||
Online: ptr.To(c.Change == change.NodeCameOnline),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
case change.NodeNewOrUpdate:
|
||||
mapResp, err = mapper.fullMapResponse(nodeID, version)
|
||||
|
||||
case change.NodeRemove:
|
||||
mapResp, err = mapper.peerRemovedResponse(nodeID, c.NodeID)
|
||||
|
||||
default:
|
||||
// The following will always hit this:
|
||||
// change.Full, change.Policy
|
||||
mapResp, err = mapper.fullMapResponse(nodeID, version)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generating map response for nodeID %d: %w", nodeID, err)
|
||||
}
|
||||
|
||||
// TODO(kradalby): Is this necessary?
|
||||
// Validate the generated map response - only check for nil response
|
||||
// Note: mapResp.Node can be nil for peer updates, which is valid
|
||||
if mapResp == nil && c.Change != change.DERP && c.Change != change.NodeRemove {
|
||||
return nil, fmt.Errorf("generated nil map response for nodeID %d change %s", nodeID, c.Change.String())
|
||||
}
|
||||
|
||||
return mapResp, nil
|
||||
}
|
||||
|
||||
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.ChangeSet].
|
||||
func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) error {
|
||||
if nc == nil {
|
||||
return fmt.Errorf("nodeConnection is nil")
|
||||
}
|
||||
|
||||
nodeID := nc.nodeID()
|
||||
data, err := generateMapResponse(nodeID, nc.version(), mapper, c)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generating map response for node %d: %w", nodeID, err)
|
||||
}
|
||||
|
||||
if data == nil {
|
||||
// No data to send is valid for some change types
|
||||
return nil
|
||||
}
|
||||
|
||||
// Send the map response
|
||||
if err := nc.send(data); err != nil {
|
||||
return fmt.Errorf("sending map response to node %d: %w", nodeID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// workResult represents the result of processing a change.
|
||||
type workResult struct {
|
||||
mapResponse *tailcfg.MapResponse
|
||||
err error
|
||||
}
|
||||
|
||||
// work represents a unit of work to be processed by workers.
|
||||
type work struct {
|
||||
c change.ChangeSet
|
||||
nodeID types.NodeID
|
||||
resultCh chan<- workResult // optional channel for synchronous operations
|
||||
}
|
491
hscontrol/mapper/batcher_lockfree.go
Normal file
491
hscontrol/mapper/batcher_lockfree.go
Normal file
@ -0,0 +1,491 @@
|
||||
package mapper
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
"github.com/rs/zerolog/log"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
// LockFreeBatcher uses atomic operations and concurrent maps to eliminate mutex contention.
|
||||
type LockFreeBatcher struct {
|
||||
tick *time.Ticker
|
||||
mapper *mapper
|
||||
workers int
|
||||
|
||||
// Lock-free concurrent maps
|
||||
nodes *xsync.Map[types.NodeID, *nodeConn]
|
||||
connected *xsync.Map[types.NodeID, *time.Time]
|
||||
|
||||
// Work queue channel
|
||||
workCh chan work
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
// Batching state
|
||||
pendingChanges *xsync.Map[types.NodeID, []change.ChangeSet]
|
||||
batchMutex sync.RWMutex
|
||||
|
||||
// Metrics
|
||||
totalNodes atomic.Int64
|
||||
totalUpdates atomic.Int64
|
||||
workQueuedCount atomic.Int64
|
||||
workProcessed atomic.Int64
|
||||
workErrors atomic.Int64
|
||||
}
|
||||
|
||||
// AddNode registers a new node connection with the batcher and sends an initial map response.
|
||||
// It creates or updates the node's connection data, validates the initial map generation,
|
||||
// and notifies other nodes that this node has come online.
|
||||
// TODO(kradalby): See if we can move the isRouter argument somewhere else.
|
||||
func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error {
|
||||
// First validate that we can generate initial map before doing anything else
|
||||
fullSelfChange := change.FullSelf(id)
|
||||
|
||||
// TODO(kradalby): This should not be generated here, but rather in MapResponseFromChange.
|
||||
// This currently means that the goroutine for the node connection will do the processing
|
||||
// which means that we might have uncontrolled concurrency.
|
||||
// When we use MapResponseFromChange, it will be processed by the same worker pool, causing
|
||||
// it to be processed in a more controlled manner.
|
||||
initialMap, err := generateMapResponse(id, version, b.mapper, fullSelfChange)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate initial map for node %d: %w", id, err)
|
||||
}
|
||||
|
||||
// Only after validation succeeds, create or update node connection
|
||||
newConn := newNodeConn(id, c, version, b.mapper)
|
||||
|
||||
var conn *nodeConn
|
||||
if existing, loaded := b.nodes.LoadOrStore(id, newConn); loaded {
|
||||
// Update existing connection
|
||||
existing.updateConnection(c, version)
|
||||
conn = existing
|
||||
} else {
|
||||
b.totalNodes.Add(1)
|
||||
conn = newConn
|
||||
}
|
||||
|
||||
// Mark as connected only after validation succeeds
|
||||
b.connected.Store(id, nil) // nil = connected
|
||||
|
||||
log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node connected to batcher")
|
||||
|
||||
// Send the validated initial map
|
||||
if initialMap != nil {
|
||||
if err := conn.send(initialMap); err != nil {
|
||||
// Clean up the connection state on send failure
|
||||
b.nodes.Delete(id)
|
||||
b.connected.Delete(id)
|
||||
return fmt.Errorf("failed to send initial map to node %d: %w", id, err)
|
||||
}
|
||||
|
||||
// Notify other nodes that this node came online
|
||||
b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeCameOnline, IsSubnetRouter: isRouter})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveNode disconnects a node from the batcher, marking it as offline and cleaning up its state.
|
||||
// It validates the connection channel matches the current one, closes the connection,
|
||||
// and notifies other nodes that this node has gone offline.
|
||||
func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool) {
|
||||
// Check if this is the current connection and mark it as closed
|
||||
if existing, ok := b.nodes.Load(id); ok {
|
||||
if !existing.matchesChannel(c) {
|
||||
log.Debug().Uint64("node.id", id.Uint64()).Msg("RemoveNode called for non-current connection, ignoring")
|
||||
return // Not the current connection, not an error
|
||||
}
|
||||
|
||||
// Mark the connection as closed to prevent further sends
|
||||
if connData := existing.connData.Load(); connData != nil {
|
||||
connData.closed.Store(true)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node disconnected from batcher, marking as offline")
|
||||
|
||||
// Remove node and mark disconnected atomically
|
||||
b.nodes.Delete(id)
|
||||
b.connected.Store(id, ptr.To(time.Now()))
|
||||
b.totalNodes.Add(-1)
|
||||
|
||||
// Notify other nodes that this node went offline
|
||||
b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeWentOffline, IsSubnetRouter: isRouter})
|
||||
}
|
||||
|
||||
// AddWork queues a change to be processed by the batcher.
|
||||
// Critical changes are processed immediately, while others are batched for efficiency.
|
||||
func (b *LockFreeBatcher) AddWork(c change.ChangeSet) {
|
||||
b.addWork(c)
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) Start() {
|
||||
b.ctx, b.cancel = context.WithCancel(context.Background())
|
||||
go b.doWork()
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) Close() {
|
||||
if b.cancel != nil {
|
||||
b.cancel()
|
||||
}
|
||||
close(b.workCh)
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) doWork() {
|
||||
log.Debug().Msg("batcher doWork loop started")
|
||||
defer log.Debug().Msg("batcher doWork loop stopped")
|
||||
|
||||
for i := range b.workers {
|
||||
go b.worker(i + 1)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-b.tick.C:
|
||||
// Process batched changes
|
||||
b.processBatchedChanges()
|
||||
case <-b.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) worker(workerID int) {
|
||||
log.Debug().Int("workerID", workerID).Msg("batcher worker started")
|
||||
defer log.Debug().Int("workerID", workerID).Msg("batcher worker stopped")
|
||||
|
||||
for {
|
||||
select {
|
||||
case w, ok := <-b.workCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
b.workProcessed.Add(1)
|
||||
|
||||
// If the resultCh is set, it means that this is a work request
|
||||
// where there is a blocking function waiting for the map that
|
||||
// is being generated.
|
||||
// This is used for synchronous map generation.
|
||||
if w.resultCh != nil {
|
||||
var result workResult
|
||||
if nc, exists := b.nodes.Load(w.nodeID); exists {
|
||||
result.mapResponse, result.err = generateMapResponse(nc.nodeID(), nc.version(), b.mapper, w.c)
|
||||
if result.err != nil {
|
||||
b.workErrors.Add(1)
|
||||
log.Error().Err(result.err).
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Msg("failed to generate map response for synchronous work")
|
||||
}
|
||||
} else {
|
||||
result.err = fmt.Errorf("node %d not found", w.nodeID)
|
||||
b.workErrors.Add(1)
|
||||
log.Error().Err(result.err).
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Msg("node not found for synchronous work")
|
||||
}
|
||||
|
||||
// Send result
|
||||
select {
|
||||
case w.resultCh <- result:
|
||||
case <-b.ctx.Done():
|
||||
return
|
||||
}
|
||||
|
||||
duration := time.Since(startTime)
|
||||
if duration > 100*time.Millisecond {
|
||||
log.Warn().
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Dur("duration", duration).
|
||||
Msg("slow synchronous work processing")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// If resultCh is nil, this is an asynchronous work request
|
||||
// that should be processed and sent to the node instead of
|
||||
// returned to the caller.
|
||||
if nc, exists := b.nodes.Load(w.nodeID); exists {
|
||||
// Check if this connection is still active before processing
|
||||
if connData := nc.connData.Load(); connData != nil && connData.closed.Load() {
|
||||
log.Debug().
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Msg("skipping work for closed connection")
|
||||
continue
|
||||
}
|
||||
|
||||
err := nc.change(w.c)
|
||||
if err != nil {
|
||||
b.workErrors.Add(1)
|
||||
log.Error().Err(err).
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.c.NodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Msg("failed to apply change")
|
||||
}
|
||||
} else {
|
||||
log.Debug().
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Msg("node not found for asynchronous work - node may have disconnected")
|
||||
}
|
||||
|
||||
duration := time.Since(startTime)
|
||||
if duration > 100*time.Millisecond {
|
||||
log.Warn().
|
||||
Int("workerID", workerID).
|
||||
Uint64("node.id", w.nodeID.Uint64()).
|
||||
Str("change", w.c.Change.String()).
|
||||
Dur("duration", duration).
|
||||
Msg("slow asynchronous work processing")
|
||||
}
|
||||
|
||||
case <-b.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *LockFreeBatcher) addWork(c change.ChangeSet) {
|
||||
// For critical changes that need immediate processing, send directly
|
||||
if b.shouldProcessImmediately(c) {
|
||||
if c.SelfUpdateOnly {
|
||||
b.queueWork(work{c: c, nodeID: c.NodeID, resultCh: nil})
|
||||
return
|
||||
}
|
||||
b.nodes.Range(func(nodeID types.NodeID, _ *nodeConn) bool {
|
||||
if c.NodeID == nodeID && !c.AlsoSelf() {
|
||||
return true
|
||||
}
|
||||
b.queueWork(work{c: c, nodeID: nodeID, resultCh: nil})
|
||||
return true
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// For non-critical changes, add to batch
|
||||
b.addToBatch(c)
|
||||
}
|
||||
|
||||
// queueWork safely queues work
|
||||
func (b *LockFreeBatcher) queueWork(w work) {
|
||||
b.workQueuedCount.Add(1)
|
||||
|
||||
select {
|
||||
case b.workCh <- w:
|
||||
// Successfully queued
|
||||
case <-b.ctx.Done():
|
||||
// Batcher is shutting down
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// shouldProcessImmediately determines if a change should bypass batching
|
||||
func (b *LockFreeBatcher) shouldProcessImmediately(c change.ChangeSet) bool {
|
||||
// Process these changes immediately to avoid delaying critical functionality
|
||||
switch c.Change {
|
||||
case change.Full, change.NodeRemove, change.NodeCameOnline, change.NodeWentOffline, change.Policy:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// addToBatch adds a change to the pending batch
|
||||
func (b *LockFreeBatcher) addToBatch(c change.ChangeSet) {
|
||||
b.batchMutex.Lock()
|
||||
defer b.batchMutex.Unlock()
|
||||
|
||||
if c.SelfUpdateOnly {
|
||||
changes, _ := b.pendingChanges.LoadOrStore(c.NodeID, []change.ChangeSet{})
|
||||
changes = append(changes, c)
|
||||
b.pendingChanges.Store(c.NodeID, changes)
|
||||
return
|
||||
}
|
||||
|
||||
b.nodes.Range(func(nodeID types.NodeID, _ *nodeConn) bool {
|
||||
if c.NodeID == nodeID && !c.AlsoSelf() {
|
||||
return true
|
||||
}
|
||||
|
||||
changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.ChangeSet{})
|
||||
changes = append(changes, c)
|
||||
b.pendingChanges.Store(nodeID, changes)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// processBatchedChanges processes all pending batched changes
|
||||
func (b *LockFreeBatcher) processBatchedChanges() {
|
||||
b.batchMutex.Lock()
|
||||
defer b.batchMutex.Unlock()
|
||||
|
||||
if b.pendingChanges == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Process all pending changes
|
||||
b.pendingChanges.Range(func(nodeID types.NodeID, changes []change.ChangeSet) bool {
|
||||
if len(changes) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Send all batched changes for this node
|
||||
for _, c := range changes {
|
||||
b.queueWork(work{c: c, nodeID: nodeID, resultCh: nil})
|
||||
}
|
||||
|
||||
// Clear the pending changes for this node
|
||||
b.pendingChanges.Delete(nodeID)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// IsConnected is lock-free read.
|
||||
func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool {
|
||||
if val, ok := b.connected.Load(id); ok {
|
||||
// nil means connected
|
||||
return val == nil
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ConnectedMap returns a lock-free map of all connected nodes.
|
||||
func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
|
||||
ret := xsync.NewMap[types.NodeID, bool]()
|
||||
|
||||
b.connected.Range(func(id types.NodeID, val *time.Time) bool {
|
||||
// nil means connected
|
||||
ret.Store(id, val == nil)
|
||||
return true
|
||||
})
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// MapResponseFromChange queues work to generate a map response and waits for the result.
|
||||
// This allows synchronous map generation using the same worker pool.
|
||||
func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error) {
|
||||
resultCh := make(chan workResult, 1)
|
||||
|
||||
// Queue the work with a result channel using the safe queueing method
|
||||
b.queueWork(work{c: c, nodeID: id, resultCh: resultCh})
|
||||
|
||||
// Wait for the result
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
return result.mapResponse, result.err
|
||||
case <-b.ctx.Done():
|
||||
return nil, fmt.Errorf("batcher shutting down while generating map response for node %d", id)
|
||||
}
|
||||
}
|
||||
|
||||
// connectionData holds the channel and connection parameters.
|
||||
type connectionData struct {
|
||||
c chan<- *tailcfg.MapResponse
|
||||
version tailcfg.CapabilityVersion
|
||||
closed atomic.Bool // Track if this connection has been closed
|
||||
}
|
||||
|
||||
// nodeConn described the node connection and its associated data.
|
||||
type nodeConn struct {
|
||||
id types.NodeID
|
||||
mapper *mapper
|
||||
|
||||
// Atomic pointer to connection data - allows lock-free updates
|
||||
connData atomic.Pointer[connectionData]
|
||||
|
||||
updateCount atomic.Int64
|
||||
}
|
||||
|
||||
func newNodeConn(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion, mapper *mapper) *nodeConn {
|
||||
nc := &nodeConn{
|
||||
id: id,
|
||||
mapper: mapper,
|
||||
}
|
||||
|
||||
// Initialize connection data
|
||||
data := &connectionData{
|
||||
c: c,
|
||||
version: version,
|
||||
}
|
||||
nc.connData.Store(data)
|
||||
|
||||
return nc
|
||||
}
|
||||
|
||||
// updateConnection atomically updates connection parameters.
|
||||
func (nc *nodeConn) updateConnection(c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) {
|
||||
newData := &connectionData{
|
||||
c: c,
|
||||
version: version,
|
||||
}
|
||||
nc.connData.Store(newData)
|
||||
}
|
||||
|
||||
// matchesChannel checks if the given channel matches current connection.
|
||||
func (nc *nodeConn) matchesChannel(c chan<- *tailcfg.MapResponse) bool {
|
||||
data := nc.connData.Load()
|
||||
if data == nil {
|
||||
return false
|
||||
}
|
||||
// Compare channel pointers directly
|
||||
return data.c == c
|
||||
}
|
||||
|
||||
// compressAndVersion atomically reads connection settings.
|
||||
func (nc *nodeConn) version() tailcfg.CapabilityVersion {
|
||||
data := nc.connData.Load()
|
||||
if data == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return data.version
|
||||
}
|
||||
|
||||
func (nc *nodeConn) nodeID() types.NodeID {
|
||||
return nc.id
|
||||
}
|
||||
|
||||
func (nc *nodeConn) change(c change.ChangeSet) error {
|
||||
return handleNodeChange(nc, nc.mapper, c)
|
||||
}
|
||||
|
||||
// send sends data to the node's channel.
|
||||
// The node will pick it up and send it to the HTTP handler.
|
||||
func (nc *nodeConn) send(data *tailcfg.MapResponse) error {
|
||||
connData := nc.connData.Load()
|
||||
if connData == nil {
|
||||
return fmt.Errorf("node %d: no connection data", nc.id)
|
||||
}
|
||||
|
||||
// Check if connection has been closed
|
||||
if connData.closed.Load() {
|
||||
return fmt.Errorf("node %d: connection closed", nc.id)
|
||||
}
|
||||
|
||||
// TODO(kradalby): We might need some sort of timeout here if the client is not reading
|
||||
// the channel. That might mean that we are sending to a node that has gone offline, but
|
||||
// the channel is still open.
|
||||
connData.c <- data
|
||||
nc.updateCount.Add(1)
|
||||
return nil
|
||||
}
|
1977
hscontrol/mapper/batcher_test.go
Normal file
1977
hscontrol/mapper/batcher_test.go
Normal file
File diff suppressed because it is too large
Load Diff
259
hscontrol/mapper/builder.go
Normal file
259
hscontrol/mapper/builder.go
Normal file
@ -0,0 +1,259 @@
|
||||
package mapper
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/views"
|
||||
"tailscale.com/util/multierr"
|
||||
)
|
||||
|
||||
// MapResponseBuilder provides a fluent interface for building tailcfg.MapResponse
|
||||
type MapResponseBuilder struct {
|
||||
resp *tailcfg.MapResponse
|
||||
mapper *mapper
|
||||
nodeID types.NodeID
|
||||
capVer tailcfg.CapabilityVersion
|
||||
errs []error
|
||||
}
|
||||
|
||||
// NewMapResponseBuilder creates a new builder with basic fields set
|
||||
func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder {
|
||||
now := time.Now()
|
||||
return &MapResponseBuilder{
|
||||
resp: &tailcfg.MapResponse{
|
||||
KeepAlive: false,
|
||||
ControlTime: &now,
|
||||
},
|
||||
mapper: m,
|
||||
nodeID: nodeID,
|
||||
errs: nil,
|
||||
}
|
||||
}
|
||||
|
||||
// addError adds an error to the builder's error list
|
||||
func (b *MapResponseBuilder) addError(err error) {
|
||||
if err != nil {
|
||||
b.errs = append(b.errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
// hasErrors returns true if the builder has accumulated any errors
|
||||
func (b *MapResponseBuilder) hasErrors() bool {
|
||||
return len(b.errs) > 0
|
||||
}
|
||||
|
||||
// WithCapabilityVersion sets the capability version for the response
|
||||
func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVersion) *MapResponseBuilder {
|
||||
b.capVer = capVer
|
||||
return b
|
||||
}
|
||||
|
||||
// WithSelfNode adds the requesting node to the response
|
||||
func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
_, matchers := b.mapper.state.Filter()
|
||||
tailnode, err := tailNode(
|
||||
node.View(), b.capVer, b.mapper.state,
|
||||
func(id types.NodeID) []netip.Prefix {
|
||||
return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers)
|
||||
},
|
||||
b.mapper.cfg)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
b.resp.Node = tailnode
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDERPMap adds the DERP map to the response
|
||||
func (b *MapResponseBuilder) WithDERPMap() *MapResponseBuilder {
|
||||
b.resp.DERPMap = b.mapper.state.DERPMap()
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDomain adds the domain configuration
|
||||
func (b *MapResponseBuilder) WithDomain() *MapResponseBuilder {
|
||||
b.resp.Domain = b.mapper.cfg.Domain()
|
||||
return b
|
||||
}
|
||||
|
||||
// WithCollectServicesDisabled sets the collect services flag to false
|
||||
func (b *MapResponseBuilder) WithCollectServicesDisabled() *MapResponseBuilder {
|
||||
b.resp.CollectServices.Set(false)
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDebugConfig adds debug configuration
|
||||
// It disables log tailing if the mapper's LogTail is not enabled
|
||||
func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
|
||||
b.resp.Debug = &tailcfg.Debug{
|
||||
DisableLogTail: !b.mapper.cfg.LogTail.Enabled,
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// WithSSHPolicy adds SSH policy configuration for the requesting node
|
||||
func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
sshPolicy, err := b.mapper.state.SSHPolicy(node.View())
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
b.resp.SSHPolicy = sshPolicy
|
||||
return b
|
||||
}
|
||||
|
||||
// WithDNSConfig adds DNS configuration for the requesting node
|
||||
func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
b.resp.DNSConfig = generateDNSConfig(b.mapper.cfg, node)
|
||||
return b
|
||||
}
|
||||
|
||||
// WithUserProfiles adds user profiles for the requesting node and given peers
|
||||
func (b *MapResponseBuilder) WithUserProfiles(peers types.Nodes) *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
b.resp.UserProfiles = generateUserProfiles(node, peers)
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPacketFilters adds packet filter rules based on policy
|
||||
func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
filter, _ := b.mapper.state.Filter()
|
||||
|
||||
// CapVer 81: 2023-11-17: MapResponse.PacketFilters (incremental packet filter updates)
|
||||
// Currently, we do not send incremental package filters, however using the
|
||||
// new PacketFilters field and "base" allows us to send a full update when we
|
||||
// have to send an empty list, avoiding the hack in the else block.
|
||||
b.resp.PacketFilters = map[string][]tailcfg.FilterRule{
|
||||
"base": policy.ReduceFilterRules(node.View(), filter),
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPeers adds full peer list with policy filtering (for full map response)
|
||||
func (b *MapResponseBuilder) WithPeers(peers types.Nodes) *MapResponseBuilder {
|
||||
|
||||
tailPeers, err := b.buildTailPeers(peers)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
b.resp.Peers = tailPeers
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPeerChanges adds changed peers with policy filtering (for incremental updates)
|
||||
func (b *MapResponseBuilder) WithPeerChanges(peers types.Nodes) *MapResponseBuilder {
|
||||
|
||||
tailPeers, err := b.buildTailPeers(peers)
|
||||
if err != nil {
|
||||
b.addError(err)
|
||||
return b
|
||||
}
|
||||
|
||||
b.resp.PeersChanged = tailPeers
|
||||
return b
|
||||
}
|
||||
|
||||
// buildTailPeers converts types.Nodes to []tailcfg.Node with policy filtering and sorting
|
||||
func (b *MapResponseBuilder) buildTailPeers(peers types.Nodes) ([]*tailcfg.Node, error) {
|
||||
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
filter, matchers := b.mapper.state.Filter()
|
||||
|
||||
// If there are filter rules present, see if there are any nodes that cannot
|
||||
// access each-other at all and remove them from the peers.
|
||||
var changedViews views.Slice[types.NodeView]
|
||||
if len(filter) > 0 {
|
||||
changedViews = policy.ReduceNodes(node.View(), peers.ViewSlice(), matchers)
|
||||
} else {
|
||||
changedViews = peers.ViewSlice()
|
||||
}
|
||||
|
||||
tailPeers, err := tailNodes(
|
||||
changedViews, b.capVer, b.mapper.state,
|
||||
func(id types.NodeID) []netip.Prefix {
|
||||
return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers)
|
||||
},
|
||||
b.mapper.cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Peers is always returned sorted by Node.ID.
|
||||
sort.SliceStable(tailPeers, func(x, y int) bool {
|
||||
return tailPeers[x].ID < tailPeers[y].ID
|
||||
})
|
||||
|
||||
return tailPeers, nil
|
||||
}
|
||||
|
||||
// WithPeerChangedPatch adds peer change patches
|
||||
func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange) *MapResponseBuilder {
|
||||
b.resp.PeersChangedPatch = changes
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPeersRemoved adds removed peer IDs
|
||||
func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder {
|
||||
|
||||
var tailscaleIDs []tailcfg.NodeID
|
||||
for _, id := range removedIDs {
|
||||
tailscaleIDs = append(tailscaleIDs, id.NodeID())
|
||||
}
|
||||
b.resp.PeersRemoved = tailscaleIDs
|
||||
return b
|
||||
}
|
||||
|
||||
// Build finalizes the response and returns marshaled bytes
|
||||
func (b *MapResponseBuilder) Build(messages ...string) (*tailcfg.MapResponse, error) {
|
||||
if len(b.errs) > 0 {
|
||||
return nil, multierr.New(b.errs...)
|
||||
}
|
||||
if debugDumpMapResponsePath != "" {
|
||||
writeDebugMapResponse(b.resp, b.nodeID)
|
||||
}
|
||||
|
||||
return b.resp, nil
|
||||
}
|
347
hscontrol/mapper/builder_test.go
Normal file
347
hscontrol/mapper/builder_test.go
Normal file
@ -0,0 +1,347 @@
|
||||
package mapper
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/state"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
func TestMapResponseBuilder_Basic(t *testing.T) {
|
||||
cfg := &types.Config{
|
||||
BaseDomain: "example.com",
|
||||
LogTail: types.LogTailConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID)
|
||||
|
||||
// Test basic builder creation
|
||||
assert.NotNil(t, builder)
|
||||
assert.Equal(t, nodeID, builder.nodeID)
|
||||
assert.NotNil(t, builder.resp)
|
||||
assert.False(t, builder.resp.KeepAlive)
|
||||
assert.NotNil(t, builder.resp.ControlTime)
|
||||
assert.WithinDuration(t, time.Now(), *builder.resp.ControlTime, time.Second)
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_WithCapabilityVersion(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
capVer := tailcfg.CapabilityVersion(42)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithCapabilityVersion(capVer)
|
||||
|
||||
assert.Equal(t, capVer, builder.capVer)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_WithDomain(t *testing.T) {
|
||||
domain := "test.example.com"
|
||||
cfg := &types.Config{
|
||||
ServerURL: "https://test.example.com",
|
||||
BaseDomain: domain,
|
||||
}
|
||||
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithDomain()
|
||||
|
||||
assert.Equal(t, domain, builder.resp.Domain)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_WithCollectServicesDisabled(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithCollectServicesDisabled()
|
||||
|
||||
value, isSet := builder.resp.CollectServices.Get()
|
||||
assert.True(t, isSet)
|
||||
assert.False(t, value)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_WithDebugConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
logTailEnabled bool
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "LogTail enabled",
|
||||
logTailEnabled: true,
|
||||
expected: false, // DisableLogTail should be false when LogTail is enabled
|
||||
},
|
||||
{
|
||||
name: "LogTail disabled",
|
||||
logTailEnabled: false,
|
||||
expected: true, // DisableLogTail should be true when LogTail is disabled
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := &types.Config{
|
||||
LogTail: types.LogTailConfig{
|
||||
Enabled: tt.logTailEnabled,
|
||||
},
|
||||
}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithDebugConfig()
|
||||
|
||||
require.NotNil(t, builder.resp.Debug)
|
||||
assert.Equal(t, tt.expected, builder.resp.Debug.DisableLogTail)
|
||||
assert.False(t, builder.hasErrors())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_WithPeerChangedPatch(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
changes := []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 123,
|
||||
DERPRegion: 1,
|
||||
},
|
||||
{
|
||||
NodeID: 456,
|
||||
DERPRegion: 2,
|
||||
},
|
||||
}
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithPeerChangedPatch(changes)
|
||||
|
||||
assert.Equal(t, changes, builder.resp.PeersChangedPatch)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_WithPeersRemoved(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
removedID1 := types.NodeID(123)
|
||||
removedID2 := types.NodeID(456)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithPeersRemoved(removedID1, removedID2)
|
||||
|
||||
expected := []tailcfg.NodeID{
|
||||
removedID1.NodeID(),
|
||||
removedID2.NodeID(),
|
||||
}
|
||||
assert.Equal(t, expected, builder.resp.PeersRemoved)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_ErrorHandling(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
// Simulate an error in the builder
|
||||
builder := m.NewMapResponseBuilder(nodeID)
|
||||
builder.addError(assert.AnError)
|
||||
|
||||
// All subsequent calls should continue to work and accumulate errors
|
||||
result := builder.
|
||||
WithDomain().
|
||||
WithCollectServicesDisabled().
|
||||
WithDebugConfig()
|
||||
|
||||
assert.True(t, result.hasErrors())
|
||||
assert.Len(t, result.errs, 1)
|
||||
assert.Equal(t, assert.AnError, result.errs[0])
|
||||
|
||||
// Build should return the error
|
||||
data, err := result.Build("none")
|
||||
assert.Nil(t, data)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_ChainedCalls(t *testing.T) {
|
||||
domain := "chained.example.com"
|
||||
cfg := &types.Config{
|
||||
ServerURL: "https://chained.example.com",
|
||||
BaseDomain: domain,
|
||||
LogTail: types.LogTailConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
}
|
||||
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
capVer := tailcfg.CapabilityVersion(99)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithCapabilityVersion(capVer).
|
||||
WithDomain().
|
||||
WithCollectServicesDisabled().
|
||||
WithDebugConfig()
|
||||
|
||||
// Verify all fields are set correctly
|
||||
assert.Equal(t, capVer, builder.capVer)
|
||||
assert.Equal(t, domain, builder.resp.Domain)
|
||||
value, isSet := builder.resp.CollectServices.Get()
|
||||
assert.True(t, isSet)
|
||||
assert.False(t, value)
|
||||
assert.NotNil(t, builder.resp.Debug)
|
||||
assert.True(t, builder.resp.Debug.DisableLogTail)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_MultipleWithPeersRemoved(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
removedID1 := types.NodeID(100)
|
||||
removedID2 := types.NodeID(200)
|
||||
|
||||
// Test calling WithPeersRemoved multiple times
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithPeersRemoved(removedID1).
|
||||
WithPeersRemoved(removedID2)
|
||||
|
||||
// Second call should overwrite the first
|
||||
expected := []tailcfg.NodeID{removedID2.NodeID()}
|
||||
assert.Equal(t, expected, builder.resp.PeersRemoved)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_EmptyPeerChangedPatch(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithPeerChangedPatch([]*tailcfg.PeerChange{})
|
||||
|
||||
assert.Empty(t, builder.resp.PeersChangedPatch)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_NilPeerChangedPatch(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
builder := m.NewMapResponseBuilder(nodeID).
|
||||
WithPeerChangedPatch(nil)
|
||||
|
||||
assert.Nil(t, builder.resp.PeersChangedPatch)
|
||||
assert.False(t, builder.hasErrors())
|
||||
}
|
||||
|
||||
func TestMapResponseBuilder_MultipleErrors(t *testing.T) {
|
||||
cfg := &types.Config{}
|
||||
mockState := &state.State{}
|
||||
m := &mapper{
|
||||
cfg: cfg,
|
||||
state: mockState,
|
||||
}
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
// Create a builder and add multiple errors
|
||||
builder := m.NewMapResponseBuilder(nodeID)
|
||||
builder.addError(assert.AnError)
|
||||
builder.addError(assert.AnError)
|
||||
builder.addError(nil) // This should be ignored
|
||||
|
||||
// All subsequent calls should continue to work
|
||||
result := builder.
|
||||
WithDomain().
|
||||
WithCollectServicesDisabled()
|
||||
|
||||
assert.True(t, result.hasErrors())
|
||||
assert.Len(t, result.errs, 2) // nil error should be ignored
|
||||
|
||||
// Build should return a multierr
|
||||
data, err := result.Build("none")
|
||||
assert.Nil(t, data)
|
||||
assert.Error(t, err)
|
||||
|
||||
// The error should contain information about multiple errors
|
||||
assert.Contains(t, err.Error(), "multiple errors")
|
||||
}
|
@ -1,7 +1,6 @@
|
||||
package mapper
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
@ -10,31 +9,21 @@ import (
|
||||
"os"
|
||||
"path"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/state"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/rs/zerolog/log"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/smallzstd"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/dnstype"
|
||||
"tailscale.com/types/views"
|
||||
)
|
||||
|
||||
const (
|
||||
nextDNSDoHPrefix = "https://dns.nextdns.io"
|
||||
reservedResponseHeaderSize = 4
|
||||
mapperIDLength = 8
|
||||
debugMapResponsePerm = 0o755
|
||||
nextDNSDoHPrefix = "https://dns.nextdns.io"
|
||||
mapperIDLength = 8
|
||||
debugMapResponsePerm = 0o755
|
||||
)
|
||||
|
||||
var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH")
|
||||
@ -50,15 +39,13 @@ var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_
|
||||
// - Create a "minifier" that removes info not needed for the node
|
||||
// - some sort of batching, wait for 5 or 60 seconds before sending
|
||||
|
||||
type Mapper struct {
|
||||
type mapper struct {
|
||||
// Configuration
|
||||
state *state.State
|
||||
cfg *types.Config
|
||||
notif *notifier.Notifier
|
||||
state *state.State
|
||||
cfg *types.Config
|
||||
batcher Batcher
|
||||
|
||||
uid string
|
||||
created time.Time
|
||||
seq uint64
|
||||
}
|
||||
|
||||
type patch struct {
|
||||
@ -66,41 +53,31 @@ type patch struct {
|
||||
change *tailcfg.PeerChange
|
||||
}
|
||||
|
||||
func NewMapper(
|
||||
state *state.State,
|
||||
func newMapper(
|
||||
cfg *types.Config,
|
||||
notif *notifier.Notifier,
|
||||
) *Mapper {
|
||||
uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
|
||||
state *state.State,
|
||||
) *mapper {
|
||||
// uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
|
||||
|
||||
return &Mapper{
|
||||
return &mapper{
|
||||
state: state,
|
||||
cfg: cfg,
|
||||
notif: notif,
|
||||
|
||||
uid: uid,
|
||||
created: time.Now(),
|
||||
seq: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mapper) String() string {
|
||||
return fmt.Sprintf("Mapper: { seq: %d, uid: %s, created: %s }", m.seq, m.uid, m.created)
|
||||
}
|
||||
|
||||
func generateUserProfiles(
|
||||
node types.NodeView,
|
||||
peers views.Slice[types.NodeView],
|
||||
node *types.Node,
|
||||
peers types.Nodes,
|
||||
) []tailcfg.UserProfile {
|
||||
userMap := make(map[uint]*types.User)
|
||||
ids := make([]uint, 0, peers.Len()+1)
|
||||
user := node.User()
|
||||
userMap[user.ID] = &user
|
||||
ids = append(ids, user.ID)
|
||||
for _, peer := range peers.All() {
|
||||
peerUser := peer.User()
|
||||
userMap[peerUser.ID] = &peerUser
|
||||
ids = append(ids, peerUser.ID)
|
||||
ids := make([]uint, 0, len(userMap))
|
||||
userMap[node.User.ID] = &node.User
|
||||
ids = append(ids, node.User.ID)
|
||||
for _, peer := range peers {
|
||||
userMap[peer.User.ID] = &peer.User
|
||||
ids = append(ids, peer.User.ID)
|
||||
}
|
||||
|
||||
slices.Sort(ids)
|
||||
@ -117,7 +94,7 @@ func generateUserProfiles(
|
||||
|
||||
func generateDNSConfig(
|
||||
cfg *types.Config,
|
||||
node types.NodeView,
|
||||
node *types.Node,
|
||||
) *tailcfg.DNSConfig {
|
||||
if cfg.TailcfgDNSConfig == nil {
|
||||
return nil
|
||||
@ -137,17 +114,16 @@ func generateDNSConfig(
|
||||
//
|
||||
// This will produce a resolver like:
|
||||
// `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1`
|
||||
func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
|
||||
func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
|
||||
for _, resolver := range resolvers {
|
||||
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {
|
||||
attrs := url.Values{
|
||||
"device_name": []string{node.Hostname()},
|
||||
"device_model": []string{node.Hostinfo().OS()},
|
||||
"device_name": []string{node.Hostname},
|
||||
"device_model": []string{node.Hostinfo.OS},
|
||||
}
|
||||
|
||||
nodeIPs := node.IPs()
|
||||
if len(nodeIPs) > 0 {
|
||||
attrs.Add("device_ip", nodeIPs[0].String())
|
||||
if len(node.IPs()) > 0 {
|
||||
attrs.Add("device_ip", node.IPs()[0].String())
|
||||
}
|
||||
|
||||
resolver.Addr = fmt.Sprintf("%s?%s", resolver.Addr, attrs.Encode())
|
||||
@ -155,434 +131,151 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
|
||||
}
|
||||
}
|
||||
|
||||
// fullMapResponse creates a complete MapResponse for a node.
|
||||
// It is a separate function to make testing easier.
|
||||
func (m *Mapper) fullMapResponse(
|
||||
node types.NodeView,
|
||||
peers views.Slice[types.NodeView],
|
||||
// fullMapResponse returns a MapResponse for the given node.
|
||||
func (m *mapper) fullMapResponse(
|
||||
nodeID types.NodeID,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
messages ...string,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
resp, err := m.baseWithConfigMapResponse(node, capVer)
|
||||
peers, err := m.listPeers(nodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = appendPeerChanges(
|
||||
resp,
|
||||
true, // full change
|
||||
m.state,
|
||||
node,
|
||||
capVer,
|
||||
peers,
|
||||
m.cfg,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithCapabilityVersion(capVer).
|
||||
WithSelfNode().
|
||||
WithDERPMap().
|
||||
WithDomain().
|
||||
WithCollectServicesDisabled().
|
||||
WithDebugConfig().
|
||||
WithSSHPolicy().
|
||||
WithDNSConfig().
|
||||
WithUserProfiles(peers).
|
||||
WithPacketFilters().
|
||||
WithPeers(peers).
|
||||
Build(messages...)
|
||||
}
|
||||
|
||||
// FullMapResponse returns a MapResponse for the given node.
|
||||
func (m *Mapper) FullMapResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
node types.NodeView,
|
||||
messages ...string,
|
||||
) ([]byte, error) {
|
||||
peers, err := m.ListPeers(node.ID())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := m.fullMapResponse(node, peers.ViewSlice(), mapRequest.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
|
||||
}
|
||||
|
||||
// ReadOnlyMapResponse returns a MapResponse for the given node.
|
||||
// Lite means that the peers has been omitted, this is intended
|
||||
// to be used to answer MapRequests with OmitPeers set to true.
|
||||
func (m *Mapper) ReadOnlyMapResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
node types.NodeView,
|
||||
messages ...string,
|
||||
) ([]byte, error) {
|
||||
resp, err := m.baseWithConfigMapResponse(node, mapRequest.Version)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
|
||||
}
|
||||
|
||||
func (m *Mapper) KeepAliveResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
node types.NodeView,
|
||||
) ([]byte, error) {
|
||||
resp := m.baseMapResponse()
|
||||
resp.KeepAlive = true
|
||||
|
||||
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
|
||||
}
|
||||
|
||||
func (m *Mapper) DERPMapResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
node types.NodeView,
|
||||
derpMap *tailcfg.DERPMap,
|
||||
) ([]byte, error) {
|
||||
resp := m.baseMapResponse()
|
||||
resp.DERPMap = derpMap
|
||||
|
||||
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
|
||||
}
|
||||
|
||||
func (m *Mapper) PeerChangedResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
node types.NodeView,
|
||||
changed map[types.NodeID]bool,
|
||||
patches []*tailcfg.PeerChange,
|
||||
messages ...string,
|
||||
) ([]byte, error) {
|
||||
var err error
|
||||
resp := m.baseMapResponse()
|
||||
|
||||
var removedIDs []tailcfg.NodeID
|
||||
var changedIDs []types.NodeID
|
||||
for nodeID, nodeChanged := range changed {
|
||||
if nodeChanged {
|
||||
if nodeID != node.ID() {
|
||||
changedIDs = append(changedIDs, nodeID)
|
||||
}
|
||||
} else {
|
||||
removedIDs = append(removedIDs, nodeID.NodeID())
|
||||
}
|
||||
}
|
||||
changedNodes := types.Nodes{}
|
||||
if len(changedIDs) > 0 {
|
||||
changedNodes, err = m.ListNodes(changedIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = appendPeerChanges(
|
||||
&resp,
|
||||
false, // partial change
|
||||
m.state,
|
||||
node,
|
||||
mapRequest.Version,
|
||||
changedNodes.ViewSlice(),
|
||||
m.cfg,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp.PeersRemoved = removedIDs
|
||||
|
||||
// Sending patches as a part of a PeersChanged response
|
||||
// is technically not suppose to be done, but they are
|
||||
// applied after the PeersChanged. The patch list
|
||||
// should _only_ contain Nodes that are not in the
|
||||
// PeersChanged or PeersRemoved list and the caller
|
||||
// should filter them out.
|
||||
//
|
||||
// From tailcfg docs:
|
||||
// These are applied after Peers* above, but in practice the
|
||||
// control server should only send these on their own, without
|
||||
// the Peers* fields also set.
|
||||
if patches != nil {
|
||||
resp.PeersChangedPatch = patches
|
||||
}
|
||||
|
||||
_, matchers := m.state.Filter()
|
||||
// Add the node itself, it might have changed, and particularly
|
||||
// if there are no patches or changes, this is a self update.
|
||||
tailnode, err := tailNode(
|
||||
node, mapRequest.Version, m.state,
|
||||
func(id types.NodeID) []netip.Prefix {
|
||||
return policy.ReduceRoutes(node, m.state.GetNodePrimaryRoutes(id), matchers)
|
||||
},
|
||||
m.cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp.Node = tailnode
|
||||
|
||||
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress, messages...)
|
||||
func (m *mapper) derpMapResponse(
|
||||
nodeID types.NodeID,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithDERPMap().
|
||||
Build()
|
||||
}
|
||||
|
||||
// PeerChangedPatchResponse creates a patch MapResponse with
|
||||
// incoming update from a state change.
|
||||
func (m *Mapper) PeerChangedPatchResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
node types.NodeView,
|
||||
func (m *mapper) peerChangedPatchResponse(
|
||||
nodeID types.NodeID,
|
||||
changed []*tailcfg.PeerChange,
|
||||
) ([]byte, error) {
|
||||
resp := m.baseMapResponse()
|
||||
resp.PeersChangedPatch = changed
|
||||
|
||||
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
|
||||
}
|
||||
|
||||
func (m *Mapper) marshalMapResponse(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
resp *tailcfg.MapResponse,
|
||||
node types.NodeView,
|
||||
compression string,
|
||||
messages ...string,
|
||||
) ([]byte, error) {
|
||||
atomic.AddUint64(&m.seq, 1)
|
||||
|
||||
jsonBody, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshalling map response: %w", err)
|
||||
}
|
||||
|
||||
if debugDumpMapResponsePath != "" {
|
||||
data := map[string]any{
|
||||
"Messages": messages,
|
||||
"MapRequest": mapRequest,
|
||||
"MapResponse": resp,
|
||||
}
|
||||
|
||||
responseType := "keepalive"
|
||||
|
||||
switch {
|
||||
case resp.Peers != nil && len(resp.Peers) > 0:
|
||||
responseType = "full"
|
||||
case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil && resp.DERPMap == nil && !resp.KeepAlive:
|
||||
responseType = "self"
|
||||
case resp.PeersChanged != nil && len(resp.PeersChanged) > 0:
|
||||
responseType = "changed"
|
||||
case resp.PeersChangedPatch != nil && len(resp.PeersChangedPatch) > 0:
|
||||
responseType = "patch"
|
||||
case resp.PeersRemoved != nil && len(resp.PeersRemoved) > 0:
|
||||
responseType = "removed"
|
||||
}
|
||||
|
||||
body, err := json.MarshalIndent(data, "", " ")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshalling map response: %w", err)
|
||||
}
|
||||
|
||||
perms := fs.FileMode(debugMapResponsePerm)
|
||||
mPath := path.Join(debugDumpMapResponsePath, node.Hostname())
|
||||
err = os.MkdirAll(mPath, perms)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
now := time.Now().Format("2006-01-02T15-04-05.999999999")
|
||||
|
||||
mapResponsePath := path.Join(
|
||||
mPath,
|
||||
fmt.Sprintf("%s-%s-%d-%s.json", now, m.uid, atomic.LoadUint64(&m.seq), responseType),
|
||||
)
|
||||
|
||||
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
|
||||
err = os.WriteFile(mapResponsePath, body, perms)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
var respBody []byte
|
||||
if compression == util.ZstdCompression {
|
||||
respBody = zstdEncode(jsonBody)
|
||||
} else {
|
||||
respBody = jsonBody
|
||||
}
|
||||
|
||||
data := make([]byte, reservedResponseHeaderSize)
|
||||
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
|
||||
data = append(data, respBody...)
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func zstdEncode(in []byte) []byte {
|
||||
encoder, ok := zstdEncoderPool.Get().(*zstd.Encoder)
|
||||
if !ok {
|
||||
panic("invalid type in sync pool")
|
||||
}
|
||||
out := encoder.EncodeAll(in, nil)
|
||||
_ = encoder.Close()
|
||||
zstdEncoderPool.Put(encoder)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
var zstdEncoderPool = &sync.Pool{
|
||||
New: func() any {
|
||||
encoder, err := smallzstd.NewEncoder(
|
||||
nil,
|
||||
zstd.WithEncoderLevel(zstd.SpeedFastest))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return encoder
|
||||
},
|
||||
}
|
||||
|
||||
// baseMapResponse returns a tailcfg.MapResponse with
|
||||
// KeepAlive false and ControlTime set to now.
|
||||
func (m *Mapper) baseMapResponse() tailcfg.MapResponse {
|
||||
now := time.Now()
|
||||
|
||||
resp := tailcfg.MapResponse{
|
||||
KeepAlive: false,
|
||||
ControlTime: &now,
|
||||
// TODO(kradalby): Implement PingRequest?
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
// baseWithConfigMapResponse returns a tailcfg.MapResponse struct
|
||||
// with the basic configuration from headscale set.
|
||||
// It is used in for bigger updates, such as full and lite, not
|
||||
// incremental.
|
||||
func (m *Mapper) baseWithConfigMapResponse(
|
||||
node types.NodeView,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
resp := m.baseMapResponse()
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithPeerChangedPatch(changed).
|
||||
Build()
|
||||
}
|
||||
|
||||
_, matchers := m.state.Filter()
|
||||
tailnode, err := tailNode(
|
||||
node, capVer, m.state,
|
||||
func(id types.NodeID) []netip.Prefix {
|
||||
return policy.ReduceRoutes(node, m.state.GetNodePrimaryRoutes(id), matchers)
|
||||
},
|
||||
m.cfg)
|
||||
// peerChangeResponse returns a MapResponse with changed or added nodes.
|
||||
func (m *mapper) peerChangeResponse(
|
||||
nodeID types.NodeID,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
changedNodeID types.NodeID,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
peers, err := m.listPeers(nodeID, changedNodeID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp.Node = tailnode
|
||||
|
||||
resp.DERPMap = m.state.DERPMap()
|
||||
|
||||
resp.Domain = m.cfg.Domain()
|
||||
|
||||
// Do not instruct clients to collect services we do not
|
||||
// support or do anything with them
|
||||
resp.CollectServices = "false"
|
||||
|
||||
resp.KeepAlive = false
|
||||
|
||||
resp.Debug = &tailcfg.Debug{
|
||||
DisableLogTail: !m.cfg.LogTail.Enabled,
|
||||
}
|
||||
|
||||
return &resp, nil
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithCapabilityVersion(capVer).
|
||||
WithSelfNode().
|
||||
WithUserProfiles(peers).
|
||||
WithPeerChanges(peers).
|
||||
Build()
|
||||
}
|
||||
|
||||
// ListPeers returns peers of node, regardless of any Policy or if the node is expired.
|
||||
// peerRemovedResponse creates a MapResponse indicating that a peer has been removed.
|
||||
func (m *mapper) peerRemovedResponse(
|
||||
nodeID types.NodeID,
|
||||
removedNodeID types.NodeID,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithPeersRemoved(removedNodeID).
|
||||
Build()
|
||||
}
|
||||
|
||||
func writeDebugMapResponse(
|
||||
resp *tailcfg.MapResponse,
|
||||
nodeID types.NodeID,
|
||||
messages ...string,
|
||||
) {
|
||||
data := map[string]any{
|
||||
"Messages": messages,
|
||||
"MapResponse": resp,
|
||||
}
|
||||
|
||||
responseType := "keepalive"
|
||||
|
||||
switch {
|
||||
case len(resp.Peers) > 0:
|
||||
responseType = "full"
|
||||
case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil && resp.DERPMap == nil && !resp.KeepAlive:
|
||||
responseType = "self"
|
||||
case len(resp.PeersChanged) > 0:
|
||||
responseType = "changed"
|
||||
case len(resp.PeersChangedPatch) > 0:
|
||||
responseType = "patch"
|
||||
case len(resp.PeersRemoved) > 0:
|
||||
responseType = "removed"
|
||||
}
|
||||
|
||||
body, err := json.MarshalIndent(data, "", " ")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
perms := fs.FileMode(debugMapResponsePerm)
|
||||
mPath := path.Join(debugDumpMapResponsePath, nodeID.String())
|
||||
err = os.MkdirAll(mPath, perms)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
now := time.Now().Format("2006-01-02T15-04-05.999999999")
|
||||
|
||||
mapResponsePath := path.Join(
|
||||
mPath,
|
||||
fmt.Sprintf("%s-%s.json", now, responseType),
|
||||
)
|
||||
|
||||
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
|
||||
err = os.WriteFile(mapResponsePath, body, perms)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// listPeers returns peers of node, regardless of any Policy or if the node is expired.
|
||||
// If no peer IDs are given, all peers are returned.
|
||||
// If at least one peer ID is given, only these peer nodes will be returned.
|
||||
func (m *Mapper) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
|
||||
func (m *mapper) listPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
|
||||
peers, err := m.state.ListPeers(nodeID, peerIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO(kradalby): Add back online via batcher. This was removed
|
||||
// to avoid a circular dependency between the mapper and the notification.
|
||||
for _, peer := range peers {
|
||||
online := m.notif.IsLikelyConnected(peer.ID)
|
||||
online := m.batcher.IsConnected(peer.ID)
|
||||
peer.IsOnline = &online
|
||||
}
|
||||
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
// ListNodes queries the database for either all nodes if no parameters are given
|
||||
// or for the given nodes if at least one node ID is given as parameter.
|
||||
func (m *Mapper) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
|
||||
nodes, err := m.state.ListNodes(nodeIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
online := m.notif.IsLikelyConnected(node.ID)
|
||||
node.IsOnline = &online
|
||||
}
|
||||
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
// routeFilterFunc is a function that takes a node ID and returns a list of
|
||||
// netip.Prefixes that are allowed for that node. It is used to filter routes
|
||||
// from the primary route manager to the node.
|
||||
type routeFilterFunc func(id types.NodeID) []netip.Prefix
|
||||
|
||||
// appendPeerChanges mutates a tailcfg.MapResponse with all the
|
||||
// necessary changes when peers have changed.
|
||||
func appendPeerChanges(
|
||||
resp *tailcfg.MapResponse,
|
||||
|
||||
fullChange bool,
|
||||
state *state.State,
|
||||
node types.NodeView,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
changed views.Slice[types.NodeView],
|
||||
cfg *types.Config,
|
||||
) error {
|
||||
filter, matchers := state.Filter()
|
||||
|
||||
sshPolicy, err := state.SSHPolicy(node)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If there are filter rules present, see if there are any nodes that cannot
|
||||
// access each-other at all and remove them from the peers.
|
||||
var reducedChanged views.Slice[types.NodeView]
|
||||
if len(filter) > 0 {
|
||||
reducedChanged = policy.ReduceNodes(node, changed, matchers)
|
||||
} else {
|
||||
reducedChanged = changed
|
||||
}
|
||||
|
||||
profiles := generateUserProfiles(node, reducedChanged)
|
||||
|
||||
dnsConfig := generateDNSConfig(cfg, node)
|
||||
|
||||
tailPeers, err := tailNodes(
|
||||
reducedChanged, capVer, state,
|
||||
func(id types.NodeID) []netip.Prefix {
|
||||
return policy.ReduceRoutes(node, state.GetNodePrimaryRoutes(id), matchers)
|
||||
},
|
||||
cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Peers is always returned sorted by Node.ID.
|
||||
sort.SliceStable(tailPeers, func(x, y int) bool {
|
||||
return tailPeers[x].ID < tailPeers[y].ID
|
||||
})
|
||||
|
||||
if fullChange {
|
||||
resp.Peers = tailPeers
|
||||
} else {
|
||||
resp.PeersChanged = tailPeers
|
||||
}
|
||||
resp.DNSConfig = dnsConfig
|
||||
resp.UserProfiles = profiles
|
||||
resp.SSHPolicy = sshPolicy
|
||||
|
||||
// CapVer 81: 2023-11-17: MapResponse.PacketFilters (incremental packet filter updates)
|
||||
// Currently, we do not send incremental package filters, however using the
|
||||
// new PacketFilters field and "base" allows us to send a full update when we
|
||||
// have to send an empty list, avoiding the hack in the else block.
|
||||
resp.PacketFilters = map[string][]tailcfg.FilterRule{
|
||||
"base": policy.ReduceFilterRules(node, filter),
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package mapper
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@ -70,7 +71,7 @@ func TestDNSConfigMapResponse(t *testing.T) {
|
||||
&types.Config{
|
||||
TailcfgDNSConfig: &dnsConfigOrig,
|
||||
},
|
||||
nodeInShared1.View(),
|
||||
nodeInShared1,
|
||||
)
|
||||
|
||||
if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" {
|
||||
@ -126,11 +127,8 @@ func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (typ
|
||||
// Filter peers by the provided IDs
|
||||
var filtered types.Nodes
|
||||
for _, peer := range m.peers {
|
||||
for _, id := range peerIDs {
|
||||
if peer.ID == id {
|
||||
filtered = append(filtered, peer)
|
||||
break
|
||||
}
|
||||
if slices.Contains(peerIDs, peer.ID) {
|
||||
filtered = append(filtered, peer)
|
||||
}
|
||||
}
|
||||
|
||||
@ -152,11 +150,8 @@ func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
|
||||
// Filter nodes by the provided IDs
|
||||
var filtered types.Nodes
|
||||
for _, node := range m.nodes {
|
||||
for _, id := range nodeIDs {
|
||||
if node.ID == id {
|
||||
filtered = append(filtered, node)
|
||||
break
|
||||
}
|
||||
if slices.Contains(nodeIDs, node.ID) {
|
||||
filtered = append(filtered, node)
|
||||
}
|
||||
}
|
||||
|
||||
|
47
hscontrol/mapper/utils.go
Normal file
47
hscontrol/mapper/utils.go
Normal file
@ -0,0 +1,47 @@
|
||||
package mapper
|
||||
|
||||
import "tailscale.com/tailcfg"
|
||||
|
||||
// mergePatch takes the current patch and a newer patch
|
||||
// and override any field that has changed.
|
||||
func mergePatch(currPatch, newPatch *tailcfg.PeerChange) {
|
||||
if newPatch.DERPRegion != 0 {
|
||||
currPatch.DERPRegion = newPatch.DERPRegion
|
||||
}
|
||||
|
||||
if newPatch.Cap != 0 {
|
||||
currPatch.Cap = newPatch.Cap
|
||||
}
|
||||
|
||||
if newPatch.CapMap != nil {
|
||||
currPatch.CapMap = newPatch.CapMap
|
||||
}
|
||||
|
||||
if newPatch.Endpoints != nil {
|
||||
currPatch.Endpoints = newPatch.Endpoints
|
||||
}
|
||||
|
||||
if newPatch.Key != nil {
|
||||
currPatch.Key = newPatch.Key
|
||||
}
|
||||
|
||||
if newPatch.KeySignature != nil {
|
||||
currPatch.KeySignature = newPatch.KeySignature
|
||||
}
|
||||
|
||||
if newPatch.DiscoKey != nil {
|
||||
currPatch.DiscoKey = newPatch.DiscoKey
|
||||
}
|
||||
|
||||
if newPatch.Online != nil {
|
||||
currPatch.Online = newPatch.Online
|
||||
}
|
||||
|
||||
if newPatch.LastSeen != nil {
|
||||
currPatch.LastSeen = newPatch.LastSeen
|
||||
}
|
||||
|
||||
if newPatch.KeyExpiry != nil {
|
||||
currPatch.KeyExpiry = newPatch.KeyExpiry
|
||||
}
|
||||
}
|
@ -221,7 +221,7 @@ func (ns *noiseServer) NoisePollNetMapHandler(
|
||||
|
||||
ns.nodeKey = nv.NodeKey()
|
||||
|
||||
sess := ns.headscale.newMapSession(req.Context(), mapRequest, writer, nv)
|
||||
sess := ns.headscale.newMapSession(req.Context(), mapRequest, writer, nv.AsStruct())
|
||||
sess.tracef("a node sending a MapRequest with Noise protocol")
|
||||
if !sess.isStreaming() {
|
||||
sess.serve()
|
||||
@ -279,28 +279,33 @@ func (ns *noiseServer) NoiseRegistrationHandler(
|
||||
return
|
||||
}
|
||||
|
||||
respBody, err := json.Marshal(registerResponse)
|
||||
if err != nil {
|
||||
httpError(writer, err)
|
||||
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
|
||||
if err := json.NewEncoder(writer).Encode(registerResponse); err != nil {
|
||||
log.Error().Err(err).Msg("NoiseRegistrationHandler: failed to encode RegisterResponse")
|
||||
return
|
||||
}
|
||||
|
||||
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
writer.Write(respBody)
|
||||
// Ensure response is flushed to client
|
||||
if flusher, ok := writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// getAndValidateNode retrieves the node from the database using the NodeKey
|
||||
// and validates that it matches the MachineKey from the Noise session.
|
||||
func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types.NodeView, error) {
|
||||
nv, err := ns.headscale.state.GetNodeViewByNodeKey(mapRequest.NodeKey)
|
||||
node, err := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node not found", nil)
|
||||
}
|
||||
return types.NodeView{}, err
|
||||
return types.NodeView{}, NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("lookup node: %s", err), nil)
|
||||
}
|
||||
|
||||
nv := node.View()
|
||||
|
||||
// Validate that the MachineKey in the Noise session matches the one associated with the NodeKey.
|
||||
if ns.machineKey != nv.MachineKey() {
|
||||
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node key in request does not match the one associated with this machine key", nil)
|
||||
|
@ -1,68 +0,0 @@
|
||||
package notifier
|
||||
|
||||
import (
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
"tailscale.com/envknob"
|
||||
)
|
||||
|
||||
const prometheusNamespace = "headscale"
|
||||
|
||||
var debugHighCardinalityMetrics = envknob.Bool("HEADSCALE_DEBUG_HIGH_CARDINALITY_METRICS")
|
||||
|
||||
var notifierUpdateSent *prometheus.CounterVec
|
||||
|
||||
func init() {
|
||||
if debugHighCardinalityMetrics {
|
||||
notifierUpdateSent = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "notifier_update_sent_total",
|
||||
Help: "total count of update sent on nodes channel",
|
||||
}, []string{"status", "type", "trigger", "id"})
|
||||
} else {
|
||||
notifierUpdateSent = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "notifier_update_sent_total",
|
||||
Help: "total count of update sent on nodes channel",
|
||||
}, []string{"status", "type", "trigger"})
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
notifierWaitersForLock = promauto.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "notifier_waiters_for_lock",
|
||||
Help: "gauge of waiters for the notifier lock",
|
||||
}, []string{"type", "action"})
|
||||
notifierWaitForLock = promauto.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "notifier_wait_for_lock_seconds",
|
||||
Help: "histogram of time spent waiting for the notifier lock",
|
||||
Buckets: []float64{0.001, 0.01, 0.1, 0.3, 0.5, 1, 3, 5, 10},
|
||||
}, []string{"action"})
|
||||
notifierUpdateReceived = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "notifier_update_received_total",
|
||||
Help: "total count of updates received by notifier",
|
||||
}, []string{"type", "trigger"})
|
||||
notifierNodeUpdateChans = promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "notifier_open_channels_total",
|
||||
Help: "total count open channels in notifier",
|
||||
})
|
||||
notifierBatcherWaitersForLock = promauto.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "notifier_batcher_waiters_for_lock",
|
||||
Help: "gauge of waiters for the notifier batcher lock",
|
||||
}, []string{"type", "action"})
|
||||
notifierBatcherChanges = promauto.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "notifier_batcher_changes_pending",
|
||||
Help: "gauge of full changes pending in the notifier batcher",
|
||||
}, []string{})
|
||||
notifierBatcherPatches = promauto.NewGaugeVec(prometheus.GaugeOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "notifier_batcher_patches_pending",
|
||||
Help: "gauge of patches pending in the notifier batcher",
|
||||
}, []string{})
|
||||
)
|
@ -1,488 +0,0 @@
|
||||
package notifier
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/puzpuzpuz/xsync/v4"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/sasha-s/go-deadlock"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/set"
|
||||
)
|
||||
|
||||
var (
|
||||
debugDeadlock = envknob.Bool("HEADSCALE_DEBUG_DEADLOCK")
|
||||
debugDeadlockTimeout = envknob.RegisterDuration("HEADSCALE_DEBUG_DEADLOCK_TIMEOUT")
|
||||
)
|
||||
|
||||
func init() {
|
||||
deadlock.Opts.Disable = !debugDeadlock
|
||||
if debugDeadlock {
|
||||
deadlock.Opts.DeadlockTimeout = debugDeadlockTimeout()
|
||||
deadlock.Opts.PrintAllCurrentGoroutines = true
|
||||
}
|
||||
}
|
||||
|
||||
type Notifier struct {
|
||||
l deadlock.Mutex
|
||||
nodes map[types.NodeID]chan<- types.StateUpdate
|
||||
connected *xsync.MapOf[types.NodeID, bool]
|
||||
b *batcher
|
||||
cfg *types.Config
|
||||
closed bool
|
||||
}
|
||||
|
||||
func NewNotifier(cfg *types.Config) *Notifier {
|
||||
n := &Notifier{
|
||||
nodes: make(map[types.NodeID]chan<- types.StateUpdate),
|
||||
connected: xsync.NewMapOf[types.NodeID, bool](),
|
||||
cfg: cfg,
|
||||
closed: false,
|
||||
}
|
||||
b := newBatcher(cfg.Tuning.BatchChangeDelay, n)
|
||||
n.b = b
|
||||
|
||||
go b.doWork()
|
||||
|
||||
return n
|
||||
}
|
||||
|
||||
// Close stops the batcher and closes all channels.
|
||||
func (n *Notifier) Close() {
|
||||
notifierWaitersForLock.WithLabelValues("lock", "close").Inc()
|
||||
n.l.Lock()
|
||||
defer n.l.Unlock()
|
||||
notifierWaitersForLock.WithLabelValues("lock", "close").Dec()
|
||||
|
||||
n.closed = true
|
||||
n.b.close()
|
||||
|
||||
// Close channels safely using the helper method
|
||||
for nodeID, c := range n.nodes {
|
||||
n.safeCloseChannel(nodeID, c)
|
||||
}
|
||||
|
||||
// Clear node map after closing channels
|
||||
n.nodes = make(map[types.NodeID]chan<- types.StateUpdate)
|
||||
}
|
||||
|
||||
// safeCloseChannel closes a channel and panic recovers if already closed.
|
||||
func (n *Notifier) safeCloseChannel(nodeID types.NodeID, c chan<- types.StateUpdate) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Error().
|
||||
Uint64("node.id", nodeID.Uint64()).
|
||||
Any("recover", r).
|
||||
Msg("recovered from panic when closing channel in Close()")
|
||||
}
|
||||
}()
|
||||
close(c)
|
||||
}
|
||||
|
||||
func (n *Notifier) tracef(nID types.NodeID, msg string, args ...any) {
|
||||
log.Trace().
|
||||
Uint64("node.id", nID.Uint64()).
|
||||
Int("open_chans", len(n.nodes)).Msgf(msg, args...)
|
||||
}
|
||||
|
||||
func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) {
|
||||
start := time.Now()
|
||||
notifierWaitersForLock.WithLabelValues("lock", "add").Inc()
|
||||
n.l.Lock()
|
||||
defer n.l.Unlock()
|
||||
notifierWaitersForLock.WithLabelValues("lock", "add").Dec()
|
||||
notifierWaitForLock.WithLabelValues("add").Observe(time.Since(start).Seconds())
|
||||
|
||||
if n.closed {
|
||||
return
|
||||
}
|
||||
|
||||
// If a channel exists, it means the node has opened a new
|
||||
// connection. Close the old channel and replace it.
|
||||
if curr, ok := n.nodes[nodeID]; ok {
|
||||
n.tracef(nodeID, "channel present, closing and replacing")
|
||||
// Use the safeCloseChannel helper in a goroutine to avoid deadlocks
|
||||
// if/when someone is waiting to send on this channel
|
||||
go func(ch chan<- types.StateUpdate) {
|
||||
n.safeCloseChannel(nodeID, ch)
|
||||
}(curr)
|
||||
}
|
||||
|
||||
n.nodes[nodeID] = c
|
||||
n.connected.Store(nodeID, true)
|
||||
|
||||
n.tracef(nodeID, "added new channel")
|
||||
notifierNodeUpdateChans.Inc()
|
||||
}
|
||||
|
||||
// RemoveNode removes a node and a given channel from the notifier.
|
||||
// It checks that the channel is the same as currently being updated
|
||||
// and ignores the removal if it is not.
|
||||
// RemoveNode reports if the node/chan was removed.
|
||||
func (n *Notifier) RemoveNode(nodeID types.NodeID, c chan<- types.StateUpdate) bool {
|
||||
start := time.Now()
|
||||
notifierWaitersForLock.WithLabelValues("lock", "remove").Inc()
|
||||
n.l.Lock()
|
||||
defer n.l.Unlock()
|
||||
notifierWaitersForLock.WithLabelValues("lock", "remove").Dec()
|
||||
notifierWaitForLock.WithLabelValues("remove").Observe(time.Since(start).Seconds())
|
||||
|
||||
if n.closed {
|
||||
return true
|
||||
}
|
||||
|
||||
if len(n.nodes) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// If the channel exist, but it does not belong
|
||||
// to the caller, ignore.
|
||||
if curr, ok := n.nodes[nodeID]; ok {
|
||||
if curr != c {
|
||||
n.tracef(nodeID, "channel has been replaced, not removing")
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
delete(n.nodes, nodeID)
|
||||
n.connected.Store(nodeID, false)
|
||||
|
||||
n.tracef(nodeID, "removed channel")
|
||||
notifierNodeUpdateChans.Dec()
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// IsConnected reports if a node is connected to headscale and has a
|
||||
// poll session open.
|
||||
func (n *Notifier) IsConnected(nodeID types.NodeID) bool {
|
||||
notifierWaitersForLock.WithLabelValues("lock", "conncheck").Inc()
|
||||
n.l.Lock()
|
||||
defer n.l.Unlock()
|
||||
notifierWaitersForLock.WithLabelValues("lock", "conncheck").Dec()
|
||||
|
||||
if val, ok := n.connected.Load(nodeID); ok {
|
||||
return val
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// IsLikelyConnected reports if a node is connected to headscale and has a
|
||||
// poll session open, but doesn't lock, so might be wrong.
|
||||
func (n *Notifier) IsLikelyConnected(nodeID types.NodeID) bool {
|
||||
if val, ok := n.connected.Load(nodeID); ok {
|
||||
return val
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// LikelyConnectedMap returns a thread safe map of connected nodes.
|
||||
func (n *Notifier) LikelyConnectedMap() *xsync.MapOf[types.NodeID, bool] {
|
||||
return n.connected
|
||||
}
|
||||
|
||||
func (n *Notifier) NotifyAll(ctx context.Context, update types.StateUpdate) {
|
||||
n.NotifyWithIgnore(ctx, update)
|
||||
}
|
||||
|
||||
func (n *Notifier) NotifyWithIgnore(
|
||||
ctx context.Context,
|
||||
update types.StateUpdate,
|
||||
ignoreNodeIDs ...types.NodeID,
|
||||
) {
|
||||
if n.closed {
|
||||
return
|
||||
}
|
||||
|
||||
notifierUpdateReceived.WithLabelValues(update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
|
||||
n.b.addOrPassthrough(update)
|
||||
}
|
||||
|
||||
func (n *Notifier) NotifyByNodeID(
|
||||
ctx context.Context,
|
||||
update types.StateUpdate,
|
||||
nodeID types.NodeID,
|
||||
) {
|
||||
start := time.Now()
|
||||
notifierWaitersForLock.WithLabelValues("lock", "notify").Inc()
|
||||
n.l.Lock()
|
||||
defer n.l.Unlock()
|
||||
notifierWaitersForLock.WithLabelValues("lock", "notify").Dec()
|
||||
notifierWaitForLock.WithLabelValues("notify").Observe(time.Since(start).Seconds())
|
||||
|
||||
if n.closed {
|
||||
return
|
||||
}
|
||||
|
||||
if c, ok := n.nodes[nodeID]; ok {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Error().
|
||||
Err(ctx.Err()).
|
||||
Uint64("node.id", nodeID.Uint64()).
|
||||
Any("origin", types.NotifyOriginKey.Value(ctx)).
|
||||
Any("origin-hostname", types.NotifyHostnameKey.Value(ctx)).
|
||||
Msgf("update not sent, context cancelled")
|
||||
if debugHighCardinalityMetrics {
|
||||
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), types.NotifyOriginKey.Value(ctx), nodeID.String()).Inc()
|
||||
} else {
|
||||
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
|
||||
}
|
||||
|
||||
return
|
||||
case c <- update:
|
||||
n.tracef(nodeID, "update successfully sent on chan, origin: %s, origin-hostname: %s", ctx.Value("origin"), ctx.Value("hostname"))
|
||||
if debugHighCardinalityMetrics {
|
||||
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), types.NotifyOriginKey.Value(ctx), nodeID.String()).Inc()
|
||||
} else {
|
||||
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Notifier) sendAll(update types.StateUpdate) {
|
||||
start := time.Now()
|
||||
notifierWaitersForLock.WithLabelValues("lock", "send-all").Inc()
|
||||
n.l.Lock()
|
||||
defer n.l.Unlock()
|
||||
notifierWaitersForLock.WithLabelValues("lock", "send-all").Dec()
|
||||
notifierWaitForLock.WithLabelValues("send-all").Observe(time.Since(start).Seconds())
|
||||
|
||||
if n.closed {
|
||||
return
|
||||
}
|
||||
|
||||
for id, c := range n.nodes {
|
||||
// Whenever an update is sent to all nodes, there is a chance that the node
|
||||
// has disconnected and the goroutine that was supposed to consume the update
|
||||
// has shut down the channel and is waiting for the lock held here in RemoveNode.
|
||||
// This means that there is potential for a deadlock which would stop all updates
|
||||
// going out to clients. This timeout prevents that from happening by moving on to the
|
||||
// next node if the context is cancelled. After sendAll releases the lock, the add/remove
|
||||
// call will succeed and the update will go to the correct nodes on the next call.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), n.cfg.Tuning.NotifierSendTimeout)
|
||||
defer cancel()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Error().
|
||||
Err(ctx.Err()).
|
||||
Uint64("node.id", id.Uint64()).
|
||||
Msgf("update not sent, context cancelled")
|
||||
if debugHighCardinalityMetrics {
|
||||
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), "send-all", id.String()).Inc()
|
||||
} else {
|
||||
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), "send-all").Inc()
|
||||
}
|
||||
|
||||
return
|
||||
case c <- update:
|
||||
if debugHighCardinalityMetrics {
|
||||
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), "send-all", id.String()).Inc()
|
||||
} else {
|
||||
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), "send-all").Inc()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Notifier) String() string {
|
||||
notifierWaitersForLock.WithLabelValues("lock", "string").Inc()
|
||||
n.l.Lock()
|
||||
defer n.l.Unlock()
|
||||
notifierWaitersForLock.WithLabelValues("lock", "string").Dec()
|
||||
|
||||
var b strings.Builder
|
||||
fmt.Fprintf(&b, "chans (%d):\n", len(n.nodes))
|
||||
|
||||
var keys []types.NodeID
|
||||
n.connected.Range(func(key types.NodeID, value bool) bool {
|
||||
keys = append(keys, key)
|
||||
return true
|
||||
})
|
||||
sort.Slice(keys, func(i, j int) bool {
|
||||
return keys[i] < keys[j]
|
||||
})
|
||||
|
||||
for _, key := range keys {
|
||||
fmt.Fprintf(&b, "\t%d: %p\n", key, n.nodes[key])
|
||||
}
|
||||
|
||||
b.WriteString("\n")
|
||||
fmt.Fprintf(&b, "connected (%d):\n", len(n.nodes))
|
||||
|
||||
for _, key := range keys {
|
||||
val, _ := n.connected.Load(key)
|
||||
fmt.Fprintf(&b, "\t%d: %t\n", key, val)
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
type batcher struct {
|
||||
tick *time.Ticker
|
||||
|
||||
mu sync.Mutex
|
||||
|
||||
cancelCh chan struct{}
|
||||
|
||||
changedNodeIDs set.Slice[types.NodeID]
|
||||
nodesChanged bool
|
||||
patches map[types.NodeID]tailcfg.PeerChange
|
||||
patchesChanged bool
|
||||
|
||||
n *Notifier
|
||||
}
|
||||
|
||||
func newBatcher(batchTime time.Duration, n *Notifier) *batcher {
|
||||
return &batcher{
|
||||
tick: time.NewTicker(batchTime),
|
||||
cancelCh: make(chan struct{}),
|
||||
patches: make(map[types.NodeID]tailcfg.PeerChange),
|
||||
n: n,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *batcher) close() {
|
||||
b.cancelCh <- struct{}{}
|
||||
}
|
||||
|
||||
// addOrPassthrough adds the update to the batcher, if it is not a
|
||||
// type that is currently batched, it will be sent immediately.
|
||||
func (b *batcher) addOrPassthrough(update types.StateUpdate) {
|
||||
notifierBatcherWaitersForLock.WithLabelValues("lock", "add").Inc()
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
notifierBatcherWaitersForLock.WithLabelValues("lock", "add").Dec()
|
||||
|
||||
switch update.Type {
|
||||
case types.StatePeerChanged:
|
||||
b.changedNodeIDs.Add(update.ChangeNodes...)
|
||||
b.nodesChanged = true
|
||||
notifierBatcherChanges.WithLabelValues().Set(float64(b.changedNodeIDs.Len()))
|
||||
|
||||
case types.StatePeerChangedPatch:
|
||||
for _, newPatch := range update.ChangePatches {
|
||||
if curr, ok := b.patches[types.NodeID(newPatch.NodeID)]; ok {
|
||||
overwritePatch(&curr, newPatch)
|
||||
b.patches[types.NodeID(newPatch.NodeID)] = curr
|
||||
} else {
|
||||
b.patches[types.NodeID(newPatch.NodeID)] = *newPatch
|
||||
}
|
||||
}
|
||||
b.patchesChanged = true
|
||||
notifierBatcherPatches.WithLabelValues().Set(float64(len(b.patches)))
|
||||
|
||||
default:
|
||||
b.n.sendAll(update)
|
||||
}
|
||||
}
|
||||
|
||||
// flush sends all the accumulated patches to all
|
||||
// nodes in the notifier.
|
||||
func (b *batcher) flush() {
|
||||
notifierBatcherWaitersForLock.WithLabelValues("lock", "flush").Inc()
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
notifierBatcherWaitersForLock.WithLabelValues("lock", "flush").Dec()
|
||||
|
||||
if b.nodesChanged || b.patchesChanged {
|
||||
var patches []*tailcfg.PeerChange
|
||||
// If a node is getting a full update from a change
|
||||
// node update, then the patch can be dropped.
|
||||
for nodeID, patch := range b.patches {
|
||||
if b.changedNodeIDs.Contains(nodeID) {
|
||||
delete(b.patches, nodeID)
|
||||
} else {
|
||||
patches = append(patches, &patch)
|
||||
}
|
||||
}
|
||||
|
||||
changedNodes := b.changedNodeIDs.Slice().AsSlice()
|
||||
sort.Slice(changedNodes, func(i, j int) bool {
|
||||
return changedNodes[i] < changedNodes[j]
|
||||
})
|
||||
|
||||
if b.changedNodeIDs.Slice().Len() > 0 {
|
||||
update := types.UpdatePeerChanged(changedNodes...)
|
||||
|
||||
b.n.sendAll(update)
|
||||
}
|
||||
|
||||
if len(patches) > 0 {
|
||||
patchUpdate := types.UpdatePeerPatch(patches...)
|
||||
|
||||
b.n.sendAll(patchUpdate)
|
||||
}
|
||||
|
||||
b.changedNodeIDs = set.Slice[types.NodeID]{}
|
||||
notifierBatcherChanges.WithLabelValues().Set(0)
|
||||
b.nodesChanged = false
|
||||
b.patches = make(map[types.NodeID]tailcfg.PeerChange, len(b.patches))
|
||||
notifierBatcherPatches.WithLabelValues().Set(0)
|
||||
b.patchesChanged = false
|
||||
}
|
||||
}
|
||||
|
||||
func (b *batcher) doWork() {
|
||||
for {
|
||||
select {
|
||||
case <-b.cancelCh:
|
||||
return
|
||||
case <-b.tick.C:
|
||||
b.flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// overwritePatch takes the current patch and a newer patch
|
||||
// and override any field that has changed.
|
||||
func overwritePatch(currPatch, newPatch *tailcfg.PeerChange) {
|
||||
if newPatch.DERPRegion != 0 {
|
||||
currPatch.DERPRegion = newPatch.DERPRegion
|
||||
}
|
||||
|
||||
if newPatch.Cap != 0 {
|
||||
currPatch.Cap = newPatch.Cap
|
||||
}
|
||||
|
||||
if newPatch.CapMap != nil {
|
||||
currPatch.CapMap = newPatch.CapMap
|
||||
}
|
||||
|
||||
if newPatch.Endpoints != nil {
|
||||
currPatch.Endpoints = newPatch.Endpoints
|
||||
}
|
||||
|
||||
if newPatch.Key != nil {
|
||||
currPatch.Key = newPatch.Key
|
||||
}
|
||||
|
||||
if newPatch.KeySignature != nil {
|
||||
currPatch.KeySignature = newPatch.KeySignature
|
||||
}
|
||||
|
||||
if newPatch.DiscoKey != nil {
|
||||
currPatch.DiscoKey = newPatch.DiscoKey
|
||||
}
|
||||
|
||||
if newPatch.Online != nil {
|
||||
currPatch.Online = newPatch.Online
|
||||
}
|
||||
|
||||
if newPatch.LastSeen != nil {
|
||||
currPatch.LastSeen = newPatch.LastSeen
|
||||
}
|
||||
|
||||
if newPatch.KeyExpiry != nil {
|
||||
currPatch.KeyExpiry = newPatch.KeyExpiry
|
||||
}
|
||||
}
|
@ -1,342 +0,0 @@
|
||||
package notifier
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sort"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
func TestBatcher(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
updates []types.StateUpdate
|
||||
want []types.StateUpdate
|
||||
}{
|
||||
{
|
||||
name: "full-passthrough",
|
||||
updates: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StateFullUpdate,
|
||||
},
|
||||
},
|
||||
want: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StateFullUpdate,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "derp-passthrough",
|
||||
updates: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StateDERPUpdated,
|
||||
},
|
||||
},
|
||||
want: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StateDERPUpdated,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single-node-update",
|
||||
updates: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: []types.NodeID{
|
||||
2,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: []types.NodeID{
|
||||
2,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "merge-node-update",
|
||||
updates: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: []types.NodeID{
|
||||
2, 4,
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: []types.NodeID{
|
||||
2, 3,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: []types.NodeID{
|
||||
2, 3, 4,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single-patch-update",
|
||||
updates: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 2,
|
||||
DERPRegion: 5,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 2,
|
||||
DERPRegion: 5,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "merge-patch-to-same-node-update",
|
||||
updates: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 2,
|
||||
DERPRegion: 5,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 2,
|
||||
DERPRegion: 6,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 2,
|
||||
DERPRegion: 6,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "merge-patch-to-multiple-node-update",
|
||||
updates: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 3,
|
||||
Endpoints: []netip.AddrPort{
|
||||
netip.MustParseAddrPort("1.1.1.1:9090"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 3,
|
||||
Endpoints: []netip.AddrPort{
|
||||
netip.MustParseAddrPort("1.1.1.1:9090"),
|
||||
netip.MustParseAddrPort("2.2.2.2:8080"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 4,
|
||||
DERPRegion: 6,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 4,
|
||||
Cap: tailcfg.CapabilityVersion(54),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: []types.StateUpdate{
|
||||
{
|
||||
Type: types.StatePeerChangedPatch,
|
||||
ChangePatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: 3,
|
||||
Endpoints: []netip.AddrPort{
|
||||
netip.MustParseAddrPort("1.1.1.1:9090"),
|
||||
netip.MustParseAddrPort("2.2.2.2:8080"),
|
||||
},
|
||||
},
|
||||
{
|
||||
NodeID: 4,
|
||||
DERPRegion: 6,
|
||||
Cap: tailcfg.CapabilityVersion(54),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
n := NewNotifier(&types.Config{
|
||||
Tuning: types.Tuning{
|
||||
// We will call flush manually for the tests,
|
||||
// so do not run the worker.
|
||||
BatchChangeDelay: time.Hour,
|
||||
|
||||
// Since we do not load the config, we won't get the
|
||||
// default, so set it manually so we dont time out
|
||||
// and have flakes.
|
||||
NotifierSendTimeout: time.Second,
|
||||
},
|
||||
})
|
||||
|
||||
ch := make(chan types.StateUpdate, 30)
|
||||
defer close(ch)
|
||||
n.AddNode(1, ch)
|
||||
defer n.RemoveNode(1, ch)
|
||||
|
||||
for _, u := range tt.updates {
|
||||
n.NotifyAll(t.Context(), u)
|
||||
}
|
||||
|
||||
n.b.flush()
|
||||
|
||||
var got []types.StateUpdate
|
||||
for len(ch) > 0 {
|
||||
out := <-ch
|
||||
got = append(got, out)
|
||||
}
|
||||
|
||||
// Make the inner order stable for comparison.
|
||||
for _, u := range got {
|
||||
slices.Sort(u.ChangeNodes)
|
||||
sort.Slice(u.ChangePatches, func(i, j int) bool {
|
||||
return u.ChangePatches[i].NodeID < u.ChangePatches[j].NodeID
|
||||
})
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
|
||||
t.Errorf("batcher() unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsLikelyConnectedRaceCondition tests for a race condition in IsLikelyConnected
|
||||
// Multiple goroutines calling AddNode and RemoveNode cause panics when trying to
|
||||
// close a channel that was already closed, which can happen when a node changes
|
||||
// network transport quickly (eg mobile->wifi) and reconnects whilst also disconnecting.
|
||||
func TestIsLikelyConnectedRaceCondition(t *testing.T) {
|
||||
// mock config for the notifier
|
||||
cfg := &types.Config{
|
||||
Tuning: types.Tuning{
|
||||
NotifierSendTimeout: 1 * time.Second,
|
||||
BatchChangeDelay: 1 * time.Second,
|
||||
NodeMapSessionBufferedChanSize: 30,
|
||||
},
|
||||
}
|
||||
|
||||
notifier := NewNotifier(cfg)
|
||||
defer notifier.Close()
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
updateChan := make(chan types.StateUpdate, 10)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Number of goroutines to spawn for concurrent access
|
||||
concurrentAccessors := 100
|
||||
iterations := 100
|
||||
|
||||
// Add node to notifier
|
||||
notifier.AddNode(nodeID, updateChan)
|
||||
|
||||
// Track errors
|
||||
errChan := make(chan string, concurrentAccessors*iterations)
|
||||
|
||||
// Start goroutines to cause a race
|
||||
wg.Add(concurrentAccessors)
|
||||
for i := range concurrentAccessors {
|
||||
go func(routineID int) {
|
||||
defer wg.Done()
|
||||
|
||||
for range iterations {
|
||||
// Simulate race by having some goroutines check IsLikelyConnected
|
||||
// while others add/remove the node
|
||||
switch routineID % 3 {
|
||||
case 0:
|
||||
// This goroutine checks connection status
|
||||
isConnected := notifier.IsLikelyConnected(nodeID)
|
||||
if isConnected != true && isConnected != false {
|
||||
errChan <- fmt.Sprintf("Invalid connection status: %v", isConnected)
|
||||
}
|
||||
case 1:
|
||||
// This goroutine removes the node
|
||||
notifier.RemoveNode(nodeID, updateChan)
|
||||
default:
|
||||
// This goroutine adds the node back
|
||||
notifier.AddNode(nodeID, updateChan)
|
||||
}
|
||||
|
||||
// Small random delay to increase chance of races
|
||||
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
|
||||
// Collate errors
|
||||
var errors []string
|
||||
for err := range errChan {
|
||||
errors = append(errors, err)
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
t.Errorf("Detected %d race condition errors: %v", len(errors), errors)
|
||||
}
|
||||
}
|
@ -16,9 +16,8 @@ import (
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/juanfont/headscale/hscontrol/db"
|
||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||
"github.com/juanfont/headscale/hscontrol/state"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/oauth2"
|
||||
@ -56,11 +55,10 @@ type RegistrationInfo struct {
|
||||
}
|
||||
|
||||
type AuthProviderOIDC struct {
|
||||
h *Headscale
|
||||
serverURL string
|
||||
cfg *types.OIDCConfig
|
||||
state *state.State
|
||||
registrationCache *zcache.Cache[string, RegistrationInfo]
|
||||
notifier *notifier.Notifier
|
||||
|
||||
oidcProvider *oidc.Provider
|
||||
oauth2Config *oauth2.Config
|
||||
@ -68,10 +66,9 @@ type AuthProviderOIDC struct {
|
||||
|
||||
func NewAuthProviderOIDC(
|
||||
ctx context.Context,
|
||||
h *Headscale,
|
||||
serverURL string,
|
||||
cfg *types.OIDCConfig,
|
||||
state *state.State,
|
||||
notif *notifier.Notifier,
|
||||
) (*AuthProviderOIDC, error) {
|
||||
var err error
|
||||
// grab oidc config if it hasn't been already
|
||||
@ -94,11 +91,10 @@ func NewAuthProviderOIDC(
|
||||
)
|
||||
|
||||
return &AuthProviderOIDC{
|
||||
h: h,
|
||||
serverURL: serverURL,
|
||||
cfg: cfg,
|
||||
state: state,
|
||||
registrationCache: registrationCache,
|
||||
notifier: notif,
|
||||
|
||||
oidcProvider: oidcProvider,
|
||||
oauth2Config: oauth2Config,
|
||||
@ -318,8 +314,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "oidc-user-created", user.Name)
|
||||
a.notifier.NotifyAll(ctx, types.UpdateFull())
|
||||
a.h.Change(change.PolicyChange())
|
||||
}
|
||||
|
||||
// TODO(kradalby): Is this comment right?
|
||||
@ -360,8 +355,6 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
||||
// Neither node nor machine key was found in the state cache meaning
|
||||
// that we could not reauth nor register the node.
|
||||
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func extractCodeAndStateParamFromRequest(
|
||||
@ -490,12 +483,14 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
||||
var err error
|
||||
var newUser bool
|
||||
var policyChanged bool
|
||||
user, err = a.state.GetUserByOIDCIdentifier(claims.Identifier())
|
||||
user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier())
|
||||
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
|
||||
return nil, false, fmt.Errorf("creating or updating user: %w", err)
|
||||
}
|
||||
|
||||
// if the user is still not found, create a new empty user.
|
||||
// TODO(kradalby): This might cause us to not have an ID below which
|
||||
// is a problem.
|
||||
if user == nil {
|
||||
newUser = true
|
||||
user = &types.User{}
|
||||
@ -504,12 +499,12 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
||||
user.FromClaim(claims)
|
||||
|
||||
if newUser {
|
||||
user, policyChanged, err = a.state.CreateUser(*user)
|
||||
user, policyChanged, err = a.h.state.CreateUser(*user)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("creating user: %w", err)
|
||||
}
|
||||
} else {
|
||||
_, policyChanged, err = a.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error {
|
||||
_, policyChanged, err = a.h.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error {
|
||||
*u = *user
|
||||
return nil
|
||||
})
|
||||
@ -526,7 +521,7 @@ func (a *AuthProviderOIDC) handleRegistration(
|
||||
registrationID types.RegistrationID,
|
||||
expiry time.Time,
|
||||
) (bool, error) {
|
||||
node, newNode, err := a.state.HandleNodeFromAuthPath(
|
||||
node, nodeChange, err := a.h.state.HandleNodeFromAuthPath(
|
||||
registrationID,
|
||||
types.UserID(user.ID),
|
||||
&expiry,
|
||||
@ -547,31 +542,20 @@ func (a *AuthProviderOIDC) handleRegistration(
|
||||
// ensure we send an update.
|
||||
// This works, but might be another good candidate for doing some sort of
|
||||
// eventbus.
|
||||
routesChanged := a.state.AutoApproveRoutes(node)
|
||||
_, policyChanged, err := a.state.SaveNode(node)
|
||||
_ = a.h.state.AutoApproveRoutes(node)
|
||||
_, policyChange, err := a.h.state.SaveNode(node)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("saving auto approved routes to node: %w", err)
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed (from SaveNode or route changes)
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "oidc-nodes-change", "all")
|
||||
a.notifier.NotifyAll(ctx, types.UpdateFull())
|
||||
// Policy updates are full and take precedence over node changes.
|
||||
if !policyChange.Empty() {
|
||||
a.h.Change(policyChange)
|
||||
} else {
|
||||
a.h.Change(nodeChange)
|
||||
}
|
||||
|
||||
if routesChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname)
|
||||
a.notifier.NotifyByNodeID(
|
||||
ctx,
|
||||
types.UpdateSelf(node.ID),
|
||||
node.ID,
|
||||
)
|
||||
|
||||
ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname)
|
||||
a.notifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
|
||||
}
|
||||
|
||||
return newNode, nil
|
||||
return !nodeChange.Empty(), nil
|
||||
}
|
||||
|
||||
// TODO(kradalby):
|
||||
|
@ -113,6 +113,17 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Also check approved subnet routes - nodes should have access
|
||||
// to subnets they're approved to route traffic for.
|
||||
subnetRoutes := node.SubnetRoutes()
|
||||
|
||||
for _, subnetRoute := range subnetRoutes {
|
||||
if expanded.OverlapsPrefix(subnetRoute) {
|
||||
dests = append(dests, dest)
|
||||
continue DEST_LOOP
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(dests) > 0 {
|
||||
@ -142,16 +153,23 @@ func AutoApproveRoutes(pm PolicyManager, node *types.Node) bool {
|
||||
newApproved = append(newApproved, route)
|
||||
}
|
||||
}
|
||||
if newApproved != nil {
|
||||
newApproved = append(newApproved, node.ApprovedRoutes...)
|
||||
tsaddr.SortPrefixes(newApproved)
|
||||
newApproved = slices.Compact(newApproved)
|
||||
newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool {
|
||||
|
||||
// Only modify ApprovedRoutes if we have new routes to approve.
|
||||
// This prevents clearing existing approved routes when nodes
|
||||
// temporarily don't have announced routes during policy changes.
|
||||
if len(newApproved) > 0 {
|
||||
combined := append(newApproved, node.ApprovedRoutes...)
|
||||
tsaddr.SortPrefixes(combined)
|
||||
combined = slices.Compact(combined)
|
||||
combined = lo.Filter(combined, func(route netip.Prefix, index int) bool {
|
||||
return route.IsValid()
|
||||
})
|
||||
node.ApprovedRoutes = newApproved
|
||||
|
||||
return true
|
||||
// Only update if the routes actually changed
|
||||
if !slices.Equal(node.ApprovedRoutes, combined) {
|
||||
node.ApprovedRoutes = combined
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
|
@ -56,10 +56,13 @@ func (pol *Policy) compileFilterRules(
|
||||
}
|
||||
|
||||
if ips == nil {
|
||||
log.Debug().Msgf("destination resolved to nil ips: %v", dest)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, pref := range ips.Prefixes() {
|
||||
prefixes := ips.Prefixes()
|
||||
|
||||
for _, pref := range prefixes {
|
||||
for _, port := range dest.Ports {
|
||||
pr := tailcfg.NetPortRange{
|
||||
IP: pref.String(),
|
||||
@ -103,6 +106,8 @@ func (pol *Policy) compileSSHPolicy(
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
log.Trace().Msgf("compiling SSH policy for node %q", node.Hostname())
|
||||
|
||||
var rules []*tailcfg.SSHRule
|
||||
|
||||
for index, rule := range pol.SSHs {
|
||||
@ -137,7 +142,8 @@ func (pol *Policy) compileSSHPolicy(
|
||||
var principals []*tailcfg.SSHPrincipal
|
||||
srcIPs, err := rule.Sources.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
log.Trace().Err(err).Msgf("resolving source ips")
|
||||
log.Trace().Err(err).Msgf("SSH policy compilation failed resolving source ips for rule %+v", rule)
|
||||
continue // Skip this rule if we can't resolve sources
|
||||
}
|
||||
|
||||
for addr := range util.IPSetAddrIter(srcIPs) {
|
||||
|
@ -70,7 +70,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||
// TODO(kradalby): This could potentially be optimized by only clearing the
|
||||
// policies for nodes that have changed. Particularly if the only difference is
|
||||
// that nodes has been added or removed.
|
||||
defer clear(pm.sshPolicyMap)
|
||||
clear(pm.sshPolicyMap)
|
||||
|
||||
filter, err := pm.pol.compileFilterRules(pm.users, pm.nodes)
|
||||
if err != nil {
|
||||
|
@ -1730,7 +1730,7 @@ func (u SSHUser) MarshalJSON() ([]byte, error) {
|
||||
// In addition to unmarshalling, it will also validate the policy.
|
||||
// This is the only entrypoint of reading a policy from a file or other source.
|
||||
func unmarshalPolicy(b []byte) (*Policy, error) {
|
||||
if b == nil || len(b) == 0 {
|
||||
if len(b) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
@ -2,20 +2,20 @@ package hscontrol
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/mapper"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/sasha-s/go-deadlock"
|
||||
xslices "golang.org/x/exp/slices"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/zstdframe"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -31,18 +31,17 @@ type mapSession struct {
|
||||
req tailcfg.MapRequest
|
||||
ctx context.Context
|
||||
capVer tailcfg.CapabilityVersion
|
||||
mapper *mapper.Mapper
|
||||
|
||||
cancelChMu deadlock.Mutex
|
||||
|
||||
ch chan types.StateUpdate
|
||||
ch chan *tailcfg.MapResponse
|
||||
cancelCh chan struct{}
|
||||
cancelChOpen bool
|
||||
|
||||
keepAlive time.Duration
|
||||
keepAliveTicker *time.Ticker
|
||||
|
||||
node types.NodeView
|
||||
node *types.Node
|
||||
w http.ResponseWriter
|
||||
|
||||
warnf func(string, ...any)
|
||||
@ -55,18 +54,9 @@ func (h *Headscale) newMapSession(
|
||||
ctx context.Context,
|
||||
req tailcfg.MapRequest,
|
||||
w http.ResponseWriter,
|
||||
nv types.NodeView,
|
||||
node *types.Node,
|
||||
) *mapSession {
|
||||
warnf, infof, tracef, errf := logPollFuncView(req, nv)
|
||||
|
||||
var updateChan chan types.StateUpdate
|
||||
if req.Stream {
|
||||
// Use a buffered channel in case a node is not fully ready
|
||||
// to receive a message to make sure we dont block the entire
|
||||
// notifier.
|
||||
updateChan = make(chan types.StateUpdate, h.cfg.Tuning.NodeMapSessionBufferedChanSize)
|
||||
updateChan <- types.UpdateFull()
|
||||
}
|
||||
warnf, infof, tracef, errf := logPollFunc(req, node)
|
||||
|
||||
ka := keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond)
|
||||
|
||||
@ -75,11 +65,10 @@ func (h *Headscale) newMapSession(
|
||||
ctx: ctx,
|
||||
req: req,
|
||||
w: w,
|
||||
node: nv,
|
||||
node: node,
|
||||
capVer: req.Version,
|
||||
mapper: h.mapper,
|
||||
|
||||
ch: updateChan,
|
||||
ch: make(chan *tailcfg.MapResponse, h.cfg.Tuning.NodeMapSessionBufferedChanSize),
|
||||
cancelCh: make(chan struct{}),
|
||||
cancelChOpen: true,
|
||||
|
||||
@ -95,15 +84,11 @@ func (h *Headscale) newMapSession(
|
||||
}
|
||||
|
||||
func (m *mapSession) isStreaming() bool {
|
||||
return m.req.Stream && !m.req.ReadOnly
|
||||
return m.req.Stream
|
||||
}
|
||||
|
||||
func (m *mapSession) isEndpointUpdate() bool {
|
||||
return !m.req.Stream && !m.req.ReadOnly && m.req.OmitPeers
|
||||
}
|
||||
|
||||
func (m *mapSession) isReadOnlyUpdate() bool {
|
||||
return !m.req.Stream && m.req.OmitPeers && m.req.ReadOnly
|
||||
return !m.req.Stream && m.req.OmitPeers
|
||||
}
|
||||
|
||||
func (m *mapSession) resetKeepAlive() {
|
||||
@ -112,25 +97,22 @@ func (m *mapSession) resetKeepAlive() {
|
||||
|
||||
func (m *mapSession) beforeServeLongPoll() {
|
||||
if m.node.IsEphemeral() {
|
||||
m.h.ephemeralGC.Cancel(m.node.ID())
|
||||
m.h.ephemeralGC.Cancel(m.node.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mapSession) afterServeLongPoll() {
|
||||
if m.node.IsEphemeral() {
|
||||
m.h.ephemeralGC.Schedule(m.node.ID(), m.h.cfg.EphemeralNodeInactivityTimeout)
|
||||
m.h.ephemeralGC.Schedule(m.node.ID, m.h.cfg.EphemeralNodeInactivityTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
// serve handles non-streaming requests.
|
||||
func (m *mapSession) serve() {
|
||||
// TODO(kradalby): A set todos to harden:
|
||||
// - func to tell the stream to die, readonly -> false, !stream && omitpeers -> false, true
|
||||
|
||||
// This is the mechanism where the node gives us information about its
|
||||
// current configuration.
|
||||
//
|
||||
// If OmitPeers is true, Stream is false, and ReadOnly is false,
|
||||
// If OmitPeers is true and Stream is false
|
||||
// then the server will let clients update their endpoints without
|
||||
// breaking existing long-polling (Stream == true) connections.
|
||||
// In this case, the server can omit the entire response; the client
|
||||
@ -138,26 +120,18 @@ func (m *mapSession) serve() {
|
||||
//
|
||||
// This is what Tailscale calls a Lite update, the client ignores
|
||||
// the response and just wants a 200.
|
||||
// !req.stream && !req.ReadOnly && req.OmitPeers
|
||||
//
|
||||
// TODO(kradalby): remove ReadOnly when we only support capVer 68+
|
||||
// !req.stream && req.OmitPeers
|
||||
if m.isEndpointUpdate() {
|
||||
m.handleEndpointUpdate()
|
||||
c, err := m.h.state.UpdateNodeFromMapRequest(m.node, m.req)
|
||||
if err != nil {
|
||||
httpError(m.w, err)
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
m.h.Change(c)
|
||||
|
||||
// ReadOnly is whether the client just wants to fetch the
|
||||
// MapResponse, without updating their Endpoints. The
|
||||
// Endpoints field will be ignored and LastSeen will not be
|
||||
// updated and peers will not be notified of changes.
|
||||
//
|
||||
// The intended use is for clients to discover the DERP map at
|
||||
// start-up before their first real endpoint update.
|
||||
if m.isReadOnlyUpdate() {
|
||||
m.handleReadOnlyRequest()
|
||||
|
||||
return
|
||||
m.w.WriteHeader(http.StatusOK)
|
||||
mapResponseEndpointUpdates.WithLabelValues("ok").Inc()
|
||||
}
|
||||
}
|
||||
|
||||
@ -175,23 +149,15 @@ func (m *mapSession) serveLongPoll() {
|
||||
close(m.cancelCh)
|
||||
m.cancelChMu.Unlock()
|
||||
|
||||
// only update node status if the node channel was removed.
|
||||
// in principal, it will be removed, but the client rapidly
|
||||
// reconnects, the channel might be of another connection.
|
||||
// In that case, it is not closed and the node is still online.
|
||||
if m.h.nodeNotifier.RemoveNode(m.node.ID(), m.ch) {
|
||||
// TODO(kradalby): This can likely be made more effective, but likely most
|
||||
// nodes has access to the same routes, so it might not be a big deal.
|
||||
change, err := m.h.state.Disconnect(m.node.ID())
|
||||
if err != nil {
|
||||
m.errf(err, "Failed to disconnect node %s", m.node.Hostname())
|
||||
}
|
||||
|
||||
if change {
|
||||
ctx := types.NotifyCtx(context.Background(), "poll-primary-change", m.node.Hostname())
|
||||
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
}
|
||||
// TODO(kradalby): This can likely be made more effective, but likely most
|
||||
// nodes has access to the same routes, so it might not be a big deal.
|
||||
disconnectChange, err := m.h.state.Disconnect(m.node)
|
||||
if err != nil {
|
||||
m.errf(err, "Failed to disconnect node %s", m.node.Hostname)
|
||||
}
|
||||
m.h.Change(disconnectChange)
|
||||
|
||||
m.h.mapBatcher.RemoveNode(m.node.ID, m.ch, m.node.IsSubnetRouter())
|
||||
|
||||
m.afterServeLongPoll()
|
||||
m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch)
|
||||
@ -201,21 +167,30 @@ func (m *mapSession) serveLongPoll() {
|
||||
m.h.pollNetMapStreamWG.Add(1)
|
||||
defer m.h.pollNetMapStreamWG.Done()
|
||||
|
||||
m.h.state.Connect(m.node.ID())
|
||||
|
||||
// Upgrade the writer to a ResponseController
|
||||
rc := http.NewResponseController(m.w)
|
||||
|
||||
// Longpolling will break if there is a write timeout,
|
||||
// so it needs to be disabled.
|
||||
rc.SetWriteDeadline(time.Time{})
|
||||
|
||||
ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname()))
|
||||
ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname))
|
||||
defer cancel()
|
||||
|
||||
m.keepAliveTicker = time.NewTicker(m.keepAlive)
|
||||
|
||||
m.h.nodeNotifier.AddNode(m.node.ID(), m.ch)
|
||||
// Add node to batcher BEFORE sending Connect change to prevent race condition
|
||||
// where the change is sent before the node is in the batcher's node map
|
||||
if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.node.IsSubnetRouter(), m.capVer); err != nil {
|
||||
m.errf(err, "failed to add node to batcher")
|
||||
// Send empty response to client to fail fast for invalid/non-existent nodes
|
||||
select {
|
||||
case m.ch <- &tailcfg.MapResponse{}:
|
||||
default:
|
||||
// Channel might be closed
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Now send the Connect change - the batcher handles NodeCameOnline internally
|
||||
// but we still need to update routes and other state-level changes
|
||||
connectChange := m.h.state.Connect(m.node)
|
||||
if !connectChange.Empty() && connectChange.Change != change.NodeCameOnline {
|
||||
m.h.Change(connectChange)
|
||||
}
|
||||
|
||||
m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch)
|
||||
|
||||
@ -236,290 +211,94 @@ func (m *mapSession) serveLongPoll() {
|
||||
|
||||
// Consume updates sent to node
|
||||
case update, ok := <-m.ch:
|
||||
m.tracef("received update from channel, ok: %t", ok)
|
||||
if !ok {
|
||||
m.tracef("update channel closed, streaming session is likely being replaced")
|
||||
return
|
||||
}
|
||||
|
||||
// If the node has been removed from headscale, close the stream
|
||||
if slices.Contains(update.Removed, m.node.ID()) {
|
||||
m.tracef("node removed, closing stream")
|
||||
if err := m.writeMap(update); err != nil {
|
||||
m.errf(err, "cannot write update to client")
|
||||
return
|
||||
}
|
||||
|
||||
m.tracef("received stream update: %s %s", update.Type.String(), update.Message)
|
||||
mapResponseUpdateReceived.WithLabelValues(update.Type.String()).Inc()
|
||||
|
||||
var data []byte
|
||||
var err error
|
||||
var lastMessage string
|
||||
|
||||
// Ensure the node view is updated, for example, there
|
||||
// might have been a hostinfo update in a sidechannel
|
||||
// which contains data needed to generate a map response.
|
||||
m.node, err = m.h.state.GetNodeViewByID(m.node.ID())
|
||||
if err != nil {
|
||||
m.errf(err, "Could not get machine from db")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
updateType := "full"
|
||||
switch update.Type {
|
||||
case types.StateFullUpdate:
|
||||
m.tracef("Sending Full MapResponse")
|
||||
data, err = m.mapper.FullMapResponse(m.req, m.node, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming()))
|
||||
case types.StatePeerChanged:
|
||||
changed := make(map[types.NodeID]bool, len(update.ChangeNodes))
|
||||
|
||||
for _, nodeID := range update.ChangeNodes {
|
||||
changed[nodeID] = true
|
||||
}
|
||||
|
||||
lastMessage = update.Message
|
||||
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
|
||||
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage)
|
||||
updateType = "change"
|
||||
|
||||
case types.StatePeerChangedPatch:
|
||||
m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage))
|
||||
data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches)
|
||||
updateType = "patch"
|
||||
case types.StatePeerRemoved:
|
||||
changed := make(map[types.NodeID]bool, len(update.Removed))
|
||||
|
||||
for _, nodeID := range update.Removed {
|
||||
changed[nodeID] = false
|
||||
}
|
||||
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
|
||||
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage)
|
||||
updateType = "remove"
|
||||
case types.StateSelfUpdate:
|
||||
lastMessage = update.Message
|
||||
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
|
||||
// create the map so an empty (self) update is sent
|
||||
data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, lastMessage)
|
||||
updateType = "remove"
|
||||
case types.StateDERPUpdated:
|
||||
m.tracef("Sending DERPUpdate MapResponse")
|
||||
data, err = m.mapper.DERPMapResponse(m.req, m.node, m.h.state.DERPMap())
|
||||
updateType = "derp"
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
m.errf(err, "Could not get the create map update")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Only send update if there is change
|
||||
if data != nil {
|
||||
startWrite := time.Now()
|
||||
_, err = m.w.Write(data)
|
||||
if err != nil {
|
||||
mapResponseSent.WithLabelValues("error", updateType).Inc()
|
||||
m.errf(err, "could not write the map response(%s), for mapSession: %p", update.Type.String(), m)
|
||||
return
|
||||
}
|
||||
|
||||
err = rc.Flush()
|
||||
if err != nil {
|
||||
mapResponseSent.WithLabelValues("error", updateType).Inc()
|
||||
m.errf(err, "flushing the map response to client, for mapSession: %p", m)
|
||||
return
|
||||
}
|
||||
|
||||
log.Trace().Str("node", m.node.Hostname()).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey().String()).Msg("finished writing mapresp to node")
|
||||
|
||||
if debugHighCardinalityMetrics {
|
||||
mapResponseLastSentSeconds.WithLabelValues(updateType, m.node.ID().String()).Set(float64(time.Now().Unix()))
|
||||
}
|
||||
mapResponseSent.WithLabelValues("ok", updateType).Inc()
|
||||
m.tracef("update sent")
|
||||
m.resetKeepAlive()
|
||||
}
|
||||
m.tracef("update sent")
|
||||
m.resetKeepAlive()
|
||||
|
||||
case <-m.keepAliveTicker.C:
|
||||
data, err := m.mapper.KeepAliveResponse(m.req, m.node)
|
||||
if err != nil {
|
||||
m.errf(err, "Error generating the keep alive msg")
|
||||
mapResponseSent.WithLabelValues("error", "keepalive").Inc()
|
||||
return
|
||||
}
|
||||
_, err = m.w.Write(data)
|
||||
if err != nil {
|
||||
m.errf(err, "Cannot write keep alive message")
|
||||
mapResponseSent.WithLabelValues("error", "keepalive").Inc()
|
||||
return
|
||||
}
|
||||
err = rc.Flush()
|
||||
if err != nil {
|
||||
m.errf(err, "flushing keep alive to client, for mapSession: %p", m)
|
||||
mapResponseSent.WithLabelValues("error", "keepalive").Inc()
|
||||
if err := m.writeMap(&keepAlive); err != nil {
|
||||
m.errf(err, "cannot write keep alive")
|
||||
return
|
||||
}
|
||||
|
||||
if debugHighCardinalityMetrics {
|
||||
mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID().String()).Set(float64(time.Now().Unix()))
|
||||
mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID.String()).Set(float64(time.Now().Unix()))
|
||||
}
|
||||
mapResponseSent.WithLabelValues("ok", "keepalive").Inc()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mapSession) handleEndpointUpdate() {
|
||||
m.tracef("received endpoint update")
|
||||
|
||||
// Get fresh node state from database for accurate route calculations
|
||||
node, err := m.h.state.GetNodeByID(m.node.ID())
|
||||
// writeMap writes the map response to the client.
|
||||
// It handles compression if requested and any headers that need to be set.
|
||||
// It also handles flushing the response if the ResponseWriter
|
||||
// implements http.Flusher.
|
||||
func (m *mapSession) writeMap(msg *tailcfg.MapResponse) error {
|
||||
jsonBody, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
m.errf(err, "Failed to get fresh node from database for endpoint update")
|
||||
http.Error(m.w, "", http.StatusInternalServerError)
|
||||
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
|
||||
return
|
||||
return fmt.Errorf("marshalling map response: %w", err)
|
||||
}
|
||||
|
||||
change := m.node.PeerChangeFromMapRequest(m.req)
|
||||
|
||||
online := m.h.nodeNotifier.IsLikelyConnected(m.node.ID())
|
||||
change.Online = &online
|
||||
|
||||
node.ApplyPeerChange(&change)
|
||||
|
||||
sendUpdate, routesChanged := hostInfoChanged(node.Hostinfo, m.req.Hostinfo)
|
||||
|
||||
// The node might not set NetInfo if it has not changed and if
|
||||
// the full HostInfo object is overwritten, the information is lost.
|
||||
// If there is no NetInfo, keep the previous one.
|
||||
// From 1.66 the client only sends it if changed:
|
||||
// https://github.com/tailscale/tailscale/commit/e1011f138737286ecf5123ff887a7a5800d129a2
|
||||
// TODO(kradalby): evaluate if we need better comparing of hostinfo
|
||||
// before we take the changes.
|
||||
if m.req.Hostinfo.NetInfo == nil && node.Hostinfo != nil {
|
||||
m.req.Hostinfo.NetInfo = node.Hostinfo.NetInfo
|
||||
}
|
||||
node.Hostinfo = m.req.Hostinfo
|
||||
|
||||
logTracePeerChange(node.Hostname, sendUpdate, &change)
|
||||
|
||||
// If there is no changes and nothing to save,
|
||||
// return early.
|
||||
if peerChangeEmpty(change) && !sendUpdate {
|
||||
mapResponseEndpointUpdates.WithLabelValues("noop").Inc()
|
||||
return
|
||||
if m.req.Compress == util.ZstdCompression {
|
||||
jsonBody = zstdframe.AppendEncode(nil, jsonBody, zstdframe.FastestCompression)
|
||||
}
|
||||
|
||||
// Auto approve any routes that have been defined in policy as
|
||||
// auto approved. Check if this actually changed the node.
|
||||
routesAutoApproved := m.h.state.AutoApproveRoutes(node)
|
||||
data := make([]byte, reservedResponseHeaderSize)
|
||||
binary.LittleEndian.PutUint32(data, uint32(len(jsonBody)))
|
||||
data = append(data, jsonBody...)
|
||||
|
||||
// Always update routes for connected nodes to handle reconnection scenarios
|
||||
// where routes need to be restored to the primary routes system
|
||||
routesToSet := node.SubnetRoutes()
|
||||
startWrite := time.Now()
|
||||
|
||||
if m.h.state.SetNodeRoutes(node.ID, routesToSet...) {
|
||||
ctx := types.NotifyCtx(m.ctx, "poll-primary-change", node.Hostname)
|
||||
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
} else if routesChanged {
|
||||
// Only send peer changed notification if routes actually changed
|
||||
ctx := types.NotifyCtx(m.ctx, "cli-approveroutes", node.Hostname)
|
||||
m.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
|
||||
|
||||
// TODO(kradalby): I am not sure if we need this?
|
||||
// Send an update to the node itself with to ensure it
|
||||
// has an updated packetfilter allowing the new route
|
||||
// if it is defined in the ACL.
|
||||
ctx = types.NotifyCtx(m.ctx, "poll-nodeupdate-self-hostinfochange", node.Hostname)
|
||||
m.h.nodeNotifier.NotifyByNodeID(
|
||||
ctx,
|
||||
types.UpdateSelf(node.ID),
|
||||
node.ID)
|
||||
_, err = m.w.Write(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If routes were auto-approved, we need to save the node to persist the changes
|
||||
if routesAutoApproved {
|
||||
if _, _, err := m.h.state.SaveNode(node); err != nil {
|
||||
m.errf(err, "Failed to save auto-approved routes to node")
|
||||
http.Error(m.w, "", http.StatusInternalServerError)
|
||||
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
|
||||
return
|
||||
if m.isStreaming() {
|
||||
if f, ok := m.w.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
} else {
|
||||
m.errf(nil, "ResponseWriter does not implement http.Flusher, cannot flush")
|
||||
}
|
||||
}
|
||||
|
||||
// Check if there has been a change to Hostname and update them
|
||||
// in the database. Then send a Changed update
|
||||
// (containing the whole node object) to peers to inform about
|
||||
// the hostname change.
|
||||
node.ApplyHostnameFromHostInfo(m.req.Hostinfo)
|
||||
log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey.String()).Msg("finished writing mapresp to node")
|
||||
|
||||
_, policyChanged, err := m.h.state.SaveNode(node)
|
||||
if err != nil {
|
||||
m.errf(err, "Failed to persist/update node in the database")
|
||||
http.Error(m.w, "", http.StatusInternalServerError)
|
||||
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Send policy update notifications if needed
|
||||
if policyChanged {
|
||||
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-policy", node.Hostname)
|
||||
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
}
|
||||
|
||||
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-patch", node.Hostname)
|
||||
m.h.nodeNotifier.NotifyWithIgnore(
|
||||
ctx,
|
||||
types.UpdatePeerChanged(node.ID),
|
||||
node.ID,
|
||||
)
|
||||
|
||||
m.w.WriteHeader(http.StatusOK)
|
||||
mapResponseEndpointUpdates.WithLabelValues("ok").Inc()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mapSession) handleReadOnlyRequest() {
|
||||
m.tracef("Client asked for a lite update, responding without peers")
|
||||
|
||||
mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node)
|
||||
if err != nil {
|
||||
m.errf(err, "Failed to create MapResponse")
|
||||
http.Error(m.w, "", http.StatusInternalServerError)
|
||||
mapResponseReadOnly.WithLabelValues("error").Inc()
|
||||
return
|
||||
}
|
||||
|
||||
m.w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
m.w.WriteHeader(http.StatusOK)
|
||||
_, err = m.w.Write(mapResp)
|
||||
if err != nil {
|
||||
m.errf(err, "Failed to write response")
|
||||
mapResponseReadOnly.WithLabelValues("error").Inc()
|
||||
return
|
||||
}
|
||||
|
||||
m.w.WriteHeader(http.StatusOK)
|
||||
mapResponseReadOnly.WithLabelValues("ok").Inc()
|
||||
var keepAlive = tailcfg.MapResponse{
|
||||
KeepAlive: true,
|
||||
}
|
||||
|
||||
func logTracePeerChange(hostname string, hostinfoChange bool, change *tailcfg.PeerChange) {
|
||||
trace := log.Trace().Uint64("node.id", uint64(change.NodeID)).Str("hostname", hostname)
|
||||
func logTracePeerChange(hostname string, hostinfoChange bool, peerChange *tailcfg.PeerChange) {
|
||||
trace := log.Trace().Uint64("node.id", uint64(peerChange.NodeID)).Str("hostname", hostname)
|
||||
|
||||
if change.Key != nil {
|
||||
trace = trace.Str("node_key", change.Key.ShortString())
|
||||
if peerChange.Key != nil {
|
||||
trace = trace.Str("node_key", peerChange.Key.ShortString())
|
||||
}
|
||||
|
||||
if change.DiscoKey != nil {
|
||||
trace = trace.Str("disco_key", change.DiscoKey.ShortString())
|
||||
if peerChange.DiscoKey != nil {
|
||||
trace = trace.Str("disco_key", peerChange.DiscoKey.ShortString())
|
||||
}
|
||||
|
||||
if change.Online != nil {
|
||||
trace = trace.Bool("online", *change.Online)
|
||||
if peerChange.Online != nil {
|
||||
trace = trace.Bool("online", *peerChange.Online)
|
||||
}
|
||||
|
||||
if change.Endpoints != nil {
|
||||
eps := make([]string, len(change.Endpoints))
|
||||
for idx, ep := range change.Endpoints {
|
||||
if peerChange.Endpoints != nil {
|
||||
eps := make([]string, len(peerChange.Endpoints))
|
||||
for idx, ep := range peerChange.Endpoints {
|
||||
eps[idx] = ep.String()
|
||||
}
|
||||
|
||||
@ -530,21 +309,11 @@ func logTracePeerChange(hostname string, hostinfoChange bool, change *tailcfg.Pe
|
||||
trace = trace.Bool("hostinfo_changed", hostinfoChange)
|
||||
}
|
||||
|
||||
if change.DERPRegion != 0 {
|
||||
trace = trace.Int("derp_region", change.DERPRegion)
|
||||
if peerChange.DERPRegion != 0 {
|
||||
trace = trace.Int("derp_region", peerChange.DERPRegion)
|
||||
}
|
||||
|
||||
trace.Time("last_seen", *change.LastSeen).Msg("PeerChange received")
|
||||
}
|
||||
|
||||
func peerChangeEmpty(chng tailcfg.PeerChange) bool {
|
||||
return chng.Key == nil &&
|
||||
chng.DiscoKey == nil &&
|
||||
chng.Online == nil &&
|
||||
chng.Endpoints == nil &&
|
||||
chng.DERPRegion == 0 &&
|
||||
chng.LastSeen == nil &&
|
||||
chng.KeyExpiry == nil
|
||||
trace.Time("last_seen", *peerChange.LastSeen).Msg("PeerChange received")
|
||||
}
|
||||
|
||||
func logPollFunc(
|
||||
@ -554,7 +323,6 @@ func logPollFunc(
|
||||
return func(msg string, a ...any) {
|
||||
log.Warn().
|
||||
Caller().
|
||||
Bool("readOnly", mapRequest.ReadOnly).
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Uint64("node.id", node.ID.Uint64()).
|
||||
@ -564,7 +332,6 @@ func logPollFunc(
|
||||
func(msg string, a ...any) {
|
||||
log.Info().
|
||||
Caller().
|
||||
Bool("readOnly", mapRequest.ReadOnly).
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Uint64("node.id", node.ID.Uint64()).
|
||||
@ -574,7 +341,6 @@ func logPollFunc(
|
||||
func(msg string, a ...any) {
|
||||
log.Trace().
|
||||
Caller().
|
||||
Bool("readOnly", mapRequest.ReadOnly).
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Uint64("node.id", node.ID.Uint64()).
|
||||
@ -584,7 +350,6 @@ func logPollFunc(
|
||||
func(err error, msg string, a ...any) {
|
||||
log.Error().
|
||||
Caller().
|
||||
Bool("readOnly", mapRequest.ReadOnly).
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Uint64("node.id", node.ID.Uint64()).
|
||||
@ -593,91 +358,3 @@ func logPollFunc(
|
||||
Msgf(msg, a...)
|
||||
}
|
||||
}
|
||||
|
||||
func logPollFuncView(
|
||||
mapRequest tailcfg.MapRequest,
|
||||
nodeView types.NodeView,
|
||||
) (func(string, ...any), func(string, ...any), func(string, ...any), func(error, string, ...any)) {
|
||||
return func(msg string, a ...any) {
|
||||
log.Warn().
|
||||
Caller().
|
||||
Bool("readOnly", mapRequest.ReadOnly).
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Uint64("node.id", nodeView.ID().Uint64()).
|
||||
Str("node", nodeView.Hostname()).
|
||||
Msgf(msg, a...)
|
||||
},
|
||||
func(msg string, a ...any) {
|
||||
log.Info().
|
||||
Caller().
|
||||
Bool("readOnly", mapRequest.ReadOnly).
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Uint64("node.id", nodeView.ID().Uint64()).
|
||||
Str("node", nodeView.Hostname()).
|
||||
Msgf(msg, a...)
|
||||
},
|
||||
func(msg string, a ...any) {
|
||||
log.Trace().
|
||||
Caller().
|
||||
Bool("readOnly", mapRequest.ReadOnly).
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Uint64("node.id", nodeView.ID().Uint64()).
|
||||
Str("node", nodeView.Hostname()).
|
||||
Msgf(msg, a...)
|
||||
},
|
||||
func(err error, msg string, a ...any) {
|
||||
log.Error().
|
||||
Caller().
|
||||
Bool("readOnly", mapRequest.ReadOnly).
|
||||
Bool("omitPeers", mapRequest.OmitPeers).
|
||||
Bool("stream", mapRequest.Stream).
|
||||
Uint64("node.id", nodeView.ID().Uint64()).
|
||||
Str("node", nodeView.Hostname()).
|
||||
Err(err).
|
||||
Msgf(msg, a...)
|
||||
}
|
||||
}
|
||||
|
||||
// hostInfoChanged reports if hostInfo has changed in two ways,
|
||||
// - first bool reports if an update needs to be sent to nodes
|
||||
// - second reports if there has been changes to routes
|
||||
// the caller can then use this info to save and update nodes
|
||||
// and routes as needed.
|
||||
func hostInfoChanged(old, new *tailcfg.Hostinfo) (bool, bool) {
|
||||
if old.Equal(new) {
|
||||
return false, false
|
||||
}
|
||||
|
||||
if old == nil && new != nil {
|
||||
return true, true
|
||||
}
|
||||
|
||||
// Routes
|
||||
oldRoutes := make([]netip.Prefix, 0)
|
||||
if old != nil {
|
||||
oldRoutes = old.RoutableIPs
|
||||
}
|
||||
newRoutes := new.RoutableIPs
|
||||
|
||||
tsaddr.SortPrefixes(oldRoutes)
|
||||
tsaddr.SortPrefixes(newRoutes)
|
||||
|
||||
if !xslices.Equal(oldRoutes, newRoutes) {
|
||||
return true, true
|
||||
}
|
||||
|
||||
// Services is mostly useful for discovery and not critical,
|
||||
// except for peerapi, which is how nodes talk to each other.
|
||||
// If peerapi was not part of the initial mapresponse, we
|
||||
// need to make sure its sent out later as it is needed for
|
||||
// Taildrop.
|
||||
// TODO(kradalby): Length comparison is a bit naive, replace.
|
||||
if len(old.Services) != len(new.Services) {
|
||||
return true, false
|
||||
}
|
||||
|
||||
return false, false
|
||||
}
|
||||
|
@ -17,10 +17,13 @@ import (
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
"github.com/juanfont/headscale/hscontrol/routes"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/sasha-s/go-deadlock"
|
||||
xslices "golang.org/x/exp/slices"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/ptr"
|
||||
@ -46,12 +49,6 @@ type State struct {
|
||||
// cfg holds the current Headscale configuration
|
||||
cfg *types.Config
|
||||
|
||||
// in-memory data, protected by mu
|
||||
// nodes contains the current set of registered nodes
|
||||
nodes types.Nodes
|
||||
// users contains the current set of users/namespaces
|
||||
users types.Users
|
||||
|
||||
// subsystem keeping state
|
||||
// db provides persistent storage and database operations
|
||||
db *hsdb.HSDatabase
|
||||
@ -113,9 +110,6 @@ func NewState(cfg *types.Config) (*State, error) {
|
||||
return &State{
|
||||
cfg: cfg,
|
||||
|
||||
nodes: nodes,
|
||||
users: users,
|
||||
|
||||
db: db,
|
||||
ipAlloc: ipAlloc,
|
||||
// TODO(kradalby): Update DERPMap
|
||||
@ -215,6 +209,7 @@ func (s *State) CreateUser(user types.User) (*types.User, bool, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
|
||||
if err := s.db.DB.Save(&user).Error; err != nil {
|
||||
return nil, false, fmt.Errorf("creating user: %w", err)
|
||||
}
|
||||
@ -226,6 +221,18 @@ func (s *State) CreateUser(user types.User) (*types.User, bool, error) {
|
||||
return &user, false, fmt.Errorf("failed to update policy manager after user creation: %w", err)
|
||||
}
|
||||
|
||||
// Even if the policy manager doesn't detect a filter change, SSH policies
|
||||
// might now be resolvable when they weren't before. If there are existing
|
||||
// nodes, we should send a policy change to ensure they get updated SSH policies.
|
||||
if !policyChanged {
|
||||
nodes, err := s.ListNodes()
|
||||
if err == nil && len(nodes) > 0 {
|
||||
policyChanged = true
|
||||
}
|
||||
}
|
||||
|
||||
log.Info().Str("user", user.Name).Bool("policyChanged", policyChanged).Msg("User created, policy manager updated")
|
||||
|
||||
// TODO(kradalby): implement the user in-memory cache
|
||||
|
||||
return &user, policyChanged, nil
|
||||
@ -329,7 +336,7 @@ func (s *State) CreateNode(node *types.Node) (*types.Node, bool, error) {
|
||||
}
|
||||
|
||||
// updateNodeTx performs a database transaction to update a node and refresh the policy manager.
|
||||
func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) error) (*types.Node, bool, error) {
|
||||
func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) error) (*types.Node, change.ChangeSet, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
@ -350,72 +357,100 @@ func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) err
|
||||
return node, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
return nil, change.EmptySet, err
|
||||
}
|
||||
|
||||
// Check if policy manager needs updating
|
||||
policyChanged, err := s.updatePolicyManagerNodes()
|
||||
if err != nil {
|
||||
return node, false, fmt.Errorf("failed to update policy manager after node update: %w", err)
|
||||
return node, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err)
|
||||
}
|
||||
|
||||
// TODO(kradalby): implement the node in-memory cache
|
||||
|
||||
return node, policyChanged, nil
|
||||
var c change.ChangeSet
|
||||
if policyChanged {
|
||||
c = change.PolicyChange()
|
||||
} else {
|
||||
// Basic node change without specific details since this is a generic update
|
||||
c = change.NodeAdded(node.ID)
|
||||
}
|
||||
|
||||
return node, c, nil
|
||||
}
|
||||
|
||||
// SaveNode persists an existing node to the database and updates the policy manager.
|
||||
func (s *State) SaveNode(node *types.Node) (*types.Node, bool, error) {
|
||||
func (s *State) SaveNode(node *types.Node) (*types.Node, change.ChangeSet, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if err := s.db.DB.Save(node).Error; err != nil {
|
||||
return nil, false, fmt.Errorf("saving node: %w", err)
|
||||
return nil, change.EmptySet, fmt.Errorf("saving node: %w", err)
|
||||
}
|
||||
|
||||
// Check if policy manager needs updating
|
||||
policyChanged, err := s.updatePolicyManagerNodes()
|
||||
if err != nil {
|
||||
return node, false, fmt.Errorf("failed to update policy manager after node save: %w", err)
|
||||
return node, change.EmptySet, fmt.Errorf("failed to update policy manager after node save: %w", err)
|
||||
}
|
||||
|
||||
// TODO(kradalby): implement the node in-memory cache
|
||||
|
||||
return node, policyChanged, nil
|
||||
if policyChanged {
|
||||
return node, change.PolicyChange(), nil
|
||||
}
|
||||
|
||||
return node, change.EmptySet, nil
|
||||
}
|
||||
|
||||
// DeleteNode permanently removes a node and cleans up associated resources.
|
||||
// Returns whether policies changed and any error. This operation is irreversible.
|
||||
func (s *State) DeleteNode(node *types.Node) (bool, error) {
|
||||
func (s *State) DeleteNode(node *types.Node) (change.ChangeSet, error) {
|
||||
err := s.db.DeleteNode(node)
|
||||
if err != nil {
|
||||
return false, err
|
||||
return change.EmptySet, err
|
||||
}
|
||||
|
||||
c := change.NodeRemoved(node.ID)
|
||||
|
||||
// Check if policy manager needs updating after node deletion
|
||||
policyChanged, err := s.updatePolicyManagerNodes()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to update policy manager after node deletion: %w", err)
|
||||
return change.EmptySet, fmt.Errorf("failed to update policy manager after node deletion: %w", err)
|
||||
}
|
||||
|
||||
return policyChanged, nil
|
||||
if policyChanged {
|
||||
c = change.PolicyChange()
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (s *State) Connect(id types.NodeID) {
|
||||
func (s *State) Connect(node *types.Node) change.ChangeSet {
|
||||
c := change.NodeOnline(node.ID)
|
||||
routeChange := s.primaryRoutes.SetRoutes(node.ID, node.SubnetRoutes()...)
|
||||
|
||||
if routeChange {
|
||||
c = change.NodeAdded(node.ID)
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (s *State) Disconnect(id types.NodeID) (bool, error) {
|
||||
// TODO(kradalby): This node should update the in memory state
|
||||
_, polChanged, err := s.SetLastSeen(id, time.Now())
|
||||
func (s *State) Disconnect(node *types.Node) (change.ChangeSet, error) {
|
||||
c := change.NodeOffline(node.ID)
|
||||
|
||||
_, _, err := s.SetLastSeen(node.ID, time.Now())
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("disconnecting node: %w", err)
|
||||
return c, fmt.Errorf("disconnecting node: %w", err)
|
||||
}
|
||||
|
||||
changed := s.primaryRoutes.SetRoutes(id)
|
||||
if routeChange := s.primaryRoutes.SetRoutes(node.ID); routeChange {
|
||||
c = change.PolicyChange()
|
||||
}
|
||||
|
||||
// TODO(kradalby): the returned change should be more nuanced allowing us to
|
||||
// send more directed updates.
|
||||
return changed || polChanged, nil
|
||||
// TODO(kradalby): This node should update the in memory state
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// GetNodeByID retrieves a node by ID.
|
||||
@ -475,45 +510,93 @@ func (s *State) ListEphemeralNodes() (types.Nodes, error) {
|
||||
}
|
||||
|
||||
// SetNodeExpiry updates the expiration time for a node.
|
||||
func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (*types.Node, bool, error) {
|
||||
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (*types.Node, change.ChangeSet, error) {
|
||||
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
return hsdb.NodeSetExpiry(tx, nodeID, expiry)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, fmt.Errorf("setting node expiry: %w", err)
|
||||
}
|
||||
|
||||
if !c.IsFull() {
|
||||
c = change.KeyExpiry(nodeID)
|
||||
}
|
||||
|
||||
return n, c, nil
|
||||
}
|
||||
|
||||
// SetNodeTags assigns tags to a node for use in access control policies.
|
||||
func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (*types.Node, bool, error) {
|
||||
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (*types.Node, change.ChangeSet, error) {
|
||||
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
return hsdb.SetTags(tx, nodeID, tags)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, fmt.Errorf("setting node tags: %w", err)
|
||||
}
|
||||
|
||||
if !c.IsFull() {
|
||||
c = change.NodeAdded(nodeID)
|
||||
}
|
||||
|
||||
return n, c, nil
|
||||
}
|
||||
|
||||
// SetApprovedRoutes sets the network routes that a node is approved to advertise.
|
||||
func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (*types.Node, bool, error) {
|
||||
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (*types.Node, change.ChangeSet, error) {
|
||||
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
return hsdb.SetApprovedRoutes(tx, nodeID, routes)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, fmt.Errorf("setting approved routes: %w", err)
|
||||
}
|
||||
|
||||
// Update primary routes after changing approved routes
|
||||
routeChange := s.primaryRoutes.SetRoutes(nodeID, n.SubnetRoutes()...)
|
||||
|
||||
if routeChange || !c.IsFull() {
|
||||
c = change.PolicyChange()
|
||||
}
|
||||
|
||||
return n, c, nil
|
||||
}
|
||||
|
||||
// RenameNode changes the display name of a node.
|
||||
func (s *State) RenameNode(nodeID types.NodeID, newName string) (*types.Node, bool, error) {
|
||||
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
func (s *State) RenameNode(nodeID types.NodeID, newName string) (*types.Node, change.ChangeSet, error) {
|
||||
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
return hsdb.RenameNode(tx, nodeID, newName)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, fmt.Errorf("renaming node: %w", err)
|
||||
}
|
||||
|
||||
if !c.IsFull() {
|
||||
c = change.NodeAdded(nodeID)
|
||||
}
|
||||
|
||||
return n, c, nil
|
||||
}
|
||||
|
||||
// SetLastSeen updates when a node was last seen, used for connectivity monitoring.
|
||||
func (s *State) SetLastSeen(nodeID types.NodeID, lastSeen time.Time) (*types.Node, bool, error) {
|
||||
func (s *State) SetLastSeen(nodeID types.NodeID, lastSeen time.Time) (*types.Node, change.ChangeSet, error) {
|
||||
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
return hsdb.SetLastSeen(tx, nodeID, lastSeen)
|
||||
})
|
||||
}
|
||||
|
||||
// AssignNodeToUser transfers a node to a different user.
|
||||
func (s *State) AssignNodeToUser(nodeID types.NodeID, userID types.UserID) (*types.Node, bool, error) {
|
||||
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
func (s *State) AssignNodeToUser(nodeID types.NodeID, userID types.UserID) (*types.Node, change.ChangeSet, error) {
|
||||
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||
return hsdb.AssignNodeToUser(tx, nodeID, userID)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, fmt.Errorf("assigning node to user: %w", err)
|
||||
}
|
||||
|
||||
if !c.IsFull() {
|
||||
c = change.NodeAdded(nodeID)
|
||||
}
|
||||
|
||||
return n, c, nil
|
||||
}
|
||||
|
||||
// BackfillNodeIPs assigns IP addresses to nodes that don't have them.
|
||||
@ -523,7 +606,7 @@ func (s *State) BackfillNodeIPs() ([]string, error) {
|
||||
|
||||
// ExpireExpiredNodes finds and processes expired nodes since the last check.
|
||||
// Returns next check time, state update with expired nodes, and whether any were found.
|
||||
func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, types.StateUpdate, bool) {
|
||||
func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.ChangeSet, bool) {
|
||||
return hsdb.ExpireExpiredNodes(s.db.DB, lastCheck)
|
||||
}
|
||||
|
||||
@ -568,8 +651,14 @@ func (s *State) SetPolicyInDB(data string) (*types.Policy, error) {
|
||||
}
|
||||
|
||||
// SetNodeRoutes sets the primary routes for a node.
|
||||
func (s *State) SetNodeRoutes(nodeID types.NodeID, routes ...netip.Prefix) bool {
|
||||
return s.primaryRoutes.SetRoutes(nodeID, routes...)
|
||||
func (s *State) SetNodeRoutes(nodeID types.NodeID, routes ...netip.Prefix) change.ChangeSet {
|
||||
if s.primaryRoutes.SetRoutes(nodeID, routes...) {
|
||||
// Route changes affect packet filters for all nodes, so trigger a policy change
|
||||
// to ensure filters are regenerated across the entire network
|
||||
return change.PolicyChange()
|
||||
}
|
||||
|
||||
return change.EmptySet
|
||||
}
|
||||
|
||||
// GetNodePrimaryRoutes returns the primary routes for a node.
|
||||
@ -653,10 +742,10 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
userID types.UserID,
|
||||
expiry *time.Time,
|
||||
registrationMethod string,
|
||||
) (*types.Node, bool, error) {
|
||||
) (*types.Node, change.ChangeSet, error) {
|
||||
ipv4, ipv6, err := s.ipAlloc.Next()
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
return nil, change.EmptySet, err
|
||||
}
|
||||
|
||||
return s.db.HandleNodeFromAuthPath(
|
||||
@ -672,12 +761,15 @@ func (s *State) HandleNodeFromAuthPath(
|
||||
func (s *State) HandleNodeFromPreAuthKey(
|
||||
regReq tailcfg.RegisterRequest,
|
||||
machineKey key.MachinePublic,
|
||||
) (*types.Node, bool, error) {
|
||||
) (*types.Node, change.ChangeSet, bool, error) {
|
||||
pak, err := s.GetPreAuthKey(regReq.Auth.AuthKey)
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, false, err
|
||||
}
|
||||
|
||||
err = pak.Validate()
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
return nil, change.EmptySet, false, err
|
||||
}
|
||||
|
||||
nodeToRegister := types.Node{
|
||||
@ -698,22 +790,13 @@ func (s *State) HandleNodeFromPreAuthKey(
|
||||
AuthKeyID: &pak.ID,
|
||||
}
|
||||
|
||||
// For auth key registration, ensure we don't keep an expired node
|
||||
// This is especially important for re-registration after logout
|
||||
if !regReq.Expiry.IsZero() && regReq.Expiry.After(time.Now()) {
|
||||
if !regReq.Expiry.IsZero() {
|
||||
nodeToRegister.Expiry = ®Req.Expiry
|
||||
} else if !regReq.Expiry.IsZero() {
|
||||
// If client is sending an expired time (e.g., after logout),
|
||||
// don't set expiry so the node won't be considered expired
|
||||
log.Debug().
|
||||
Time("requested_expiry", regReq.Expiry).
|
||||
Str("node", regReq.Hostinfo.Hostname).
|
||||
Msg("Ignoring expired expiry time from auth key registration")
|
||||
}
|
||||
|
||||
ipv4, ipv6, err := s.ipAlloc.Next()
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("allocating IPs: %w", err)
|
||||
return nil, change.EmptySet, false, fmt.Errorf("allocating IPs: %w", err)
|
||||
}
|
||||
|
||||
node, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
@ -735,18 +818,38 @@ func (s *State) HandleNodeFromPreAuthKey(
|
||||
return node, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("writing node to database: %w", err)
|
||||
return nil, change.EmptySet, false, fmt.Errorf("writing node to database: %w", err)
|
||||
}
|
||||
|
||||
// Check if this is a logout request for an ephemeral node
|
||||
if !regReq.Expiry.IsZero() && regReq.Expiry.Before(time.Now()) && pak.Ephemeral {
|
||||
// This is a logout request for an ephemeral node, delete it immediately
|
||||
c, err := s.DeleteNode(node)
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, false, fmt.Errorf("deleting ephemeral node during logout: %w", err)
|
||||
}
|
||||
return nil, c, false, nil
|
||||
}
|
||||
|
||||
// Check if policy manager needs updating
|
||||
// This is necessary because we just created a new node.
|
||||
// We need to ensure that the policy manager is aware of this new node.
|
||||
policyChanged, err := s.updatePolicyManagerNodes()
|
||||
// Also update users to ensure all users are known when evaluating policies.
|
||||
usersChanged, err := s.updatePolicyManagerUsers()
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("failed to update policy manager after node registration: %w", err)
|
||||
return nil, change.EmptySet, false, fmt.Errorf("failed to update policy manager users after node registration: %w", err)
|
||||
}
|
||||
|
||||
return node, policyChanged, nil
|
||||
nodesChanged, err := s.updatePolicyManagerNodes()
|
||||
if err != nil {
|
||||
return nil, change.EmptySet, false, fmt.Errorf("failed to update policy manager nodes after node registration: %w", err)
|
||||
}
|
||||
|
||||
policyChanged := usersChanged || nodesChanged
|
||||
|
||||
c := change.NodeAdded(node.ID)
|
||||
|
||||
return node, c, policyChanged, nil
|
||||
}
|
||||
|
||||
// AllocateNextIPs allocates the next available IPv4 and IPv6 addresses.
|
||||
@ -766,11 +869,15 @@ func (s *State) updatePolicyManagerUsers() (bool, error) {
|
||||
return false, fmt.Errorf("listing users for policy update: %w", err)
|
||||
}
|
||||
|
||||
log.Debug().Int("userCount", len(users)).Msg("Updating policy manager with users")
|
||||
|
||||
changed, err := s.polMan.SetUsers(users)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("updating policy manager users: %w", err)
|
||||
}
|
||||
|
||||
log.Debug().Bool("changed", changed).Msg("Policy manager users updated")
|
||||
|
||||
return changed, nil
|
||||
}
|
||||
|
||||
@ -835,3 +942,125 @@ func (s *State) autoApproveNodes() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(kradalby): This should just take the node ID?
|
||||
func (s *State) UpdateNodeFromMapRequest(node *types.Node, req tailcfg.MapRequest) (change.ChangeSet, error) {
|
||||
// TODO(kradalby): This is essentially a patch update that could be sent directly to nodes,
|
||||
// which means we could shortcut the whole change thing if there are no other important updates.
|
||||
peerChange := node.PeerChangeFromMapRequest(req)
|
||||
|
||||
node.ApplyPeerChange(&peerChange)
|
||||
|
||||
sendUpdate, routesChanged := hostInfoChanged(node.Hostinfo, req.Hostinfo)
|
||||
|
||||
// The node might not set NetInfo if it has not changed and if
|
||||
// the full HostInfo object is overwritten, the information is lost.
|
||||
// If there is no NetInfo, keep the previous one.
|
||||
// From 1.66 the client only sends it if changed:
|
||||
// https://github.com/tailscale/tailscale/commit/e1011f138737286ecf5123ff887a7a5800d129a2
|
||||
// TODO(kradalby): evaluate if we need better comparing of hostinfo
|
||||
// before we take the changes.
|
||||
if req.Hostinfo.NetInfo == nil && node.Hostinfo != nil {
|
||||
req.Hostinfo.NetInfo = node.Hostinfo.NetInfo
|
||||
}
|
||||
node.Hostinfo = req.Hostinfo
|
||||
|
||||
// If there is no changes and nothing to save,
|
||||
// return early.
|
||||
if peerChangeEmpty(peerChange) && !sendUpdate {
|
||||
// mapResponseEndpointUpdates.WithLabelValues("noop").Inc()
|
||||
return change.EmptySet, nil
|
||||
}
|
||||
|
||||
c := change.EmptySet
|
||||
|
||||
// Check if the Hostinfo of the node has changed.
|
||||
// If it has changed, check if there has been a change to
|
||||
// the routable IPs of the host and update them in
|
||||
// the database. Then send a Changed update
|
||||
// (containing the whole node object) to peers to inform about
|
||||
// the route change.
|
||||
// If the hostinfo has changed, but not the routes, just update
|
||||
// hostinfo and let the function continue.
|
||||
if routesChanged {
|
||||
// Auto approve any routes that have been defined in policy as
|
||||
// auto approved. Check if this actually changed the node.
|
||||
_ = s.AutoApproveRoutes(node)
|
||||
|
||||
// Update the routes of the given node in the route manager to
|
||||
// see if an update needs to be sent.
|
||||
c = s.SetNodeRoutes(node.ID, node.SubnetRoutes()...)
|
||||
}
|
||||
|
||||
// Check if there has been a change to Hostname and update them
|
||||
// in the database. Then send a Changed update
|
||||
// (containing the whole node object) to peers to inform about
|
||||
// the hostname change.
|
||||
node.ApplyHostnameFromHostInfo(req.Hostinfo)
|
||||
|
||||
_, policyChange, err := s.SaveNode(node)
|
||||
if err != nil {
|
||||
return change.EmptySet, err
|
||||
}
|
||||
|
||||
if policyChange.IsFull() {
|
||||
c = policyChange
|
||||
}
|
||||
|
||||
if c.Empty() {
|
||||
c = change.NodeAdded(node.ID)
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// hostInfoChanged reports if hostInfo has changed in two ways,
|
||||
// - first bool reports if an update needs to be sent to nodes
|
||||
// - second reports if there has been changes to routes
|
||||
// the caller can then use this info to save and update nodes
|
||||
// and routes as needed.
|
||||
func hostInfoChanged(old, new *tailcfg.Hostinfo) (bool, bool) {
|
||||
if old.Equal(new) {
|
||||
return false, false
|
||||
}
|
||||
|
||||
if old == nil && new != nil {
|
||||
return true, true
|
||||
}
|
||||
|
||||
// Routes
|
||||
oldRoutes := make([]netip.Prefix, 0)
|
||||
if old != nil {
|
||||
oldRoutes = old.RoutableIPs
|
||||
}
|
||||
newRoutes := new.RoutableIPs
|
||||
|
||||
tsaddr.SortPrefixes(oldRoutes)
|
||||
tsaddr.SortPrefixes(newRoutes)
|
||||
|
||||
if !xslices.Equal(oldRoutes, newRoutes) {
|
||||
return true, true
|
||||
}
|
||||
|
||||
// Services is mostly useful for discovery and not critical,
|
||||
// except for peerapi, which is how nodes talk to each other.
|
||||
// If peerapi was not part of the initial mapresponse, we
|
||||
// need to make sure its sent out later as it is needed for
|
||||
// Taildrop.
|
||||
// TODO(kradalby): Length comparison is a bit naive, replace.
|
||||
if len(old.Services) != len(new.Services) {
|
||||
return true, false
|
||||
}
|
||||
|
||||
return false, false
|
||||
}
|
||||
|
||||
func peerChangeEmpty(peerChange tailcfg.PeerChange) bool {
|
||||
return peerChange.Key == nil &&
|
||||
peerChange.DiscoKey == nil &&
|
||||
peerChange.Online == nil &&
|
||||
peerChange.Endpoints == nil &&
|
||||
peerChange.DERPRegion == 0 &&
|
||||
peerChange.LastSeen == nil &&
|
||||
peerChange.KeyExpiry == nil
|
||||
}
|
||||
|
183
hscontrol/types/change/change.go
Normal file
183
hscontrol/types/change/change.go
Normal file
@ -0,0 +1,183 @@
|
||||
//go:generate go tool stringer -type=Change
|
||||
package change
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
)
|
||||
|
||||
type (
|
||||
NodeID = types.NodeID
|
||||
UserID = types.UserID
|
||||
)
|
||||
|
||||
type Change int
|
||||
|
||||
const (
|
||||
ChangeUnknown Change = 0
|
||||
|
||||
// Deprecated: Use specific change instead
|
||||
// Full is a legacy change to ensure places where we
|
||||
// have not yet determined the specific update, can send.
|
||||
Full Change = 9
|
||||
|
||||
// Server changes.
|
||||
Policy Change = 11
|
||||
DERP Change = 12
|
||||
ExtraRecords Change = 13
|
||||
|
||||
// Node changes.
|
||||
NodeCameOnline Change = 21
|
||||
NodeWentOffline Change = 22
|
||||
NodeRemove Change = 23
|
||||
NodeKeyExpiry Change = 24
|
||||
NodeNewOrUpdate Change = 25
|
||||
|
||||
// User changes.
|
||||
UserNewOrUpdate Change = 51
|
||||
UserRemove Change = 52
|
||||
)
|
||||
|
||||
// AlsoSelf reports whether this change should also be sent to the node itself.
|
||||
func (c Change) AlsoSelf() bool {
|
||||
switch c {
|
||||
case NodeRemove, NodeKeyExpiry, NodeNewOrUpdate:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type ChangeSet struct {
|
||||
Change Change
|
||||
|
||||
// SelfUpdateOnly indicates that this change should only be sent
|
||||
// to the node itself, and not to other nodes.
|
||||
// This is used for changes that are not relevant to other nodes.
|
||||
// NodeID must be set if this is true.
|
||||
SelfUpdateOnly bool
|
||||
|
||||
// NodeID if set, is the ID of the node that is being changed.
|
||||
// It must be set if this is a node change.
|
||||
NodeID types.NodeID
|
||||
|
||||
// UserID if set, is the ID of the user that is being changed.
|
||||
// It must be set if this is a user change.
|
||||
UserID types.UserID
|
||||
|
||||
// IsSubnetRouter indicates whether the node is a subnet router.
|
||||
IsSubnetRouter bool
|
||||
}
|
||||
|
||||
func (c *ChangeSet) Validate() error {
|
||||
if c.Change >= NodeCameOnline || c.Change <= NodeNewOrUpdate {
|
||||
if c.NodeID == 0 {
|
||||
return errors.New("ChangeSet.NodeID must be set for node updates")
|
||||
}
|
||||
}
|
||||
|
||||
if c.Change >= UserNewOrUpdate || c.Change <= UserRemove {
|
||||
if c.UserID == 0 {
|
||||
return errors.New("ChangeSet.UserID must be set for user updates")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Empty reports whether the ChangeSet is empty, meaning it does not
|
||||
// represent any change.
|
||||
func (c ChangeSet) Empty() bool {
|
||||
return c.Change == ChangeUnknown && c.NodeID == 0 && c.UserID == 0
|
||||
}
|
||||
|
||||
// IsFull reports whether the ChangeSet represents a full update.
|
||||
func (c ChangeSet) IsFull() bool {
|
||||
return c.Change == Full || c.Change == Policy
|
||||
}
|
||||
|
||||
func (c ChangeSet) AlsoSelf() bool {
|
||||
// If NodeID is 0, it means this ChangeSet is not related to a specific node,
|
||||
// so we consider it as a change that should be sent to all nodes.
|
||||
if c.NodeID == 0 {
|
||||
return true
|
||||
}
|
||||
return c.Change.AlsoSelf() || c.SelfUpdateOnly
|
||||
}
|
||||
|
||||
var (
|
||||
EmptySet = ChangeSet{Change: ChangeUnknown}
|
||||
FullSet = ChangeSet{Change: Full}
|
||||
DERPSet = ChangeSet{Change: DERP}
|
||||
PolicySet = ChangeSet{Change: Policy}
|
||||
ExtraRecordsSet = ChangeSet{Change: ExtraRecords}
|
||||
)
|
||||
|
||||
func FullSelf(id types.NodeID) ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: Full,
|
||||
SelfUpdateOnly: true,
|
||||
NodeID: id,
|
||||
}
|
||||
}
|
||||
|
||||
func NodeAdded(id types.NodeID) ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: NodeNewOrUpdate,
|
||||
NodeID: id,
|
||||
}
|
||||
}
|
||||
|
||||
func NodeRemoved(id types.NodeID) ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: NodeRemove,
|
||||
NodeID: id,
|
||||
}
|
||||
}
|
||||
|
||||
func NodeOnline(id types.NodeID) ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: NodeCameOnline,
|
||||
NodeID: id,
|
||||
}
|
||||
}
|
||||
|
||||
func NodeOffline(id types.NodeID) ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: NodeWentOffline,
|
||||
NodeID: id,
|
||||
}
|
||||
}
|
||||
|
||||
func KeyExpiry(id types.NodeID) ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: NodeKeyExpiry,
|
||||
NodeID: id,
|
||||
}
|
||||
}
|
||||
|
||||
func UserAdded(id types.UserID) ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: UserNewOrUpdate,
|
||||
UserID: id,
|
||||
}
|
||||
}
|
||||
|
||||
func UserRemoved(id types.UserID) ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: UserRemove,
|
||||
UserID: id,
|
||||
}
|
||||
}
|
||||
|
||||
func PolicyChange() ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: Policy,
|
||||
}
|
||||
}
|
||||
|
||||
func DERPChange() ChangeSet {
|
||||
return ChangeSet{
|
||||
Change: DERP,
|
||||
}
|
||||
}
|
57
hscontrol/types/change/change_string.go
Normal file
57
hscontrol/types/change/change_string.go
Normal file
@ -0,0 +1,57 @@
|
||||
// Code generated by "stringer -type=Change"; DO NOT EDIT.
|
||||
|
||||
package change
|
||||
|
||||
import "strconv"
|
||||
|
||||
func _() {
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
var x [1]struct{}
|
||||
_ = x[ChangeUnknown-0]
|
||||
_ = x[Full-9]
|
||||
_ = x[Policy-11]
|
||||
_ = x[DERP-12]
|
||||
_ = x[ExtraRecords-13]
|
||||
_ = x[NodeCameOnline-21]
|
||||
_ = x[NodeWentOffline-22]
|
||||
_ = x[NodeRemove-23]
|
||||
_ = x[NodeKeyExpiry-24]
|
||||
_ = x[NodeNewOrUpdate-25]
|
||||
_ = x[UserNewOrUpdate-51]
|
||||
_ = x[UserRemove-52]
|
||||
}
|
||||
|
||||
const (
|
||||
_Change_name_0 = "ChangeUnknown"
|
||||
_Change_name_1 = "Full"
|
||||
_Change_name_2 = "PolicyDERPExtraRecords"
|
||||
_Change_name_3 = "NodeCameOnlineNodeWentOfflineNodeRemoveNodeKeyExpiryNodeNewOrUpdate"
|
||||
_Change_name_4 = "UserNewOrUpdateUserRemove"
|
||||
)
|
||||
|
||||
var (
|
||||
_Change_index_2 = [...]uint8{0, 6, 10, 22}
|
||||
_Change_index_3 = [...]uint8{0, 14, 29, 39, 52, 67}
|
||||
_Change_index_4 = [...]uint8{0, 15, 25}
|
||||
)
|
||||
|
||||
func (i Change) String() string {
|
||||
switch {
|
||||
case i == 0:
|
||||
return _Change_name_0
|
||||
case i == 9:
|
||||
return _Change_name_1
|
||||
case 11 <= i && i <= 13:
|
||||
i -= 11
|
||||
return _Change_name_2[_Change_index_2[i]:_Change_index_2[i+1]]
|
||||
case 21 <= i && i <= 25:
|
||||
i -= 21
|
||||
return _Change_name_3[_Change_index_3[i]:_Change_index_3[i+1]]
|
||||
case 51 <= i && i <= 52:
|
||||
i -= 51
|
||||
return _Change_name_4[_Change_index_4[i]:_Change_index_4[i+1]]
|
||||
default:
|
||||
return "Change(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
}
|
@ -1,16 +1,16 @@
|
||||
//go:generate go run tailscale.com/cmd/viewer --type=User,Node,PreAuthKey
|
||||
|
||||
//go:generate go tool viewer --type=User,Node,PreAuthKey
|
||||
package types
|
||||
|
||||
//go:generate go run tailscale.com/cmd/viewer --type=User,Node,PreAuthKey
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/ctxkey"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -150,18 +150,6 @@ func UpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate {
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
NotifyOriginKey = ctxkey.New("notify.origin", "")
|
||||
NotifyHostnameKey = ctxkey.New("notify.hostname", "")
|
||||
)
|
||||
|
||||
func NotifyCtx(ctx context.Context, origin, hostname string) context.Context {
|
||||
ctx2, _ := context.WithTimeout(ctx, 3*time.Second)
|
||||
ctx2 = NotifyOriginKey.WithValue(ctx2, origin)
|
||||
ctx2 = NotifyHostnameKey.WithValue(ctx2, hostname)
|
||||
return ctx2
|
||||
}
|
||||
|
||||
const RegistrationIDLength = 24
|
||||
|
||||
type RegistrationID string
|
||||
@ -199,3 +187,20 @@ type RegisterNode struct {
|
||||
Node Node
|
||||
Registered chan *Node
|
||||
}
|
||||
|
||||
// DefaultBatcherWorkers returns the default number of batcher workers.
|
||||
// Default to 3/4 of CPU cores, minimum 1, no maximum.
|
||||
func DefaultBatcherWorkers() int {
|
||||
return DefaultBatcherWorkersFor(runtime.NumCPU())
|
||||
}
|
||||
|
||||
// DefaultBatcherWorkersFor returns the default number of batcher workers for a given CPU count.
|
||||
// Default to 3/4 of CPU cores, minimum 1, no maximum.
|
||||
func DefaultBatcherWorkersFor(cpuCount int) int {
|
||||
defaultWorkers := (cpuCount * 3) / 4
|
||||
if defaultWorkers < 1 {
|
||||
defaultWorkers = 1
|
||||
}
|
||||
|
||||
return defaultWorkers
|
||||
}
|
||||
|
36
hscontrol/types/common_test.go
Normal file
36
hscontrol/types/common_test.go
Normal file
@ -0,0 +1,36 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDefaultBatcherWorkersFor(t *testing.T) {
|
||||
tests := []struct {
|
||||
cpuCount int
|
||||
expected int
|
||||
}{
|
||||
{1, 1}, // (1*3)/4 = 0, should be minimum 1
|
||||
{2, 1}, // (2*3)/4 = 1
|
||||
{4, 3}, // (4*3)/4 = 3
|
||||
{8, 6}, // (8*3)/4 = 6
|
||||
{12, 9}, // (12*3)/4 = 9
|
||||
{16, 12}, // (16*3)/4 = 12
|
||||
{20, 15}, // (20*3)/4 = 15
|
||||
{24, 18}, // (24*3)/4 = 18
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
result := DefaultBatcherWorkersFor(test.cpuCount)
|
||||
if result != test.expected {
|
||||
t.Errorf("DefaultBatcherWorkersFor(%d) = %d, expected %d", test.cpuCount, result, test.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultBatcherWorkers(t *testing.T) {
|
||||
// Just verify it returns a valid value (>= 1)
|
||||
result := DefaultBatcherWorkers()
|
||||
if result < 1 {
|
||||
t.Errorf("DefaultBatcherWorkers() = %d, expected value >= 1", result)
|
||||
}
|
||||
}
|
@ -234,6 +234,7 @@ type Tuning struct {
|
||||
NotifierSendTimeout time.Duration
|
||||
BatchChangeDelay time.Duration
|
||||
NodeMapSessionBufferedChanSize int
|
||||
BatcherWorkers int
|
||||
}
|
||||
|
||||
func validatePKCEMethod(method string) error {
|
||||
@ -991,6 +992,12 @@ func LoadServerConfig() (*Config, error) {
|
||||
NodeMapSessionBufferedChanSize: viper.GetInt(
|
||||
"tuning.node_mapsession_buffered_chan_size",
|
||||
),
|
||||
BatcherWorkers: func() int {
|
||||
if workers := viper.GetInt("tuning.batcher_workers"); workers > 0 {
|
||||
return workers
|
||||
}
|
||||
return DefaultBatcherWorkers()
|
||||
}(),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
@ -431,6 +431,11 @@ func (node *Node) SubnetRoutes() []netip.Prefix {
|
||||
return routes
|
||||
}
|
||||
|
||||
// IsSubnetRouter reports if the node has any subnet routes.
|
||||
func (node *Node) IsSubnetRouter() bool {
|
||||
return len(node.SubnetRoutes()) > 0
|
||||
}
|
||||
|
||||
func (node *Node) String() string {
|
||||
return node.Hostname
|
||||
}
|
||||
@ -669,6 +674,13 @@ func (v NodeView) SubnetRoutes() []netip.Prefix {
|
||||
return v.ж.SubnetRoutes()
|
||||
}
|
||||
|
||||
func (v NodeView) IsSubnetRouter() bool {
|
||||
if !v.Valid() {
|
||||
return false
|
||||
}
|
||||
return v.ж.IsSubnetRouter()
|
||||
}
|
||||
|
||||
func (v NodeView) AppendToIPSet(build *netipx.IPSetBuilder) {
|
||||
if !v.Valid() {
|
||||
return
|
||||
|
@ -1,17 +1,16 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/rs/zerolog/log"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
)
|
||||
|
||||
type PAKError string
|
||||
|
||||
func (e PAKError) Error() string { return string(e) }
|
||||
func (e PAKError) Unwrap() error { return fmt.Errorf("preauth key error: %w", e) }
|
||||
|
||||
// PreAuthKey describes a pre-authorization key usable in a particular user.
|
||||
type PreAuthKey struct {
|
||||
@ -60,6 +59,21 @@ func (pak *PreAuthKey) Validate() error {
|
||||
if pak == nil {
|
||||
return PAKError("invalid authkey")
|
||||
}
|
||||
|
||||
log.Debug().
|
||||
Str("key", pak.Key).
|
||||
Bool("hasExpiration", pak.Expiration != nil).
|
||||
Time("expiration", func() time.Time {
|
||||
if pak.Expiration != nil {
|
||||
return *pak.Expiration
|
||||
}
|
||||
return time.Time{}
|
||||
}()).
|
||||
Time("now", time.Now()).
|
||||
Bool("reusable", pak.Reusable).
|
||||
Bool("used", pak.Used).
|
||||
Msg("PreAuthKey.Validate: checking key")
|
||||
|
||||
if pak.Expiration != nil && pak.Expiration.Before(time.Now()) {
|
||||
return PAKError("authkey expired")
|
||||
}
|
||||
|
@ -5,6 +5,8 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"tailscale.com/util/dnsname"
|
||||
"tailscale.com/util/must"
|
||||
)
|
||||
|
||||
func TestCheckForFQDNRules(t *testing.T) {
|
||||
@ -102,59 +104,16 @@ func TestConvertWithFQDNRules(t *testing.T) {
|
||||
func TestMagicDNSRootDomains100(t *testing.T) {
|
||||
domains := GenerateIPv4DNSRootDomain(netip.MustParsePrefix("100.64.0.0/10"))
|
||||
|
||||
found := false
|
||||
for _, domain := range domains {
|
||||
if domain == "64.100.in-addr.arpa." {
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
|
||||
found = false
|
||||
for _, domain := range domains {
|
||||
if domain == "100.100.in-addr.arpa." {
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
|
||||
found = false
|
||||
for _, domain := range domains {
|
||||
if domain == "127.100.in-addr.arpa." {
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
assert.Contains(t, domains, must.Get(dnsname.ToFQDN("64.100.in-addr.arpa.")))
|
||||
assert.Contains(t, domains, must.Get(dnsname.ToFQDN("100.100.in-addr.arpa.")))
|
||||
assert.Contains(t, domains, must.Get(dnsname.ToFQDN("127.100.in-addr.arpa.")))
|
||||
}
|
||||
|
||||
func TestMagicDNSRootDomains172(t *testing.T) {
|
||||
domains := GenerateIPv4DNSRootDomain(netip.MustParsePrefix("172.16.0.0/16"))
|
||||
|
||||
found := false
|
||||
for _, domain := range domains {
|
||||
if domain == "0.16.172.in-addr.arpa." {
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
|
||||
found = false
|
||||
for _, domain := range domains {
|
||||
if domain == "255.16.172.in-addr.arpa." {
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
assert.Contains(t, domains, must.Get(dnsname.ToFQDN("0.16.172.in-addr.arpa.")))
|
||||
assert.Contains(t, domains, must.Get(dnsname.ToFQDN("255.16.172.in-addr.arpa.")))
|
||||
}
|
||||
|
||||
// Happens when netmask is a multiple of 4 bits (sounds likely).
|
||||
|
@ -143,7 +143,7 @@ func ParseTraceroute(output string) (Traceroute, error) {
|
||||
|
||||
// Parse latencies
|
||||
for j := 5; j <= 7; j++ {
|
||||
if matches[j] != "" {
|
||||
if j < len(matches) && matches[j] != "" {
|
||||
ms, err := strconv.ParseFloat(matches[j], 64)
|
||||
if err != nil {
|
||||
return Traceroute{}, fmt.Errorf("parsing latency: %w", err)
|
||||
|
@ -88,7 +88,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
|
||||
var err error
|
||||
listNodes, err = headscale.ListNodes()
|
||||
assert.NoError(ct, err)
|
||||
assert.Equal(ct, nodeCountBeforeLogout, len(listNodes), "Node count should match before logout count")
|
||||
assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match before logout count")
|
||||
}, 20*time.Second, 1*time.Second)
|
||||
|
||||
for _, node := range listNodes {
|
||||
@ -123,7 +123,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
|
||||
var err error
|
||||
listNodes, err = headscale.ListNodes()
|
||||
assert.NoError(ct, err)
|
||||
assert.Equal(ct, nodeCountBeforeLogout, len(listNodes), "Node count should match after HTTPS reconnection")
|
||||
assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match after HTTPS reconnection")
|
||||
}, 30*time.Second, 2*time.Second)
|
||||
|
||||
for _, node := range listNodes {
|
||||
@ -161,7 +161,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
|
||||
}
|
||||
|
||||
listNodes, err = headscale.ListNodes()
|
||||
require.Equal(t, nodeCountBeforeLogout, len(listNodes))
|
||||
require.Len(t, listNodes, nodeCountBeforeLogout)
|
||||
for _, node := range listNodes {
|
||||
assertLastSeenSet(t, node)
|
||||
}
|
||||
@ -355,7 +355,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
|
||||
"--user",
|
||||
strconv.FormatUint(userMap[userName].GetId(), 10),
|
||||
"expire",
|
||||
key.Key,
|
||||
key.GetKey(),
|
||||
})
|
||||
assertNoErr(t, err)
|
||||
|
||||
|
@ -604,7 +604,7 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
status, err := client.Status()
|
||||
assert.NoError(ct, err)
|
||||
assert.NotContains(ct, []string{"Starting", "Running"}, status.BackendState,
|
||||
assert.NotContains(ct, []string{"Starting", "Running"}, status.BackendState,
|
||||
"Expected node to be logged out, backend state: %s", status.BackendState)
|
||||
}, 30*time.Second, 2*time.Second)
|
||||
|
||||
|
@ -147,3 +147,9 @@ func DockerAllowNetworkAdministration(config *docker.HostConfig) {
|
||||
config.CapAdd = append(config.CapAdd, "NET_ADMIN")
|
||||
config.Privileged = true
|
||||
}
|
||||
|
||||
// DockerMemoryLimit sets memory limit and disables OOM kill for containers.
|
||||
func DockerMemoryLimit(config *docker.HostConfig) {
|
||||
config.Memory = 2 * 1024 * 1024 * 1024 // 2GB in bytes
|
||||
config.OOMKillDisable = true
|
||||
}
|
||||
|
@ -145,9 +145,9 @@ func derpServerScenario(
|
||||
assert.NoError(ct, err, "Failed to get status for client %s", client.Hostname())
|
||||
|
||||
for _, health := range status.Health {
|
||||
assert.NotContains(ct, health, "could not connect to any relay server",
|
||||
assert.NotContains(ct, health, "could not connect to any relay server",
|
||||
"Client %s should be connected to DERP relay", client.Hostname())
|
||||
assert.NotContains(ct, health, "could not connect to the 'Headscale Embedded DERP' relay server.",
|
||||
assert.NotContains(ct, health, "could not connect to the 'Headscale Embedded DERP' relay server.",
|
||||
"Client %s should be connected to Headscale Embedded DERP", client.Hostname())
|
||||
}
|
||||
}, 30*time.Second, 2*time.Second)
|
||||
@ -166,9 +166,9 @@ func derpServerScenario(
|
||||
assert.NoError(ct, err, "Failed to get status for client %s", client.Hostname())
|
||||
|
||||
for _, health := range status.Health {
|
||||
assert.NotContains(ct, health, "could not connect to any relay server",
|
||||
assert.NotContains(ct, health, "could not connect to any relay server",
|
||||
"Client %s should be connected to DERP relay after first run", client.Hostname())
|
||||
assert.NotContains(ct, health, "could not connect to the 'Headscale Embedded DERP' relay server.",
|
||||
assert.NotContains(ct, health, "could not connect to the 'Headscale Embedded DERP' relay server.",
|
||||
"Client %s should be connected to Headscale Embedded DERP after first run", client.Hostname())
|
||||
}
|
||||
}, 30*time.Second, 2*time.Second)
|
||||
@ -191,9 +191,9 @@ func derpServerScenario(
|
||||
assert.NoError(ct, err, "Failed to get status for client %s", client.Hostname())
|
||||
|
||||
for _, health := range status.Health {
|
||||
assert.NotContains(ct, health, "could not connect to any relay server",
|
||||
assert.NotContains(ct, health, "could not connect to any relay server",
|
||||
"Client %s should be connected to DERP relay after second run", client.Hostname())
|
||||
assert.NotContains(ct, health, "could not connect to the 'Headscale Embedded DERP' relay server.",
|
||||
assert.NotContains(ct, health, "could not connect to the 'Headscale Embedded DERP' relay server.",
|
||||
"Client %s should be connected to Headscale Embedded DERP after second run", client.Hostname())
|
||||
}
|
||||
}, 30*time.Second, 2*time.Second)
|
||||
|
@ -883,6 +883,10 @@ func TestNodeOnlineStatus(t *testing.T) {
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
status, err := client.Status()
|
||||
assert.NoError(ct, err)
|
||||
if status == nil {
|
||||
assert.Fail(ct, "status is nil")
|
||||
return
|
||||
}
|
||||
|
||||
for _, peerKey := range status.Peers() {
|
||||
peerStatus := status.Peer[peerKey]
|
||||
@ -984,16 +988,11 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
|
||||
}
|
||||
|
||||
// Wait for sync and successful pings after nodes come back up
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assert.NoError(ct, err)
|
||||
|
||||
success := pingAllHelper(t, allClients, allAddrs)
|
||||
assert.Greater(ct, success, 0, "Nodes should be able to ping after coming back up")
|
||||
}, 30*time.Second, 2*time.Second)
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
assert.NoError(t, err)
|
||||
|
||||
success := pingAllHelper(t, allClients, allAddrs)
|
||||
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
||||
assert.Equalf(t, len(allClients)*len(allIps), success, "%d successful pings out of %d", success, len(allClients)*len(allIps))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -260,7 +260,9 @@ func WithDERPConfig(derpMap tailcfg.DERPMap) Option {
|
||||
func WithTuning(batchTimeout time.Duration, mapSessionChanSize int) Option {
|
||||
return func(hsic *HeadscaleInContainer) {
|
||||
hsic.env["HEADSCALE_TUNING_BATCH_CHANGE_DELAY"] = batchTimeout.String()
|
||||
hsic.env["HEADSCALE_TUNING_NODE_MAPSESSION_BUFFERED_CHAN_SIZE"] = strconv.Itoa(mapSessionChanSize)
|
||||
hsic.env["HEADSCALE_TUNING_NODE_MAPSESSION_BUFFERED_CHAN_SIZE"] = strconv.Itoa(
|
||||
mapSessionChanSize,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -279,10 +281,16 @@ func WithDebugPort(port int) Option {
|
||||
|
||||
// buildEntrypoint builds the container entrypoint command based on configuration.
|
||||
func (hsic *HeadscaleInContainer) buildEntrypoint() []string {
|
||||
debugCmd := fmt.Sprintf("/go/bin/dlv --listen=0.0.0.0:%d --headless=true --api-version=2 --accept-multiclient --allow-non-terminal-interactive=true exec /go/bin/headscale --continue -- serve", hsic.debugPort)
|
||||
|
||||
entrypoint := fmt.Sprintf("/bin/sleep 3 ; update-ca-certificates ; %s ; /bin/sleep 30", debugCmd)
|
||||
|
||||
debugCmd := fmt.Sprintf(
|
||||
"/go/bin/dlv --listen=0.0.0.0:%d --headless=true --api-version=2 --accept-multiclient --allow-non-terminal-interactive=true exec /go/bin/headscale --continue -- serve",
|
||||
hsic.debugPort,
|
||||
)
|
||||
|
||||
entrypoint := fmt.Sprintf(
|
||||
"/bin/sleep 3 ; update-ca-certificates ; %s ; /bin/sleep 30",
|
||||
debugCmd,
|
||||
)
|
||||
|
||||
return []string{"/bin/bash", "-c", entrypoint}
|
||||
}
|
||||
|
||||
@ -447,8 +455,12 @@ func New(
|
||||
log.Printf("Created %s container\n", hsic.hostname)
|
||||
|
||||
hsic.container = container
|
||||
|
||||
log.Printf("Debug ports for %s: delve=%s, metrics/pprof=49090\n", hsic.hostname, hsic.GetHostDebugPort())
|
||||
|
||||
log.Printf(
|
||||
"Debug ports for %s: delve=%s, metrics/pprof=49090\n",
|
||||
hsic.hostname,
|
||||
hsic.GetHostDebugPort(),
|
||||
)
|
||||
|
||||
// Write the CA certificates to the container
|
||||
for i, cert := range hsic.caCerts {
|
||||
@ -684,14 +696,6 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// First, let's see what files are actually in /tmp
|
||||
tmpListing, err := t.Execute([]string{"ls", "-la", "/tmp/"})
|
||||
if err != nil {
|
||||
log.Printf("Warning: could not list /tmp directory: %v", err)
|
||||
} else {
|
||||
log.Printf("Contents of /tmp in container %s:\n%s", t.hostname, tmpListing)
|
||||
}
|
||||
|
||||
// Also check for any .sqlite files
|
||||
sqliteFiles, err := t.Execute([]string{"find", "/tmp", "-name", "*.sqlite*", "-type", "f"})
|
||||
if err != nil {
|
||||
@ -718,12 +722,6 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
|
||||
return errors.New("database file exists but has no schema (empty database)")
|
||||
}
|
||||
|
||||
// Show a preview of the schema (first 500 chars)
|
||||
schemaPreview := schemaCheck
|
||||
if len(schemaPreview) > 500 {
|
||||
schemaPreview = schemaPreview[:500] + "..."
|
||||
}
|
||||
|
||||
tarFile, err := t.FetchPath("/tmp/integration_test_db.sqlite3")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch database file: %w", err)
|
||||
@ -740,7 +738,12 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
|
||||
return fmt.Errorf("failed to read tar header: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("Found file in tar: %s (type: %d, size: %d)", header.Name, header.Typeflag, header.Size)
|
||||
log.Printf(
|
||||
"Found file in tar: %s (type: %d, size: %d)",
|
||||
header.Name,
|
||||
header.Typeflag,
|
||||
header.Size,
|
||||
)
|
||||
|
||||
// Extract the first regular file we find
|
||||
if header.Typeflag == tar.TypeReg {
|
||||
@ -756,11 +759,20 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
|
||||
return fmt.Errorf("failed to copy database file: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("Extracted database file: %s (%d bytes written, header claimed %d bytes)", dbPath, written, header.Size)
|
||||
log.Printf(
|
||||
"Extracted database file: %s (%d bytes written, header claimed %d bytes)",
|
||||
dbPath,
|
||||
written,
|
||||
header.Size,
|
||||
)
|
||||
|
||||
// Check if we actually wrote something
|
||||
if written == 0 {
|
||||
return fmt.Errorf("database file is empty (size: %d, header size: %d)", written, header.Size)
|
||||
return fmt.Errorf(
|
||||
"database file is empty (size: %d, header size: %d)",
|
||||
written,
|
||||
header.Size,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -871,7 +883,15 @@ func (t *HeadscaleInContainer) WaitForRunning() error {
|
||||
func (t *HeadscaleInContainer) CreateUser(
|
||||
user string,
|
||||
) (*v1.User, error) {
|
||||
command := []string{"headscale", "users", "create", user, fmt.Sprintf("--email=%s@test.no", user), "--output", "json"}
|
||||
command := []string{
|
||||
"headscale",
|
||||
"users",
|
||||
"create",
|
||||
user,
|
||||
fmt.Sprintf("--email=%s@test.no", user),
|
||||
"--output",
|
||||
"json",
|
||||
}
|
||||
|
||||
result, _, err := dockertestutil.ExecuteCommand(
|
||||
t.container,
|
||||
@ -1182,13 +1202,18 @@ func (t *HeadscaleInContainer) ApproveRoutes(id uint64, routes []netip.Prefix) (
|
||||
[]string{},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute list node command: %w", err)
|
||||
return nil, fmt.Errorf(
|
||||
"failed to execute approve routes command (node %d, routes %v): %w",
|
||||
id,
|
||||
routes,
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
var node *v1.Node
|
||||
err = json.Unmarshal([]byte(result), &node)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal nodes: %w", err)
|
||||
return nil, fmt.Errorf("failed to unmarshal node response: %q, error: %w", result, err)
|
||||
}
|
||||
|
||||
return node, nil
|
||||
|
@ -310,7 +310,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
// Enable route on node 1
|
||||
t.Logf("Enabling route on subnet router 1, no HA")
|
||||
_, err = headscale.ApproveRoutes(
|
||||
1,
|
||||
MustFindNode(subRouter1.Hostname(), nodes).GetId(),
|
||||
[]netip.Prefix{pref},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
@ -366,7 +366,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
// Enable route on node 2, now we will have a HA subnet router
|
||||
t.Logf("Enabling route on subnet router 2, now HA, subnetrouter 1 is primary, 2 is standby")
|
||||
_, err = headscale.ApproveRoutes(
|
||||
2,
|
||||
MustFindNode(subRouter2.Hostname(), nodes).GetId(),
|
||||
[]netip.Prefix{pref},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
@ -422,7 +422,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
// be enabled.
|
||||
t.Logf("Enabling route on subnet router 3, now HA, subnetrouter 1 is primary, 2 and 3 is standby")
|
||||
_, err = headscale.ApproveRoutes(
|
||||
3,
|
||||
MustFindNode(subRouter3.Hostname(), nodes).GetId(),
|
||||
[]netip.Prefix{pref},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
@ -639,7 +639,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
|
||||
t.Logf("disabling route in subnet router r3 (%s)", subRouter3.Hostname())
|
||||
t.Logf("expecting route to failover to r1 (%s), which is still available with r2", subRouter1.Hostname())
|
||||
_, err = headscale.ApproveRoutes(nodes[2].GetId(), []netip.Prefix{})
|
||||
_, err = headscale.ApproveRoutes(MustFindNode(subRouter3.Hostname(), nodes).GetId(), []netip.Prefix{})
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
@ -647,9 +647,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 6)
|
||||
|
||||
requireNodeRouteCount(t, nodes[0], 1, 1, 1)
|
||||
requireNodeRouteCount(t, nodes[1], 1, 1, 0)
|
||||
requireNodeRouteCount(t, nodes[2], 1, 0, 0)
|
||||
requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 1)
|
||||
requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 0)
|
||||
requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
|
||||
|
||||
// Verify that the route is announced from subnet router 1
|
||||
clientStatus, err = client.Status()
|
||||
@ -684,7 +684,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
// Disable the route of subnet router 1, making it failover to 2
|
||||
t.Logf("disabling route in subnet router r1 (%s)", subRouter1.Hostname())
|
||||
t.Logf("expecting route to failover to r2 (%s)", subRouter2.Hostname())
|
||||
_, err = headscale.ApproveRoutes(nodes[0].GetId(), []netip.Prefix{})
|
||||
_, err = headscale.ApproveRoutes(MustFindNode(subRouter1.Hostname(), nodes).GetId(), []netip.Prefix{})
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
@ -692,9 +692,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 6)
|
||||
|
||||
requireNodeRouteCount(t, nodes[0], 1, 0, 0)
|
||||
requireNodeRouteCount(t, nodes[1], 1, 1, 1)
|
||||
requireNodeRouteCount(t, nodes[2], 1, 0, 0)
|
||||
requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 0, 0)
|
||||
requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1)
|
||||
requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
|
||||
|
||||
// Verify that the route is announced from subnet router 1
|
||||
clientStatus, err = client.Status()
|
||||
@ -729,9 +729,10 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
// enable the route of subnet router 1, no change expected
|
||||
t.Logf("enabling route in subnet router 1 (%s)", subRouter1.Hostname())
|
||||
t.Logf("both online, expecting r2 (%s) to still be primary (no flapping)", subRouter2.Hostname())
|
||||
r1Node := MustFindNode(subRouter1.Hostname(), nodes)
|
||||
_, err = headscale.ApproveRoutes(
|
||||
nodes[0].GetId(),
|
||||
util.MustStringsToPrefixes(nodes[0].GetAvailableRoutes()),
|
||||
r1Node.GetId(),
|
||||
util.MustStringsToPrefixes(r1Node.GetAvailableRoutes()),
|
||||
)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
@ -740,9 +741,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, nodes, 6)
|
||||
|
||||
requireNodeRouteCount(t, nodes[0], 1, 1, 0)
|
||||
requireNodeRouteCount(t, nodes[1], 1, 1, 1)
|
||||
requireNodeRouteCount(t, nodes[2], 1, 0, 0)
|
||||
requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 0)
|
||||
requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1)
|
||||
requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
|
||||
|
||||
// Verify that the route is announced from subnet router 1
|
||||
clientStatus, err = client.Status()
|
||||
|
@ -223,7 +223,7 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) {
|
||||
|
||||
s.userToNetwork = userToNetwork
|
||||
|
||||
if spec.OIDCUsers != nil && len(spec.OIDCUsers) != 0 {
|
||||
if len(spec.OIDCUsers) != 0 {
|
||||
ttl := defaultAccessTTL
|
||||
if spec.OIDCAccessTTL != 0 {
|
||||
ttl = spec.OIDCAccessTTL
|
||||
|
@ -370,10 +370,12 @@ func TestSSHUserOnlyIsolation(t *testing.T) {
|
||||
}
|
||||
|
||||
func doSSH(t *testing.T, client TailscaleClient, peer TailscaleClient) (string, string, error) {
|
||||
t.Helper()
|
||||
return doSSHWithRetry(t, client, peer, true)
|
||||
}
|
||||
|
||||
func doSSHWithoutRetry(t *testing.T, client TailscaleClient, peer TailscaleClient) (string, string, error) {
|
||||
t.Helper()
|
||||
return doSSHWithRetry(t, client, peer, false)
|
||||
}
|
||||
|
||||
|
@ -319,6 +319,7 @@ func New(
|
||||
dockertestutil.DockerRestartPolicy,
|
||||
dockertestutil.DockerAllowLocalIPv6,
|
||||
dockertestutil.DockerAllowNetworkAdministration,
|
||||
dockertestutil.DockerMemoryLimit,
|
||||
)
|
||||
case "unstable":
|
||||
tailscaleOptions.Repository = "tailscale/tailscale"
|
||||
@ -329,6 +330,7 @@ func New(
|
||||
dockertestutil.DockerRestartPolicy,
|
||||
dockertestutil.DockerAllowLocalIPv6,
|
||||
dockertestutil.DockerAllowNetworkAdministration,
|
||||
dockertestutil.DockerMemoryLimit,
|
||||
)
|
||||
default:
|
||||
tailscaleOptions.Repository = "tailscale/tailscale"
|
||||
@ -339,6 +341,7 @@ func New(
|
||||
dockertestutil.DockerRestartPolicy,
|
||||
dockertestutil.DockerAllowLocalIPv6,
|
||||
dockertestutil.DockerAllowNetworkAdministration,
|
||||
dockertestutil.DockerMemoryLimit,
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -22,11 +22,11 @@ import (
|
||||
|
||||
const (
|
||||
// derpPingTimeout defines the timeout for individual DERP ping operations
|
||||
// Used in DERP connectivity tests to verify relay server communication
|
||||
// Used in DERP connectivity tests to verify relay server communication.
|
||||
derpPingTimeout = 2 * time.Second
|
||||
|
||||
|
||||
// derpPingCount defines the number of ping attempts for DERP connectivity tests
|
||||
// Higher count provides better reliability assessment of DERP connectivity
|
||||
// Higher count provides better reliability assessment of DERP connectivity.
|
||||
derpPingCount = 10
|
||||
)
|
||||
|
||||
@ -317,11 +317,11 @@ func assertValidNetcheck(t *testing.T, client TailscaleClient) {
|
||||
|
||||
// assertCommandOutputContains executes a command with exponential backoff retry until the output
|
||||
// contains the expected string or timeout is reached (10 seconds).
|
||||
// This implements eventual consistency patterns and should be used instead of time.Sleep
|
||||
// This implements eventual consistency patterns and should be used instead of time.Sleep
|
||||
// before executing commands that depend on network state propagation.
|
||||
//
|
||||
// Timeout: 10 seconds with exponential backoff
|
||||
// Use cases: DNS resolution, route propagation, policy updates
|
||||
// Use cases: DNS resolution, route propagation, policy updates.
|
||||
func assertCommandOutputContains(t *testing.T, c TailscaleClient, command []string, contains string) {
|
||||
t.Helper()
|
||||
|
||||
@ -361,10 +361,10 @@ func isSelfClient(client TailscaleClient, addr string) bool {
|
||||
}
|
||||
|
||||
func dockertestMaxWait() time.Duration {
|
||||
wait := 120 * time.Second //nolint
|
||||
wait := 300 * time.Second //nolint
|
||||
|
||||
if util.IsCI() {
|
||||
wait = 300 * time.Second //nolint
|
||||
wait = 600 * time.Second //nolint
|
||||
}
|
||||
|
||||
return wait
|
||||
|
@ -6,7 +6,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
@ -21,7 +20,7 @@ import (
|
||||
const (
|
||||
releasesURL = "https://api.github.com/repos/tailscale/tailscale/releases"
|
||||
rawFileURL = "https://github.com/tailscale/tailscale/raw/refs/tags/%s/tailcfg/tailcfg.go"
|
||||
outputFile = "../capver_generated.go"
|
||||
outputFile = "../../hscontrol/capver/capver_generated.go"
|
||||
)
|
||||
|
||||
type Release struct {
|
||||
@ -105,7 +104,7 @@ func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion
|
||||
sortedVersions := xmaps.Keys(versions)
|
||||
sort.Strings(sortedVersions)
|
||||
for _, version := range sortedVersions {
|
||||
file.WriteString(fmt.Sprintf("\t\"%s\": %d,\n", version, versions[version]))
|
||||
fmt.Fprintf(file, "\t\"%s\": %d,\n", version, versions[version])
|
||||
}
|
||||
file.WriteString("}\n")
|
||||
|
||||
@ -115,16 +114,13 @@ func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion
|
||||
capVarToTailscaleVer := make(map[tailcfg.CapabilityVersion]string)
|
||||
for _, v := range sortedVersions {
|
||||
cap := versions[v]
|
||||
log.Printf("cap for v: %d, %s", cap, v)
|
||||
|
||||
// If it is already set, skip and continue,
|
||||
// we only want the first tailscale vsion per
|
||||
// capability vsion.
|
||||
if _, ok := capVarToTailscaleVer[cap]; ok {
|
||||
log.Printf("Skipping %d, %s", cap, v)
|
||||
continue
|
||||
}
|
||||
log.Printf("Storing %d, %s", cap, v)
|
||||
capVarToTailscaleVer[cap] = v
|
||||
}
|
||||
|
||||
@ -133,7 +129,7 @@ func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion
|
||||
return capsSorted[i] < capsSorted[j]
|
||||
})
|
||||
for _, capVer := range capsSorted {
|
||||
file.WriteString(fmt.Sprintf("\t%d:\t\t\"%s\",\n", capVer, capVarToTailscaleVer[capVer]))
|
||||
fmt.Fprintf(file, "\t%d:\t\t\"%s\",\n", capVer, capVarToTailscaleVer[capVer])
|
||||
}
|
||||
file.WriteString("}\n")
|
||||
|
Loading…
x
Reference in New Issue
Block a user