mirror of
https://github.com/juanfont/headscale.git
synced 2025-07-30 18:43:43 +00:00
mapper: produce map before poll (#2628)
This commit is contained in:
parent
b2a18830ed
commit
a058bf3cd3
@ -17,3 +17,7 @@ LICENSE
|
|||||||
.vscode
|
.vscode
|
||||||
|
|
||||||
*.sock
|
*.sock
|
||||||
|
|
||||||
|
node_modules/
|
||||||
|
package-lock.json
|
||||||
|
package.json
|
||||||
|
55
.github/workflows/check-generated.yml
vendored
Normal file
55
.github/workflows/check-generated.yml
vendored
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
name: Check Generated Files
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
check-generated:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
with:
|
||||||
|
fetch-depth: 2
|
||||||
|
- name: Get changed files
|
||||||
|
id: changed-files
|
||||||
|
uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2
|
||||||
|
with:
|
||||||
|
filters: |
|
||||||
|
files:
|
||||||
|
- '*.nix'
|
||||||
|
- 'go.*'
|
||||||
|
- '**/*.go'
|
||||||
|
- '**/*.proto'
|
||||||
|
- 'buf.gen.yaml'
|
||||||
|
- 'tools/**'
|
||||||
|
- uses: nixbuild/nix-quick-install-action@889f3180bb5f064ee9e3201428d04ae9e41d54ad # v31
|
||||||
|
if: steps.changed-files.outputs.files == 'true'
|
||||||
|
- uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3
|
||||||
|
if: steps.changed-files.outputs.files == 'true'
|
||||||
|
with:
|
||||||
|
primary-key: nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix', '**/flake.lock') }}
|
||||||
|
restore-prefixes-first-match: nix-${{ runner.os }}-${{ runner.arch }}
|
||||||
|
|
||||||
|
- name: Run make generate
|
||||||
|
if: steps.changed-files.outputs.files == 'true'
|
||||||
|
run: nix develop --command -- make generate
|
||||||
|
|
||||||
|
- name: Check for uncommitted changes
|
||||||
|
if: steps.changed-files.outputs.files == 'true'
|
||||||
|
run: |
|
||||||
|
if ! git diff --exit-code; then
|
||||||
|
echo "❌ Generated files are not up to date!"
|
||||||
|
echo "Please run 'make generate' and commit the changes."
|
||||||
|
exit 1
|
||||||
|
else
|
||||||
|
echo "✅ All generated files are up to date."
|
||||||
|
fi
|
@ -77,7 +77,7 @@ jobs:
|
|||||||
attempt_delay: 300000 # 5 min
|
attempt_delay: 300000 # 5 min
|
||||||
attempt_limit: 2
|
attempt_limit: 2
|
||||||
command: |
|
command: |
|
||||||
nix develop --command -- hi run "^${{ inputs.test }}$" \
|
nix develop --command -- hi run --stats --ts-memory-limit=300 --hs-memory-limit=500 "^${{ inputs.test }}$" \
|
||||||
--timeout=120m \
|
--timeout=120m \
|
||||||
${{ inputs.postgres_flag }}
|
${{ inputs.postgres_flag }}
|
||||||
- uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
- uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
|
||||||
|
7
.gitignore
vendored
7
.gitignore
vendored
@ -1,6 +1,9 @@
|
|||||||
ignored/
|
ignored/
|
||||||
tailscale/
|
tailscale/
|
||||||
.vscode/
|
.vscode/
|
||||||
|
.claude/
|
||||||
|
|
||||||
|
*.prof
|
||||||
|
|
||||||
# Binaries for programs and plugins
|
# Binaries for programs and plugins
|
||||||
*.exe
|
*.exe
|
||||||
@ -46,3 +49,7 @@ integration_test/etc/config.dump.yaml
|
|||||||
/site
|
/site
|
||||||
|
|
||||||
__debug_bin
|
__debug_bin
|
||||||
|
|
||||||
|
node_modules/
|
||||||
|
package-lock.json
|
||||||
|
package.json
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
## Next
|
## Next
|
||||||
|
|
||||||
|
**Minimum supported Tailscale client version: v1.64.0**
|
||||||
|
|
||||||
### Database integrity improvements
|
### Database integrity improvements
|
||||||
|
|
||||||
This release includes a significant database migration that addresses longstanding
|
This release includes a significant database migration that addresses longstanding
|
||||||
|
395
CLAUDE.md
Normal file
395
CLAUDE.md
Normal file
@ -0,0 +1,395 @@
|
|||||||
|
# CLAUDE.md
|
||||||
|
|
||||||
|
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Headscale is an open-source implementation of the Tailscale control server written in Go. It provides self-hosted coordination for Tailscale networks (tailnets), managing node registration, IP allocation, policy enforcement, and DERP routing.
|
||||||
|
|
||||||
|
## Development Commands
|
||||||
|
|
||||||
|
### Quick Setup
|
||||||
|
```bash
|
||||||
|
# Recommended: Use Nix for dependency management
|
||||||
|
nix develop
|
||||||
|
|
||||||
|
# Full development workflow
|
||||||
|
make dev # runs fmt + lint + test + build
|
||||||
|
```
|
||||||
|
|
||||||
|
### Essential Commands
|
||||||
|
```bash
|
||||||
|
# Build headscale binary
|
||||||
|
make build
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
make test
|
||||||
|
go test ./... # All unit tests
|
||||||
|
go test -race ./... # With race detection
|
||||||
|
|
||||||
|
# Run specific integration test
|
||||||
|
go run ./cmd/hi run "TestName" --postgres
|
||||||
|
|
||||||
|
# Code formatting and linting
|
||||||
|
make fmt # Format all code (Go, docs, proto)
|
||||||
|
make lint # Lint all code (Go, proto)
|
||||||
|
make fmt-go # Format Go code only
|
||||||
|
make lint-go # Lint Go code only
|
||||||
|
|
||||||
|
# Protocol buffer generation (after modifying proto/)
|
||||||
|
make generate
|
||||||
|
|
||||||
|
# Clean build artifacts
|
||||||
|
make clean
|
||||||
|
```
|
||||||
|
|
||||||
|
### Integration Testing
|
||||||
|
```bash
|
||||||
|
# Use the hi (Headscale Integration) test runner
|
||||||
|
go run ./cmd/hi doctor # Check system requirements
|
||||||
|
go run ./cmd/hi run "TestPattern" # Run specific test
|
||||||
|
go run ./cmd/hi run "TestPattern" --postgres # With PostgreSQL backend
|
||||||
|
|
||||||
|
# Test artifacts are saved to control_logs/ with logs and debug data
|
||||||
|
```
|
||||||
|
|
||||||
|
## Project Structure & Architecture
|
||||||
|
|
||||||
|
### Top-Level Organization
|
||||||
|
|
||||||
|
```
|
||||||
|
headscale/
|
||||||
|
├── cmd/ # Command-line applications
|
||||||
|
│ ├── headscale/ # Main headscale server binary
|
||||||
|
│ └── hi/ # Headscale Integration test runner
|
||||||
|
├── hscontrol/ # Core control plane logic
|
||||||
|
├── integration/ # End-to-end Docker-based tests
|
||||||
|
├── proto/ # Protocol buffer definitions
|
||||||
|
├── gen/ # Generated code (protobuf)
|
||||||
|
├── docs/ # Documentation
|
||||||
|
└── packaging/ # Distribution packaging
|
||||||
|
```
|
||||||
|
|
||||||
|
### Core Packages (`hscontrol/`)
|
||||||
|
|
||||||
|
**Main Server (`hscontrol/`)**
|
||||||
|
- `app.go`: Application setup, dependency injection, server lifecycle
|
||||||
|
- `handlers.go`: HTTP/gRPC API endpoints for management operations
|
||||||
|
- `grpcv1.go`: gRPC service implementation for headscale API
|
||||||
|
- `poll.go`: **Critical** - Handles Tailscale MapRequest/MapResponse protocol
|
||||||
|
- `noise.go`: Noise protocol implementation for secure client communication
|
||||||
|
- `auth.go`: Authentication flows (web, OIDC, command-line)
|
||||||
|
- `oidc.go`: OpenID Connect integration for user authentication
|
||||||
|
|
||||||
|
**State Management (`hscontrol/state/`)**
|
||||||
|
- `state.go`: Central coordinator for all subsystems (database, policy, IP allocation, DERP)
|
||||||
|
- `node_store.go`: **Performance-critical** - In-memory cache with copy-on-write semantics
|
||||||
|
- Thread-safe operations with deadlock detection
|
||||||
|
- Coordinates between database persistence and real-time operations
|
||||||
|
|
||||||
|
**Database Layer (`hscontrol/db/`)**
|
||||||
|
- `db.go`: Database abstraction, GORM setup, migration management
|
||||||
|
- `node.go`: Node lifecycle, registration, expiration, IP assignment
|
||||||
|
- `users.go`: User management, namespace isolation
|
||||||
|
- `api_key.go`: API authentication tokens
|
||||||
|
- `preauth_keys.go`: Pre-authentication keys for automated node registration
|
||||||
|
- `ip.go`: IP address allocation and management
|
||||||
|
- `policy.go`: Policy storage and retrieval
|
||||||
|
- Schema migrations in `schema.sql` with extensive test data coverage
|
||||||
|
|
||||||
|
**Policy Engine (`hscontrol/policy/`)**
|
||||||
|
- `policy.go`: Core ACL evaluation logic, HuJSON parsing
|
||||||
|
- `v2/`: Next-generation policy system with improved filtering
|
||||||
|
- `matcher/`: ACL rule matching and evaluation engine
|
||||||
|
- Determines peer visibility, route approval, and network access rules
|
||||||
|
- Supports both file-based and database-stored policies
|
||||||
|
|
||||||
|
**Network Management (`hscontrol/`)**
|
||||||
|
- `derp/`: DERP (Designated Encrypted Relay for Packets) server implementation
|
||||||
|
- NAT traversal when direct connections fail
|
||||||
|
- Fallback relay for firewall-restricted environments
|
||||||
|
- `mapper/`: Converts internal Headscale state to Tailscale's wire protocol format
|
||||||
|
- `tail.go`: Tailscale-specific data structure generation
|
||||||
|
- `routes/`: Subnet route management and primary route selection
|
||||||
|
- `dns/`: DNS record management and MagicDNS implementation
|
||||||
|
|
||||||
|
**Utilities & Support (`hscontrol/`)**
|
||||||
|
- `types/`: Core data structures, configuration, validation
|
||||||
|
- `util/`: Helper functions for networking, DNS, key management
|
||||||
|
- `templates/`: Client configuration templates (Apple, Windows, etc.)
|
||||||
|
- `notifier/`: Event notification system for real-time updates
|
||||||
|
- `metrics.go`: Prometheus metrics collection
|
||||||
|
- `capver/`: Tailscale capability version management
|
||||||
|
|
||||||
|
### Key Subsystem Interactions
|
||||||
|
|
||||||
|
**Node Registration Flow**
|
||||||
|
1. **Client Connection**: `noise.go` handles secure protocol handshake
|
||||||
|
2. **Authentication**: `auth.go` validates credentials (web/OIDC/preauth)
|
||||||
|
3. **State Creation**: `state.go` coordinates IP allocation via `db/ip.go`
|
||||||
|
4. **Storage**: `db/node.go` persists node, `NodeStore` caches in memory
|
||||||
|
5. **Network Setup**: `mapper/` generates initial Tailscale network map
|
||||||
|
|
||||||
|
**Ongoing Operations**
|
||||||
|
1. **Poll Requests**: `poll.go` receives periodic client updates
|
||||||
|
2. **State Updates**: `NodeStore` maintains real-time node information
|
||||||
|
3. **Policy Application**: `policy/` evaluates ACL rules for peer relationships
|
||||||
|
4. **Map Distribution**: `mapper/` sends network topology to all affected clients
|
||||||
|
|
||||||
|
**Route Management**
|
||||||
|
1. **Advertisement**: Clients announce routes via `poll.go` Hostinfo updates
|
||||||
|
2. **Storage**: `db/` persists routes, `NodeStore` caches for performance
|
||||||
|
3. **Approval**: `policy/` auto-approves routes based on ACL rules
|
||||||
|
4. **Distribution**: `routes/` selects primary routes, `mapper/` distributes to peers
|
||||||
|
|
||||||
|
### Command-Line Tools (`cmd/`)
|
||||||
|
|
||||||
|
**Main Server (`cmd/headscale/`)**
|
||||||
|
- `headscale.go`: CLI parsing, configuration loading, server startup
|
||||||
|
- Supports daemon mode, CLI operations (user/node management), database operations
|
||||||
|
|
||||||
|
**Integration Test Runner (`cmd/hi/`)**
|
||||||
|
- `main.go`: Test execution framework with Docker orchestration
|
||||||
|
- `run.go`: Individual test execution with artifact collection
|
||||||
|
- `doctor.go`: System requirements validation
|
||||||
|
- `docker.go`: Container lifecycle management
|
||||||
|
- Essential for validating changes against real Tailscale clients
|
||||||
|
|
||||||
|
### Generated & External Code
|
||||||
|
|
||||||
|
**Protocol Buffers (`proto/` → `gen/`)**
|
||||||
|
- Defines gRPC API for headscale management operations
|
||||||
|
- Client libraries can generate from these definitions
|
||||||
|
- Run `make generate` after modifying `.proto` files
|
||||||
|
|
||||||
|
**Integration Testing (`integration/`)**
|
||||||
|
- `scenario.go`: Docker test environment setup
|
||||||
|
- `tailscale.go`: Tailscale client container management
|
||||||
|
- Individual test files for specific functionality areas
|
||||||
|
- Real end-to-end validation with network isolation
|
||||||
|
|
||||||
|
### Critical Performance Paths
|
||||||
|
|
||||||
|
**High-Frequency Operations**
|
||||||
|
1. **MapRequest Processing** (`poll.go`): Every 15-60 seconds per client
|
||||||
|
2. **NodeStore Reads** (`node_store.go`): Every operation requiring node data
|
||||||
|
3. **Policy Evaluation** (`policy/`): On every peer relationship calculation
|
||||||
|
4. **Route Lookups** (`routes/`): During network map generation
|
||||||
|
|
||||||
|
**Database Write Patterns**
|
||||||
|
- **Frequent**: Node heartbeats, endpoint updates, route changes
|
||||||
|
- **Moderate**: User operations, policy updates, API key management
|
||||||
|
- **Rare**: Schema migrations, bulk operations
|
||||||
|
|
||||||
|
### Configuration & Deployment
|
||||||
|
|
||||||
|
**Configuration** (`hscontrol/types/config.go`)**
|
||||||
|
- Database connection settings (SQLite/PostgreSQL)
|
||||||
|
- Network configuration (IP ranges, DNS settings)
|
||||||
|
- Policy mode (file vs database)
|
||||||
|
- DERP relay configuration
|
||||||
|
- OIDC provider settings
|
||||||
|
|
||||||
|
**Key Dependencies**
|
||||||
|
- **GORM**: Database ORM with migration support
|
||||||
|
- **Tailscale Libraries**: Core networking and protocol code
|
||||||
|
- **Zerolog**: Structured logging throughout the application
|
||||||
|
- **Buf**: Protocol buffer toolchain for code generation
|
||||||
|
|
||||||
|
### Development Workflow Integration
|
||||||
|
|
||||||
|
The architecture supports incremental development:
|
||||||
|
- **Unit Tests**: Focus on individual packages (`*_test.go` files)
|
||||||
|
- **Integration Tests**: Validate cross-component interactions
|
||||||
|
- **Database Tests**: Extensive migration and data integrity validation
|
||||||
|
- **Policy Tests**: ACL rule evaluation and edge cases
|
||||||
|
- **Performance Tests**: NodeStore and high-frequency operation validation
|
||||||
|
|
||||||
|
## Integration Test System
|
||||||
|
|
||||||
|
### Overview
|
||||||
|
Integration tests use Docker containers running real Tailscale clients against a Headscale server. Tests validate end-to-end functionality including routing, ACLs, node lifecycle, and network coordination.
|
||||||
|
|
||||||
|
### Running Integration Tests
|
||||||
|
|
||||||
|
**System Requirements**
|
||||||
|
```bash
|
||||||
|
# Check if your system is ready
|
||||||
|
go run ./cmd/hi doctor
|
||||||
|
```
|
||||||
|
This verifies Docker, Go, required images, and disk space.
|
||||||
|
|
||||||
|
**Test Execution Patterns**
|
||||||
|
```bash
|
||||||
|
# Run a single test (recommended for development)
|
||||||
|
go run ./cmd/hi run "TestSubnetRouterMultiNetwork"
|
||||||
|
|
||||||
|
# Run with PostgreSQL backend (for database-heavy tests)
|
||||||
|
go run ./cmd/hi run "TestExpireNode" --postgres
|
||||||
|
|
||||||
|
# Run multiple tests with pattern matching
|
||||||
|
go run ./cmd/hi run "TestSubnet*"
|
||||||
|
|
||||||
|
# Run all integration tests (CI/full validation)
|
||||||
|
go test ./integration -timeout 30m
|
||||||
|
```
|
||||||
|
|
||||||
|
**Test Categories & Timing**
|
||||||
|
- **Fast tests** (< 2 min): Basic functionality, CLI operations
|
||||||
|
- **Medium tests** (2-5 min): Route management, ACL validation
|
||||||
|
- **Slow tests** (5+ min): Node expiration, HA failover
|
||||||
|
- **Long-running tests** (10+ min): `TestNodeOnlineStatus` (12 min duration)
|
||||||
|
|
||||||
|
### Test Infrastructure
|
||||||
|
|
||||||
|
**Docker Setup**
|
||||||
|
- Headscale server container with configurable database backend
|
||||||
|
- Multiple Tailscale client containers with different versions
|
||||||
|
- Isolated networks per test scenario
|
||||||
|
- Automatic cleanup after test completion
|
||||||
|
|
||||||
|
**Test Artifacts**
|
||||||
|
All test runs save artifacts to `control_logs/TIMESTAMP-ID/`:
|
||||||
|
```
|
||||||
|
control_logs/20250713-213106-iajsux/
|
||||||
|
├── hs-testname-abc123.stderr.log # Headscale server logs
|
||||||
|
├── hs-testname-abc123.stdout.log
|
||||||
|
├── hs-testname-abc123.db # Database snapshot
|
||||||
|
├── hs-testname-abc123_metrics.txt # Prometheus metrics
|
||||||
|
├── hs-testname-abc123-mapresponses/ # Protocol debug data
|
||||||
|
├── ts-client-xyz789.stderr.log # Tailscale client logs
|
||||||
|
├── ts-client-xyz789.stdout.log
|
||||||
|
└── ts-client-xyz789_status.json # Client status dump
|
||||||
|
```
|
||||||
|
|
||||||
|
### Test Development Guidelines
|
||||||
|
|
||||||
|
**Timing Considerations**
|
||||||
|
Integration tests involve real network operations and Docker container lifecycle:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// ❌ Wrong: Immediate assertions after async operations
|
||||||
|
client.Execute([]string{"tailscale", "set", "--advertise-routes=10.0.0.0/24"})
|
||||||
|
nodes, _ := headscale.ListNodes()
|
||||||
|
require.Len(t, nodes[0].GetAvailableRoutes(), 1) // May fail due to timing
|
||||||
|
|
||||||
|
// ✅ Correct: Wait for async operations to complete
|
||||||
|
client.Execute([]string{"tailscale", "set", "--advertise-routes=10.0.0.0/24"})
|
||||||
|
require.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||||
|
nodes, err := headscale.ListNodes()
|
||||||
|
assert.NoError(c, err)
|
||||||
|
assert.Len(c, nodes[0].GetAvailableRoutes(), 1)
|
||||||
|
}, 10*time.Second, 100*time.Millisecond, "route should be advertised")
|
||||||
|
```
|
||||||
|
|
||||||
|
**Common Test Patterns**
|
||||||
|
- **Route Advertisement**: Use `EventuallyWithT` for route propagation
|
||||||
|
- **Node State Changes**: Wait for NodeStore synchronization
|
||||||
|
- **ACL Policy Changes**: Allow time for policy recalculation
|
||||||
|
- **Network Connectivity**: Use ping tests with retries
|
||||||
|
|
||||||
|
**Test Data Management**
|
||||||
|
```go
|
||||||
|
// Node identification: Don't assume array ordering
|
||||||
|
expectedRoutes := map[string]string{"1": "10.33.0.0/16"}
|
||||||
|
for _, node := range nodes {
|
||||||
|
nodeIDStr := fmt.Sprintf("%d", node.GetId())
|
||||||
|
if route, shouldHaveRoute := expectedRoutes[nodeIDStr]; shouldHaveRoute {
|
||||||
|
// Test the node that should have the route
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Troubleshooting Integration Tests
|
||||||
|
|
||||||
|
**Common Failure Patterns**
|
||||||
|
1. **Timing Issues**: Test assertions run before async operations complete
|
||||||
|
- **Solution**: Use `EventuallyWithT` with appropriate timeouts
|
||||||
|
- **Timeout Guidelines**: 3-5s for route operations, 10s for complex scenarios
|
||||||
|
|
||||||
|
2. **Infrastructure Problems**: Disk space, Docker issues, network conflicts
|
||||||
|
- **Check**: `go run ./cmd/hi doctor` for system health
|
||||||
|
- **Clean**: Remove old test containers and networks
|
||||||
|
|
||||||
|
3. **NodeStore Synchronization**: Tests expecting immediate data availability
|
||||||
|
- **Key Points**: Route advertisements must propagate through poll requests
|
||||||
|
- **Fix**: Wait for NodeStore updates after Hostinfo changes
|
||||||
|
|
||||||
|
4. **Database Backend Differences**: SQLite vs PostgreSQL behavior differences
|
||||||
|
- **Use**: `--postgres` flag for database-intensive tests
|
||||||
|
- **Note**: Some timing characteristics differ between backends
|
||||||
|
|
||||||
|
**Debugging Failed Tests**
|
||||||
|
1. **Check test artifacts** in `control_logs/` for detailed logs
|
||||||
|
2. **Examine MapResponse JSON** files for protocol-level debugging
|
||||||
|
3. **Review Headscale stderr logs** for server-side error messages
|
||||||
|
4. **Check Tailscale client status** for network-level issues
|
||||||
|
|
||||||
|
**Resource Management**
|
||||||
|
- Tests require significant disk space (each run ~100MB of logs)
|
||||||
|
- Docker containers are cleaned up automatically on success
|
||||||
|
- Failed tests may leave containers running - clean manually if needed
|
||||||
|
- Use `docker system prune` periodically to reclaim space
|
||||||
|
|
||||||
|
### Best Practices for Test Modifications
|
||||||
|
|
||||||
|
1. **Always test locally** before committing integration test changes
|
||||||
|
2. **Use appropriate timeouts** - too short causes flaky tests, too long slows CI
|
||||||
|
3. **Clean up properly** - ensure tests don't leave persistent state
|
||||||
|
4. **Handle both success and failure paths** in test scenarios
|
||||||
|
5. **Document timing requirements** for complex test scenarios
|
||||||
|
|
||||||
|
## NodeStore Implementation Details
|
||||||
|
|
||||||
|
**Key Insight from Recent Work**: The NodeStore is a critical performance optimization that caches node data in memory while ensuring consistency with the database. When working with route advertisements or node state changes:
|
||||||
|
|
||||||
|
1. **Timing Considerations**: Route advertisements need time to propagate from clients to server. Use `require.EventuallyWithT()` patterns in tests instead of immediate assertions.
|
||||||
|
|
||||||
|
2. **Synchronization Points**: NodeStore updates happen at specific points like `poll.go:420` after Hostinfo changes. Ensure these are maintained when modifying the polling logic.
|
||||||
|
|
||||||
|
3. **Peer Visibility**: The NodeStore's `peersFunc` determines which nodes are visible to each other. Policy-based filtering is separate from monitoring visibility - expired nodes should remain visible for debugging but marked as expired.
|
||||||
|
|
||||||
|
## Testing Guidelines
|
||||||
|
|
||||||
|
### Integration Test Patterns
|
||||||
|
```go
|
||||||
|
// Use EventuallyWithT for async operations
|
||||||
|
require.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||||
|
nodes, err := headscale.ListNodes()
|
||||||
|
assert.NoError(c, err)
|
||||||
|
// Check expected state
|
||||||
|
}, 10*time.Second, 100*time.Millisecond, "description")
|
||||||
|
|
||||||
|
// Node route checking by actual node properties, not array position
|
||||||
|
var routeNode *v1.Node
|
||||||
|
for _, node := range nodes {
|
||||||
|
if nodeIDStr := fmt.Sprintf("%d", node.GetId()); expectedRoutes[nodeIDStr] != "" {
|
||||||
|
routeNode = node
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Running Problematic Tests
|
||||||
|
- Some tests require significant time (e.g., `TestNodeOnlineStatus` runs for 12 minutes)
|
||||||
|
- Infrastructure issues like disk space can cause test failures unrelated to code changes
|
||||||
|
- Use `--postgres` flag when testing database-heavy scenarios
|
||||||
|
|
||||||
|
## Important Notes
|
||||||
|
|
||||||
|
- **Dependencies**: Use `nix develop` for consistent toolchain (Go, buf, protobuf tools, linting)
|
||||||
|
- **Protocol Buffers**: Changes to `proto/` require `make generate` and should be committed separately
|
||||||
|
- **Code Style**: Enforced via golangci-lint with golines (width 88) and gofumpt formatting
|
||||||
|
- **Database**: Supports both SQLite (development) and PostgreSQL (production/testing)
|
||||||
|
- **Integration Tests**: Require Docker and can consume significant disk space
|
||||||
|
- **Performance**: NodeStore optimizations are critical for scale - be careful with changes to state management
|
||||||
|
|
||||||
|
## Debugging Integration Tests
|
||||||
|
|
||||||
|
Test artifacts are preserved in `control_logs/TIMESTAMP-ID/` including:
|
||||||
|
- Headscale server logs (stderr/stdout)
|
||||||
|
- Tailscale client logs and status
|
||||||
|
- Database dumps and network captures
|
||||||
|
- MapResponse JSON files for protocol debugging
|
||||||
|
|
||||||
|
When tests fail, check these artifacts first before assuming code issues.
|
7
Makefile
7
Makefile
@ -87,10 +87,9 @@ lint-proto: check-deps $(PROTO_SOURCES)
|
|||||||
|
|
||||||
# Code generation
|
# Code generation
|
||||||
.PHONY: generate
|
.PHONY: generate
|
||||||
generate: check-deps $(PROTO_SOURCES)
|
generate: check-deps
|
||||||
@echo "Generating code from Protocol Buffers..."
|
@echo "Generating code..."
|
||||||
rm -rf gen
|
go generate ./...
|
||||||
buf generate proto
|
|
||||||
|
|
||||||
# Clean targets
|
# Clean targets
|
||||||
.PHONY: clean
|
.PHONY: clean
|
||||||
|
@ -212,13 +212,10 @@ var listUsersCmd = &cobra.Command{
|
|||||||
switch {
|
switch {
|
||||||
case id > 0:
|
case id > 0:
|
||||||
request.Id = uint64(id)
|
request.Id = uint64(id)
|
||||||
break
|
|
||||||
case username != "":
|
case username != "":
|
||||||
request.Name = username
|
request.Name = username
|
||||||
break
|
|
||||||
case email != "":
|
case email != "":
|
||||||
request.Email = email
|
request.Email = email
|
||||||
break
|
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := client.ListUsers(ctx, request)
|
response, err := client.ListUsers(ctx, request)
|
||||||
|
@ -90,6 +90,32 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
|||||||
|
|
||||||
log.Printf("Starting test: %s", config.TestPattern)
|
log.Printf("Starting test: %s", config.TestPattern)
|
||||||
|
|
||||||
|
// Start stats collection for container resource monitoring (if enabled)
|
||||||
|
var statsCollector *StatsCollector
|
||||||
|
if config.Stats {
|
||||||
|
var err error
|
||||||
|
statsCollector, err = NewStatsCollector()
|
||||||
|
if err != nil {
|
||||||
|
if config.Verbose {
|
||||||
|
log.Printf("Warning: failed to create stats collector: %v", err)
|
||||||
|
}
|
||||||
|
statsCollector = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if statsCollector != nil {
|
||||||
|
defer statsCollector.Close()
|
||||||
|
|
||||||
|
// Start stats collection immediately - no need for complex retry logic
|
||||||
|
// The new implementation monitors Docker events and will catch containers as they start
|
||||||
|
if err := statsCollector.StartCollection(ctx, runID, config.Verbose); err != nil {
|
||||||
|
if config.Verbose {
|
||||||
|
log.Printf("Warning: failed to start stats collection: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
defer statsCollector.StopCollection()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
exitCode, err := streamAndWait(ctx, cli, resp.ID)
|
exitCode, err := streamAndWait(ctx, cli, resp.ID)
|
||||||
|
|
||||||
// Ensure all containers have finished and logs are flushed before extracting artifacts
|
// Ensure all containers have finished and logs are flushed before extracting artifacts
|
||||||
@ -105,6 +131,20 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
|||||||
// Always list control files regardless of test outcome
|
// Always list control files regardless of test outcome
|
||||||
listControlFiles(logsDir)
|
listControlFiles(logsDir)
|
||||||
|
|
||||||
|
// Print stats summary and check memory limits if enabled
|
||||||
|
if config.Stats && statsCollector != nil {
|
||||||
|
violations := statsCollector.PrintSummaryAndCheckLimits(config.HSMemoryLimit, config.TSMemoryLimit)
|
||||||
|
if len(violations) > 0 {
|
||||||
|
log.Printf("MEMORY LIMIT VIOLATIONS DETECTED:")
|
||||||
|
log.Printf("=================================")
|
||||||
|
for _, violation := range violations {
|
||||||
|
log.Printf("Container %s exceeded memory limit: %.1f MB > %.1f MB",
|
||||||
|
violation.ContainerName, violation.MaxMemoryMB, violation.LimitMB)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("test failed: %d container(s) exceeded memory limits", len(violations))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
shouldCleanup := config.CleanAfter && (!config.KeepOnFailure || exitCode == 0)
|
shouldCleanup := config.CleanAfter && (!config.KeepOnFailure || exitCode == 0)
|
||||||
if shouldCleanup {
|
if shouldCleanup {
|
||||||
if config.Verbose {
|
if config.Verbose {
|
||||||
@ -379,10 +419,37 @@ func getDockerSocketPath() string {
|
|||||||
return "/var/run/docker.sock"
|
return "/var/run/docker.sock"
|
||||||
}
|
}
|
||||||
|
|
||||||
// ensureImageAvailable pulls the specified Docker image to ensure it's available.
|
// checkImageAvailableLocally checks if the specified Docker image is available locally.
|
||||||
|
func checkImageAvailableLocally(ctx context.Context, cli *client.Client, imageName string) (bool, error) {
|
||||||
|
_, _, err := cli.ImageInspectWithRaw(ctx, imageName)
|
||||||
|
if err != nil {
|
||||||
|
if client.IsErrNotFound(err) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, fmt.Errorf("failed to inspect image %s: %w", imageName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureImageAvailable checks if the image is available locally first, then pulls if needed.
|
||||||
func ensureImageAvailable(ctx context.Context, cli *client.Client, imageName string, verbose bool) error {
|
func ensureImageAvailable(ctx context.Context, cli *client.Client, imageName string, verbose bool) error {
|
||||||
|
// First check if image is available locally
|
||||||
|
available, err := checkImageAvailableLocally(ctx, cli, imageName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to check local image availability: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if available {
|
||||||
|
if verbose {
|
||||||
|
log.Printf("Image %s is available locally", imageName)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Image not available locally, try to pull it
|
||||||
if verbose {
|
if verbose {
|
||||||
log.Printf("Pulling image %s...", imageName)
|
log.Printf("Image %s not found locally, pulling...", imageName)
|
||||||
}
|
}
|
||||||
|
|
||||||
reader, err := cli.ImagePull(ctx, imageName, image.PullOptions{})
|
reader, err := cli.ImagePull(ctx, imageName, image.PullOptions{})
|
||||||
|
@ -190,7 +190,7 @@ func checkDockerSocket(ctx context.Context) DoctorResult {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkGolangImage verifies we can access the golang Docker image.
|
// checkGolangImage verifies the golang Docker image is available locally or can be pulled.
|
||||||
func checkGolangImage(ctx context.Context) DoctorResult {
|
func checkGolangImage(ctx context.Context) DoctorResult {
|
||||||
cli, err := createDockerClient()
|
cli, err := createDockerClient()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -205,17 +205,40 @@ func checkGolangImage(ctx context.Context) DoctorResult {
|
|||||||
goVersion := detectGoVersion()
|
goVersion := detectGoVersion()
|
||||||
imageName := "golang:" + goVersion
|
imageName := "golang:" + goVersion
|
||||||
|
|
||||||
// Check if we can pull the image
|
// First check if image is available locally
|
||||||
|
available, err := checkImageAvailableLocally(ctx, cli, imageName)
|
||||||
|
if err != nil {
|
||||||
|
return DoctorResult{
|
||||||
|
Name: "Golang Image",
|
||||||
|
Status: "FAIL",
|
||||||
|
Message: fmt.Sprintf("Cannot check golang image %s: %v", imageName, err),
|
||||||
|
Suggestions: []string{
|
||||||
|
"Check Docker daemon status",
|
||||||
|
"Try: docker images | grep golang",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if available {
|
||||||
|
return DoctorResult{
|
||||||
|
Name: "Golang Image",
|
||||||
|
Status: "PASS",
|
||||||
|
Message: fmt.Sprintf("Golang image %s is available locally", imageName),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Image not available locally, try to pull it
|
||||||
err = ensureImageAvailable(ctx, cli, imageName, false)
|
err = ensureImageAvailable(ctx, cli, imageName, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return DoctorResult{
|
return DoctorResult{
|
||||||
Name: "Golang Image",
|
Name: "Golang Image",
|
||||||
Status: "FAIL",
|
Status: "FAIL",
|
||||||
Message: fmt.Sprintf("Cannot pull golang image %s: %v", imageName, err),
|
Message: fmt.Sprintf("Golang image %s not available locally and cannot pull: %v", imageName, err),
|
||||||
Suggestions: []string{
|
Suggestions: []string{
|
||||||
"Check internet connectivity",
|
"Check internet connectivity",
|
||||||
"Verify Docker Hub access",
|
"Verify Docker Hub access",
|
||||||
"Try: docker pull " + imageName,
|
"Try: docker pull " + imageName,
|
||||||
|
"Or run tests offline if image was pulled previously",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -223,7 +246,7 @@ func checkGolangImage(ctx context.Context) DoctorResult {
|
|||||||
return DoctorResult{
|
return DoctorResult{
|
||||||
Name: "Golang Image",
|
Name: "Golang Image",
|
||||||
Status: "PASS",
|
Status: "PASS",
|
||||||
Message: fmt.Sprintf("Golang image %s is available", imageName),
|
Message: fmt.Sprintf("Golang image %s is now available", imageName),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -24,6 +24,9 @@ type RunConfig struct {
|
|||||||
KeepOnFailure bool `flag:"keep-on-failure,default=false,Keep containers on test failure"`
|
KeepOnFailure bool `flag:"keep-on-failure,default=false,Keep containers on test failure"`
|
||||||
LogsDir string `flag:"logs-dir,default=control_logs,Control logs directory"`
|
LogsDir string `flag:"logs-dir,default=control_logs,Control logs directory"`
|
||||||
Verbose bool `flag:"verbose,default=false,Verbose output"`
|
Verbose bool `flag:"verbose,default=false,Verbose output"`
|
||||||
|
Stats bool `flag:"stats,default=false,Collect and display container resource usage statistics"`
|
||||||
|
HSMemoryLimit float64 `flag:"hs-memory-limit,default=0,Fail test if any Headscale container exceeds this memory limit in MB (0 = disabled)"`
|
||||||
|
TSMemoryLimit float64 `flag:"ts-memory-limit,default=0,Fail test if any Tailscale container exceeds this memory limit in MB (0 = disabled)"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// runIntegrationTest executes the integration test workflow.
|
// runIntegrationTest executes the integration test workflow.
|
||||||
|
468
cmd/hi/stats.go
Normal file
468
cmd/hi/stats.go
Normal file
@ -0,0 +1,468 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/docker/docker/api/types"
|
||||||
|
"github.com/docker/docker/api/types/container"
|
||||||
|
"github.com/docker/docker/api/types/events"
|
||||||
|
"github.com/docker/docker/api/types/filters"
|
||||||
|
"github.com/docker/docker/client"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ContainerStats represents statistics for a single container
|
||||||
|
type ContainerStats struct {
|
||||||
|
ContainerID string
|
||||||
|
ContainerName string
|
||||||
|
Stats []StatsSample
|
||||||
|
mutex sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatsSample represents a single stats measurement
|
||||||
|
type StatsSample struct {
|
||||||
|
Timestamp time.Time
|
||||||
|
CPUUsage float64 // CPU usage percentage
|
||||||
|
MemoryMB float64 // Memory usage in MB
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatsCollector manages collection of container statistics
|
||||||
|
type StatsCollector struct {
|
||||||
|
client *client.Client
|
||||||
|
containers map[string]*ContainerStats
|
||||||
|
stopChan chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
|
mutex sync.RWMutex
|
||||||
|
collectionStarted bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewStatsCollector creates a new stats collector instance
|
||||||
|
func NewStatsCollector() (*StatsCollector, error) {
|
||||||
|
cli, err := createDockerClient()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create Docker client: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &StatsCollector{
|
||||||
|
client: cli,
|
||||||
|
containers: make(map[string]*ContainerStats),
|
||||||
|
stopChan: make(chan struct{}),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartCollection begins monitoring all containers and collecting stats for hs- and ts- containers with matching run ID
|
||||||
|
func (sc *StatsCollector) StartCollection(ctx context.Context, runID string, verbose bool) error {
|
||||||
|
sc.mutex.Lock()
|
||||||
|
defer sc.mutex.Unlock()
|
||||||
|
|
||||||
|
if sc.collectionStarted {
|
||||||
|
return fmt.Errorf("stats collection already started")
|
||||||
|
}
|
||||||
|
|
||||||
|
sc.collectionStarted = true
|
||||||
|
|
||||||
|
// Start monitoring existing containers
|
||||||
|
sc.wg.Add(1)
|
||||||
|
go sc.monitorExistingContainers(ctx, runID, verbose)
|
||||||
|
|
||||||
|
// Start Docker events monitoring for new containers
|
||||||
|
sc.wg.Add(1)
|
||||||
|
go sc.monitorDockerEvents(ctx, runID, verbose)
|
||||||
|
|
||||||
|
if verbose {
|
||||||
|
log.Printf("Started container monitoring for run ID %s", runID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopCollection stops all stats collection
|
||||||
|
func (sc *StatsCollector) StopCollection() {
|
||||||
|
// Check if already stopped without holding lock
|
||||||
|
sc.mutex.RLock()
|
||||||
|
if !sc.collectionStarted {
|
||||||
|
sc.mutex.RUnlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sc.mutex.RUnlock()
|
||||||
|
|
||||||
|
// Signal stop to all goroutines
|
||||||
|
close(sc.stopChan)
|
||||||
|
|
||||||
|
// Wait for all goroutines to finish
|
||||||
|
sc.wg.Wait()
|
||||||
|
|
||||||
|
// Mark as stopped
|
||||||
|
sc.mutex.Lock()
|
||||||
|
sc.collectionStarted = false
|
||||||
|
sc.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// monitorExistingContainers checks for existing containers that match our criteria
|
||||||
|
func (sc *StatsCollector) monitorExistingContainers(ctx context.Context, runID string, verbose bool) {
|
||||||
|
defer sc.wg.Done()
|
||||||
|
|
||||||
|
containers, err := sc.client.ContainerList(ctx, container.ListOptions{})
|
||||||
|
if err != nil {
|
||||||
|
if verbose {
|
||||||
|
log.Printf("Failed to list existing containers: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, cont := range containers {
|
||||||
|
if sc.shouldMonitorContainer(cont, runID) {
|
||||||
|
sc.startStatsForContainer(ctx, cont.ID, cont.Names[0], verbose)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// monitorDockerEvents listens for container start events and begins monitoring relevant containers
|
||||||
|
func (sc *StatsCollector) monitorDockerEvents(ctx context.Context, runID string, verbose bool) {
|
||||||
|
defer sc.wg.Done()
|
||||||
|
|
||||||
|
filter := filters.NewArgs()
|
||||||
|
filter.Add("type", "container")
|
||||||
|
filter.Add("event", "start")
|
||||||
|
|
||||||
|
eventOptions := events.ListOptions{
|
||||||
|
Filters: filter,
|
||||||
|
}
|
||||||
|
|
||||||
|
events, errs := sc.client.Events(ctx, eventOptions)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-sc.stopChan:
|
||||||
|
return
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case event := <-events:
|
||||||
|
if event.Type == "container" && event.Action == "start" {
|
||||||
|
// Get container details
|
||||||
|
containerInfo, err := sc.client.ContainerInspect(ctx, event.ID)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to types.Container format for consistency
|
||||||
|
cont := types.Container{
|
||||||
|
ID: containerInfo.ID,
|
||||||
|
Names: []string{containerInfo.Name},
|
||||||
|
Labels: containerInfo.Config.Labels,
|
||||||
|
}
|
||||||
|
|
||||||
|
if sc.shouldMonitorContainer(cont, runID) {
|
||||||
|
sc.startStatsForContainer(ctx, cont.ID, cont.Names[0], verbose)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case err := <-errs:
|
||||||
|
if verbose {
|
||||||
|
log.Printf("Error in Docker events stream: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// shouldMonitorContainer determines if a container should be monitored
|
||||||
|
func (sc *StatsCollector) shouldMonitorContainer(cont types.Container, runID string) bool {
|
||||||
|
// Check if it has the correct run ID label
|
||||||
|
if cont.Labels == nil || cont.Labels["hi.run-id"] != runID {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if it's an hs- or ts- container
|
||||||
|
for _, name := range cont.Names {
|
||||||
|
containerName := strings.TrimPrefix(name, "/")
|
||||||
|
if strings.HasPrefix(containerName, "hs-") || strings.HasPrefix(containerName, "ts-") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// startStatsForContainer begins stats collection for a specific container
|
||||||
|
func (sc *StatsCollector) startStatsForContainer(ctx context.Context, containerID, containerName string, verbose bool) {
|
||||||
|
containerName = strings.TrimPrefix(containerName, "/")
|
||||||
|
|
||||||
|
sc.mutex.Lock()
|
||||||
|
// Check if we're already monitoring this container
|
||||||
|
if _, exists := sc.containers[containerID]; exists {
|
||||||
|
sc.mutex.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sc.containers[containerID] = &ContainerStats{
|
||||||
|
ContainerID: containerID,
|
||||||
|
ContainerName: containerName,
|
||||||
|
Stats: make([]StatsSample, 0),
|
||||||
|
}
|
||||||
|
sc.mutex.Unlock()
|
||||||
|
|
||||||
|
if verbose {
|
||||||
|
log.Printf("Starting stats collection for container %s (%s)", containerName, containerID[:12])
|
||||||
|
}
|
||||||
|
|
||||||
|
sc.wg.Add(1)
|
||||||
|
go sc.collectStatsForContainer(ctx, containerID, verbose)
|
||||||
|
}
|
||||||
|
|
||||||
|
// collectStatsForContainer collects stats for a specific container using Docker API streaming
|
||||||
|
func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containerID string, verbose bool) {
|
||||||
|
defer sc.wg.Done()
|
||||||
|
|
||||||
|
// Use Docker API streaming stats - much more efficient than CLI
|
||||||
|
statsResponse, err := sc.client.ContainerStats(ctx, containerID, true)
|
||||||
|
if err != nil {
|
||||||
|
if verbose {
|
||||||
|
log.Printf("Failed to get stats stream for container %s: %v", containerID[:12], err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer statsResponse.Body.Close()
|
||||||
|
|
||||||
|
decoder := json.NewDecoder(statsResponse.Body)
|
||||||
|
var prevStats *container.Stats
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-sc.stopChan:
|
||||||
|
return
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
var stats container.Stats
|
||||||
|
if err := decoder.Decode(&stats); err != nil {
|
||||||
|
// EOF is expected when container stops or stream ends
|
||||||
|
if err.Error() != "EOF" && verbose {
|
||||||
|
log.Printf("Failed to decode stats for container %s: %v", containerID[:12], err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate CPU percentage (only if we have previous stats)
|
||||||
|
var cpuPercent float64
|
||||||
|
if prevStats != nil {
|
||||||
|
cpuPercent = calculateCPUPercent(prevStats, &stats)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate memory usage in MB
|
||||||
|
memoryMB := float64(stats.MemoryStats.Usage) / (1024 * 1024)
|
||||||
|
|
||||||
|
// Store the sample (skip first sample since CPU calculation needs previous stats)
|
||||||
|
if prevStats != nil {
|
||||||
|
// Get container stats reference without holding the main mutex
|
||||||
|
var containerStats *ContainerStats
|
||||||
|
var exists bool
|
||||||
|
|
||||||
|
sc.mutex.RLock()
|
||||||
|
containerStats, exists = sc.containers[containerID]
|
||||||
|
sc.mutex.RUnlock()
|
||||||
|
|
||||||
|
if exists && containerStats != nil {
|
||||||
|
containerStats.mutex.Lock()
|
||||||
|
containerStats.Stats = append(containerStats.Stats, StatsSample{
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
CPUUsage: cpuPercent,
|
||||||
|
MemoryMB: memoryMB,
|
||||||
|
})
|
||||||
|
containerStats.mutex.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save current stats for next iteration
|
||||||
|
prevStats = &stats
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculateCPUPercent calculates CPU usage percentage from Docker stats
|
||||||
|
func calculateCPUPercent(prevStats, stats *container.Stats) float64 {
|
||||||
|
// CPU calculation based on Docker's implementation
|
||||||
|
cpuDelta := float64(stats.CPUStats.CPUUsage.TotalUsage) - float64(prevStats.CPUStats.CPUUsage.TotalUsage)
|
||||||
|
systemDelta := float64(stats.CPUStats.SystemUsage) - float64(prevStats.CPUStats.SystemUsage)
|
||||||
|
|
||||||
|
if systemDelta > 0 && cpuDelta >= 0 {
|
||||||
|
// Calculate CPU percentage: (container CPU delta / system CPU delta) * number of CPUs * 100
|
||||||
|
numCPUs := float64(len(stats.CPUStats.CPUUsage.PercpuUsage))
|
||||||
|
if numCPUs == 0 {
|
||||||
|
// Fallback: if PercpuUsage is not available, assume 1 CPU
|
||||||
|
numCPUs = 1.0
|
||||||
|
}
|
||||||
|
return (cpuDelta / systemDelta) * numCPUs * 100.0
|
||||||
|
}
|
||||||
|
return 0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
// ContainerStatsSummary represents summary statistics for a container
|
||||||
|
type ContainerStatsSummary struct {
|
||||||
|
ContainerName string
|
||||||
|
SampleCount int
|
||||||
|
CPU StatsSummary
|
||||||
|
Memory StatsSummary
|
||||||
|
}
|
||||||
|
|
||||||
|
// MemoryViolation represents a container that exceeded the memory limit
|
||||||
|
type MemoryViolation struct {
|
||||||
|
ContainerName string
|
||||||
|
MaxMemoryMB float64
|
||||||
|
LimitMB float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// StatsSummary represents min, max, and average for a metric
|
||||||
|
type StatsSummary struct {
|
||||||
|
Min float64
|
||||||
|
Max float64
|
||||||
|
Average float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSummary returns a summary of collected statistics
|
||||||
|
func (sc *StatsCollector) GetSummary() []ContainerStatsSummary {
|
||||||
|
// Take snapshot of container references without holding main lock long
|
||||||
|
sc.mutex.RLock()
|
||||||
|
containerRefs := make([]*ContainerStats, 0, len(sc.containers))
|
||||||
|
for _, containerStats := range sc.containers {
|
||||||
|
containerRefs = append(containerRefs, containerStats)
|
||||||
|
}
|
||||||
|
sc.mutex.RUnlock()
|
||||||
|
|
||||||
|
summaries := make([]ContainerStatsSummary, 0, len(containerRefs))
|
||||||
|
|
||||||
|
for _, containerStats := range containerRefs {
|
||||||
|
containerStats.mutex.RLock()
|
||||||
|
stats := make([]StatsSample, len(containerStats.Stats))
|
||||||
|
copy(stats, containerStats.Stats)
|
||||||
|
containerName := containerStats.ContainerName
|
||||||
|
containerStats.mutex.RUnlock()
|
||||||
|
|
||||||
|
if len(stats) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
summary := ContainerStatsSummary{
|
||||||
|
ContainerName: containerName,
|
||||||
|
SampleCount: len(stats),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate CPU stats
|
||||||
|
cpuValues := make([]float64, len(stats))
|
||||||
|
memoryValues := make([]float64, len(stats))
|
||||||
|
|
||||||
|
for i, sample := range stats {
|
||||||
|
cpuValues[i] = sample.CPUUsage
|
||||||
|
memoryValues[i] = sample.MemoryMB
|
||||||
|
}
|
||||||
|
|
||||||
|
summary.CPU = calculateStatsSummary(cpuValues)
|
||||||
|
summary.Memory = calculateStatsSummary(memoryValues)
|
||||||
|
|
||||||
|
summaries = append(summaries, summary)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort by container name for consistent output
|
||||||
|
sort.Slice(summaries, func(i, j int) bool {
|
||||||
|
return summaries[i].ContainerName < summaries[j].ContainerName
|
||||||
|
})
|
||||||
|
|
||||||
|
return summaries
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculateStatsSummary calculates min, max, and average for a slice of values
|
||||||
|
func calculateStatsSummary(values []float64) StatsSummary {
|
||||||
|
if len(values) == 0 {
|
||||||
|
return StatsSummary{}
|
||||||
|
}
|
||||||
|
|
||||||
|
min := values[0]
|
||||||
|
max := values[0]
|
||||||
|
sum := 0.0
|
||||||
|
|
||||||
|
for _, value := range values {
|
||||||
|
if value < min {
|
||||||
|
min = value
|
||||||
|
}
|
||||||
|
if value > max {
|
||||||
|
max = value
|
||||||
|
}
|
||||||
|
sum += value
|
||||||
|
}
|
||||||
|
|
||||||
|
return StatsSummary{
|
||||||
|
Min: min,
|
||||||
|
Max: max,
|
||||||
|
Average: sum / float64(len(values)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrintSummary prints the statistics summary to the console
|
||||||
|
func (sc *StatsCollector) PrintSummary() {
|
||||||
|
summaries := sc.GetSummary()
|
||||||
|
|
||||||
|
if len(summaries) == 0 {
|
||||||
|
log.Printf("No container statistics collected")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Container Resource Usage Summary:")
|
||||||
|
log.Printf("================================")
|
||||||
|
|
||||||
|
for _, summary := range summaries {
|
||||||
|
log.Printf("Container: %s (%d samples)", summary.ContainerName, summary.SampleCount)
|
||||||
|
log.Printf(" CPU Usage: Min: %6.2f%% Max: %6.2f%% Avg: %6.2f%%",
|
||||||
|
summary.CPU.Min, summary.CPU.Max, summary.CPU.Average)
|
||||||
|
log.Printf(" Memory Usage: Min: %6.1f MB Max: %6.1f MB Avg: %6.1f MB",
|
||||||
|
summary.Memory.Min, summary.Memory.Max, summary.Memory.Average)
|
||||||
|
log.Printf("")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckMemoryLimits checks if any containers exceeded their memory limits
|
||||||
|
func (sc *StatsCollector) CheckMemoryLimits(hsLimitMB, tsLimitMB float64) []MemoryViolation {
|
||||||
|
if hsLimitMB <= 0 && tsLimitMB <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
summaries := sc.GetSummary()
|
||||||
|
var violations []MemoryViolation
|
||||||
|
|
||||||
|
for _, summary := range summaries {
|
||||||
|
var limitMB float64
|
||||||
|
if strings.HasPrefix(summary.ContainerName, "hs-") {
|
||||||
|
limitMB = hsLimitMB
|
||||||
|
} else if strings.HasPrefix(summary.ContainerName, "ts-") {
|
||||||
|
limitMB = tsLimitMB
|
||||||
|
} else {
|
||||||
|
continue // Skip containers that don't match our patterns
|
||||||
|
}
|
||||||
|
|
||||||
|
if limitMB > 0 && summary.Memory.Max > limitMB {
|
||||||
|
violations = append(violations, MemoryViolation{
|
||||||
|
ContainerName: summary.ContainerName,
|
||||||
|
MaxMemoryMB: summary.Memory.Max,
|
||||||
|
LimitMB: limitMB,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return violations
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrintSummaryAndCheckLimits prints the statistics summary and returns memory violations if any
|
||||||
|
func (sc *StatsCollector) PrintSummaryAndCheckLimits(hsLimitMB, tsLimitMB float64) []MemoryViolation {
|
||||||
|
sc.PrintSummary()
|
||||||
|
return sc.CheckMemoryLimits(hsLimitMB, tsLimitMB)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the stats collector and cleans up resources
|
||||||
|
func (sc *StatsCollector) Close() error {
|
||||||
|
sc.StopCollection()
|
||||||
|
return sc.client.Close()
|
||||||
|
}
|
@ -19,7 +19,7 @@
|
|||||||
overlay = _: prev: let
|
overlay = _: prev: let
|
||||||
pkgs = nixpkgs.legacyPackages.${prev.system};
|
pkgs = nixpkgs.legacyPackages.${prev.system};
|
||||||
buildGo = pkgs.buildGo124Module;
|
buildGo = pkgs.buildGo124Module;
|
||||||
vendorHash = "sha256-S2GnCg2dyfjIyi5gXhVEuRs5Bop2JAhZcnhg1fu4/Gg=";
|
vendorHash = "sha256-83L2NMyOwKCHWqcowStJ7Ze/U9CJYhzleDRLrJNhX2g=";
|
||||||
in {
|
in {
|
||||||
headscale = buildGo {
|
headscale = buildGo {
|
||||||
pname = "headscale";
|
pname = "headscale";
|
||||||
|
27
go.mod
27
go.mod
@ -23,7 +23,6 @@ require (
|
|||||||
github.com/gorilla/mux v1.8.1
|
github.com/gorilla/mux v1.8.1
|
||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.0
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.0
|
||||||
github.com/jagottsicher/termcolor v1.0.2
|
github.com/jagottsicher/termcolor v1.0.2
|
||||||
github.com/klauspost/compress v1.18.0
|
|
||||||
github.com/oauth2-proxy/mockoidc v0.0.0-20240214162133-caebfff84d25
|
github.com/oauth2-proxy/mockoidc v0.0.0-20240214162133-caebfff84d25
|
||||||
github.com/ory/dockertest/v3 v3.12.0
|
github.com/ory/dockertest/v3 v3.12.0
|
||||||
github.com/philip-bui/grpc-zerolog v1.0.1
|
github.com/philip-bui/grpc-zerolog v1.0.1
|
||||||
@ -43,11 +42,11 @@ require (
|
|||||||
github.com/tailscale/tailsql v0.0.0-20250421235516-02f85f087b97
|
github.com/tailscale/tailsql v0.0.0-20250421235516-02f85f087b97
|
||||||
github.com/tcnksm/go-latest v0.0.0-20170313132115-e3007ae9052e
|
github.com/tcnksm/go-latest v0.0.0-20170313132115-e3007ae9052e
|
||||||
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba
|
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba
|
||||||
golang.org/x/crypto v0.39.0
|
golang.org/x/crypto v0.40.0
|
||||||
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0
|
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0
|
||||||
golang.org/x/net v0.41.0
|
golang.org/x/net v0.42.0
|
||||||
golang.org/x/oauth2 v0.30.0
|
golang.org/x/oauth2 v0.30.0
|
||||||
golang.org/x/sync v0.15.0
|
golang.org/x/sync v0.16.0
|
||||||
google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822
|
google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822
|
||||||
google.golang.org/grpc v1.73.0
|
google.golang.org/grpc v1.73.0
|
||||||
google.golang.org/protobuf v1.36.6
|
google.golang.org/protobuf v1.36.6
|
||||||
@ -55,7 +54,7 @@ require (
|
|||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
gorm.io/driver/postgres v1.6.0
|
gorm.io/driver/postgres v1.6.0
|
||||||
gorm.io/gorm v1.30.0
|
gorm.io/gorm v1.30.0
|
||||||
tailscale.com v1.84.2
|
tailscale.com v1.84.3
|
||||||
zgo.at/zcache/v2 v2.2.0
|
zgo.at/zcache/v2 v2.2.0
|
||||||
zombiezen.com/go/postgrestest v1.0.1
|
zombiezen.com/go/postgrestest v1.0.1
|
||||||
)
|
)
|
||||||
@ -81,7 +80,7 @@ require (
|
|||||||
modernc.org/libc v1.62.1 // indirect
|
modernc.org/libc v1.62.1 // indirect
|
||||||
modernc.org/mathutil v1.7.1 // indirect
|
modernc.org/mathutil v1.7.1 // indirect
|
||||||
modernc.org/memory v1.10.0 // indirect
|
modernc.org/memory v1.10.0 // indirect
|
||||||
modernc.org/sqlite v1.37.0 // indirect
|
modernc.org/sqlite v1.37.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
@ -166,6 +165,7 @@ require (
|
|||||||
github.com/jmespath/go-jmespath v0.4.0 // indirect
|
github.com/jmespath/go-jmespath v0.4.0 // indirect
|
||||||
github.com/jsimonetti/rtnetlink v1.4.1 // indirect
|
github.com/jsimonetti/rtnetlink v1.4.1 // indirect
|
||||||
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
|
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
|
||||||
|
github.com/klauspost/compress v1.18.0 // indirect
|
||||||
github.com/kr/pretty v0.3.1 // indirect
|
github.com/kr/pretty v0.3.1 // indirect
|
||||||
github.com/kr/text v0.2.0 // indirect
|
github.com/kr/text v0.2.0 // indirect
|
||||||
github.com/lib/pq v1.10.9 // indirect
|
github.com/lib/pq v1.10.9 // indirect
|
||||||
@ -231,14 +231,19 @@ require (
|
|||||||
go.opentelemetry.io/otel/trace v1.36.0 // indirect
|
go.opentelemetry.io/otel/trace v1.36.0 // indirect
|
||||||
go.uber.org/multierr v1.11.0 // indirect
|
go.uber.org/multierr v1.11.0 // indirect
|
||||||
go4.org/mem v0.0.0-20240501181205-ae6ca9944745 // indirect
|
go4.org/mem v0.0.0-20240501181205-ae6ca9944745 // indirect
|
||||||
golang.org/x/mod v0.25.0 // indirect
|
golang.org/x/mod v0.26.0 // indirect
|
||||||
golang.org/x/sys v0.33.0 // indirect
|
golang.org/x/sys v0.34.0 // indirect
|
||||||
golang.org/x/term v0.32.0 // indirect
|
golang.org/x/term v0.33.0 // indirect
|
||||||
golang.org/x/text v0.26.0 // indirect
|
golang.org/x/text v0.27.0 // indirect
|
||||||
golang.org/x/time v0.10.0 // indirect
|
golang.org/x/time v0.10.0 // indirect
|
||||||
golang.org/x/tools v0.33.0 // indirect
|
golang.org/x/tools v0.35.0 // indirect
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
|
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 // indirect
|
||||||
gvisor.dev/gvisor v0.0.0-20250205023644-9414b50a5633 // indirect
|
gvisor.dev/gvisor v0.0.0-20250205023644-9414b50a5633 // indirect
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tool (
|
||||||
|
golang.org/x/tools/cmd/stringer
|
||||||
|
tailscale.com/cmd/viewer
|
||||||
|
)
|
||||||
|
34
go.sum
34
go.sum
@ -555,8 +555,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
|
|||||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||||
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
|
golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
|
||||||
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
|
golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
|
||||||
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 h1:R84qjqJb5nVJMxqWYb3np9L5ZsaDtB+a39EqjV0JSUM=
|
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 h1:R84qjqJb5nVJMxqWYb3np9L5ZsaDtB+a39EqjV0JSUM=
|
||||||
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0/go.mod h1:S9Xr4PYopiDyqSyp5NjCrhFrqg6A5zA2E/iPHPhqnS8=
|
golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0/go.mod h1:S9Xr4PYopiDyqSyp5NjCrhFrqg6A5zA2E/iPHPhqnS8=
|
||||||
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f h1:phY1HzDcf18Aq9A8KkmRtY9WvOFIxN8wgfvy6Zm1DV8=
|
golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f h1:phY1HzDcf18Aq9A8KkmRtY9WvOFIxN8wgfvy6Zm1DV8=
|
||||||
@ -567,8 +567,8 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
|||||||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||||
golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w=
|
golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
|
||||||
golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
|
golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
|
||||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||||
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||||
@ -577,8 +577,8 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v
|
|||||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||||
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
|
golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs=
|
||||||
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
|
golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8=
|
||||||
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
||||||
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
||||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
@ -587,8 +587,8 @@ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJ
|
|||||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8=
|
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||||
golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
@ -615,8 +615,8 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|||||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
|
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
|
||||||
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.0.0-20210615171337-6886f2dfbf5b/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
golang.org/x/term v0.0.0-20210615171337-6886f2dfbf5b/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||||
@ -624,8 +624,8 @@ golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuX
|
|||||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||||
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
||||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
||||||
golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg=
|
golang.org/x/term v0.33.0 h1:NuFncQrRcaRvVmgRkvM3j/F00gWIAlcmlB8ACEKmGIg=
|
||||||
golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ=
|
golang.org/x/term v0.33.0/go.mod h1:s18+ql9tYWp1IfpV9DmCtQDDSRBUjKaw9M1eAv5UeF0=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||||
@ -633,8 +633,8 @@ golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
|||||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||||
golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M=
|
golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
|
||||||
golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
|
golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
|
||||||
golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4=
|
golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4=
|
||||||
golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
@ -643,8 +643,8 @@ golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roY
|
|||||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||||
golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc=
|
golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0=
|
||||||
golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=
|
golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=
|
||||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
@ -714,6 +714,8 @@ software.sslmate.com/src/go-pkcs12 v0.4.0 h1:H2g08FrTvSFKUj+D309j1DPfk5APnIdAQAB
|
|||||||
software.sslmate.com/src/go-pkcs12 v0.4.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=
|
software.sslmate.com/src/go-pkcs12 v0.4.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=
|
||||||
tailscale.com v1.84.2 h1:v6aM4RWUgYiV52LRAx6ET+dlGnvO/5lnqPXb7/pMnR0=
|
tailscale.com v1.84.2 h1:v6aM4RWUgYiV52LRAx6ET+dlGnvO/5lnqPXb7/pMnR0=
|
||||||
tailscale.com v1.84.2/go.mod h1:6/S63NMAhmncYT/1zIPDJkvCuZwMw+JnUuOfSPNazpo=
|
tailscale.com v1.84.2/go.mod h1:6/S63NMAhmncYT/1zIPDJkvCuZwMw+JnUuOfSPNazpo=
|
||||||
|
tailscale.com v1.84.3 h1:Ur9LMedSgicwbqpy5xn7t49G8490/s6rqAJOk5Q5AYE=
|
||||||
|
tailscale.com v1.84.3/go.mod h1:6/S63NMAhmncYT/1zIPDJkvCuZwMw+JnUuOfSPNazpo=
|
||||||
zgo.at/zcache/v2 v2.2.0 h1:K29/IPjMniZfveYE+IRXfrl11tMzHkIPuyGrfVZ2fGo=
|
zgo.at/zcache/v2 v2.2.0 h1:K29/IPjMniZfveYE+IRXfrl11tMzHkIPuyGrfVZ2fGo=
|
||||||
zgo.at/zcache/v2 v2.2.0/go.mod h1:gyCeoLVo01QjDZynjime8xUGHHMbsLiPyUTBpDGd4Gk=
|
zgo.at/zcache/v2 v2.2.0/go.mod h1:gyCeoLVo01QjDZynjime8xUGHHMbsLiPyUTBpDGd4Gk=
|
||||||
zombiezen.com/go/postgrestest v1.0.1 h1:aXoADQAJmZDU3+xilYVut0pHhgc0sF8ZspPW9gFNwP4=
|
zombiezen.com/go/postgrestest v1.0.1 h1:aXoADQAJmZDU3+xilYVut0pHhgc0sF8ZspPW9gFNwP4=
|
||||||
|
130
hscontrol/app.go
130
hscontrol/app.go
@ -28,14 +28,15 @@ import (
|
|||||||
derpServer "github.com/juanfont/headscale/hscontrol/derp/server"
|
derpServer "github.com/juanfont/headscale/hscontrol/derp/server"
|
||||||
"github.com/juanfont/headscale/hscontrol/dns"
|
"github.com/juanfont/headscale/hscontrol/dns"
|
||||||
"github.com/juanfont/headscale/hscontrol/mapper"
|
"github.com/juanfont/headscale/hscontrol/mapper"
|
||||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
|
||||||
"github.com/juanfont/headscale/hscontrol/state"
|
"github.com/juanfont/headscale/hscontrol/state"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
zerolog "github.com/philip-bui/grpc-zerolog"
|
zerolog "github.com/philip-bui/grpc-zerolog"
|
||||||
"github.com/pkg/profile"
|
"github.com/pkg/profile"
|
||||||
zl "github.com/rs/zerolog"
|
zl "github.com/rs/zerolog"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
"github.com/sasha-s/go-deadlock"
|
||||||
"golang.org/x/crypto/acme"
|
"golang.org/x/crypto/acme"
|
||||||
"golang.org/x/crypto/acme/autocert"
|
"golang.org/x/crypto/acme/autocert"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
@ -64,6 +65,19 @@ var (
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
debugDeadlock = envknob.Bool("HEADSCALE_DEBUG_DEADLOCK")
|
||||||
|
debugDeadlockTimeout = envknob.RegisterDuration("HEADSCALE_DEBUG_DEADLOCK_TIMEOUT")
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
deadlock.Opts.Disable = !debugDeadlock
|
||||||
|
if debugDeadlock {
|
||||||
|
deadlock.Opts.DeadlockTimeout = debugDeadlockTimeout()
|
||||||
|
deadlock.Opts.PrintAllCurrentGoroutines = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
AuthPrefix = "Bearer "
|
AuthPrefix = "Bearer "
|
||||||
updateInterval = 5 * time.Second
|
updateInterval = 5 * time.Second
|
||||||
@ -82,9 +96,8 @@ type Headscale struct {
|
|||||||
|
|
||||||
// Things that generate changes
|
// Things that generate changes
|
||||||
extraRecordMan *dns.ExtraRecordsMan
|
extraRecordMan *dns.ExtraRecordsMan
|
||||||
mapper *mapper.Mapper
|
|
||||||
nodeNotifier *notifier.Notifier
|
|
||||||
authProvider AuthProvider
|
authProvider AuthProvider
|
||||||
|
mapBatcher mapper.Batcher
|
||||||
|
|
||||||
pollNetMapStreamWG sync.WaitGroup
|
pollNetMapStreamWG sync.WaitGroup
|
||||||
}
|
}
|
||||||
@ -118,7 +131,6 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
noisePrivateKey: noisePrivateKey,
|
noisePrivateKey: noisePrivateKey,
|
||||||
pollNetMapStreamWG: sync.WaitGroup{},
|
pollNetMapStreamWG: sync.WaitGroup{},
|
||||||
nodeNotifier: notifier.NewNotifier(cfg),
|
|
||||||
state: s,
|
state: s,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -136,12 +148,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send policy update notifications if needed
|
app.Change(policyChanged)
|
||||||
if policyChanged {
|
|
||||||
ctx := types.NotifyCtx(context.Background(), "ephemeral-gc-policy", node.Hostname)
|
|
||||||
app.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Uint64("node.id", ni.Uint64()).Msgf("deleted ephemeral node")
|
log.Debug().Uint64("node.id", ni.Uint64()).Msgf("deleted ephemeral node")
|
||||||
})
|
})
|
||||||
app.ephemeralGC = ephemeralGC
|
app.ephemeralGC = ephemeralGC
|
||||||
@ -153,10 +160,9 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
oidcProvider, err := NewAuthProviderOIDC(
|
oidcProvider, err := NewAuthProviderOIDC(
|
||||||
ctx,
|
ctx,
|
||||||
|
&app,
|
||||||
cfg.ServerURL,
|
cfg.ServerURL,
|
||||||
&cfg.OIDC,
|
&cfg.OIDC,
|
||||||
app.state,
|
|
||||||
app.nodeNotifier,
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if cfg.OIDC.OnlyStartIfOIDCIsAvailable {
|
if cfg.OIDC.OnlyStartIfOIDCIsAvailable {
|
||||||
@ -262,16 +268,18 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
|||||||
return
|
return
|
||||||
|
|
||||||
case <-expireTicker.C:
|
case <-expireTicker.C:
|
||||||
var update types.StateUpdate
|
var expiredNodeChanges []change.ChangeSet
|
||||||
var changed bool
|
var changed bool
|
||||||
|
|
||||||
lastExpiryCheck, update, changed = h.state.ExpireExpiredNodes(lastExpiryCheck)
|
lastExpiryCheck, expiredNodeChanges, changed = h.state.ExpireExpiredNodes(lastExpiryCheck)
|
||||||
|
|
||||||
if changed {
|
if changed {
|
||||||
log.Trace().Interface("nodes", update.ChangePatches).Msgf("expiring nodes")
|
log.Trace().Interface("changes", expiredNodeChanges).Msgf("expiring nodes")
|
||||||
|
|
||||||
ctx := types.NotifyCtx(context.Background(), "expire-expired", "na")
|
// Send the changes directly since they're already in the new format
|
||||||
h.nodeNotifier.NotifyAll(ctx, update)
|
for _, nodeChange := range expiredNodeChanges {
|
||||||
|
h.Change(nodeChange)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
case <-derpTickerChan:
|
case <-derpTickerChan:
|
||||||
@ -282,11 +290,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
|||||||
derpMap.Regions[region.RegionID] = ®ion
|
derpMap.Regions[region.RegionID] = ®ion
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := types.NotifyCtx(context.Background(), "derpmap-update", "na")
|
h.Change(change.DERPSet)
|
||||||
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
|
|
||||||
Type: types.StateDERPUpdated,
|
|
||||||
DERPMap: derpMap,
|
|
||||||
})
|
|
||||||
|
|
||||||
case records, ok := <-extraRecordsUpdate:
|
case records, ok := <-extraRecordsUpdate:
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -294,19 +298,16 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
h.cfg.TailcfgDNSConfig.ExtraRecords = records
|
h.cfg.TailcfgDNSConfig.ExtraRecords = records
|
||||||
|
|
||||||
ctx := types.NotifyCtx(context.Background(), "dns-extrarecord", "all")
|
h.Change(change.ExtraRecordsSet)
|
||||||
// TODO(kradalby): We can probably do better than sending a full update here,
|
|
||||||
// but for now this will ensure that all of the nodes get the new records.
|
|
||||||
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
|
func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context,
|
||||||
req interface{},
|
req any,
|
||||||
info *grpc.UnaryServerInfo,
|
info *grpc.UnaryServerInfo,
|
||||||
handler grpc.UnaryHandler,
|
handler grpc.UnaryHandler,
|
||||||
) (interface{}, error) {
|
) (any, error) {
|
||||||
// Check if the request is coming from the on-server client.
|
// Check if the request is coming from the on-server client.
|
||||||
// This is not secure, but it is to maintain maintainability
|
// This is not secure, but it is to maintain maintainability
|
||||||
// with the "legacy" database-based client
|
// with the "legacy" database-based client
|
||||||
@ -484,58 +485,6 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
|
|||||||
return router
|
return router
|
||||||
}
|
}
|
||||||
|
|
||||||
// // TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
|
|
||||||
// // Maybe we should attempt a new in memory state and not go via the DB?
|
|
||||||
// // Maybe this should be implemented as an event bus?
|
|
||||||
// // A bool is returned indicating if a full update was sent to all nodes
|
|
||||||
// func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error {
|
|
||||||
// users, err := db.ListUsers()
|
|
||||||
// if err != nil {
|
|
||||||
// return err
|
|
||||||
// }
|
|
||||||
|
|
||||||
// changed, err := polMan.SetUsers(users)
|
|
||||||
// if err != nil {
|
|
||||||
// return err
|
|
||||||
// }
|
|
||||||
|
|
||||||
// if changed {
|
|
||||||
// ctx := types.NotifyCtx(context.Background(), "acl-users-change", "all")
|
|
||||||
// notif.NotifyAll(ctx, types.UpdateFull())
|
|
||||||
// }
|
|
||||||
|
|
||||||
// return nil
|
|
||||||
// }
|
|
||||||
|
|
||||||
// // TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
|
|
||||||
// // Maybe we should attempt a new in memory state and not go via the DB?
|
|
||||||
// // Maybe this should be implemented as an event bus?
|
|
||||||
// // A bool is returned indicating if a full update was sent to all nodes
|
|
||||||
// func nodesChangedHook(
|
|
||||||
// db *db.HSDatabase,
|
|
||||||
// polMan policy.PolicyManager,
|
|
||||||
// notif *notifier.Notifier,
|
|
||||||
// ) (bool, error) {
|
|
||||||
// nodes, err := db.ListNodes()
|
|
||||||
// if err != nil {
|
|
||||||
// return false, err
|
|
||||||
// }
|
|
||||||
|
|
||||||
// filterChanged, err := polMan.SetNodes(nodes)
|
|
||||||
// if err != nil {
|
|
||||||
// return false, err
|
|
||||||
// }
|
|
||||||
|
|
||||||
// if filterChanged {
|
|
||||||
// ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all")
|
|
||||||
// notif.NotifyAll(ctx, types.UpdateFull())
|
|
||||||
|
|
||||||
// return true, nil
|
|
||||||
// }
|
|
||||||
|
|
||||||
// return false, nil
|
|
||||||
// }
|
|
||||||
|
|
||||||
// Serve launches the HTTP and gRPC server service Headscale and the API.
|
// Serve launches the HTTP and gRPC server service Headscale and the API.
|
||||||
func (h *Headscale) Serve() error {
|
func (h *Headscale) Serve() error {
|
||||||
capver.CanOldCodeBeCleanedUp()
|
capver.CanOldCodeBeCleanedUp()
|
||||||
@ -562,8 +511,9 @@ func (h *Headscale) Serve() error {
|
|||||||
Str("minimum_version", capver.TailscaleVersion(capver.MinSupportedCapabilityVersion)).
|
Str("minimum_version", capver.TailscaleVersion(capver.MinSupportedCapabilityVersion)).
|
||||||
Msg("Clients with a lower minimum version will be rejected")
|
Msg("Clients with a lower minimum version will be rejected")
|
||||||
|
|
||||||
// Fetch an initial DERP Map before we start serving
|
h.mapBatcher = mapper.NewBatcherAndMapper(h.cfg, h.state)
|
||||||
h.mapper = mapper.NewMapper(h.state, h.cfg, h.nodeNotifier)
|
h.mapBatcher.Start()
|
||||||
|
defer h.mapBatcher.Close()
|
||||||
|
|
||||||
// TODO(kradalby): fix state part.
|
// TODO(kradalby): fix state part.
|
||||||
if h.cfg.DERP.ServerEnabled {
|
if h.cfg.DERP.ServerEnabled {
|
||||||
@ -838,8 +788,12 @@ func (h *Headscale) Serve() error {
|
|||||||
log.Info().
|
log.Info().
|
||||||
Msg("ACL policy successfully reloaded, notifying nodes of change")
|
Msg("ACL policy successfully reloaded, notifying nodes of change")
|
||||||
|
|
||||||
ctx := types.NotifyCtx(context.Background(), "acl-sighup", "na")
|
err = h.state.AutoApproveNodes()
|
||||||
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("failed to approve routes after new policy")
|
||||||
|
}
|
||||||
|
|
||||||
|
h.Change(change.PolicySet)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
info := func(msg string) { log.Info().Msg(msg) }
|
info := func(msg string) { log.Info().Msg(msg) }
|
||||||
@ -865,7 +819,6 @@ func (h *Headscale) Serve() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
info("closing node notifier")
|
info("closing node notifier")
|
||||||
h.nodeNotifier.Close()
|
|
||||||
|
|
||||||
info("waiting for netmap stream to close")
|
info("waiting for netmap stream to close")
|
||||||
h.pollNetMapStreamWG.Wait()
|
h.pollNetMapStreamWG.Wait()
|
||||||
@ -1047,3 +1000,10 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
|
|||||||
|
|
||||||
return &machineKey, nil
|
return &machineKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Change is used to send changes to nodes.
|
||||||
|
// All change should be enqueued here and empty will be automatically
|
||||||
|
// ignored.
|
||||||
|
func (h *Headscale) Change(c change.ChangeSet) {
|
||||||
|
h.mapBatcher.AddWork(c)
|
||||||
|
}
|
||||||
|
@ -10,6 +10,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||||
|
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
@ -32,6 +34,21 @@ func (h *Headscale) handleRegister(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if node != nil {
|
if node != nil {
|
||||||
|
// If an existing node is trying to register with an auth key,
|
||||||
|
// we need to validate the auth key even for existing nodes
|
||||||
|
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
|
||||||
|
resp, err := h.handleRegisterWithAuthKey(regReq, machineKey)
|
||||||
|
if err != nil {
|
||||||
|
// Preserve HTTPError types so they can be handled properly by the HTTP layer
|
||||||
|
var httpErr HTTPError
|
||||||
|
if errors.As(err, &httpErr) {
|
||||||
|
return nil, httpErr
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("handling register with auth key for existing node: %w", err)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
resp, err := h.handleExistingNode(node, regReq, machineKey)
|
resp, err := h.handleExistingNode(node, regReq, machineKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("handling existing node: %w", err)
|
return nil, fmt.Errorf("handling existing node: %w", err)
|
||||||
@ -47,6 +64,11 @@ func (h *Headscale) handleRegister(
|
|||||||
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
|
if regReq.Auth != nil && regReq.Auth.AuthKey != "" {
|
||||||
resp, err := h.handleRegisterWithAuthKey(regReq, machineKey)
|
resp, err := h.handleRegisterWithAuthKey(regReq, machineKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// Preserve HTTPError types so they can be handled properly by the HTTP layer
|
||||||
|
var httpErr HTTPError
|
||||||
|
if errors.As(err, &httpErr) {
|
||||||
|
return nil, httpErr
|
||||||
|
}
|
||||||
return nil, fmt.Errorf("handling register with auth key: %w", err)
|
return nil, fmt.Errorf("handling register with auth key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -66,11 +88,13 @@ func (h *Headscale) handleExistingNode(
|
|||||||
regReq tailcfg.RegisterRequest,
|
regReq tailcfg.RegisterRequest,
|
||||||
machineKey key.MachinePublic,
|
machineKey key.MachinePublic,
|
||||||
) (*tailcfg.RegisterResponse, error) {
|
) (*tailcfg.RegisterResponse, error) {
|
||||||
|
|
||||||
if node.MachineKey != machineKey {
|
if node.MachineKey != machineKey {
|
||||||
return nil, NewHTTPError(http.StatusUnauthorized, "node exist with different machine key", nil)
|
return nil, NewHTTPError(http.StatusUnauthorized, "node exist with different machine key", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
expired := node.IsExpired()
|
expired := node.IsExpired()
|
||||||
|
|
||||||
if !expired && !regReq.Expiry.IsZero() {
|
if !expired && !regReq.Expiry.IsZero() {
|
||||||
requestExpiry := regReq.Expiry
|
requestExpiry := regReq.Expiry
|
||||||
|
|
||||||
@ -82,42 +106,26 @@ func (h *Headscale) handleExistingNode(
|
|||||||
// If the request expiry is in the past, we consider it a logout.
|
// If the request expiry is in the past, we consider it a logout.
|
||||||
if requestExpiry.Before(time.Now()) {
|
if requestExpiry.Before(time.Now()) {
|
||||||
if node.IsEphemeral() {
|
if node.IsEphemeral() {
|
||||||
policyChanged, err := h.state.DeleteNode(node)
|
c, err := h.state.DeleteNode(node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("deleting ephemeral node: %w", err)
|
return nil, fmt.Errorf("deleting ephemeral node: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send policy update notifications if needed
|
h.Change(c)
|
||||||
if policyChanged {
|
|
||||||
ctx := types.NotifyCtx(context.Background(), "auth-logout-ephemeral-policy", "na")
|
|
||||||
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
|
||||||
} else {
|
|
||||||
ctx := types.NotifyCtx(context.Background(), "logout-ephemeral", "na")
|
|
||||||
h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerRemoved(node.ID))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
n, policyChanged, err := h.state.SetNodeExpiry(node.ID, requestExpiry)
|
_, c, err := h.state.SetNodeExpiry(node.ID, requestExpiry)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("setting node expiry: %w", err)
|
return nil, fmt.Errorf("setting node expiry: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send policy update notifications if needed
|
h.Change(c)
|
||||||
if policyChanged {
|
|
||||||
ctx := types.NotifyCtx(context.Background(), "auth-expiry-policy", "na")
|
|
||||||
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
|
||||||
} else {
|
|
||||||
ctx := types.NotifyCtx(context.Background(), "logout-expiry", "na")
|
|
||||||
h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdateExpire(node.ID, requestExpiry), node.ID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nodeToRegisterResponse(n), nil
|
return nodeToRegisterResponse(node), nil
|
||||||
}
|
|
||||||
|
|
||||||
return nodeToRegisterResponse(node), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse {
|
func nodeToRegisterResponse(node *types.Node) *tailcfg.RegisterResponse {
|
||||||
@ -168,7 +176,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
|||||||
regReq tailcfg.RegisterRequest,
|
regReq tailcfg.RegisterRequest,
|
||||||
machineKey key.MachinePublic,
|
machineKey key.MachinePublic,
|
||||||
) (*tailcfg.RegisterResponse, error) {
|
) (*tailcfg.RegisterResponse, error) {
|
||||||
node, changed, err := h.state.HandleNodeFromPreAuthKey(
|
node, changed, policyChanged, err := h.state.HandleNodeFromPreAuthKey(
|
||||||
regReq,
|
regReq,
|
||||||
machineKey,
|
machineKey,
|
||||||
)
|
)
|
||||||
@ -184,6 +192,12 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If node is nil, it means an ephemeral node was deleted during logout
|
||||||
|
if node == nil {
|
||||||
|
h.Change(changed)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
// This is a bit of a back and forth, but we have a bit of a chicken and egg
|
// This is a bit of a back and forth, but we have a bit of a chicken and egg
|
||||||
// dependency here.
|
// dependency here.
|
||||||
// Because the way the policy manager works, we need to have the node
|
// Because the way the policy manager works, we need to have the node
|
||||||
@ -195,23 +209,22 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
|||||||
// ensure we send an update.
|
// ensure we send an update.
|
||||||
// This works, but might be another good candidate for doing some sort of
|
// This works, but might be another good candidate for doing some sort of
|
||||||
// eventbus.
|
// eventbus.
|
||||||
routesChanged := h.state.AutoApproveRoutes(node)
|
// TODO(kradalby): This needs to be ran as part of the batcher maybe?
|
||||||
|
// now since we dont update the node/pol here anymore
|
||||||
|
routeChange := h.state.AutoApproveRoutes(node)
|
||||||
if _, _, err := h.state.SaveNode(node); err != nil {
|
if _, _, err := h.state.SaveNode(node); err != nil {
|
||||||
return nil, fmt.Errorf("saving auto approved routes to node: %w", err)
|
return nil, fmt.Errorf("saving auto approved routes to node: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if routesChanged {
|
if routeChange && changed.Empty() {
|
||||||
ctx := types.NotifyCtx(context.Background(), "node updated", node.Hostname)
|
changed = change.NodeAdded(node.ID)
|
||||||
h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(node.ID))
|
}
|
||||||
} else if changed {
|
h.Change(changed)
|
||||||
ctx := types.NotifyCtx(context.Background(), "node created", node.Hostname)
|
|
||||||
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
// If policy changed due to node registration, send a separate policy change
|
||||||
} else {
|
if policyChanged {
|
||||||
// Existing node re-registering without route changes
|
policyChange := change.PolicyChange()
|
||||||
// Still need to notify peers about the node being active again
|
h.Change(policyChange)
|
||||||
// Use UpdateFull to ensure all peers get complete peer maps
|
|
||||||
ctx := types.NotifyCtx(context.Background(), "node re-registered", node.Hostname)
|
|
||||||
h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &tailcfg.RegisterResponse{
|
return &tailcfg.RegisterResponse{
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
package capver
|
package capver
|
||||||
|
|
||||||
|
//go:generate go run ../../tools/capver/main.go
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"slices"
|
"slices"
|
||||||
"sort"
|
"sort"
|
||||||
@ -10,7 +12,7 @@ import (
|
|||||||
"tailscale.com/util/set"
|
"tailscale.com/util/set"
|
||||||
)
|
)
|
||||||
|
|
||||||
const MinSupportedCapabilityVersion tailcfg.CapabilityVersion = 88
|
const MinSupportedCapabilityVersion tailcfg.CapabilityVersion = 90
|
||||||
|
|
||||||
// CanOldCodeBeCleanedUp is intended to be called on startup to see if
|
// CanOldCodeBeCleanedUp is intended to be called on startup to see if
|
||||||
// there are old code that can ble cleaned up, entries should contain
|
// there are old code that can ble cleaned up, entries should contain
|
||||||
|
@ -1,14 +1,10 @@
|
|||||||
package capver
|
package capver
|
||||||
|
|
||||||
// Generated DO NOT EDIT
|
//Generated DO NOT EDIT
|
||||||
|
|
||||||
import "tailscale.com/tailcfg"
|
import "tailscale.com/tailcfg"
|
||||||
|
|
||||||
var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{
|
var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{
|
||||||
"v1.60.0": 87,
|
|
||||||
"v1.60.1": 87,
|
|
||||||
"v1.62.0": 88,
|
|
||||||
"v1.62.1": 88,
|
|
||||||
"v1.64.0": 90,
|
"v1.64.0": 90,
|
||||||
"v1.64.1": 90,
|
"v1.64.1": 90,
|
||||||
"v1.64.2": 90,
|
"v1.64.2": 90,
|
||||||
@ -36,18 +32,21 @@ var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{
|
|||||||
"v1.80.3": 113,
|
"v1.80.3": 113,
|
||||||
"v1.82.0": 115,
|
"v1.82.0": 115,
|
||||||
"v1.82.5": 115,
|
"v1.82.5": 115,
|
||||||
|
"v1.84.0": 116,
|
||||||
|
"v1.84.1": 116,
|
||||||
|
"v1.84.2": 116,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{
|
var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{
|
||||||
87: "v1.60.0",
|
90: "v1.64.0",
|
||||||
88: "v1.62.0",
|
95: "v1.66.0",
|
||||||
90: "v1.64.0",
|
97: "v1.68.0",
|
||||||
95: "v1.66.0",
|
102: "v1.70.0",
|
||||||
97: "v1.68.0",
|
104: "v1.72.0",
|
||||||
102: "v1.70.0",
|
106: "v1.74.0",
|
||||||
104: "v1.72.0",
|
109: "v1.78.0",
|
||||||
106: "v1.74.0",
|
113: "v1.80.0",
|
||||||
109: "v1.78.0",
|
115: "v1.82.0",
|
||||||
113: "v1.80.0",
|
116: "v1.84.0",
|
||||||
115: "v1.82.0",
|
|
||||||
}
|
}
|
||||||
|
@ -13,11 +13,10 @@ func TestTailscaleLatestMajorMinor(t *testing.T) {
|
|||||||
stripV bool
|
stripV bool
|
||||||
expected []string
|
expected []string
|
||||||
}{
|
}{
|
||||||
{3, false, []string{"v1.78", "v1.80", "v1.82"}},
|
{3, false, []string{"v1.80", "v1.82", "v1.84"}},
|
||||||
{2, true, []string{"1.80", "1.82"}},
|
{2, true, []string{"1.82", "1.84"}},
|
||||||
// Lazy way to see all supported versions
|
// Lazy way to see all supported versions
|
||||||
{10, true, []string{
|
{10, true, []string{
|
||||||
"1.64",
|
|
||||||
"1.66",
|
"1.66",
|
||||||
"1.68",
|
"1.68",
|
||||||
"1.70",
|
"1.70",
|
||||||
@ -27,6 +26,7 @@ func TestTailscaleLatestMajorMinor(t *testing.T) {
|
|||||||
"1.78",
|
"1.78",
|
||||||
"1.80",
|
"1.80",
|
||||||
"1.82",
|
"1.82",
|
||||||
|
"1.84",
|
||||||
}},
|
}},
|
||||||
{0, false, nil},
|
{0, false, nil},
|
||||||
}
|
}
|
||||||
@ -46,7 +46,6 @@ func TestCapVerMinimumTailscaleVersion(t *testing.T) {
|
|||||||
input tailcfg.CapabilityVersion
|
input tailcfg.CapabilityVersion
|
||||||
expected string
|
expected string
|
||||||
}{
|
}{
|
||||||
{88, "v1.62.0"},
|
|
||||||
{90, "v1.64.0"},
|
{90, "v1.64.0"},
|
||||||
{95, "v1.66.0"},
|
{95, "v1.66.0"},
|
||||||
{106, "v1.74.0"},
|
{106, "v1.74.0"},
|
||||||
|
@ -7,7 +7,6 @@ import (
|
|||||||
"os/exec"
|
"os/exec"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"slices"
|
"slices"
|
||||||
"sort"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -362,8 +361,8 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if diff := cmp.Diff(expectedKeys, keys, cmp.Comparer(func(a, b []string) bool {
|
if diff := cmp.Diff(expectedKeys, keys, cmp.Comparer(func(a, b []string) bool {
|
||||||
sort.Sort(sort.StringSlice(a))
|
slices.Sort(a)
|
||||||
sort.Sort(sort.StringSlice(b))
|
slices.Sort(b)
|
||||||
return slices.Equal(a, b)
|
return slices.Equal(a, b)
|
||||||
}), cmpopts.IgnoreFields(types.PreAuthKey{}, "User", "CreatedAt", "Reusable", "Ephemeral", "Used", "Expiration")); diff != "" {
|
}), cmpopts.IgnoreFields(types.PreAuthKey{}, "User", "CreatedAt", "Reusable", "Ephemeral", "Used", "Expiration")); diff != "" {
|
||||||
t.Errorf("TestSQLiteMigrationAndDataValidation() pre-auth key tags migration mismatch (-want +got):\n%s", diff)
|
t.Errorf("TestSQLiteMigrationAndDataValidation() pre-auth key tags migration mismatch (-want +got):\n%s", diff)
|
||||||
|
@ -7,15 +7,19 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
"sort"
|
"sort"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
|
"tailscale.com/types/ptr"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -39,9 +43,7 @@ var (
|
|||||||
// If no peer IDs are given, all peers are returned.
|
// If no peer IDs are given, all peers are returned.
|
||||||
// If at least one peer ID is given, only these peer nodes will be returned.
|
// If at least one peer ID is given, only these peer nodes will be returned.
|
||||||
func (hsdb *HSDatabase) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
|
func (hsdb *HSDatabase) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
|
||||||
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
return ListPeers(hsdb.DB, nodeID, peerIDs...)
|
||||||
return ListPeers(rx, nodeID, peerIDs...)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListPeers returns peers of node, regardless of any Policy or if the node is expired.
|
// ListPeers returns peers of node, regardless of any Policy or if the node is expired.
|
||||||
@ -66,9 +68,7 @@ func ListPeers(tx *gorm.DB, nodeID types.NodeID, peerIDs ...types.NodeID) (types
|
|||||||
// ListNodes queries the database for either all nodes if no parameters are given
|
// ListNodes queries the database for either all nodes if no parameters are given
|
||||||
// or for the given nodes if at least one node ID is given as parameter.
|
// or for the given nodes if at least one node ID is given as parameter.
|
||||||
func (hsdb *HSDatabase) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
|
func (hsdb *HSDatabase) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
|
||||||
return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) {
|
return ListNodes(hsdb.DB, nodeIDs...)
|
||||||
return ListNodes(rx, nodeIDs...)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListNodes queries the database for either all nodes if no parameters are given
|
// ListNodes queries the database for either all nodes if no parameters are given
|
||||||
@ -120,9 +120,7 @@ func getNode(tx *gorm.DB, uid types.UserID, name string) (*types.Node, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) GetNodeByID(id types.NodeID) (*types.Node, error) {
|
func (hsdb *HSDatabase) GetNodeByID(id types.NodeID) (*types.Node, error) {
|
||||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
|
return GetNodeByID(hsdb.DB, id)
|
||||||
return GetNodeByID(rx, id)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNodeByID finds a Node by ID and returns the Node struct.
|
// GetNodeByID finds a Node by ID and returns the Node struct.
|
||||||
@ -140,9 +138,7 @@ func GetNodeByID(tx *gorm.DB, id types.NodeID) (*types.Node, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) GetNodeByMachineKey(machineKey key.MachinePublic) (*types.Node, error) {
|
func (hsdb *HSDatabase) GetNodeByMachineKey(machineKey key.MachinePublic) (*types.Node, error) {
|
||||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
|
return GetNodeByMachineKey(hsdb.DB, machineKey)
|
||||||
return GetNodeByMachineKey(rx, machineKey)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNodeByMachineKey finds a Node by its MachineKey and returns the Node struct.
|
// GetNodeByMachineKey finds a Node by its MachineKey and returns the Node struct.
|
||||||
@ -163,9 +159,7 @@ func GetNodeByMachineKey(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) GetNodeByNodeKey(nodeKey key.NodePublic) (*types.Node, error) {
|
func (hsdb *HSDatabase) GetNodeByNodeKey(nodeKey key.NodePublic) (*types.Node, error) {
|
||||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) {
|
return GetNodeByNodeKey(hsdb.DB, nodeKey)
|
||||||
return GetNodeByNodeKey(rx, nodeKey)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNodeByNodeKey finds a Node by its NodeKey and returns the Node struct.
|
// GetNodeByNodeKey finds a Node by its NodeKey and returns the Node struct.
|
||||||
@ -352,8 +346,8 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
|
|||||||
registrationMethod string,
|
registrationMethod string,
|
||||||
ipv4 *netip.Addr,
|
ipv4 *netip.Addr,
|
||||||
ipv6 *netip.Addr,
|
ipv6 *netip.Addr,
|
||||||
) (*types.Node, bool, error) {
|
) (*types.Node, change.ChangeSet, error) {
|
||||||
var newNode bool
|
var nodeChange change.ChangeSet
|
||||||
node, err := Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
|
node, err := Write(hsdb.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||||
if reg, ok := hsdb.regCache.Get(registrationID); ok {
|
if reg, ok := hsdb.regCache.Get(registrationID); ok {
|
||||||
if node, _ := GetNodeByNodeKey(tx, reg.Node.NodeKey); node == nil {
|
if node, _ := GetNodeByNodeKey(tx, reg.Node.NodeKey); node == nil {
|
||||||
@ -405,7 +399,7 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
|
|||||||
}
|
}
|
||||||
close(reg.Registered)
|
close(reg.Registered)
|
||||||
|
|
||||||
newNode = true
|
nodeChange = change.NodeAdded(node.ID)
|
||||||
|
|
||||||
return node, err
|
return node, err
|
||||||
} else {
|
} else {
|
||||||
@ -415,6 +409,8 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
nodeChange = change.KeyExpiry(node.ID)
|
||||||
|
|
||||||
return node, nil
|
return node, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -422,7 +418,7 @@ func (hsdb *HSDatabase) HandleNodeFromAuthPath(
|
|||||||
return nil, ErrNodeNotFoundRegistrationCache
|
return nil, ErrNodeNotFoundRegistrationCache
|
||||||
})
|
})
|
||||||
|
|
||||||
return node, newNode, err
|
return node, nodeChange, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {
|
func (hsdb *HSDatabase) RegisterNode(node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) {
|
||||||
@ -448,6 +444,7 @@ func RegisterNode(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Ad
|
|||||||
if oldNode != nil && oldNode.UserID == node.UserID {
|
if oldNode != nil && oldNode.UserID == node.UserID {
|
||||||
node.ID = oldNode.ID
|
node.ID = oldNode.ID
|
||||||
node.GivenName = oldNode.GivenName
|
node.GivenName = oldNode.GivenName
|
||||||
|
node.ApprovedRoutes = oldNode.ApprovedRoutes
|
||||||
ipv4 = oldNode.IPv4
|
ipv4 = oldNode.IPv4
|
||||||
ipv6 = oldNode.IPv6
|
ipv6 = oldNode.IPv6
|
||||||
}
|
}
|
||||||
@ -594,17 +591,18 @@ func ensureUniqueGivenName(
|
|||||||
// containing the expired nodes, and a boolean indicating if any nodes were found.
|
// containing the expired nodes, and a boolean indicating if any nodes were found.
|
||||||
func ExpireExpiredNodes(tx *gorm.DB,
|
func ExpireExpiredNodes(tx *gorm.DB,
|
||||||
lastCheck time.Time,
|
lastCheck time.Time,
|
||||||
) (time.Time, types.StateUpdate, bool) {
|
) (time.Time, []change.ChangeSet, bool) {
|
||||||
// use the time of the start of the function to ensure we
|
// use the time of the start of the function to ensure we
|
||||||
// dont miss some nodes by returning it _after_ we have
|
// dont miss some nodes by returning it _after_ we have
|
||||||
// checked everything.
|
// checked everything.
|
||||||
started := time.Now()
|
started := time.Now()
|
||||||
|
|
||||||
expired := make([]*tailcfg.PeerChange, 0)
|
expired := make([]*tailcfg.PeerChange, 0)
|
||||||
|
var updates []change.ChangeSet
|
||||||
|
|
||||||
nodes, err := ListNodes(tx)
|
nodes, err := ListNodes(tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return time.Unix(0, 0), types.StateUpdate{}, false
|
return time.Unix(0, 0), nil, false
|
||||||
}
|
}
|
||||||
for _, node := range nodes {
|
for _, node := range nodes {
|
||||||
if node.IsExpired() && node.Expiry.After(lastCheck) {
|
if node.IsExpired() && node.Expiry.After(lastCheck) {
|
||||||
@ -612,14 +610,15 @@ func ExpireExpiredNodes(tx *gorm.DB,
|
|||||||
NodeID: tailcfg.NodeID(node.ID),
|
NodeID: tailcfg.NodeID(node.ID),
|
||||||
KeyExpiry: node.Expiry,
|
KeyExpiry: node.Expiry,
|
||||||
})
|
})
|
||||||
|
updates = append(updates, change.KeyExpiry(node.ID))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(expired) > 0 {
|
if len(expired) > 0 {
|
||||||
return started, types.UpdatePeerPatch(expired...), true
|
return started, updates, true
|
||||||
}
|
}
|
||||||
|
|
||||||
return started, types.StateUpdate{}, false
|
return started, nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// EphemeralGarbageCollector is a garbage collector that will delete nodes after
|
// EphemeralGarbageCollector is a garbage collector that will delete nodes after
|
||||||
@ -732,3 +731,114 @@ func (e *EphemeralGarbageCollector) Start() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string) *types.Node {
|
||||||
|
if !testing.Testing() {
|
||||||
|
panic("CreateNodeForTest can only be called during tests")
|
||||||
|
}
|
||||||
|
|
||||||
|
if user == nil {
|
||||||
|
panic("CreateNodeForTest requires a valid user")
|
||||||
|
}
|
||||||
|
|
||||||
|
nodeName := "testnode"
|
||||||
|
if len(hostname) > 0 && hostname[0] != "" {
|
||||||
|
nodeName = hostname[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a preauth key for the node
|
||||||
|
pak, err := hsdb.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("failed to create preauth key for test node: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
nodeKey := key.NewNode()
|
||||||
|
machineKey := key.NewMachine()
|
||||||
|
discoKey := key.NewDisco()
|
||||||
|
|
||||||
|
node := &types.Node{
|
||||||
|
MachineKey: machineKey.Public(),
|
||||||
|
NodeKey: nodeKey.Public(),
|
||||||
|
DiscoKey: discoKey.Public(),
|
||||||
|
Hostname: nodeName,
|
||||||
|
UserID: user.ID,
|
||||||
|
RegisterMethod: util.RegisterMethodAuthKey,
|
||||||
|
AuthKeyID: ptr.To(pak.ID),
|
||||||
|
}
|
||||||
|
|
||||||
|
err = hsdb.DB.Save(node).Error
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("failed to create test node: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hsdb *HSDatabase) CreateRegisteredNodeForTest(user *types.User, hostname ...string) *types.Node {
|
||||||
|
if !testing.Testing() {
|
||||||
|
panic("CreateRegisteredNodeForTest can only be called during tests")
|
||||||
|
}
|
||||||
|
|
||||||
|
node := hsdb.CreateNodeForTest(user, hostname...)
|
||||||
|
|
||||||
|
err := hsdb.DB.Transaction(func(tx *gorm.DB) error {
|
||||||
|
_, err := RegisterNode(tx, *node, nil, nil)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("failed to register test node: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
registeredNode, err := hsdb.GetNodeByID(node.ID)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("failed to get registered test node: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return registeredNode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hsdb *HSDatabase) CreateNodesForTest(user *types.User, count int, hostnamePrefix ...string) []*types.Node {
|
||||||
|
if !testing.Testing() {
|
||||||
|
panic("CreateNodesForTest can only be called during tests")
|
||||||
|
}
|
||||||
|
|
||||||
|
if user == nil {
|
||||||
|
panic("CreateNodesForTest requires a valid user")
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix := "testnode"
|
||||||
|
if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" {
|
||||||
|
prefix = hostnamePrefix[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
nodes := make([]*types.Node, count)
|
||||||
|
for i := range count {
|
||||||
|
hostname := prefix + "-" + strconv.Itoa(i)
|
||||||
|
nodes[i] = hsdb.CreateNodeForTest(user, hostname)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nodes
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hsdb *HSDatabase) CreateRegisteredNodesForTest(user *types.User, count int, hostnamePrefix ...string) []*types.Node {
|
||||||
|
if !testing.Testing() {
|
||||||
|
panic("CreateRegisteredNodesForTest can only be called during tests")
|
||||||
|
}
|
||||||
|
|
||||||
|
if user == nil {
|
||||||
|
panic("CreateRegisteredNodesForTest requires a valid user")
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix := "testnode"
|
||||||
|
if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" {
|
||||||
|
prefix = hostnamePrefix[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
nodes := make([]*types.Node, count)
|
||||||
|
for i := range count {
|
||||||
|
hostname := prefix + "-" + strconv.Itoa(i)
|
||||||
|
nodes[i] = hsdb.CreateRegisteredNodeForTest(user, hostname)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nodes
|
||||||
|
}
|
||||||
|
@ -6,7 +6,6 @@ import (
|
|||||||
"math/big"
|
"math/big"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -26,82 +25,36 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (s *Suite) TestGetNode(c *check.C) {
|
func (s *Suite) TestGetNode(c *check.C) {
|
||||||
user, err := db.CreateUser(types.User{Name: "test"})
|
user := db.CreateUserForTest("test")
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
_, err := db.getNode(types.UserID(user.ID), "testnode")
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
_, err = db.getNode(types.UserID(user.ID), "testnode")
|
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
nodeKey := key.NewNode()
|
node := db.CreateNodeForTest(user, "testnode")
|
||||||
machineKey := key.NewMachine()
|
|
||||||
|
|
||||||
node := &types.Node{
|
|
||||||
ID: 0,
|
|
||||||
MachineKey: machineKey.Public(),
|
|
||||||
NodeKey: nodeKey.Public(),
|
|
||||||
Hostname: "testnode",
|
|
||||||
UserID: user.ID,
|
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
|
||||||
AuthKeyID: ptr.To(pak.ID),
|
|
||||||
}
|
|
||||||
trx := db.DB.Save(node)
|
|
||||||
c.Assert(trx.Error, check.IsNil)
|
|
||||||
|
|
||||||
_, err = db.getNode(types.UserID(user.ID), "testnode")
|
_, err = db.getNode(types.UserID(user.ID), "testnode")
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(node.Hostname, check.Equals, "testnode")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestGetNodeByID(c *check.C) {
|
func (s *Suite) TestGetNodeByID(c *check.C) {
|
||||||
user, err := db.CreateUser(types.User{Name: "test"})
|
user := db.CreateUserForTest("test")
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
_, err := db.GetNodeByID(0)
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
_, err = db.GetNodeByID(0)
|
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
nodeKey := key.NewNode()
|
node := db.CreateNodeForTest(user, "testnode")
|
||||||
machineKey := key.NewMachine()
|
|
||||||
|
|
||||||
node := types.Node{
|
retrievedNode, err := db.GetNodeByID(node.ID)
|
||||||
ID: 0,
|
|
||||||
MachineKey: machineKey.Public(),
|
|
||||||
NodeKey: nodeKey.Public(),
|
|
||||||
Hostname: "testnode",
|
|
||||||
UserID: user.ID,
|
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
|
||||||
AuthKeyID: ptr.To(pak.ID),
|
|
||||||
}
|
|
||||||
trx := db.DB.Save(&node)
|
|
||||||
c.Assert(trx.Error, check.IsNil)
|
|
||||||
|
|
||||||
_, err = db.GetNodeByID(0)
|
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
c.Assert(retrievedNode.Hostname, check.Equals, "testnode")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestHardDeleteNode(c *check.C) {
|
func (s *Suite) TestHardDeleteNode(c *check.C) {
|
||||||
user, err := db.CreateUser(types.User{Name: "test"})
|
user := db.CreateUserForTest("test")
|
||||||
c.Assert(err, check.IsNil)
|
node := db.CreateNodeForTest(user, "testnode3")
|
||||||
|
|
||||||
nodeKey := key.NewNode()
|
err := db.DeleteNode(node)
|
||||||
machineKey := key.NewMachine()
|
|
||||||
|
|
||||||
node := types.Node{
|
|
||||||
ID: 0,
|
|
||||||
MachineKey: machineKey.Public(),
|
|
||||||
NodeKey: nodeKey.Public(),
|
|
||||||
Hostname: "testnode3",
|
|
||||||
UserID: user.ID,
|
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
|
||||||
}
|
|
||||||
trx := db.DB.Save(&node)
|
|
||||||
c.Assert(trx.Error, check.IsNil)
|
|
||||||
|
|
||||||
err = db.DeleteNode(&node)
|
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
_, err = db.getNode(types.UserID(user.ID), "testnode3")
|
_, err = db.getNode(types.UserID(user.ID), "testnode3")
|
||||||
@ -109,42 +62,21 @@ func (s *Suite) TestHardDeleteNode(c *check.C) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestListPeers(c *check.C) {
|
func (s *Suite) TestListPeers(c *check.C) {
|
||||||
user, err := db.CreateUser(types.User{Name: "test"})
|
user := db.CreateUserForTest("test")
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
_, err := db.GetNodeByID(0)
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
_, err = db.GetNodeByID(0)
|
|
||||||
c.Assert(err, check.NotNil)
|
c.Assert(err, check.NotNil)
|
||||||
|
|
||||||
for index := range 11 {
|
nodes := db.CreateNodesForTest(user, 11, "testnode")
|
||||||
nodeKey := key.NewNode()
|
|
||||||
machineKey := key.NewMachine()
|
|
||||||
|
|
||||||
node := types.Node{
|
firstNode := nodes[0]
|
||||||
ID: types.NodeID(index),
|
peersOfFirstNode, err := db.ListPeers(firstNode.ID)
|
||||||
MachineKey: machineKey.Public(),
|
|
||||||
NodeKey: nodeKey.Public(),
|
|
||||||
Hostname: "testnode" + strconv.Itoa(index),
|
|
||||||
UserID: user.ID,
|
|
||||||
RegisterMethod: util.RegisterMethodAuthKey,
|
|
||||||
AuthKeyID: ptr.To(pak.ID),
|
|
||||||
}
|
|
||||||
trx := db.DB.Save(&node)
|
|
||||||
c.Assert(trx.Error, check.IsNil)
|
|
||||||
}
|
|
||||||
|
|
||||||
node0ByID, err := db.GetNodeByID(0)
|
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
|
||||||
peersOfNode0, err := db.ListPeers(node0ByID.ID)
|
c.Assert(len(peersOfFirstNode), check.Equals, 10)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(peersOfFirstNode[0].Hostname, check.Equals, "testnode-1")
|
||||||
|
c.Assert(peersOfFirstNode[5].Hostname, check.Equals, "testnode-6")
|
||||||
c.Assert(len(peersOfNode0), check.Equals, 9)
|
c.Assert(peersOfFirstNode[9].Hostname, check.Equals, "testnode-10")
|
||||||
c.Assert(peersOfNode0[0].Hostname, check.Equals, "testnode2")
|
|
||||||
c.Assert(peersOfNode0[5].Hostname, check.Equals, "testnode7")
|
|
||||||
c.Assert(peersOfNode0[8].Hostname, check.Equals, "testnode10")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestExpireNode(c *check.C) {
|
func (s *Suite) TestExpireNode(c *check.C) {
|
||||||
@ -807,13 +739,13 @@ func TestListPeers(t *testing.T) {
|
|||||||
// No parameter means no filter, should return all peers
|
// No parameter means no filter, should return all peers
|
||||||
nodes, err = db.ListPeers(1)
|
nodes, err = db.ListPeers(1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, nodes, 1)
|
assert.Equal(t, 1, len(nodes))
|
||||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||||
|
|
||||||
// Empty node list should return all peers
|
// Empty node list should return all peers
|
||||||
nodes, err = db.ListPeers(1, types.NodeIDs{}...)
|
nodes, err = db.ListPeers(1, types.NodeIDs{}...)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, nodes, 1)
|
assert.Equal(t, 1, len(nodes))
|
||||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||||
|
|
||||||
// No match in IDs should return empty list and no error
|
// No match in IDs should return empty list and no error
|
||||||
@ -824,13 +756,13 @@ func TestListPeers(t *testing.T) {
|
|||||||
// Partial match in IDs
|
// Partial match in IDs
|
||||||
nodes, err = db.ListPeers(1, types.NodeIDs{2, 3}...)
|
nodes, err = db.ListPeers(1, types.NodeIDs{2, 3}...)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, nodes, 1)
|
assert.Equal(t, 1, len(nodes))
|
||||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||||
|
|
||||||
// Several matched IDs, but node ID is still filtered out
|
// Several matched IDs, but node ID is still filtered out
|
||||||
nodes, err = db.ListPeers(1, types.NodeIDs{1, 2, 3}...)
|
nodes, err = db.ListPeers(1, types.NodeIDs{1, 2, 3}...)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, nodes, 1)
|
assert.Equal(t, 1, len(nodes))
|
||||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -892,14 +824,14 @@ func TestListNodes(t *testing.T) {
|
|||||||
// No parameter means no filter, should return all nodes
|
// No parameter means no filter, should return all nodes
|
||||||
nodes, err = db.ListNodes()
|
nodes, err = db.ListNodes()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, nodes, 2)
|
assert.Equal(t, 2, len(nodes))
|
||||||
assert.Equal(t, "test1", nodes[0].Hostname)
|
assert.Equal(t, "test1", nodes[0].Hostname)
|
||||||
assert.Equal(t, "test2", nodes[1].Hostname)
|
assert.Equal(t, "test2", nodes[1].Hostname)
|
||||||
|
|
||||||
// Empty node list should return all nodes
|
// Empty node list should return all nodes
|
||||||
nodes, err = db.ListNodes(types.NodeIDs{}...)
|
nodes, err = db.ListNodes(types.NodeIDs{}...)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, nodes, 2)
|
assert.Equal(t, 2, len(nodes))
|
||||||
assert.Equal(t, "test1", nodes[0].Hostname)
|
assert.Equal(t, "test1", nodes[0].Hostname)
|
||||||
assert.Equal(t, "test2", nodes[1].Hostname)
|
assert.Equal(t, "test2", nodes[1].Hostname)
|
||||||
|
|
||||||
@ -911,13 +843,13 @@ func TestListNodes(t *testing.T) {
|
|||||||
// Partial match in IDs
|
// Partial match in IDs
|
||||||
nodes, err = db.ListNodes(types.NodeIDs{2, 3}...)
|
nodes, err = db.ListNodes(types.NodeIDs{2, 3}...)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, nodes, 1)
|
assert.Equal(t, 1, len(nodes))
|
||||||
assert.Equal(t, "test2", nodes[0].Hostname)
|
assert.Equal(t, "test2", nodes[0].Hostname)
|
||||||
|
|
||||||
// Several matched IDs
|
// Several matched IDs
|
||||||
nodes, err = db.ListNodes(types.NodeIDs{1, 2, 3}...)
|
nodes, err = db.ListNodes(types.NodeIDs{1, 2, 3}...)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, nodes, 2)
|
assert.Equal(t, 2, len(nodes))
|
||||||
assert.Equal(t, "test1", nodes[0].Hostname)
|
assert.Equal(t, "test1", nodes[0].Hostname)
|
||||||
assert.Equal(t, "test2", nodes[1].Hostname)
|
assert.Equal(t, "test2", nodes[1].Hostname)
|
||||||
}
|
}
|
||||||
|
@ -109,9 +109,7 @@ func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) GetPreAuthKey(key string) (*types.PreAuthKey, error) {
|
func (hsdb *HSDatabase) GetPreAuthKey(key string) (*types.PreAuthKey, error) {
|
||||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.PreAuthKey, error) {
|
return GetPreAuthKey(hsdb.DB, key)
|
||||||
return GetPreAuthKey(rx, key)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPreAuthKey returns a PreAuthKey for a given key. The caller is responsible
|
// GetPreAuthKey returns a PreAuthKey for a given key. The caller is responsible
|
||||||
@ -155,11 +153,8 @@ func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
|
|||||||
|
|
||||||
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
|
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
|
||||||
func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
|
func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
|
||||||
if err := tx.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
|
now := time.Now()
|
||||||
return err
|
return tx.Model(&types.PreAuthKey{}).Where("id = ?", k.ID).Update("expiration", now).Error
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateKey() (string, error) {
|
func generateKey() (string, error) {
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"sort"
|
"slices"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
@ -57,7 +57,7 @@ func (*Suite) TestPreAuthKeyACLTags(c *check.C) {
|
|||||||
listedPaks, err := db.ListPreAuthKeys(types.UserID(user.ID))
|
listedPaks, err := db.ListPreAuthKeys(types.UserID(user.ID))
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
gotTags := listedPaks[0].Proto().GetAclTags()
|
gotTags := listedPaks[0].Proto().GetAclTags()
|
||||||
sort.Sort(sort.StringSlice(gotTags))
|
slices.Sort(gotTags)
|
||||||
c.Assert(gotTags, check.DeepEquals, tags)
|
c.Assert(gotTags, check.DeepEquals, tags)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3,6 +3,8 @@ package db
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
@ -110,9 +112,7 @@ func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) GetUserByID(uid types.UserID) (*types.User, error) {
|
func (hsdb *HSDatabase) GetUserByID(uid types.UserID) (*types.User, error) {
|
||||||
return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) {
|
return GetUserByID(hsdb.DB, uid)
|
||||||
return GetUserByID(rx, uid)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserByID(tx *gorm.DB, uid types.UserID) (*types.User, error) {
|
func GetUserByID(tx *gorm.DB, uid types.UserID) (*types.User, error) {
|
||||||
@ -146,9 +146,7 @@ func GetUserByOIDCIdentifier(tx *gorm.DB, id string) (*types.User, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) ListUsers(where ...*types.User) ([]types.User, error) {
|
func (hsdb *HSDatabase) ListUsers(where ...*types.User) ([]types.User, error) {
|
||||||
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) {
|
return ListUsers(hsdb.DB, where...)
|
||||||
return ListUsers(rx, where...)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListUsers gets all the existing users.
|
// ListUsers gets all the existing users.
|
||||||
@ -217,3 +215,40 @@ func AssignNodeToUser(tx *gorm.DB, nodeID types.NodeID, uid types.UserID) error
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (hsdb *HSDatabase) CreateUserForTest(name ...string) *types.User {
|
||||||
|
if !testing.Testing() {
|
||||||
|
panic("CreateUserForTest can only be called during tests")
|
||||||
|
}
|
||||||
|
|
||||||
|
userName := "testuser"
|
||||||
|
if len(name) > 0 && name[0] != "" {
|
||||||
|
userName = name[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := hsdb.CreateUser(types.User{Name: userName})
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("failed to create test user: %v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return user
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hsdb *HSDatabase) CreateUsersForTest(count int, namePrefix ...string) []*types.User {
|
||||||
|
if !testing.Testing() {
|
||||||
|
panic("CreateUsersForTest can only be called during tests")
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix := "testuser"
|
||||||
|
if len(namePrefix) > 0 && namePrefix[0] != "" {
|
||||||
|
prefix = namePrefix[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
users := make([]*types.User, count)
|
||||||
|
for i := range count {
|
||||||
|
name := prefix + "-" + strconv.Itoa(i)
|
||||||
|
users[i] = hsdb.CreateUserForTest(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return users
|
||||||
|
}
|
||||||
|
@ -11,8 +11,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (s *Suite) TestCreateAndDestroyUser(c *check.C) {
|
func (s *Suite) TestCreateAndDestroyUser(c *check.C) {
|
||||||
user, err := db.CreateUser(types.User{Name: "test"})
|
user := db.CreateUserForTest("test")
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
c.Assert(user.Name, check.Equals, "test")
|
c.Assert(user.Name, check.Equals, "test")
|
||||||
|
|
||||||
users, err := db.ListUsers()
|
users, err := db.ListUsers()
|
||||||
@ -30,8 +29,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
|
|||||||
err := db.DestroyUser(9998)
|
err := db.DestroyUser(9998)
|
||||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||||
|
|
||||||
user, err := db.CreateUser(types.User{Name: "test"})
|
user := db.CreateUserForTest("test")
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
@ -64,8 +62,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestRenameUser(c *check.C) {
|
func (s *Suite) TestRenameUser(c *check.C) {
|
||||||
userTest, err := db.CreateUser(types.User{Name: "test"})
|
userTest := db.CreateUserForTest("test")
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
c.Assert(userTest.Name, check.Equals, "test")
|
c.Assert(userTest.Name, check.Equals, "test")
|
||||||
|
|
||||||
users, err := db.ListUsers()
|
users, err := db.ListUsers()
|
||||||
@ -86,8 +83,7 @@ func (s *Suite) TestRenameUser(c *check.C) {
|
|||||||
err = db.RenameUser(99988, "test")
|
err = db.RenameUser(99988, "test")
|
||||||
c.Assert(err, check.Equals, ErrUserNotFound)
|
c.Assert(err, check.Equals, ErrUserNotFound)
|
||||||
|
|
||||||
userTest2, err := db.CreateUser(types.User{Name: "test2"})
|
userTest2 := db.CreateUserForTest("test2")
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
c.Assert(userTest2.Name, check.Equals, "test2")
|
c.Assert(userTest2.Name, check.Equals, "test2")
|
||||||
|
|
||||||
want := "UNIQUE constraint failed"
|
want := "UNIQUE constraint failed"
|
||||||
@ -98,11 +94,8 @@ func (s *Suite) TestRenameUser(c *check.C) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Suite) TestSetMachineUser(c *check.C) {
|
func (s *Suite) TestSetMachineUser(c *check.C) {
|
||||||
oldUser, err := db.CreateUser(types.User{Name: "old"})
|
oldUser := db.CreateUserForTest("old")
|
||||||
c.Assert(err, check.IsNil)
|
newUser := db.CreateUserForTest("new")
|
||||||
|
|
||||||
newUser, err := db.CreateUser(types.User{Name: "new"})
|
|
||||||
c.Assert(err, check.IsNil)
|
|
||||||
|
|
||||||
pak, err := db.CreatePreAuthKey(types.UserID(oldUser.ID), false, false, nil, nil)
|
pak, err := db.CreatePreAuthKey(types.UserID(oldUser.ID), false, false, nil, nil)
|
||||||
c.Assert(err, check.IsNil)
|
c.Assert(err, check.IsNil)
|
||||||
|
@ -17,10 +17,6 @@ import (
|
|||||||
func (h *Headscale) debugHTTPServer() *http.Server {
|
func (h *Headscale) debugHTTPServer() *http.Server {
|
||||||
debugMux := http.NewServeMux()
|
debugMux := http.NewServeMux()
|
||||||
debug := tsweb.Debugger(debugMux)
|
debug := tsweb.Debugger(debugMux)
|
||||||
debug.Handle("notifier", "Connected nodes in notifier", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
w.Write([]byte(h.nodeNotifier.String()))
|
|
||||||
}))
|
|
||||||
debug.Handle("config", "Current configuration", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
debug.Handle("config", "Current configuration", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
config, err := json.MarshalIndent(h.cfg, "", " ")
|
config, err := json.MarshalIndent(h.cfg, "", " ")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
|
"maps"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
@ -72,9 +73,7 @@ func mergeDERPMaps(derpMaps []*tailcfg.DERPMap) *tailcfg.DERPMap {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, derpMap := range derpMaps {
|
for _, derpMap := range derpMaps {
|
||||||
for id, region := range derpMap.Regions {
|
maps.Copy(result.Regions, derpMap.Regions)
|
||||||
result.Regions[id] = region
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &result
|
return &result
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
//go:generate buf generate --template ../buf.gen.yaml -o .. ../proto
|
||||||
|
|
||||||
// nolint
|
// nolint
|
||||||
package hscontrol
|
package hscontrol
|
||||||
|
|
||||||
@ -27,6 +29,7 @@ import (
|
|||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
"github.com/juanfont/headscale/hscontrol/state"
|
"github.com/juanfont/headscale/hscontrol/state"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -56,12 +59,14 @@ func (api headscaleV1APIServer) CreateUser(
|
|||||||
return nil, status.Errorf(codes.Internal, "failed to create user: %s", err)
|
return nil, status.Errorf(codes.Internal, "failed to create user: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send policy update notifications if needed
|
|
||||||
|
c := change.UserAdded(types.UserID(user.ID))
|
||||||
if policyChanged {
|
if policyChanged {
|
||||||
ctx := types.NotifyCtx(context.Background(), "grpc-user-created", user.Name)
|
c.Change = change.Policy
|
||||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
api.h.Change(c)
|
||||||
|
|
||||||
return &v1.CreateUserResponse{User: user.Proto()}, nil
|
return &v1.CreateUserResponse{User: user.Proto()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -81,8 +86,7 @@ func (api headscaleV1APIServer) RenameUser(
|
|||||||
|
|
||||||
// Send policy update notifications if needed
|
// Send policy update notifications if needed
|
||||||
if policyChanged {
|
if policyChanged {
|
||||||
ctx := types.NotifyCtx(context.Background(), "grpc-user-renamed", request.GetNewName())
|
api.h.Change(change.PolicyChange())
|
||||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
newUser, err := api.h.state.GetUserByName(request.GetNewName())
|
newUser, err := api.h.state.GetUserByName(request.GetNewName())
|
||||||
@ -107,6 +111,8 @@ func (api headscaleV1APIServer) DeleteUser(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
api.h.Change(change.UserRemoved(types.UserID(user.ID)))
|
||||||
|
|
||||||
return &v1.DeleteUserResponse{}, nil
|
return &v1.DeleteUserResponse{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -246,7 +252,7 @@ func (api headscaleV1APIServer) RegisterNode(
|
|||||||
return nil, fmt.Errorf("looking up user: %w", err)
|
return nil, fmt.Errorf("looking up user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
node, _, err := api.h.state.HandleNodeFromAuthPath(
|
node, nodeChange, err := api.h.state.HandleNodeFromAuthPath(
|
||||||
registrationId,
|
registrationId,
|
||||||
types.UserID(user.ID),
|
types.UserID(user.ID),
|
||||||
nil,
|
nil,
|
||||||
@ -267,22 +273,13 @@ func (api headscaleV1APIServer) RegisterNode(
|
|||||||
// ensure we send an update.
|
// ensure we send an update.
|
||||||
// This works, but might be another good candidate for doing some sort of
|
// This works, but might be another good candidate for doing some sort of
|
||||||
// eventbus.
|
// eventbus.
|
||||||
routesChanged := api.h.state.AutoApproveRoutes(node)
|
_ = api.h.state.AutoApproveRoutes(node)
|
||||||
_, policyChanged, err := api.h.state.SaveNode(node)
|
_, _, err = api.h.state.SaveNode(node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("saving auto approved routes to node: %w", err)
|
return nil, fmt.Errorf("saving auto approved routes to node: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send policy update notifications if needed (from SaveNode or route changes)
|
api.h.Change(nodeChange)
|
||||||
if policyChanged {
|
|
||||||
ctx := types.NotifyCtx(context.Background(), "grpc-nodes-change", "all")
|
|
||||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
|
||||||
}
|
|
||||||
|
|
||||||
if routesChanged {
|
|
||||||
ctx = types.NotifyCtx(context.Background(), "web-node-login", node.Hostname)
|
|
||||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerChanged(node.ID))
|
|
||||||
}
|
|
||||||
|
|
||||||
return &v1.RegisterNodeResponse{Node: node.Proto()}, nil
|
return &v1.RegisterNodeResponse{Node: node.Proto()}, nil
|
||||||
}
|
}
|
||||||
@ -300,7 +297,7 @@ func (api headscaleV1APIServer) GetNode(
|
|||||||
|
|
||||||
// Populate the online field based on
|
// Populate the online field based on
|
||||||
// currently connected nodes.
|
// currently connected nodes.
|
||||||
resp.Online = api.h.nodeNotifier.IsConnected(node.ID)
|
resp.Online = api.h.mapBatcher.IsConnected(node.ID)
|
||||||
|
|
||||||
return &v1.GetNodeResponse{Node: resp}, nil
|
return &v1.GetNodeResponse{Node: resp}, nil
|
||||||
}
|
}
|
||||||
@ -316,21 +313,14 @@ func (api headscaleV1APIServer) SetTags(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
node, policyChanged, err := api.h.state.SetNodeTags(types.NodeID(request.GetNodeId()), request.GetTags())
|
node, nodeChange, err := api.h.state.SetNodeTags(types.NodeID(request.GetNodeId()), request.GetTags())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &v1.SetTagsResponse{
|
return &v1.SetTagsResponse{
|
||||||
Node: nil,
|
Node: nil,
|
||||||
}, status.Error(codes.InvalidArgument, err.Error())
|
}, status.Error(codes.InvalidArgument, err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send policy update notifications if needed
|
api.h.Change(nodeChange)
|
||||||
if policyChanged {
|
|
||||||
ctx := types.NotifyCtx(context.Background(), "grpc-node-tags", node.Hostname)
|
|
||||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx = types.NotifyCtx(ctx, "cli-settags", node.Hostname)
|
|
||||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
|
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("node", node.Hostname).
|
Str("node", node.Hostname).
|
||||||
@ -362,23 +352,19 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
|
|||||||
tsaddr.SortPrefixes(routes)
|
tsaddr.SortPrefixes(routes)
|
||||||
routes = slices.Compact(routes)
|
routes = slices.Compact(routes)
|
||||||
|
|
||||||
node, policyChanged, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), routes)
|
node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), routes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Error(codes.InvalidArgument, err.Error())
|
return nil, status.Error(codes.InvalidArgument, err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send policy update notifications if needed
|
routeChange := api.h.state.SetNodeRoutes(node.ID, node.SubnetRoutes()...)
|
||||||
if policyChanged {
|
|
||||||
ctx := types.NotifyCtx(context.Background(), "grpc-routes-approved", node.Hostname)
|
|
||||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
|
||||||
}
|
|
||||||
|
|
||||||
if api.h.state.SetNodeRoutes(node.ID, node.SubnetRoutes()...) {
|
// Always propagate node changes from SetApprovedRoutes
|
||||||
ctx := types.NotifyCtx(ctx, "poll-primary-change", node.Hostname)
|
api.h.Change(nodeChange)
|
||||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
|
||||||
} else {
|
// If routes changed, propagate those changes too
|
||||||
ctx = types.NotifyCtx(ctx, "cli-approveroutes", node.Hostname)
|
if !routeChange.Empty() {
|
||||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
|
api.h.Change(routeChange)
|
||||||
}
|
}
|
||||||
|
|
||||||
proto := node.Proto()
|
proto := node.Proto()
|
||||||
@ -409,19 +395,12 @@ func (api headscaleV1APIServer) DeleteNode(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
policyChanged, err := api.h.state.DeleteNode(node)
|
nodeChange, err := api.h.state.DeleteNode(node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send policy update notifications if needed
|
api.h.Change(nodeChange)
|
||||||
if policyChanged {
|
|
||||||
ctx := types.NotifyCtx(context.Background(), "grpc-node-deleted", node.Hostname)
|
|
||||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx = types.NotifyCtx(ctx, "cli-deletenode", node.Hostname)
|
|
||||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdatePeerRemoved(node.ID))
|
|
||||||
|
|
||||||
return &v1.DeleteNodeResponse{}, nil
|
return &v1.DeleteNodeResponse{}, nil
|
||||||
}
|
}
|
||||||
@ -432,25 +411,13 @@ func (api headscaleV1APIServer) ExpireNode(
|
|||||||
) (*v1.ExpireNodeResponse, error) {
|
) (*v1.ExpireNodeResponse, error) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
node, policyChanged, err := api.h.state.SetNodeExpiry(types.NodeID(request.GetNodeId()), now)
|
node, nodeChange, err := api.h.state.SetNodeExpiry(types.NodeID(request.GetNodeId()), now)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send policy update notifications if needed
|
// TODO(kradalby): Ensure that both the selfupdate and peer updates are sent
|
||||||
if policyChanged {
|
api.h.Change(nodeChange)
|
||||||
ctx := types.NotifyCtx(context.Background(), "grpc-node-expired", node.Hostname)
|
|
||||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx = types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname)
|
|
||||||
api.h.nodeNotifier.NotifyByNodeID(
|
|
||||||
ctx,
|
|
||||||
types.UpdateSelf(node.ID),
|
|
||||||
node.ID)
|
|
||||||
|
|
||||||
ctx = types.NotifyCtx(ctx, "cli-expirenode-peers", node.Hostname)
|
|
||||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdateExpire(node.ID, now), node.ID)
|
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("node", node.Hostname).
|
Str("node", node.Hostname).
|
||||||
@ -464,22 +431,13 @@ func (api headscaleV1APIServer) RenameNode(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.RenameNodeRequest,
|
request *v1.RenameNodeRequest,
|
||||||
) (*v1.RenameNodeResponse, error) {
|
) (*v1.RenameNodeResponse, error) {
|
||||||
node, policyChanged, err := api.h.state.RenameNode(types.NodeID(request.GetNodeId()), request.GetNewName())
|
node, nodeChange, err := api.h.state.RenameNode(types.NodeID(request.GetNodeId()), request.GetNewName())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send policy update notifications if needed
|
// TODO(kradalby): investigate if we need selfupdate
|
||||||
if policyChanged {
|
api.h.Change(nodeChange)
|
||||||
ctx := types.NotifyCtx(context.Background(), "grpc-node-renamed", node.Hostname)
|
|
||||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx = types.NotifyCtx(ctx, "cli-renamenode-self", node.Hostname)
|
|
||||||
api.h.nodeNotifier.NotifyByNodeID(ctx, types.UpdateSelf(node.ID), node.ID)
|
|
||||||
|
|
||||||
ctx = types.NotifyCtx(ctx, "cli-renamenode-peers", node.Hostname)
|
|
||||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
|
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("node", node.Hostname).
|
Str("node", node.Hostname).
|
||||||
@ -498,7 +456,7 @@ func (api headscaleV1APIServer) ListNodes(
|
|||||||
// probably be done once.
|
// probably be done once.
|
||||||
// TODO(kradalby): This should be done in one tx.
|
// TODO(kradalby): This should be done in one tx.
|
||||||
|
|
||||||
isLikelyConnected := api.h.nodeNotifier.LikelyConnectedMap()
|
IsConnected := api.h.mapBatcher.ConnectedMap()
|
||||||
if request.GetUser() != "" {
|
if request.GetUser() != "" {
|
||||||
user, err := api.h.state.GetUserByName(request.GetUser())
|
user, err := api.h.state.GetUserByName(request.GetUser())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -510,7 +468,7 @@ func (api headscaleV1APIServer) ListNodes(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
response := nodesToProto(api.h.state, isLikelyConnected, nodes)
|
response := nodesToProto(api.h.state, IsConnected, nodes)
|
||||||
return &v1.ListNodesResponse{Nodes: response}, nil
|
return &v1.ListNodesResponse{Nodes: response}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -523,18 +481,18 @@ func (api headscaleV1APIServer) ListNodes(
|
|||||||
return nodes[i].ID < nodes[j].ID
|
return nodes[i].ID < nodes[j].ID
|
||||||
})
|
})
|
||||||
|
|
||||||
response := nodesToProto(api.h.state, isLikelyConnected, nodes)
|
response := nodesToProto(api.h.state, IsConnected, nodes)
|
||||||
return &v1.ListNodesResponse{Nodes: response}, nil
|
return &v1.ListNodesResponse{Nodes: response}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func nodesToProto(state *state.State, isLikelyConnected *xsync.MapOf[types.NodeID, bool], nodes types.Nodes) []*v1.Node {
|
func nodesToProto(state *state.State, IsConnected *xsync.MapOf[types.NodeID, bool], nodes types.Nodes) []*v1.Node {
|
||||||
response := make([]*v1.Node, len(nodes))
|
response := make([]*v1.Node, len(nodes))
|
||||||
for index, node := range nodes {
|
for index, node := range nodes {
|
||||||
resp := node.Proto()
|
resp := node.Proto()
|
||||||
|
|
||||||
// Populate the online field based on
|
// Populate the online field based on
|
||||||
// currently connected nodes.
|
// currently connected nodes.
|
||||||
if val, ok := isLikelyConnected.Load(node.ID); ok && val {
|
if val, ok := IsConnected.Load(node.ID); ok && val {
|
||||||
resp.Online = true
|
resp.Online = true
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -556,24 +514,14 @@ func (api headscaleV1APIServer) MoveNode(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
request *v1.MoveNodeRequest,
|
request *v1.MoveNodeRequest,
|
||||||
) (*v1.MoveNodeResponse, error) {
|
) (*v1.MoveNodeResponse, error) {
|
||||||
node, policyChanged, err := api.h.state.AssignNodeToUser(types.NodeID(request.GetNodeId()), types.UserID(request.GetUser()))
|
node, nodeChange, err := api.h.state.AssignNodeToUser(types.NodeID(request.GetNodeId()), types.UserID(request.GetUser()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send policy update notifications if needed
|
// TODO(kradalby): Ensure the policy is also sent
|
||||||
if policyChanged {
|
// TODO(kradalby): ensure that both the selfupdate and peer updates are sent
|
||||||
ctx := types.NotifyCtx(context.Background(), "grpc-node-moved", node.Hostname)
|
api.h.Change(nodeChange)
|
||||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx = types.NotifyCtx(ctx, "cli-movenode-self", node.Hostname)
|
|
||||||
api.h.nodeNotifier.NotifyByNodeID(
|
|
||||||
ctx,
|
|
||||||
types.UpdateSelf(node.ID),
|
|
||||||
node.ID)
|
|
||||||
ctx = types.NotifyCtx(ctx, "cli-movenode", node.Hostname)
|
|
||||||
api.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
|
|
||||||
|
|
||||||
return &v1.MoveNodeResponse{Node: node.Proto()}, nil
|
return &v1.MoveNodeResponse{Node: node.Proto()}, nil
|
||||||
}
|
}
|
||||||
@ -754,8 +702,7 @@ func (api headscaleV1APIServer) SetPolicy(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := types.NotifyCtx(context.Background(), "acl-update", "na")
|
api.h.Change(change.PolicyChange())
|
||||||
api.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
response := &v1.SetPolicyResponse{
|
response := &v1.SetPolicyResponse{
|
||||||
|
155
hscontrol/mapper/batcher.go
Normal file
155
hscontrol/mapper/batcher.go
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
package mapper
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/state"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||||
|
"github.com/puzpuzpuz/xsync/v4"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
"tailscale.com/types/ptr"
|
||||||
|
)
|
||||||
|
|
||||||
|
type batcherFunc func(cfg *types.Config, state *state.State) Batcher
|
||||||
|
|
||||||
|
// Batcher defines the common interface for all batcher implementations.
|
||||||
|
type Batcher interface {
|
||||||
|
Start()
|
||||||
|
Close()
|
||||||
|
AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error
|
||||||
|
RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool)
|
||||||
|
IsConnected(id types.NodeID) bool
|
||||||
|
ConnectedMap() *xsync.Map[types.NodeID, bool]
|
||||||
|
AddWork(c change.ChangeSet)
|
||||||
|
MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeBatcher {
|
||||||
|
return &LockFreeBatcher{
|
||||||
|
mapper: mapper,
|
||||||
|
workers: workers,
|
||||||
|
tick: time.NewTicker(batchTime),
|
||||||
|
|
||||||
|
// The size of this channel is arbitrary chosen, the sizing should be revisited.
|
||||||
|
workCh: make(chan work, workers*200),
|
||||||
|
nodes: xsync.NewMap[types.NodeID, *nodeConn](),
|
||||||
|
connected: xsync.NewMap[types.NodeID, *time.Time](),
|
||||||
|
pendingChanges: xsync.NewMap[types.NodeID, []change.ChangeSet](),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBatcherAndMapper creates a Batcher implementation.
|
||||||
|
func NewBatcherAndMapper(cfg *types.Config, state *state.State) Batcher {
|
||||||
|
m := newMapper(cfg, state)
|
||||||
|
b := NewBatcher(cfg.Tuning.BatchChangeDelay, cfg.Tuning.BatcherWorkers, m)
|
||||||
|
m.batcher = b
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// nodeConnection interface for different connection implementations.
|
||||||
|
type nodeConnection interface {
|
||||||
|
nodeID() types.NodeID
|
||||||
|
version() tailcfg.CapabilityVersion
|
||||||
|
send(data *tailcfg.MapResponse) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateMapResponse generates a [tailcfg.MapResponse] for the given NodeID that is based on the provided [change.ChangeSet].
|
||||||
|
func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion, mapper *mapper, c change.ChangeSet) (*tailcfg.MapResponse, error) {
|
||||||
|
if c.Empty() {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate inputs before processing
|
||||||
|
if nodeID == 0 {
|
||||||
|
return nil, fmt.Errorf("invalid nodeID: %d", nodeID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if mapper == nil {
|
||||||
|
return nil, fmt.Errorf("mapper is nil for nodeID %d", nodeID)
|
||||||
|
}
|
||||||
|
|
||||||
|
var mapResp *tailcfg.MapResponse
|
||||||
|
var err error
|
||||||
|
|
||||||
|
switch c.Change {
|
||||||
|
case change.DERP:
|
||||||
|
mapResp, err = mapper.derpMapResponse(nodeID)
|
||||||
|
|
||||||
|
case change.NodeCameOnline, change.NodeWentOffline:
|
||||||
|
if c.IsSubnetRouter {
|
||||||
|
// TODO(kradalby): This can potentially be a peer update of the old and new subnet router.
|
||||||
|
mapResp, err = mapper.fullMapResponse(nodeID, version)
|
||||||
|
} else {
|
||||||
|
mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{
|
||||||
|
{
|
||||||
|
NodeID: c.NodeID.NodeID(),
|
||||||
|
Online: ptr.To(c.Change == change.NodeCameOnline),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
case change.NodeNewOrUpdate:
|
||||||
|
mapResp, err = mapper.fullMapResponse(nodeID, version)
|
||||||
|
|
||||||
|
case change.NodeRemove:
|
||||||
|
mapResp, err = mapper.peerRemovedResponse(nodeID, c.NodeID)
|
||||||
|
|
||||||
|
default:
|
||||||
|
// The following will always hit this:
|
||||||
|
// change.Full, change.Policy
|
||||||
|
mapResp, err = mapper.fullMapResponse(nodeID, version)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("generating map response for nodeID %d: %w", nodeID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(kradalby): Is this necessary?
|
||||||
|
// Validate the generated map response - only check for nil response
|
||||||
|
// Note: mapResp.Node can be nil for peer updates, which is valid
|
||||||
|
if mapResp == nil && c.Change != change.DERP && c.Change != change.NodeRemove {
|
||||||
|
return nil, fmt.Errorf("generated nil map response for nodeID %d change %s", nodeID, c.Change.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
return mapResp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.ChangeSet].
|
||||||
|
func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) error {
|
||||||
|
if nc == nil {
|
||||||
|
return fmt.Errorf("nodeConnection is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
nodeID := nc.nodeID()
|
||||||
|
data, err := generateMapResponse(nodeID, nc.version(), mapper, c)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("generating map response for node %d: %w", nodeID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if data == nil {
|
||||||
|
// No data to send is valid for some change types
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the map response
|
||||||
|
if err := nc.send(data); err != nil {
|
||||||
|
return fmt.Errorf("sending map response to node %d: %w", nodeID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// workResult represents the result of processing a change.
|
||||||
|
type workResult struct {
|
||||||
|
mapResponse *tailcfg.MapResponse
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// work represents a unit of work to be processed by workers.
|
||||||
|
type work struct {
|
||||||
|
c change.ChangeSet
|
||||||
|
nodeID types.NodeID
|
||||||
|
resultCh chan<- workResult // optional channel for synchronous operations
|
||||||
|
}
|
491
hscontrol/mapper/batcher_lockfree.go
Normal file
491
hscontrol/mapper/batcher_lockfree.go
Normal file
@ -0,0 +1,491 @@
|
|||||||
|
package mapper
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||||
|
"github.com/puzpuzpuz/xsync/v4"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
"tailscale.com/types/ptr"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LockFreeBatcher uses atomic operations and concurrent maps to eliminate mutex contention.
|
||||||
|
type LockFreeBatcher struct {
|
||||||
|
tick *time.Ticker
|
||||||
|
mapper *mapper
|
||||||
|
workers int
|
||||||
|
|
||||||
|
// Lock-free concurrent maps
|
||||||
|
nodes *xsync.Map[types.NodeID, *nodeConn]
|
||||||
|
connected *xsync.Map[types.NodeID, *time.Time]
|
||||||
|
|
||||||
|
// Work queue channel
|
||||||
|
workCh chan work
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
|
||||||
|
// Batching state
|
||||||
|
pendingChanges *xsync.Map[types.NodeID, []change.ChangeSet]
|
||||||
|
batchMutex sync.RWMutex
|
||||||
|
|
||||||
|
// Metrics
|
||||||
|
totalNodes atomic.Int64
|
||||||
|
totalUpdates atomic.Int64
|
||||||
|
workQueuedCount atomic.Int64
|
||||||
|
workProcessed atomic.Int64
|
||||||
|
workErrors atomic.Int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddNode registers a new node connection with the batcher and sends an initial map response.
|
||||||
|
// It creates or updates the node's connection data, validates the initial map generation,
|
||||||
|
// and notifies other nodes that this node has come online.
|
||||||
|
// TODO(kradalby): See if we can move the isRouter argument somewhere else.
|
||||||
|
func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool, version tailcfg.CapabilityVersion) error {
|
||||||
|
// First validate that we can generate initial map before doing anything else
|
||||||
|
fullSelfChange := change.FullSelf(id)
|
||||||
|
|
||||||
|
// TODO(kradalby): This should not be generated here, but rather in MapResponseFromChange.
|
||||||
|
// This currently means that the goroutine for the node connection will do the processing
|
||||||
|
// which means that we might have uncontrolled concurrency.
|
||||||
|
// When we use MapResponseFromChange, it will be processed by the same worker pool, causing
|
||||||
|
// it to be processed in a more controlled manner.
|
||||||
|
initialMap, err := generateMapResponse(id, version, b.mapper, fullSelfChange)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to generate initial map for node %d: %w", id, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only after validation succeeds, create or update node connection
|
||||||
|
newConn := newNodeConn(id, c, version, b.mapper)
|
||||||
|
|
||||||
|
var conn *nodeConn
|
||||||
|
if existing, loaded := b.nodes.LoadOrStore(id, newConn); loaded {
|
||||||
|
// Update existing connection
|
||||||
|
existing.updateConnection(c, version)
|
||||||
|
conn = existing
|
||||||
|
} else {
|
||||||
|
b.totalNodes.Add(1)
|
||||||
|
conn = newConn
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark as connected only after validation succeeds
|
||||||
|
b.connected.Store(id, nil) // nil = connected
|
||||||
|
|
||||||
|
log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node connected to batcher")
|
||||||
|
|
||||||
|
// Send the validated initial map
|
||||||
|
if initialMap != nil {
|
||||||
|
if err := conn.send(initialMap); err != nil {
|
||||||
|
// Clean up the connection state on send failure
|
||||||
|
b.nodes.Delete(id)
|
||||||
|
b.connected.Delete(id)
|
||||||
|
return fmt.Errorf("failed to send initial map to node %d: %w", id, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Notify other nodes that this node came online
|
||||||
|
b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeCameOnline, IsSubnetRouter: isRouter})
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveNode disconnects a node from the batcher, marking it as offline and cleaning up its state.
|
||||||
|
// It validates the connection channel matches the current one, closes the connection,
|
||||||
|
// and notifies other nodes that this node has gone offline.
|
||||||
|
func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse, isRouter bool) {
|
||||||
|
// Check if this is the current connection and mark it as closed
|
||||||
|
if existing, ok := b.nodes.Load(id); ok {
|
||||||
|
if !existing.matchesChannel(c) {
|
||||||
|
log.Debug().Uint64("node.id", id.Uint64()).Msg("RemoveNode called for non-current connection, ignoring")
|
||||||
|
return // Not the current connection, not an error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark the connection as closed to prevent further sends
|
||||||
|
if connData := existing.connData.Load(); connData != nil {
|
||||||
|
connData.closed.Store(true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info().Uint64("node.id", id.Uint64()).Bool("isRouter", isRouter).Msg("Node disconnected from batcher, marking as offline")
|
||||||
|
|
||||||
|
// Remove node and mark disconnected atomically
|
||||||
|
b.nodes.Delete(id)
|
||||||
|
b.connected.Store(id, ptr.To(time.Now()))
|
||||||
|
b.totalNodes.Add(-1)
|
||||||
|
|
||||||
|
// Notify other nodes that this node went offline
|
||||||
|
b.addWork(change.ChangeSet{NodeID: id, Change: change.NodeWentOffline, IsSubnetRouter: isRouter})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddWork queues a change to be processed by the batcher.
|
||||||
|
// Critical changes are processed immediately, while others are batched for efficiency.
|
||||||
|
func (b *LockFreeBatcher) AddWork(c change.ChangeSet) {
|
||||||
|
b.addWork(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *LockFreeBatcher) Start() {
|
||||||
|
b.ctx, b.cancel = context.WithCancel(context.Background())
|
||||||
|
go b.doWork()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *LockFreeBatcher) Close() {
|
||||||
|
if b.cancel != nil {
|
||||||
|
b.cancel()
|
||||||
|
}
|
||||||
|
close(b.workCh)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *LockFreeBatcher) doWork() {
|
||||||
|
log.Debug().Msg("batcher doWork loop started")
|
||||||
|
defer log.Debug().Msg("batcher doWork loop stopped")
|
||||||
|
|
||||||
|
for i := range b.workers {
|
||||||
|
go b.worker(i + 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-b.tick.C:
|
||||||
|
// Process batched changes
|
||||||
|
b.processBatchedChanges()
|
||||||
|
case <-b.ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *LockFreeBatcher) worker(workerID int) {
|
||||||
|
log.Debug().Int("workerID", workerID).Msg("batcher worker started")
|
||||||
|
defer log.Debug().Int("workerID", workerID).Msg("batcher worker stopped")
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case w, ok := <-b.workCh:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
startTime := time.Now()
|
||||||
|
b.workProcessed.Add(1)
|
||||||
|
|
||||||
|
// If the resultCh is set, it means that this is a work request
|
||||||
|
// where there is a blocking function waiting for the map that
|
||||||
|
// is being generated.
|
||||||
|
// This is used for synchronous map generation.
|
||||||
|
if w.resultCh != nil {
|
||||||
|
var result workResult
|
||||||
|
if nc, exists := b.nodes.Load(w.nodeID); exists {
|
||||||
|
result.mapResponse, result.err = generateMapResponse(nc.nodeID(), nc.version(), b.mapper, w.c)
|
||||||
|
if result.err != nil {
|
||||||
|
b.workErrors.Add(1)
|
||||||
|
log.Error().Err(result.err).
|
||||||
|
Int("workerID", workerID).
|
||||||
|
Uint64("node.id", w.nodeID.Uint64()).
|
||||||
|
Str("change", w.c.Change.String()).
|
||||||
|
Msg("failed to generate map response for synchronous work")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
result.err = fmt.Errorf("node %d not found", w.nodeID)
|
||||||
|
b.workErrors.Add(1)
|
||||||
|
log.Error().Err(result.err).
|
||||||
|
Int("workerID", workerID).
|
||||||
|
Uint64("node.id", w.nodeID.Uint64()).
|
||||||
|
Msg("node not found for synchronous work")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send result
|
||||||
|
select {
|
||||||
|
case w.resultCh <- result:
|
||||||
|
case <-b.ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
duration := time.Since(startTime)
|
||||||
|
if duration > 100*time.Millisecond {
|
||||||
|
log.Warn().
|
||||||
|
Int("workerID", workerID).
|
||||||
|
Uint64("node.id", w.nodeID.Uint64()).
|
||||||
|
Str("change", w.c.Change.String()).
|
||||||
|
Dur("duration", duration).
|
||||||
|
Msg("slow synchronous work processing")
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// If resultCh is nil, this is an asynchronous work request
|
||||||
|
// that should be processed and sent to the node instead of
|
||||||
|
// returned to the caller.
|
||||||
|
if nc, exists := b.nodes.Load(w.nodeID); exists {
|
||||||
|
// Check if this connection is still active before processing
|
||||||
|
if connData := nc.connData.Load(); connData != nil && connData.closed.Load() {
|
||||||
|
log.Debug().
|
||||||
|
Int("workerID", workerID).
|
||||||
|
Uint64("node.id", w.nodeID.Uint64()).
|
||||||
|
Str("change", w.c.Change.String()).
|
||||||
|
Msg("skipping work for closed connection")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err := nc.change(w.c)
|
||||||
|
if err != nil {
|
||||||
|
b.workErrors.Add(1)
|
||||||
|
log.Error().Err(err).
|
||||||
|
Int("workerID", workerID).
|
||||||
|
Uint64("node.id", w.c.NodeID.Uint64()).
|
||||||
|
Str("change", w.c.Change.String()).
|
||||||
|
Msg("failed to apply change")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Debug().
|
||||||
|
Int("workerID", workerID).
|
||||||
|
Uint64("node.id", w.nodeID.Uint64()).
|
||||||
|
Str("change", w.c.Change.String()).
|
||||||
|
Msg("node not found for asynchronous work - node may have disconnected")
|
||||||
|
}
|
||||||
|
|
||||||
|
duration := time.Since(startTime)
|
||||||
|
if duration > 100*time.Millisecond {
|
||||||
|
log.Warn().
|
||||||
|
Int("workerID", workerID).
|
||||||
|
Uint64("node.id", w.nodeID.Uint64()).
|
||||||
|
Str("change", w.c.Change.String()).
|
||||||
|
Dur("duration", duration).
|
||||||
|
Msg("slow asynchronous work processing")
|
||||||
|
}
|
||||||
|
|
||||||
|
case <-b.ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *LockFreeBatcher) addWork(c change.ChangeSet) {
|
||||||
|
// For critical changes that need immediate processing, send directly
|
||||||
|
if b.shouldProcessImmediately(c) {
|
||||||
|
if c.SelfUpdateOnly {
|
||||||
|
b.queueWork(work{c: c, nodeID: c.NodeID, resultCh: nil})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
b.nodes.Range(func(nodeID types.NodeID, _ *nodeConn) bool {
|
||||||
|
if c.NodeID == nodeID && !c.AlsoSelf() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
b.queueWork(work{c: c, nodeID: nodeID, resultCh: nil})
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// For non-critical changes, add to batch
|
||||||
|
b.addToBatch(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// queueWork safely queues work
|
||||||
|
func (b *LockFreeBatcher) queueWork(w work) {
|
||||||
|
b.workQueuedCount.Add(1)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case b.workCh <- w:
|
||||||
|
// Successfully queued
|
||||||
|
case <-b.ctx.Done():
|
||||||
|
// Batcher is shutting down
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// shouldProcessImmediately determines if a change should bypass batching
|
||||||
|
func (b *LockFreeBatcher) shouldProcessImmediately(c change.ChangeSet) bool {
|
||||||
|
// Process these changes immediately to avoid delaying critical functionality
|
||||||
|
switch c.Change {
|
||||||
|
case change.Full, change.NodeRemove, change.NodeCameOnline, change.NodeWentOffline, change.Policy:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// addToBatch adds a change to the pending batch
|
||||||
|
func (b *LockFreeBatcher) addToBatch(c change.ChangeSet) {
|
||||||
|
b.batchMutex.Lock()
|
||||||
|
defer b.batchMutex.Unlock()
|
||||||
|
|
||||||
|
if c.SelfUpdateOnly {
|
||||||
|
changes, _ := b.pendingChanges.LoadOrStore(c.NodeID, []change.ChangeSet{})
|
||||||
|
changes = append(changes, c)
|
||||||
|
b.pendingChanges.Store(c.NodeID, changes)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
b.nodes.Range(func(nodeID types.NodeID, _ *nodeConn) bool {
|
||||||
|
if c.NodeID == nodeID && !c.AlsoSelf() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.ChangeSet{})
|
||||||
|
changes = append(changes, c)
|
||||||
|
b.pendingChanges.Store(nodeID, changes)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// processBatchedChanges processes all pending batched changes
|
||||||
|
func (b *LockFreeBatcher) processBatchedChanges() {
|
||||||
|
b.batchMutex.Lock()
|
||||||
|
defer b.batchMutex.Unlock()
|
||||||
|
|
||||||
|
if b.pendingChanges == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process all pending changes
|
||||||
|
b.pendingChanges.Range(func(nodeID types.NodeID, changes []change.ChangeSet) bool {
|
||||||
|
if len(changes) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send all batched changes for this node
|
||||||
|
for _, c := range changes {
|
||||||
|
b.queueWork(work{c: c, nodeID: nodeID, resultCh: nil})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear the pending changes for this node
|
||||||
|
b.pendingChanges.Delete(nodeID)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsConnected is lock-free read.
|
||||||
|
func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool {
|
||||||
|
if val, ok := b.connected.Load(id); ok {
|
||||||
|
// nil means connected
|
||||||
|
return val == nil
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectedMap returns a lock-free map of all connected nodes.
|
||||||
|
func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
|
||||||
|
ret := xsync.NewMap[types.NodeID, bool]()
|
||||||
|
|
||||||
|
b.connected.Range(func(id types.NodeID, val *time.Time) bool {
|
||||||
|
// nil means connected
|
||||||
|
ret.Store(id, val == nil)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
|
// MapResponseFromChange queues work to generate a map response and waits for the result.
|
||||||
|
// This allows synchronous map generation using the same worker pool.
|
||||||
|
func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error) {
|
||||||
|
resultCh := make(chan workResult, 1)
|
||||||
|
|
||||||
|
// Queue the work with a result channel using the safe queueing method
|
||||||
|
b.queueWork(work{c: c, nodeID: id, resultCh: resultCh})
|
||||||
|
|
||||||
|
// Wait for the result
|
||||||
|
select {
|
||||||
|
case result := <-resultCh:
|
||||||
|
return result.mapResponse, result.err
|
||||||
|
case <-b.ctx.Done():
|
||||||
|
return nil, fmt.Errorf("batcher shutting down while generating map response for node %d", id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// connectionData holds the channel and connection parameters.
|
||||||
|
type connectionData struct {
|
||||||
|
c chan<- *tailcfg.MapResponse
|
||||||
|
version tailcfg.CapabilityVersion
|
||||||
|
closed atomic.Bool // Track if this connection has been closed
|
||||||
|
}
|
||||||
|
|
||||||
|
// nodeConn described the node connection and its associated data.
|
||||||
|
type nodeConn struct {
|
||||||
|
id types.NodeID
|
||||||
|
mapper *mapper
|
||||||
|
|
||||||
|
// Atomic pointer to connection data - allows lock-free updates
|
||||||
|
connData atomic.Pointer[connectionData]
|
||||||
|
|
||||||
|
updateCount atomic.Int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func newNodeConn(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion, mapper *mapper) *nodeConn {
|
||||||
|
nc := &nodeConn{
|
||||||
|
id: id,
|
||||||
|
mapper: mapper,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize connection data
|
||||||
|
data := &connectionData{
|
||||||
|
c: c,
|
||||||
|
version: version,
|
||||||
|
}
|
||||||
|
nc.connData.Store(data)
|
||||||
|
|
||||||
|
return nc
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateConnection atomically updates connection parameters.
|
||||||
|
func (nc *nodeConn) updateConnection(c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) {
|
||||||
|
newData := &connectionData{
|
||||||
|
c: c,
|
||||||
|
version: version,
|
||||||
|
}
|
||||||
|
nc.connData.Store(newData)
|
||||||
|
}
|
||||||
|
|
||||||
|
// matchesChannel checks if the given channel matches current connection.
|
||||||
|
func (nc *nodeConn) matchesChannel(c chan<- *tailcfg.MapResponse) bool {
|
||||||
|
data := nc.connData.Load()
|
||||||
|
if data == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Compare channel pointers directly
|
||||||
|
return data.c == c
|
||||||
|
}
|
||||||
|
|
||||||
|
// compressAndVersion atomically reads connection settings.
|
||||||
|
func (nc *nodeConn) version() tailcfg.CapabilityVersion {
|
||||||
|
data := nc.connData.Load()
|
||||||
|
if data == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return data.version
|
||||||
|
}
|
||||||
|
|
||||||
|
func (nc *nodeConn) nodeID() types.NodeID {
|
||||||
|
return nc.id
|
||||||
|
}
|
||||||
|
|
||||||
|
func (nc *nodeConn) change(c change.ChangeSet) error {
|
||||||
|
return handleNodeChange(nc, nc.mapper, c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// send sends data to the node's channel.
|
||||||
|
// The node will pick it up and send it to the HTTP handler.
|
||||||
|
func (nc *nodeConn) send(data *tailcfg.MapResponse) error {
|
||||||
|
connData := nc.connData.Load()
|
||||||
|
if connData == nil {
|
||||||
|
return fmt.Errorf("node %d: no connection data", nc.id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if connection has been closed
|
||||||
|
if connData.closed.Load() {
|
||||||
|
return fmt.Errorf("node %d: connection closed", nc.id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(kradalby): We might need some sort of timeout here if the client is not reading
|
||||||
|
// the channel. That might mean that we are sending to a node that has gone offline, but
|
||||||
|
// the channel is still open.
|
||||||
|
connData.c <- data
|
||||||
|
nc.updateCount.Add(1)
|
||||||
|
return nil
|
||||||
|
}
|
1977
hscontrol/mapper/batcher_test.go
Normal file
1977
hscontrol/mapper/batcher_test.go
Normal file
File diff suppressed because it is too large
Load Diff
259
hscontrol/mapper/builder.go
Normal file
259
hscontrol/mapper/builder.go
Normal file
@ -0,0 +1,259 @@
|
|||||||
|
package mapper
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"sort"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/policy"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
"tailscale.com/types/views"
|
||||||
|
"tailscale.com/util/multierr"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MapResponseBuilder provides a fluent interface for building tailcfg.MapResponse
|
||||||
|
type MapResponseBuilder struct {
|
||||||
|
resp *tailcfg.MapResponse
|
||||||
|
mapper *mapper
|
||||||
|
nodeID types.NodeID
|
||||||
|
capVer tailcfg.CapabilityVersion
|
||||||
|
errs []error
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMapResponseBuilder creates a new builder with basic fields set
|
||||||
|
func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder {
|
||||||
|
now := time.Now()
|
||||||
|
return &MapResponseBuilder{
|
||||||
|
resp: &tailcfg.MapResponse{
|
||||||
|
KeepAlive: false,
|
||||||
|
ControlTime: &now,
|
||||||
|
},
|
||||||
|
mapper: m,
|
||||||
|
nodeID: nodeID,
|
||||||
|
errs: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// addError adds an error to the builder's error list
|
||||||
|
func (b *MapResponseBuilder) addError(err error) {
|
||||||
|
if err != nil {
|
||||||
|
b.errs = append(b.errs, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// hasErrors returns true if the builder has accumulated any errors
|
||||||
|
func (b *MapResponseBuilder) hasErrors() bool {
|
||||||
|
return len(b.errs) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithCapabilityVersion sets the capability version for the response
|
||||||
|
func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVersion) *MapResponseBuilder {
|
||||||
|
b.capVer = capVer
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithSelfNode adds the requesting node to the response
|
||||||
|
func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
|
||||||
|
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||||
|
if err != nil {
|
||||||
|
b.addError(err)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
_, matchers := b.mapper.state.Filter()
|
||||||
|
tailnode, err := tailNode(
|
||||||
|
node.View(), b.capVer, b.mapper.state,
|
||||||
|
func(id types.NodeID) []netip.Prefix {
|
||||||
|
return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers)
|
||||||
|
},
|
||||||
|
b.mapper.cfg)
|
||||||
|
if err != nil {
|
||||||
|
b.addError(err)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
b.resp.Node = tailnode
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDERPMap adds the DERP map to the response
|
||||||
|
func (b *MapResponseBuilder) WithDERPMap() *MapResponseBuilder {
|
||||||
|
b.resp.DERPMap = b.mapper.state.DERPMap()
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDomain adds the domain configuration
|
||||||
|
func (b *MapResponseBuilder) WithDomain() *MapResponseBuilder {
|
||||||
|
b.resp.Domain = b.mapper.cfg.Domain()
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithCollectServicesDisabled sets the collect services flag to false
|
||||||
|
func (b *MapResponseBuilder) WithCollectServicesDisabled() *MapResponseBuilder {
|
||||||
|
b.resp.CollectServices.Set(false)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDebugConfig adds debug configuration
|
||||||
|
// It disables log tailing if the mapper's LogTail is not enabled
|
||||||
|
func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
|
||||||
|
b.resp.Debug = &tailcfg.Debug{
|
||||||
|
DisableLogTail: !b.mapper.cfg.LogTail.Enabled,
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithSSHPolicy adds SSH policy configuration for the requesting node
|
||||||
|
func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
|
||||||
|
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||||
|
if err != nil {
|
||||||
|
b.addError(err)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
sshPolicy, err := b.mapper.state.SSHPolicy(node.View())
|
||||||
|
if err != nil {
|
||||||
|
b.addError(err)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
b.resp.SSHPolicy = sshPolicy
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDNSConfig adds DNS configuration for the requesting node
|
||||||
|
func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
|
||||||
|
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||||
|
if err != nil {
|
||||||
|
b.addError(err)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
b.resp.DNSConfig = generateDNSConfig(b.mapper.cfg, node)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithUserProfiles adds user profiles for the requesting node and given peers
|
||||||
|
func (b *MapResponseBuilder) WithUserProfiles(peers types.Nodes) *MapResponseBuilder {
|
||||||
|
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||||
|
if err != nil {
|
||||||
|
b.addError(err)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
b.resp.UserProfiles = generateUserProfiles(node, peers)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithPacketFilters adds packet filter rules based on policy
|
||||||
|
func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
|
||||||
|
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||||
|
if err != nil {
|
||||||
|
b.addError(err)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
filter, _ := b.mapper.state.Filter()
|
||||||
|
|
||||||
|
// CapVer 81: 2023-11-17: MapResponse.PacketFilters (incremental packet filter updates)
|
||||||
|
// Currently, we do not send incremental package filters, however using the
|
||||||
|
// new PacketFilters field and "base" allows us to send a full update when we
|
||||||
|
// have to send an empty list, avoiding the hack in the else block.
|
||||||
|
b.resp.PacketFilters = map[string][]tailcfg.FilterRule{
|
||||||
|
"base": policy.ReduceFilterRules(node.View(), filter),
|
||||||
|
}
|
||||||
|
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithPeers adds full peer list with policy filtering (for full map response)
|
||||||
|
func (b *MapResponseBuilder) WithPeers(peers types.Nodes) *MapResponseBuilder {
|
||||||
|
|
||||||
|
tailPeers, err := b.buildTailPeers(peers)
|
||||||
|
if err != nil {
|
||||||
|
b.addError(err)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
b.resp.Peers = tailPeers
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithPeerChanges adds changed peers with policy filtering (for incremental updates)
|
||||||
|
func (b *MapResponseBuilder) WithPeerChanges(peers types.Nodes) *MapResponseBuilder {
|
||||||
|
|
||||||
|
tailPeers, err := b.buildTailPeers(peers)
|
||||||
|
if err != nil {
|
||||||
|
b.addError(err)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
b.resp.PeersChanged = tailPeers
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildTailPeers converts types.Nodes to []tailcfg.Node with policy filtering and sorting
|
||||||
|
func (b *MapResponseBuilder) buildTailPeers(peers types.Nodes) ([]*tailcfg.Node, error) {
|
||||||
|
node, err := b.mapper.state.GetNodeByID(b.nodeID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
filter, matchers := b.mapper.state.Filter()
|
||||||
|
|
||||||
|
// If there are filter rules present, see if there are any nodes that cannot
|
||||||
|
// access each-other at all and remove them from the peers.
|
||||||
|
var changedViews views.Slice[types.NodeView]
|
||||||
|
if len(filter) > 0 {
|
||||||
|
changedViews = policy.ReduceNodes(node.View(), peers.ViewSlice(), matchers)
|
||||||
|
} else {
|
||||||
|
changedViews = peers.ViewSlice()
|
||||||
|
}
|
||||||
|
|
||||||
|
tailPeers, err := tailNodes(
|
||||||
|
changedViews, b.capVer, b.mapper.state,
|
||||||
|
func(id types.NodeID) []netip.Prefix {
|
||||||
|
return policy.ReduceRoutes(node.View(), b.mapper.state.GetNodePrimaryRoutes(id), matchers)
|
||||||
|
},
|
||||||
|
b.mapper.cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Peers is always returned sorted by Node.ID.
|
||||||
|
sort.SliceStable(tailPeers, func(x, y int) bool {
|
||||||
|
return tailPeers[x].ID < tailPeers[y].ID
|
||||||
|
})
|
||||||
|
|
||||||
|
return tailPeers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithPeerChangedPatch adds peer change patches
|
||||||
|
func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange) *MapResponseBuilder {
|
||||||
|
b.resp.PeersChangedPatch = changes
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithPeersRemoved adds removed peer IDs
|
||||||
|
func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder {
|
||||||
|
|
||||||
|
var tailscaleIDs []tailcfg.NodeID
|
||||||
|
for _, id := range removedIDs {
|
||||||
|
tailscaleIDs = append(tailscaleIDs, id.NodeID())
|
||||||
|
}
|
||||||
|
b.resp.PeersRemoved = tailscaleIDs
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build finalizes the response and returns marshaled bytes
|
||||||
|
func (b *MapResponseBuilder) Build(messages ...string) (*tailcfg.MapResponse, error) {
|
||||||
|
if len(b.errs) > 0 {
|
||||||
|
return nil, multierr.New(b.errs...)
|
||||||
|
}
|
||||||
|
if debugDumpMapResponsePath != "" {
|
||||||
|
writeDebugMapResponse(b.resp, b.nodeID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.resp, nil
|
||||||
|
}
|
347
hscontrol/mapper/builder_test.go
Normal file
347
hscontrol/mapper/builder_test.go
Normal file
@ -0,0 +1,347 @@
|
|||||||
|
package mapper
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/state"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMapResponseBuilder_Basic(t *testing.T) {
|
||||||
|
cfg := &types.Config{
|
||||||
|
BaseDomain: "example.com",
|
||||||
|
LogTail: types.LogTailConfig{
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mockState := &state.State{}
|
||||||
|
m := &mapper{
|
||||||
|
cfg: cfg,
|
||||||
|
state: mockState,
|
||||||
|
}
|
||||||
|
|
||||||
|
nodeID := types.NodeID(1)
|
||||||
|
|
||||||
|
builder := m.NewMapResponseBuilder(nodeID)
|
||||||
|
|
||||||
|
// Test basic builder creation
|
||||||
|
assert.NotNil(t, builder)
|
||||||
|
assert.Equal(t, nodeID, builder.nodeID)
|
||||||
|
assert.NotNil(t, builder.resp)
|
||||||
|
assert.False(t, builder.resp.KeepAlive)
|
||||||
|
assert.NotNil(t, builder.resp.ControlTime)
|
||||||
|
assert.WithinDuration(t, time.Now(), *builder.resp.ControlTime, time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapResponseBuilder_WithCapabilityVersion(t *testing.T) {
|
||||||
|
cfg := &types.Config{}
|
||||||
|
mockState := &state.State{}
|
||||||
|
m := &mapper{
|
||||||
|
cfg: cfg,
|
||||||
|
state: mockState,
|
||||||
|
}
|
||||||
|
|
||||||
|
nodeID := types.NodeID(1)
|
||||||
|
capVer := tailcfg.CapabilityVersion(42)
|
||||||
|
|
||||||
|
builder := m.NewMapResponseBuilder(nodeID).
|
||||||
|
WithCapabilityVersion(capVer)
|
||||||
|
|
||||||
|
assert.Equal(t, capVer, builder.capVer)
|
||||||
|
assert.False(t, builder.hasErrors())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapResponseBuilder_WithDomain(t *testing.T) {
|
||||||
|
domain := "test.example.com"
|
||||||
|
cfg := &types.Config{
|
||||||
|
ServerURL: "https://test.example.com",
|
||||||
|
BaseDomain: domain,
|
||||||
|
}
|
||||||
|
|
||||||
|
mockState := &state.State{}
|
||||||
|
m := &mapper{
|
||||||
|
cfg: cfg,
|
||||||
|
state: mockState,
|
||||||
|
}
|
||||||
|
|
||||||
|
nodeID := types.NodeID(1)
|
||||||
|
|
||||||
|
builder := m.NewMapResponseBuilder(nodeID).
|
||||||
|
WithDomain()
|
||||||
|
|
||||||
|
assert.Equal(t, domain, builder.resp.Domain)
|
||||||
|
assert.False(t, builder.hasErrors())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapResponseBuilder_WithCollectServicesDisabled(t *testing.T) {
|
||||||
|
cfg := &types.Config{}
|
||||||
|
mockState := &state.State{}
|
||||||
|
m := &mapper{
|
||||||
|
cfg: cfg,
|
||||||
|
state: mockState,
|
||||||
|
}
|
||||||
|
|
||||||
|
nodeID := types.NodeID(1)
|
||||||
|
|
||||||
|
builder := m.NewMapResponseBuilder(nodeID).
|
||||||
|
WithCollectServicesDisabled()
|
||||||
|
|
||||||
|
value, isSet := builder.resp.CollectServices.Get()
|
||||||
|
assert.True(t, isSet)
|
||||||
|
assert.False(t, value)
|
||||||
|
assert.False(t, builder.hasErrors())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapResponseBuilder_WithDebugConfig(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
logTailEnabled bool
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "LogTail enabled",
|
||||||
|
logTailEnabled: true,
|
||||||
|
expected: false, // DisableLogTail should be false when LogTail is enabled
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "LogTail disabled",
|
||||||
|
logTailEnabled: false,
|
||||||
|
expected: true, // DisableLogTail should be true when LogTail is disabled
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
cfg := &types.Config{
|
||||||
|
LogTail: types.LogTailConfig{
|
||||||
|
Enabled: tt.logTailEnabled,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
mockState := &state.State{}
|
||||||
|
m := &mapper{
|
||||||
|
cfg: cfg,
|
||||||
|
state: mockState,
|
||||||
|
}
|
||||||
|
|
||||||
|
nodeID := types.NodeID(1)
|
||||||
|
|
||||||
|
builder := m.NewMapResponseBuilder(nodeID).
|
||||||
|
WithDebugConfig()
|
||||||
|
|
||||||
|
require.NotNil(t, builder.resp.Debug)
|
||||||
|
assert.Equal(t, tt.expected, builder.resp.Debug.DisableLogTail)
|
||||||
|
assert.False(t, builder.hasErrors())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapResponseBuilder_WithPeerChangedPatch(t *testing.T) {
|
||||||
|
cfg := &types.Config{}
|
||||||
|
mockState := &state.State{}
|
||||||
|
m := &mapper{
|
||||||
|
cfg: cfg,
|
||||||
|
state: mockState,
|
||||||
|
}
|
||||||
|
|
||||||
|
nodeID := types.NodeID(1)
|
||||||
|
changes := []*tailcfg.PeerChange{
|
||||||
|
{
|
||||||
|
NodeID: 123,
|
||||||
|
DERPRegion: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
NodeID: 456,
|
||||||
|
DERPRegion: 2,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
builder := m.NewMapResponseBuilder(nodeID).
|
||||||
|
WithPeerChangedPatch(changes)
|
||||||
|
|
||||||
|
assert.Equal(t, changes, builder.resp.PeersChangedPatch)
|
||||||
|
assert.False(t, builder.hasErrors())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapResponseBuilder_WithPeersRemoved(t *testing.T) {
|
||||||
|
cfg := &types.Config{}
|
||||||
|
mockState := &state.State{}
|
||||||
|
m := &mapper{
|
||||||
|
cfg: cfg,
|
||||||
|
state: mockState,
|
||||||
|
}
|
||||||
|
|
||||||
|
nodeID := types.NodeID(1)
|
||||||
|
removedID1 := types.NodeID(123)
|
||||||
|
removedID2 := types.NodeID(456)
|
||||||
|
|
||||||
|
builder := m.NewMapResponseBuilder(nodeID).
|
||||||
|
WithPeersRemoved(removedID1, removedID2)
|
||||||
|
|
||||||
|
expected := []tailcfg.NodeID{
|
||||||
|
removedID1.NodeID(),
|
||||||
|
removedID2.NodeID(),
|
||||||
|
}
|
||||||
|
assert.Equal(t, expected, builder.resp.PeersRemoved)
|
||||||
|
assert.False(t, builder.hasErrors())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapResponseBuilder_ErrorHandling(t *testing.T) {
|
||||||
|
cfg := &types.Config{}
|
||||||
|
mockState := &state.State{}
|
||||||
|
m := &mapper{
|
||||||
|
cfg: cfg,
|
||||||
|
state: mockState,
|
||||||
|
}
|
||||||
|
|
||||||
|
nodeID := types.NodeID(1)
|
||||||
|
|
||||||
|
// Simulate an error in the builder
|
||||||
|
builder := m.NewMapResponseBuilder(nodeID)
|
||||||
|
builder.addError(assert.AnError)
|
||||||
|
|
||||||
|
// All subsequent calls should continue to work and accumulate errors
|
||||||
|
result := builder.
|
||||||
|
WithDomain().
|
||||||
|
WithCollectServicesDisabled().
|
||||||
|
WithDebugConfig()
|
||||||
|
|
||||||
|
assert.True(t, result.hasErrors())
|
||||||
|
assert.Len(t, result.errs, 1)
|
||||||
|
assert.Equal(t, assert.AnError, result.errs[0])
|
||||||
|
|
||||||
|
// Build should return the error
|
||||||
|
data, err := result.Build("none")
|
||||||
|
assert.Nil(t, data)
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapResponseBuilder_ChainedCalls(t *testing.T) {
|
||||||
|
domain := "chained.example.com"
|
||||||
|
cfg := &types.Config{
|
||||||
|
ServerURL: "https://chained.example.com",
|
||||||
|
BaseDomain: domain,
|
||||||
|
LogTail: types.LogTailConfig{
|
||||||
|
Enabled: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mockState := &state.State{}
|
||||||
|
m := &mapper{
|
||||||
|
cfg: cfg,
|
||||||
|
state: mockState,
|
||||||
|
}
|
||||||
|
|
||||||
|
nodeID := types.NodeID(1)
|
||||||
|
capVer := tailcfg.CapabilityVersion(99)
|
||||||
|
|
||||||
|
builder := m.NewMapResponseBuilder(nodeID).
|
||||||
|
WithCapabilityVersion(capVer).
|
||||||
|
WithDomain().
|
||||||
|
WithCollectServicesDisabled().
|
||||||
|
WithDebugConfig()
|
||||||
|
|
||||||
|
// Verify all fields are set correctly
|
||||||
|
assert.Equal(t, capVer, builder.capVer)
|
||||||
|
assert.Equal(t, domain, builder.resp.Domain)
|
||||||
|
value, isSet := builder.resp.CollectServices.Get()
|
||||||
|
assert.True(t, isSet)
|
||||||
|
assert.False(t, value)
|
||||||
|
assert.NotNil(t, builder.resp.Debug)
|
||||||
|
assert.True(t, builder.resp.Debug.DisableLogTail)
|
||||||
|
assert.False(t, builder.hasErrors())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapResponseBuilder_MultipleWithPeersRemoved(t *testing.T) {
|
||||||
|
cfg := &types.Config{}
|
||||||
|
mockState := &state.State{}
|
||||||
|
m := &mapper{
|
||||||
|
cfg: cfg,
|
||||||
|
state: mockState,
|
||||||
|
}
|
||||||
|
|
||||||
|
nodeID := types.NodeID(1)
|
||||||
|
removedID1 := types.NodeID(100)
|
||||||
|
removedID2 := types.NodeID(200)
|
||||||
|
|
||||||
|
// Test calling WithPeersRemoved multiple times
|
||||||
|
builder := m.NewMapResponseBuilder(nodeID).
|
||||||
|
WithPeersRemoved(removedID1).
|
||||||
|
WithPeersRemoved(removedID2)
|
||||||
|
|
||||||
|
// Second call should overwrite the first
|
||||||
|
expected := []tailcfg.NodeID{removedID2.NodeID()}
|
||||||
|
assert.Equal(t, expected, builder.resp.PeersRemoved)
|
||||||
|
assert.False(t, builder.hasErrors())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapResponseBuilder_EmptyPeerChangedPatch(t *testing.T) {
|
||||||
|
cfg := &types.Config{}
|
||||||
|
mockState := &state.State{}
|
||||||
|
m := &mapper{
|
||||||
|
cfg: cfg,
|
||||||
|
state: mockState,
|
||||||
|
}
|
||||||
|
|
||||||
|
nodeID := types.NodeID(1)
|
||||||
|
|
||||||
|
builder := m.NewMapResponseBuilder(nodeID).
|
||||||
|
WithPeerChangedPatch([]*tailcfg.PeerChange{})
|
||||||
|
|
||||||
|
assert.Empty(t, builder.resp.PeersChangedPatch)
|
||||||
|
assert.False(t, builder.hasErrors())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapResponseBuilder_NilPeerChangedPatch(t *testing.T) {
|
||||||
|
cfg := &types.Config{}
|
||||||
|
mockState := &state.State{}
|
||||||
|
m := &mapper{
|
||||||
|
cfg: cfg,
|
||||||
|
state: mockState,
|
||||||
|
}
|
||||||
|
|
||||||
|
nodeID := types.NodeID(1)
|
||||||
|
|
||||||
|
builder := m.NewMapResponseBuilder(nodeID).
|
||||||
|
WithPeerChangedPatch(nil)
|
||||||
|
|
||||||
|
assert.Nil(t, builder.resp.PeersChangedPatch)
|
||||||
|
assert.False(t, builder.hasErrors())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMapResponseBuilder_MultipleErrors(t *testing.T) {
|
||||||
|
cfg := &types.Config{}
|
||||||
|
mockState := &state.State{}
|
||||||
|
m := &mapper{
|
||||||
|
cfg: cfg,
|
||||||
|
state: mockState,
|
||||||
|
}
|
||||||
|
|
||||||
|
nodeID := types.NodeID(1)
|
||||||
|
|
||||||
|
// Create a builder and add multiple errors
|
||||||
|
builder := m.NewMapResponseBuilder(nodeID)
|
||||||
|
builder.addError(assert.AnError)
|
||||||
|
builder.addError(assert.AnError)
|
||||||
|
builder.addError(nil) // This should be ignored
|
||||||
|
|
||||||
|
// All subsequent calls should continue to work
|
||||||
|
result := builder.
|
||||||
|
WithDomain().
|
||||||
|
WithCollectServicesDisabled()
|
||||||
|
|
||||||
|
assert.True(t, result.hasErrors())
|
||||||
|
assert.Len(t, result.errs, 2) // nil error should be ignored
|
||||||
|
|
||||||
|
// Build should return a multierr
|
||||||
|
data, err := result.Build("none")
|
||||||
|
assert.Nil(t, data)
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
// The error should contain information about multiple errors
|
||||||
|
assert.Contains(t, err.Error(), "multiple errors")
|
||||||
|
}
|
@ -1,7 +1,6 @@
|
|||||||
package mapper
|
package mapper
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
@ -10,31 +9,21 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"slices"
|
"slices"
|
||||||
"sort"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
|
||||||
"github.com/juanfont/headscale/hscontrol/policy"
|
|
||||||
"github.com/juanfont/headscale/hscontrol/state"
|
"github.com/juanfont/headscale/hscontrol/state"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
|
||||||
"github.com/klauspost/compress/zstd"
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"tailscale.com/envknob"
|
"tailscale.com/envknob"
|
||||||
"tailscale.com/smallzstd"
|
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/dnstype"
|
"tailscale.com/types/dnstype"
|
||||||
"tailscale.com/types/views"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
nextDNSDoHPrefix = "https://dns.nextdns.io"
|
nextDNSDoHPrefix = "https://dns.nextdns.io"
|
||||||
reservedResponseHeaderSize = 4
|
mapperIDLength = 8
|
||||||
mapperIDLength = 8
|
debugMapResponsePerm = 0o755
|
||||||
debugMapResponsePerm = 0o755
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH")
|
var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH")
|
||||||
@ -50,15 +39,13 @@ var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_
|
|||||||
// - Create a "minifier" that removes info not needed for the node
|
// - Create a "minifier" that removes info not needed for the node
|
||||||
// - some sort of batching, wait for 5 or 60 seconds before sending
|
// - some sort of batching, wait for 5 or 60 seconds before sending
|
||||||
|
|
||||||
type Mapper struct {
|
type mapper struct {
|
||||||
// Configuration
|
// Configuration
|
||||||
state *state.State
|
state *state.State
|
||||||
cfg *types.Config
|
cfg *types.Config
|
||||||
notif *notifier.Notifier
|
batcher Batcher
|
||||||
|
|
||||||
uid string
|
|
||||||
created time.Time
|
created time.Time
|
||||||
seq uint64
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type patch struct {
|
type patch struct {
|
||||||
@ -66,41 +53,31 @@ type patch struct {
|
|||||||
change *tailcfg.PeerChange
|
change *tailcfg.PeerChange
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMapper(
|
func newMapper(
|
||||||
state *state.State,
|
|
||||||
cfg *types.Config,
|
cfg *types.Config,
|
||||||
notif *notifier.Notifier,
|
state *state.State,
|
||||||
) *Mapper {
|
) *mapper {
|
||||||
uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
|
// uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
|
||||||
|
|
||||||
return &Mapper{
|
return &mapper{
|
||||||
state: state,
|
state: state,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
notif: notif,
|
|
||||||
|
|
||||||
uid: uid,
|
|
||||||
created: time.Now(),
|
created: time.Now(),
|
||||||
seq: 0,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Mapper) String() string {
|
|
||||||
return fmt.Sprintf("Mapper: { seq: %d, uid: %s, created: %s }", m.seq, m.uid, m.created)
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateUserProfiles(
|
func generateUserProfiles(
|
||||||
node types.NodeView,
|
node *types.Node,
|
||||||
peers views.Slice[types.NodeView],
|
peers types.Nodes,
|
||||||
) []tailcfg.UserProfile {
|
) []tailcfg.UserProfile {
|
||||||
userMap := make(map[uint]*types.User)
|
userMap := make(map[uint]*types.User)
|
||||||
ids := make([]uint, 0, peers.Len()+1)
|
ids := make([]uint, 0, len(userMap))
|
||||||
user := node.User()
|
userMap[node.User.ID] = &node.User
|
||||||
userMap[user.ID] = &user
|
ids = append(ids, node.User.ID)
|
||||||
ids = append(ids, user.ID)
|
for _, peer := range peers {
|
||||||
for _, peer := range peers.All() {
|
userMap[peer.User.ID] = &peer.User
|
||||||
peerUser := peer.User()
|
ids = append(ids, peer.User.ID)
|
||||||
userMap[peerUser.ID] = &peerUser
|
|
||||||
ids = append(ids, peerUser.ID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
slices.Sort(ids)
|
slices.Sort(ids)
|
||||||
@ -117,7 +94,7 @@ func generateUserProfiles(
|
|||||||
|
|
||||||
func generateDNSConfig(
|
func generateDNSConfig(
|
||||||
cfg *types.Config,
|
cfg *types.Config,
|
||||||
node types.NodeView,
|
node *types.Node,
|
||||||
) *tailcfg.DNSConfig {
|
) *tailcfg.DNSConfig {
|
||||||
if cfg.TailcfgDNSConfig == nil {
|
if cfg.TailcfgDNSConfig == nil {
|
||||||
return nil
|
return nil
|
||||||
@ -137,17 +114,16 @@ func generateDNSConfig(
|
|||||||
//
|
//
|
||||||
// This will produce a resolver like:
|
// This will produce a resolver like:
|
||||||
// `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1`
|
// `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1`
|
||||||
func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
|
func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
|
||||||
for _, resolver := range resolvers {
|
for _, resolver := range resolvers {
|
||||||
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {
|
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {
|
||||||
attrs := url.Values{
|
attrs := url.Values{
|
||||||
"device_name": []string{node.Hostname()},
|
"device_name": []string{node.Hostname},
|
||||||
"device_model": []string{node.Hostinfo().OS()},
|
"device_model": []string{node.Hostinfo.OS},
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeIPs := node.IPs()
|
if len(node.IPs()) > 0 {
|
||||||
if len(nodeIPs) > 0 {
|
attrs.Add("device_ip", node.IPs()[0].String())
|
||||||
attrs.Add("device_ip", nodeIPs[0].String())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
resolver.Addr = fmt.Sprintf("%s?%s", resolver.Addr, attrs.Encode())
|
resolver.Addr = fmt.Sprintf("%s?%s", resolver.Addr, attrs.Encode())
|
||||||
@ -155,434 +131,151 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// fullMapResponse creates a complete MapResponse for a node.
|
// fullMapResponse returns a MapResponse for the given node.
|
||||||
// It is a separate function to make testing easier.
|
func (m *mapper) fullMapResponse(
|
||||||
func (m *Mapper) fullMapResponse(
|
nodeID types.NodeID,
|
||||||
node types.NodeView,
|
|
||||||
peers views.Slice[types.NodeView],
|
|
||||||
capVer tailcfg.CapabilityVersion,
|
capVer tailcfg.CapabilityVersion,
|
||||||
|
messages ...string,
|
||||||
) (*tailcfg.MapResponse, error) {
|
) (*tailcfg.MapResponse, error) {
|
||||||
resp, err := m.baseWithConfigMapResponse(node, capVer)
|
peers, err := m.listPeers(nodeID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = appendPeerChanges(
|
return m.NewMapResponseBuilder(nodeID).
|
||||||
resp,
|
WithCapabilityVersion(capVer).
|
||||||
true, // full change
|
WithSelfNode().
|
||||||
m.state,
|
WithDERPMap().
|
||||||
node,
|
WithDomain().
|
||||||
capVer,
|
WithCollectServicesDisabled().
|
||||||
peers,
|
WithDebugConfig().
|
||||||
m.cfg,
|
WithSSHPolicy().
|
||||||
)
|
WithDNSConfig().
|
||||||
if err != nil {
|
WithUserProfiles(peers).
|
||||||
return nil, err
|
WithPacketFilters().
|
||||||
}
|
WithPeers(peers).
|
||||||
|
Build(messages...)
|
||||||
return resp, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// FullMapResponse returns a MapResponse for the given node.
|
func (m *mapper) derpMapResponse(
|
||||||
func (m *Mapper) FullMapResponse(
|
nodeID types.NodeID,
|
||||||
mapRequest tailcfg.MapRequest,
|
) (*tailcfg.MapResponse, error) {
|
||||||
node types.NodeView,
|
return m.NewMapResponseBuilder(nodeID).
|
||||||
messages ...string,
|
WithDERPMap().
|
||||||
) ([]byte, error) {
|
Build()
|
||||||
peers, err := m.ListPeers(node.ID())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := m.fullMapResponse(node, peers.ViewSlice(), mapRequest.Version)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadOnlyMapResponse returns a MapResponse for the given node.
|
|
||||||
// Lite means that the peers has been omitted, this is intended
|
|
||||||
// to be used to answer MapRequests with OmitPeers set to true.
|
|
||||||
func (m *Mapper) ReadOnlyMapResponse(
|
|
||||||
mapRequest tailcfg.MapRequest,
|
|
||||||
node types.NodeView,
|
|
||||||
messages ...string,
|
|
||||||
) ([]byte, error) {
|
|
||||||
resp, err := m.baseWithConfigMapResponse(node, mapRequest.Version)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress, messages...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Mapper) KeepAliveResponse(
|
|
||||||
mapRequest tailcfg.MapRequest,
|
|
||||||
node types.NodeView,
|
|
||||||
) ([]byte, error) {
|
|
||||||
resp := m.baseMapResponse()
|
|
||||||
resp.KeepAlive = true
|
|
||||||
|
|
||||||
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Mapper) DERPMapResponse(
|
|
||||||
mapRequest tailcfg.MapRequest,
|
|
||||||
node types.NodeView,
|
|
||||||
derpMap *tailcfg.DERPMap,
|
|
||||||
) ([]byte, error) {
|
|
||||||
resp := m.baseMapResponse()
|
|
||||||
resp.DERPMap = derpMap
|
|
||||||
|
|
||||||
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Mapper) PeerChangedResponse(
|
|
||||||
mapRequest tailcfg.MapRequest,
|
|
||||||
node types.NodeView,
|
|
||||||
changed map[types.NodeID]bool,
|
|
||||||
patches []*tailcfg.PeerChange,
|
|
||||||
messages ...string,
|
|
||||||
) ([]byte, error) {
|
|
||||||
var err error
|
|
||||||
resp := m.baseMapResponse()
|
|
||||||
|
|
||||||
var removedIDs []tailcfg.NodeID
|
|
||||||
var changedIDs []types.NodeID
|
|
||||||
for nodeID, nodeChanged := range changed {
|
|
||||||
if nodeChanged {
|
|
||||||
if nodeID != node.ID() {
|
|
||||||
changedIDs = append(changedIDs, nodeID)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
removedIDs = append(removedIDs, nodeID.NodeID())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
changedNodes := types.Nodes{}
|
|
||||||
if len(changedIDs) > 0 {
|
|
||||||
changedNodes, err = m.ListNodes(changedIDs...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = appendPeerChanges(
|
|
||||||
&resp,
|
|
||||||
false, // partial change
|
|
||||||
m.state,
|
|
||||||
node,
|
|
||||||
mapRequest.Version,
|
|
||||||
changedNodes.ViewSlice(),
|
|
||||||
m.cfg,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
resp.PeersRemoved = removedIDs
|
|
||||||
|
|
||||||
// Sending patches as a part of a PeersChanged response
|
|
||||||
// is technically not suppose to be done, but they are
|
|
||||||
// applied after the PeersChanged. The patch list
|
|
||||||
// should _only_ contain Nodes that are not in the
|
|
||||||
// PeersChanged or PeersRemoved list and the caller
|
|
||||||
// should filter them out.
|
|
||||||
//
|
|
||||||
// From tailcfg docs:
|
|
||||||
// These are applied after Peers* above, but in practice the
|
|
||||||
// control server should only send these on their own, without
|
|
||||||
// the Peers* fields also set.
|
|
||||||
if patches != nil {
|
|
||||||
resp.PeersChangedPatch = patches
|
|
||||||
}
|
|
||||||
|
|
||||||
_, matchers := m.state.Filter()
|
|
||||||
// Add the node itself, it might have changed, and particularly
|
|
||||||
// if there are no patches or changes, this is a self update.
|
|
||||||
tailnode, err := tailNode(
|
|
||||||
node, mapRequest.Version, m.state,
|
|
||||||
func(id types.NodeID) []netip.Prefix {
|
|
||||||
return policy.ReduceRoutes(node, m.state.GetNodePrimaryRoutes(id), matchers)
|
|
||||||
},
|
|
||||||
m.cfg)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
resp.Node = tailnode
|
|
||||||
|
|
||||||
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress, messages...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// PeerChangedPatchResponse creates a patch MapResponse with
|
// PeerChangedPatchResponse creates a patch MapResponse with
|
||||||
// incoming update from a state change.
|
// incoming update from a state change.
|
||||||
func (m *Mapper) PeerChangedPatchResponse(
|
func (m *mapper) peerChangedPatchResponse(
|
||||||
mapRequest tailcfg.MapRequest,
|
nodeID types.NodeID,
|
||||||
node types.NodeView,
|
|
||||||
changed []*tailcfg.PeerChange,
|
changed []*tailcfg.PeerChange,
|
||||||
) ([]byte, error) {
|
|
||||||
resp := m.baseMapResponse()
|
|
||||||
resp.PeersChangedPatch = changed
|
|
||||||
|
|
||||||
return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Mapper) marshalMapResponse(
|
|
||||||
mapRequest tailcfg.MapRequest,
|
|
||||||
resp *tailcfg.MapResponse,
|
|
||||||
node types.NodeView,
|
|
||||||
compression string,
|
|
||||||
messages ...string,
|
|
||||||
) ([]byte, error) {
|
|
||||||
atomic.AddUint64(&m.seq, 1)
|
|
||||||
|
|
||||||
jsonBody, err := json.Marshal(resp)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("marshalling map response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if debugDumpMapResponsePath != "" {
|
|
||||||
data := map[string]any{
|
|
||||||
"Messages": messages,
|
|
||||||
"MapRequest": mapRequest,
|
|
||||||
"MapResponse": resp,
|
|
||||||
}
|
|
||||||
|
|
||||||
responseType := "keepalive"
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case resp.Peers != nil && len(resp.Peers) > 0:
|
|
||||||
responseType = "full"
|
|
||||||
case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil && resp.DERPMap == nil && !resp.KeepAlive:
|
|
||||||
responseType = "self"
|
|
||||||
case resp.PeersChanged != nil && len(resp.PeersChanged) > 0:
|
|
||||||
responseType = "changed"
|
|
||||||
case resp.PeersChangedPatch != nil && len(resp.PeersChangedPatch) > 0:
|
|
||||||
responseType = "patch"
|
|
||||||
case resp.PeersRemoved != nil && len(resp.PeersRemoved) > 0:
|
|
||||||
responseType = "removed"
|
|
||||||
}
|
|
||||||
|
|
||||||
body, err := json.MarshalIndent(data, "", " ")
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("marshalling map response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
perms := fs.FileMode(debugMapResponsePerm)
|
|
||||||
mPath := path.Join(debugDumpMapResponsePath, node.Hostname())
|
|
||||||
err = os.MkdirAll(mPath, perms)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
now := time.Now().Format("2006-01-02T15-04-05.999999999")
|
|
||||||
|
|
||||||
mapResponsePath := path.Join(
|
|
||||||
mPath,
|
|
||||||
fmt.Sprintf("%s-%s-%d-%s.json", now, m.uid, atomic.LoadUint64(&m.seq), responseType),
|
|
||||||
)
|
|
||||||
|
|
||||||
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
|
|
||||||
err = os.WriteFile(mapResponsePath, body, perms)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var respBody []byte
|
|
||||||
if compression == util.ZstdCompression {
|
|
||||||
respBody = zstdEncode(jsonBody)
|
|
||||||
} else {
|
|
||||||
respBody = jsonBody
|
|
||||||
}
|
|
||||||
|
|
||||||
data := make([]byte, reservedResponseHeaderSize)
|
|
||||||
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
|
|
||||||
data = append(data, respBody...)
|
|
||||||
|
|
||||||
return data, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func zstdEncode(in []byte) []byte {
|
|
||||||
encoder, ok := zstdEncoderPool.Get().(*zstd.Encoder)
|
|
||||||
if !ok {
|
|
||||||
panic("invalid type in sync pool")
|
|
||||||
}
|
|
||||||
out := encoder.EncodeAll(in, nil)
|
|
||||||
_ = encoder.Close()
|
|
||||||
zstdEncoderPool.Put(encoder)
|
|
||||||
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
var zstdEncoderPool = &sync.Pool{
|
|
||||||
New: func() any {
|
|
||||||
encoder, err := smallzstd.NewEncoder(
|
|
||||||
nil,
|
|
||||||
zstd.WithEncoderLevel(zstd.SpeedFastest))
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return encoder
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// baseMapResponse returns a tailcfg.MapResponse with
|
|
||||||
// KeepAlive false and ControlTime set to now.
|
|
||||||
func (m *Mapper) baseMapResponse() tailcfg.MapResponse {
|
|
||||||
now := time.Now()
|
|
||||||
|
|
||||||
resp := tailcfg.MapResponse{
|
|
||||||
KeepAlive: false,
|
|
||||||
ControlTime: &now,
|
|
||||||
// TODO(kradalby): Implement PingRequest?
|
|
||||||
}
|
|
||||||
|
|
||||||
return resp
|
|
||||||
}
|
|
||||||
|
|
||||||
// baseWithConfigMapResponse returns a tailcfg.MapResponse struct
|
|
||||||
// with the basic configuration from headscale set.
|
|
||||||
// It is used in for bigger updates, such as full and lite, not
|
|
||||||
// incremental.
|
|
||||||
func (m *Mapper) baseWithConfigMapResponse(
|
|
||||||
node types.NodeView,
|
|
||||||
capVer tailcfg.CapabilityVersion,
|
|
||||||
) (*tailcfg.MapResponse, error) {
|
) (*tailcfg.MapResponse, error) {
|
||||||
resp := m.baseMapResponse()
|
return m.NewMapResponseBuilder(nodeID).
|
||||||
|
WithPeerChangedPatch(changed).
|
||||||
|
Build()
|
||||||
|
}
|
||||||
|
|
||||||
_, matchers := m.state.Filter()
|
// peerChangeResponse returns a MapResponse with changed or added nodes.
|
||||||
tailnode, err := tailNode(
|
func (m *mapper) peerChangeResponse(
|
||||||
node, capVer, m.state,
|
nodeID types.NodeID,
|
||||||
func(id types.NodeID) []netip.Prefix {
|
capVer tailcfg.CapabilityVersion,
|
||||||
return policy.ReduceRoutes(node, m.state.GetNodePrimaryRoutes(id), matchers)
|
changedNodeID types.NodeID,
|
||||||
},
|
) (*tailcfg.MapResponse, error) {
|
||||||
m.cfg)
|
peers, err := m.listPeers(nodeID, changedNodeID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
resp.Node = tailnode
|
|
||||||
|
|
||||||
resp.DERPMap = m.state.DERPMap()
|
return m.NewMapResponseBuilder(nodeID).
|
||||||
|
WithCapabilityVersion(capVer).
|
||||||
resp.Domain = m.cfg.Domain()
|
WithSelfNode().
|
||||||
|
WithUserProfiles(peers).
|
||||||
// Do not instruct clients to collect services we do not
|
WithPeerChanges(peers).
|
||||||
// support or do anything with them
|
Build()
|
||||||
resp.CollectServices = "false"
|
|
||||||
|
|
||||||
resp.KeepAlive = false
|
|
||||||
|
|
||||||
resp.Debug = &tailcfg.Debug{
|
|
||||||
DisableLogTail: !m.cfg.LogTail.Enabled,
|
|
||||||
}
|
|
||||||
|
|
||||||
return &resp, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListPeers returns peers of node, regardless of any Policy or if the node is expired.
|
// peerRemovedResponse creates a MapResponse indicating that a peer has been removed.
|
||||||
|
func (m *mapper) peerRemovedResponse(
|
||||||
|
nodeID types.NodeID,
|
||||||
|
removedNodeID types.NodeID,
|
||||||
|
) (*tailcfg.MapResponse, error) {
|
||||||
|
return m.NewMapResponseBuilder(nodeID).
|
||||||
|
WithPeersRemoved(removedNodeID).
|
||||||
|
Build()
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeDebugMapResponse(
|
||||||
|
resp *tailcfg.MapResponse,
|
||||||
|
nodeID types.NodeID,
|
||||||
|
messages ...string,
|
||||||
|
) {
|
||||||
|
data := map[string]any{
|
||||||
|
"Messages": messages,
|
||||||
|
"MapResponse": resp,
|
||||||
|
}
|
||||||
|
|
||||||
|
responseType := "keepalive"
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case len(resp.Peers) > 0:
|
||||||
|
responseType = "full"
|
||||||
|
case resp.Peers == nil && resp.PeersChanged == nil && resp.PeersChangedPatch == nil && resp.DERPMap == nil && !resp.KeepAlive:
|
||||||
|
responseType = "self"
|
||||||
|
case len(resp.PeersChanged) > 0:
|
||||||
|
responseType = "changed"
|
||||||
|
case len(resp.PeersChangedPatch) > 0:
|
||||||
|
responseType = "patch"
|
||||||
|
case len(resp.PeersRemoved) > 0:
|
||||||
|
responseType = "removed"
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := json.MarshalIndent(data, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
perms := fs.FileMode(debugMapResponsePerm)
|
||||||
|
mPath := path.Join(debugDumpMapResponsePath, nodeID.String())
|
||||||
|
err = os.MkdirAll(mPath, perms)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now().Format("2006-01-02T15-04-05.999999999")
|
||||||
|
|
||||||
|
mapResponsePath := path.Join(
|
||||||
|
mPath,
|
||||||
|
fmt.Sprintf("%s-%s.json", now, responseType),
|
||||||
|
)
|
||||||
|
|
||||||
|
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
|
||||||
|
err = os.WriteFile(mapResponsePath, body, perms)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// listPeers returns peers of node, regardless of any Policy or if the node is expired.
|
||||||
// If no peer IDs are given, all peers are returned.
|
// If no peer IDs are given, all peers are returned.
|
||||||
// If at least one peer ID is given, only these peer nodes will be returned.
|
// If at least one peer ID is given, only these peer nodes will be returned.
|
||||||
func (m *Mapper) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
|
func (m *mapper) listPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
|
||||||
peers, err := m.state.ListPeers(nodeID, peerIDs...)
|
peers, err := m.state.ListPeers(nodeID, peerIDs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(kradalby): Add back online via batcher. This was removed
|
||||||
|
// to avoid a circular dependency between the mapper and the notification.
|
||||||
for _, peer := range peers {
|
for _, peer := range peers {
|
||||||
online := m.notif.IsLikelyConnected(peer.ID)
|
online := m.batcher.IsConnected(peer.ID)
|
||||||
peer.IsOnline = &online
|
peer.IsOnline = &online
|
||||||
}
|
}
|
||||||
|
|
||||||
return peers, nil
|
return peers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListNodes queries the database for either all nodes if no parameters are given
|
|
||||||
// or for the given nodes if at least one node ID is given as parameter.
|
|
||||||
func (m *Mapper) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
|
|
||||||
nodes, err := m.state.ListNodes(nodeIDs...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, node := range nodes {
|
|
||||||
online := m.notif.IsLikelyConnected(node.ID)
|
|
||||||
node.IsOnline = &online
|
|
||||||
}
|
|
||||||
|
|
||||||
return nodes, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// routeFilterFunc is a function that takes a node ID and returns a list of
|
// routeFilterFunc is a function that takes a node ID and returns a list of
|
||||||
// netip.Prefixes that are allowed for that node. It is used to filter routes
|
// netip.Prefixes that are allowed for that node. It is used to filter routes
|
||||||
// from the primary route manager to the node.
|
// from the primary route manager to the node.
|
||||||
type routeFilterFunc func(id types.NodeID) []netip.Prefix
|
type routeFilterFunc func(id types.NodeID) []netip.Prefix
|
||||||
|
|
||||||
// appendPeerChanges mutates a tailcfg.MapResponse with all the
|
|
||||||
// necessary changes when peers have changed.
|
|
||||||
func appendPeerChanges(
|
|
||||||
resp *tailcfg.MapResponse,
|
|
||||||
|
|
||||||
fullChange bool,
|
|
||||||
state *state.State,
|
|
||||||
node types.NodeView,
|
|
||||||
capVer tailcfg.CapabilityVersion,
|
|
||||||
changed views.Slice[types.NodeView],
|
|
||||||
cfg *types.Config,
|
|
||||||
) error {
|
|
||||||
filter, matchers := state.Filter()
|
|
||||||
|
|
||||||
sshPolicy, err := state.SSHPolicy(node)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// If there are filter rules present, see if there are any nodes that cannot
|
|
||||||
// access each-other at all and remove them from the peers.
|
|
||||||
var reducedChanged views.Slice[types.NodeView]
|
|
||||||
if len(filter) > 0 {
|
|
||||||
reducedChanged = policy.ReduceNodes(node, changed, matchers)
|
|
||||||
} else {
|
|
||||||
reducedChanged = changed
|
|
||||||
}
|
|
||||||
|
|
||||||
profiles := generateUserProfiles(node, reducedChanged)
|
|
||||||
|
|
||||||
dnsConfig := generateDNSConfig(cfg, node)
|
|
||||||
|
|
||||||
tailPeers, err := tailNodes(
|
|
||||||
reducedChanged, capVer, state,
|
|
||||||
func(id types.NodeID) []netip.Prefix {
|
|
||||||
return policy.ReduceRoutes(node, state.GetNodePrimaryRoutes(id), matchers)
|
|
||||||
},
|
|
||||||
cfg)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Peers is always returned sorted by Node.ID.
|
|
||||||
sort.SliceStable(tailPeers, func(x, y int) bool {
|
|
||||||
return tailPeers[x].ID < tailPeers[y].ID
|
|
||||||
})
|
|
||||||
|
|
||||||
if fullChange {
|
|
||||||
resp.Peers = tailPeers
|
|
||||||
} else {
|
|
||||||
resp.PeersChanged = tailPeers
|
|
||||||
}
|
|
||||||
resp.DNSConfig = dnsConfig
|
|
||||||
resp.UserProfiles = profiles
|
|
||||||
resp.SSHPolicy = sshPolicy
|
|
||||||
|
|
||||||
// CapVer 81: 2023-11-17: MapResponse.PacketFilters (incremental packet filter updates)
|
|
||||||
// Currently, we do not send incremental package filters, however using the
|
|
||||||
// new PacketFilters field and "base" allows us to send a full update when we
|
|
||||||
// have to send an empty list, avoiding the hack in the else block.
|
|
||||||
resp.PacketFilters = map[string][]tailcfg.FilterRule{
|
|
||||||
"base": policy.ReduceFilterRules(node, filter),
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
@ -3,6 +3,7 @@ package mapper
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
@ -70,7 +71,7 @@ func TestDNSConfigMapResponse(t *testing.T) {
|
|||||||
&types.Config{
|
&types.Config{
|
||||||
TailcfgDNSConfig: &dnsConfigOrig,
|
TailcfgDNSConfig: &dnsConfigOrig,
|
||||||
},
|
},
|
||||||
nodeInShared1.View(),
|
nodeInShared1,
|
||||||
)
|
)
|
||||||
|
|
||||||
if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" {
|
if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" {
|
||||||
@ -126,11 +127,8 @@ func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (typ
|
|||||||
// Filter peers by the provided IDs
|
// Filter peers by the provided IDs
|
||||||
var filtered types.Nodes
|
var filtered types.Nodes
|
||||||
for _, peer := range m.peers {
|
for _, peer := range m.peers {
|
||||||
for _, id := range peerIDs {
|
if slices.Contains(peerIDs, peer.ID) {
|
||||||
if peer.ID == id {
|
filtered = append(filtered, peer)
|
||||||
filtered = append(filtered, peer)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -152,11 +150,8 @@ func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
|
|||||||
// Filter nodes by the provided IDs
|
// Filter nodes by the provided IDs
|
||||||
var filtered types.Nodes
|
var filtered types.Nodes
|
||||||
for _, node := range m.nodes {
|
for _, node := range m.nodes {
|
||||||
for _, id := range nodeIDs {
|
if slices.Contains(nodeIDs, node.ID) {
|
||||||
if node.ID == id {
|
filtered = append(filtered, node)
|
||||||
filtered = append(filtered, node)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
47
hscontrol/mapper/utils.go
Normal file
47
hscontrol/mapper/utils.go
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
package mapper
|
||||||
|
|
||||||
|
import "tailscale.com/tailcfg"
|
||||||
|
|
||||||
|
// mergePatch takes the current patch and a newer patch
|
||||||
|
// and override any field that has changed.
|
||||||
|
func mergePatch(currPatch, newPatch *tailcfg.PeerChange) {
|
||||||
|
if newPatch.DERPRegion != 0 {
|
||||||
|
currPatch.DERPRegion = newPatch.DERPRegion
|
||||||
|
}
|
||||||
|
|
||||||
|
if newPatch.Cap != 0 {
|
||||||
|
currPatch.Cap = newPatch.Cap
|
||||||
|
}
|
||||||
|
|
||||||
|
if newPatch.CapMap != nil {
|
||||||
|
currPatch.CapMap = newPatch.CapMap
|
||||||
|
}
|
||||||
|
|
||||||
|
if newPatch.Endpoints != nil {
|
||||||
|
currPatch.Endpoints = newPatch.Endpoints
|
||||||
|
}
|
||||||
|
|
||||||
|
if newPatch.Key != nil {
|
||||||
|
currPatch.Key = newPatch.Key
|
||||||
|
}
|
||||||
|
|
||||||
|
if newPatch.KeySignature != nil {
|
||||||
|
currPatch.KeySignature = newPatch.KeySignature
|
||||||
|
}
|
||||||
|
|
||||||
|
if newPatch.DiscoKey != nil {
|
||||||
|
currPatch.DiscoKey = newPatch.DiscoKey
|
||||||
|
}
|
||||||
|
|
||||||
|
if newPatch.Online != nil {
|
||||||
|
currPatch.Online = newPatch.Online
|
||||||
|
}
|
||||||
|
|
||||||
|
if newPatch.LastSeen != nil {
|
||||||
|
currPatch.LastSeen = newPatch.LastSeen
|
||||||
|
}
|
||||||
|
|
||||||
|
if newPatch.KeyExpiry != nil {
|
||||||
|
currPatch.KeyExpiry = newPatch.KeyExpiry
|
||||||
|
}
|
||||||
|
}
|
@ -221,7 +221,7 @@ func (ns *noiseServer) NoisePollNetMapHandler(
|
|||||||
|
|
||||||
ns.nodeKey = nv.NodeKey()
|
ns.nodeKey = nv.NodeKey()
|
||||||
|
|
||||||
sess := ns.headscale.newMapSession(req.Context(), mapRequest, writer, nv)
|
sess := ns.headscale.newMapSession(req.Context(), mapRequest, writer, nv.AsStruct())
|
||||||
sess.tracef("a node sending a MapRequest with Noise protocol")
|
sess.tracef("a node sending a MapRequest with Noise protocol")
|
||||||
if !sess.isStreaming() {
|
if !sess.isStreaming() {
|
||||||
sess.serve()
|
sess.serve()
|
||||||
@ -279,28 +279,33 @@ func (ns *noiseServer) NoiseRegistrationHandler(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
respBody, err := json.Marshal(registerResponse)
|
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||||
if err != nil {
|
writer.WriteHeader(http.StatusOK)
|
||||||
httpError(writer, err)
|
|
||||||
|
if err := json.NewEncoder(writer).Encode(registerResponse); err != nil {
|
||||||
|
log.Error().Err(err).Msg("NoiseRegistrationHandler: failed to encode RegisterResponse")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
// Ensure response is flushed to client
|
||||||
writer.WriteHeader(http.StatusOK)
|
if flusher, ok := writer.(http.Flusher); ok {
|
||||||
writer.Write(respBody)
|
flusher.Flush()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// getAndValidateNode retrieves the node from the database using the NodeKey
|
// getAndValidateNode retrieves the node from the database using the NodeKey
|
||||||
// and validates that it matches the MachineKey from the Noise session.
|
// and validates that it matches the MachineKey from the Noise session.
|
||||||
func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types.NodeView, error) {
|
func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types.NodeView, error) {
|
||||||
nv, err := ns.headscale.state.GetNodeViewByNodeKey(mapRequest.NodeKey)
|
node, err := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node not found", nil)
|
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node not found", nil)
|
||||||
}
|
}
|
||||||
return types.NodeView{}, err
|
return types.NodeView{}, NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("lookup node: %s", err), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
nv := node.View()
|
||||||
|
|
||||||
// Validate that the MachineKey in the Noise session matches the one associated with the NodeKey.
|
// Validate that the MachineKey in the Noise session matches the one associated with the NodeKey.
|
||||||
if ns.machineKey != nv.MachineKey() {
|
if ns.machineKey != nv.MachineKey() {
|
||||||
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node key in request does not match the one associated with this machine key", nil)
|
return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node key in request does not match the one associated with this machine key", nil)
|
||||||
|
@ -1,68 +0,0 @@
|
|||||||
package notifier
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
|
||||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
|
||||||
"tailscale.com/envknob"
|
|
||||||
)
|
|
||||||
|
|
||||||
const prometheusNamespace = "headscale"
|
|
||||||
|
|
||||||
var debugHighCardinalityMetrics = envknob.Bool("HEADSCALE_DEBUG_HIGH_CARDINALITY_METRICS")
|
|
||||||
|
|
||||||
var notifierUpdateSent *prometheus.CounterVec
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
if debugHighCardinalityMetrics {
|
|
||||||
notifierUpdateSent = promauto.NewCounterVec(prometheus.CounterOpts{
|
|
||||||
Namespace: prometheusNamespace,
|
|
||||||
Name: "notifier_update_sent_total",
|
|
||||||
Help: "total count of update sent on nodes channel",
|
|
||||||
}, []string{"status", "type", "trigger", "id"})
|
|
||||||
} else {
|
|
||||||
notifierUpdateSent = promauto.NewCounterVec(prometheus.CounterOpts{
|
|
||||||
Namespace: prometheusNamespace,
|
|
||||||
Name: "notifier_update_sent_total",
|
|
||||||
Help: "total count of update sent on nodes channel",
|
|
||||||
}, []string{"status", "type", "trigger"})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
notifierWaitersForLock = promauto.NewGaugeVec(prometheus.GaugeOpts{
|
|
||||||
Namespace: prometheusNamespace,
|
|
||||||
Name: "notifier_waiters_for_lock",
|
|
||||||
Help: "gauge of waiters for the notifier lock",
|
|
||||||
}, []string{"type", "action"})
|
|
||||||
notifierWaitForLock = promauto.NewHistogramVec(prometheus.HistogramOpts{
|
|
||||||
Namespace: prometheusNamespace,
|
|
||||||
Name: "notifier_wait_for_lock_seconds",
|
|
||||||
Help: "histogram of time spent waiting for the notifier lock",
|
|
||||||
Buckets: []float64{0.001, 0.01, 0.1, 0.3, 0.5, 1, 3, 5, 10},
|
|
||||||
}, []string{"action"})
|
|
||||||
notifierUpdateReceived = promauto.NewCounterVec(prometheus.CounterOpts{
|
|
||||||
Namespace: prometheusNamespace,
|
|
||||||
Name: "notifier_update_received_total",
|
|
||||||
Help: "total count of updates received by notifier",
|
|
||||||
}, []string{"type", "trigger"})
|
|
||||||
notifierNodeUpdateChans = promauto.NewGauge(prometheus.GaugeOpts{
|
|
||||||
Namespace: prometheusNamespace,
|
|
||||||
Name: "notifier_open_channels_total",
|
|
||||||
Help: "total count open channels in notifier",
|
|
||||||
})
|
|
||||||
notifierBatcherWaitersForLock = promauto.NewGaugeVec(prometheus.GaugeOpts{
|
|
||||||
Namespace: prometheusNamespace,
|
|
||||||
Name: "notifier_batcher_waiters_for_lock",
|
|
||||||
Help: "gauge of waiters for the notifier batcher lock",
|
|
||||||
}, []string{"type", "action"})
|
|
||||||
notifierBatcherChanges = promauto.NewGaugeVec(prometheus.GaugeOpts{
|
|
||||||
Namespace: prometheusNamespace,
|
|
||||||
Name: "notifier_batcher_changes_pending",
|
|
||||||
Help: "gauge of full changes pending in the notifier batcher",
|
|
||||||
}, []string{})
|
|
||||||
notifierBatcherPatches = promauto.NewGaugeVec(prometheus.GaugeOpts{
|
|
||||||
Namespace: prometheusNamespace,
|
|
||||||
Name: "notifier_batcher_patches_pending",
|
|
||||||
Help: "gauge of patches pending in the notifier batcher",
|
|
||||||
}, []string{})
|
|
||||||
)
|
|
@ -1,488 +0,0 @@
|
|||||||
package notifier
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"sort"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
|
||||||
"github.com/puzpuzpuz/xsync/v4"
|
|
||||||
"github.com/rs/zerolog/log"
|
|
||||||
"github.com/sasha-s/go-deadlock"
|
|
||||||
"tailscale.com/envknob"
|
|
||||||
"tailscale.com/tailcfg"
|
|
||||||
"tailscale.com/util/set"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
debugDeadlock = envknob.Bool("HEADSCALE_DEBUG_DEADLOCK")
|
|
||||||
debugDeadlockTimeout = envknob.RegisterDuration("HEADSCALE_DEBUG_DEADLOCK_TIMEOUT")
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
deadlock.Opts.Disable = !debugDeadlock
|
|
||||||
if debugDeadlock {
|
|
||||||
deadlock.Opts.DeadlockTimeout = debugDeadlockTimeout()
|
|
||||||
deadlock.Opts.PrintAllCurrentGoroutines = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type Notifier struct {
|
|
||||||
l deadlock.Mutex
|
|
||||||
nodes map[types.NodeID]chan<- types.StateUpdate
|
|
||||||
connected *xsync.MapOf[types.NodeID, bool]
|
|
||||||
b *batcher
|
|
||||||
cfg *types.Config
|
|
||||||
closed bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewNotifier(cfg *types.Config) *Notifier {
|
|
||||||
n := &Notifier{
|
|
||||||
nodes: make(map[types.NodeID]chan<- types.StateUpdate),
|
|
||||||
connected: xsync.NewMapOf[types.NodeID, bool](),
|
|
||||||
cfg: cfg,
|
|
||||||
closed: false,
|
|
||||||
}
|
|
||||||
b := newBatcher(cfg.Tuning.BatchChangeDelay, n)
|
|
||||||
n.b = b
|
|
||||||
|
|
||||||
go b.doWork()
|
|
||||||
|
|
||||||
return n
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close stops the batcher and closes all channels.
|
|
||||||
func (n *Notifier) Close() {
|
|
||||||
notifierWaitersForLock.WithLabelValues("lock", "close").Inc()
|
|
||||||
n.l.Lock()
|
|
||||||
defer n.l.Unlock()
|
|
||||||
notifierWaitersForLock.WithLabelValues("lock", "close").Dec()
|
|
||||||
|
|
||||||
n.closed = true
|
|
||||||
n.b.close()
|
|
||||||
|
|
||||||
// Close channels safely using the helper method
|
|
||||||
for nodeID, c := range n.nodes {
|
|
||||||
n.safeCloseChannel(nodeID, c)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear node map after closing channels
|
|
||||||
n.nodes = make(map[types.NodeID]chan<- types.StateUpdate)
|
|
||||||
}
|
|
||||||
|
|
||||||
// safeCloseChannel closes a channel and panic recovers if already closed.
|
|
||||||
func (n *Notifier) safeCloseChannel(nodeID types.NodeID, c chan<- types.StateUpdate) {
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
log.Error().
|
|
||||||
Uint64("node.id", nodeID.Uint64()).
|
|
||||||
Any("recover", r).
|
|
||||||
Msg("recovered from panic when closing channel in Close()")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
close(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *Notifier) tracef(nID types.NodeID, msg string, args ...any) {
|
|
||||||
log.Trace().
|
|
||||||
Uint64("node.id", nID.Uint64()).
|
|
||||||
Int("open_chans", len(n.nodes)).Msgf(msg, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) {
|
|
||||||
start := time.Now()
|
|
||||||
notifierWaitersForLock.WithLabelValues("lock", "add").Inc()
|
|
||||||
n.l.Lock()
|
|
||||||
defer n.l.Unlock()
|
|
||||||
notifierWaitersForLock.WithLabelValues("lock", "add").Dec()
|
|
||||||
notifierWaitForLock.WithLabelValues("add").Observe(time.Since(start).Seconds())
|
|
||||||
|
|
||||||
if n.closed {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// If a channel exists, it means the node has opened a new
|
|
||||||
// connection. Close the old channel and replace it.
|
|
||||||
if curr, ok := n.nodes[nodeID]; ok {
|
|
||||||
n.tracef(nodeID, "channel present, closing and replacing")
|
|
||||||
// Use the safeCloseChannel helper in a goroutine to avoid deadlocks
|
|
||||||
// if/when someone is waiting to send on this channel
|
|
||||||
go func(ch chan<- types.StateUpdate) {
|
|
||||||
n.safeCloseChannel(nodeID, ch)
|
|
||||||
}(curr)
|
|
||||||
}
|
|
||||||
|
|
||||||
n.nodes[nodeID] = c
|
|
||||||
n.connected.Store(nodeID, true)
|
|
||||||
|
|
||||||
n.tracef(nodeID, "added new channel")
|
|
||||||
notifierNodeUpdateChans.Inc()
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveNode removes a node and a given channel from the notifier.
|
|
||||||
// It checks that the channel is the same as currently being updated
|
|
||||||
// and ignores the removal if it is not.
|
|
||||||
// RemoveNode reports if the node/chan was removed.
|
|
||||||
func (n *Notifier) RemoveNode(nodeID types.NodeID, c chan<- types.StateUpdate) bool {
|
|
||||||
start := time.Now()
|
|
||||||
notifierWaitersForLock.WithLabelValues("lock", "remove").Inc()
|
|
||||||
n.l.Lock()
|
|
||||||
defer n.l.Unlock()
|
|
||||||
notifierWaitersForLock.WithLabelValues("lock", "remove").Dec()
|
|
||||||
notifierWaitForLock.WithLabelValues("remove").Observe(time.Since(start).Seconds())
|
|
||||||
|
|
||||||
if n.closed {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(n.nodes) == 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the channel exist, but it does not belong
|
|
||||||
// to the caller, ignore.
|
|
||||||
if curr, ok := n.nodes[nodeID]; ok {
|
|
||||||
if curr != c {
|
|
||||||
n.tracef(nodeID, "channel has been replaced, not removing")
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
delete(n.nodes, nodeID)
|
|
||||||
n.connected.Store(nodeID, false)
|
|
||||||
|
|
||||||
n.tracef(nodeID, "removed channel")
|
|
||||||
notifierNodeUpdateChans.Dec()
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsConnected reports if a node is connected to headscale and has a
|
|
||||||
// poll session open.
|
|
||||||
func (n *Notifier) IsConnected(nodeID types.NodeID) bool {
|
|
||||||
notifierWaitersForLock.WithLabelValues("lock", "conncheck").Inc()
|
|
||||||
n.l.Lock()
|
|
||||||
defer n.l.Unlock()
|
|
||||||
notifierWaitersForLock.WithLabelValues("lock", "conncheck").Dec()
|
|
||||||
|
|
||||||
if val, ok := n.connected.Load(nodeID); ok {
|
|
||||||
return val
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsLikelyConnected reports if a node is connected to headscale and has a
|
|
||||||
// poll session open, but doesn't lock, so might be wrong.
|
|
||||||
func (n *Notifier) IsLikelyConnected(nodeID types.NodeID) bool {
|
|
||||||
if val, ok := n.connected.Load(nodeID); ok {
|
|
||||||
return val
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// LikelyConnectedMap returns a thread safe map of connected nodes.
|
|
||||||
func (n *Notifier) LikelyConnectedMap() *xsync.MapOf[types.NodeID, bool] {
|
|
||||||
return n.connected
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *Notifier) NotifyAll(ctx context.Context, update types.StateUpdate) {
|
|
||||||
n.NotifyWithIgnore(ctx, update)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *Notifier) NotifyWithIgnore(
|
|
||||||
ctx context.Context,
|
|
||||||
update types.StateUpdate,
|
|
||||||
ignoreNodeIDs ...types.NodeID,
|
|
||||||
) {
|
|
||||||
if n.closed {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
notifierUpdateReceived.WithLabelValues(update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
|
|
||||||
n.b.addOrPassthrough(update)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *Notifier) NotifyByNodeID(
|
|
||||||
ctx context.Context,
|
|
||||||
update types.StateUpdate,
|
|
||||||
nodeID types.NodeID,
|
|
||||||
) {
|
|
||||||
start := time.Now()
|
|
||||||
notifierWaitersForLock.WithLabelValues("lock", "notify").Inc()
|
|
||||||
n.l.Lock()
|
|
||||||
defer n.l.Unlock()
|
|
||||||
notifierWaitersForLock.WithLabelValues("lock", "notify").Dec()
|
|
||||||
notifierWaitForLock.WithLabelValues("notify").Observe(time.Since(start).Seconds())
|
|
||||||
|
|
||||||
if n.closed {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if c, ok := n.nodes[nodeID]; ok {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
log.Error().
|
|
||||||
Err(ctx.Err()).
|
|
||||||
Uint64("node.id", nodeID.Uint64()).
|
|
||||||
Any("origin", types.NotifyOriginKey.Value(ctx)).
|
|
||||||
Any("origin-hostname", types.NotifyHostnameKey.Value(ctx)).
|
|
||||||
Msgf("update not sent, context cancelled")
|
|
||||||
if debugHighCardinalityMetrics {
|
|
||||||
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), types.NotifyOriginKey.Value(ctx), nodeID.String()).Inc()
|
|
||||||
} else {
|
|
||||||
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
case c <- update:
|
|
||||||
n.tracef(nodeID, "update successfully sent on chan, origin: %s, origin-hostname: %s", ctx.Value("origin"), ctx.Value("hostname"))
|
|
||||||
if debugHighCardinalityMetrics {
|
|
||||||
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), types.NotifyOriginKey.Value(ctx), nodeID.String()).Inc()
|
|
||||||
} else {
|
|
||||||
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *Notifier) sendAll(update types.StateUpdate) {
|
|
||||||
start := time.Now()
|
|
||||||
notifierWaitersForLock.WithLabelValues("lock", "send-all").Inc()
|
|
||||||
n.l.Lock()
|
|
||||||
defer n.l.Unlock()
|
|
||||||
notifierWaitersForLock.WithLabelValues("lock", "send-all").Dec()
|
|
||||||
notifierWaitForLock.WithLabelValues("send-all").Observe(time.Since(start).Seconds())
|
|
||||||
|
|
||||||
if n.closed {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for id, c := range n.nodes {
|
|
||||||
// Whenever an update is sent to all nodes, there is a chance that the node
|
|
||||||
// has disconnected and the goroutine that was supposed to consume the update
|
|
||||||
// has shut down the channel and is waiting for the lock held here in RemoveNode.
|
|
||||||
// This means that there is potential for a deadlock which would stop all updates
|
|
||||||
// going out to clients. This timeout prevents that from happening by moving on to the
|
|
||||||
// next node if the context is cancelled. After sendAll releases the lock, the add/remove
|
|
||||||
// call will succeed and the update will go to the correct nodes on the next call.
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), n.cfg.Tuning.NotifierSendTimeout)
|
|
||||||
defer cancel()
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
log.Error().
|
|
||||||
Err(ctx.Err()).
|
|
||||||
Uint64("node.id", id.Uint64()).
|
|
||||||
Msgf("update not sent, context cancelled")
|
|
||||||
if debugHighCardinalityMetrics {
|
|
||||||
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), "send-all", id.String()).Inc()
|
|
||||||
} else {
|
|
||||||
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), "send-all").Inc()
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
case c <- update:
|
|
||||||
if debugHighCardinalityMetrics {
|
|
||||||
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), "send-all", id.String()).Inc()
|
|
||||||
} else {
|
|
||||||
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), "send-all").Inc()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *Notifier) String() string {
|
|
||||||
notifierWaitersForLock.WithLabelValues("lock", "string").Inc()
|
|
||||||
n.l.Lock()
|
|
||||||
defer n.l.Unlock()
|
|
||||||
notifierWaitersForLock.WithLabelValues("lock", "string").Dec()
|
|
||||||
|
|
||||||
var b strings.Builder
|
|
||||||
fmt.Fprintf(&b, "chans (%d):\n", len(n.nodes))
|
|
||||||
|
|
||||||
var keys []types.NodeID
|
|
||||||
n.connected.Range(func(key types.NodeID, value bool) bool {
|
|
||||||
keys = append(keys, key)
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
sort.Slice(keys, func(i, j int) bool {
|
|
||||||
return keys[i] < keys[j]
|
|
||||||
})
|
|
||||||
|
|
||||||
for _, key := range keys {
|
|
||||||
fmt.Fprintf(&b, "\t%d: %p\n", key, n.nodes[key])
|
|
||||||
}
|
|
||||||
|
|
||||||
b.WriteString("\n")
|
|
||||||
fmt.Fprintf(&b, "connected (%d):\n", len(n.nodes))
|
|
||||||
|
|
||||||
for _, key := range keys {
|
|
||||||
val, _ := n.connected.Load(key)
|
|
||||||
fmt.Fprintf(&b, "\t%d: %t\n", key, val)
|
|
||||||
}
|
|
||||||
|
|
||||||
return b.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
type batcher struct {
|
|
||||||
tick *time.Ticker
|
|
||||||
|
|
||||||
mu sync.Mutex
|
|
||||||
|
|
||||||
cancelCh chan struct{}
|
|
||||||
|
|
||||||
changedNodeIDs set.Slice[types.NodeID]
|
|
||||||
nodesChanged bool
|
|
||||||
patches map[types.NodeID]tailcfg.PeerChange
|
|
||||||
patchesChanged bool
|
|
||||||
|
|
||||||
n *Notifier
|
|
||||||
}
|
|
||||||
|
|
||||||
func newBatcher(batchTime time.Duration, n *Notifier) *batcher {
|
|
||||||
return &batcher{
|
|
||||||
tick: time.NewTicker(batchTime),
|
|
||||||
cancelCh: make(chan struct{}),
|
|
||||||
patches: make(map[types.NodeID]tailcfg.PeerChange),
|
|
||||||
n: n,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *batcher) close() {
|
|
||||||
b.cancelCh <- struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// addOrPassthrough adds the update to the batcher, if it is not a
|
|
||||||
// type that is currently batched, it will be sent immediately.
|
|
||||||
func (b *batcher) addOrPassthrough(update types.StateUpdate) {
|
|
||||||
notifierBatcherWaitersForLock.WithLabelValues("lock", "add").Inc()
|
|
||||||
b.mu.Lock()
|
|
||||||
defer b.mu.Unlock()
|
|
||||||
notifierBatcherWaitersForLock.WithLabelValues("lock", "add").Dec()
|
|
||||||
|
|
||||||
switch update.Type {
|
|
||||||
case types.StatePeerChanged:
|
|
||||||
b.changedNodeIDs.Add(update.ChangeNodes...)
|
|
||||||
b.nodesChanged = true
|
|
||||||
notifierBatcherChanges.WithLabelValues().Set(float64(b.changedNodeIDs.Len()))
|
|
||||||
|
|
||||||
case types.StatePeerChangedPatch:
|
|
||||||
for _, newPatch := range update.ChangePatches {
|
|
||||||
if curr, ok := b.patches[types.NodeID(newPatch.NodeID)]; ok {
|
|
||||||
overwritePatch(&curr, newPatch)
|
|
||||||
b.patches[types.NodeID(newPatch.NodeID)] = curr
|
|
||||||
} else {
|
|
||||||
b.patches[types.NodeID(newPatch.NodeID)] = *newPatch
|
|
||||||
}
|
|
||||||
}
|
|
||||||
b.patchesChanged = true
|
|
||||||
notifierBatcherPatches.WithLabelValues().Set(float64(len(b.patches)))
|
|
||||||
|
|
||||||
default:
|
|
||||||
b.n.sendAll(update)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// flush sends all the accumulated patches to all
|
|
||||||
// nodes in the notifier.
|
|
||||||
func (b *batcher) flush() {
|
|
||||||
notifierBatcherWaitersForLock.WithLabelValues("lock", "flush").Inc()
|
|
||||||
b.mu.Lock()
|
|
||||||
defer b.mu.Unlock()
|
|
||||||
notifierBatcherWaitersForLock.WithLabelValues("lock", "flush").Dec()
|
|
||||||
|
|
||||||
if b.nodesChanged || b.patchesChanged {
|
|
||||||
var patches []*tailcfg.PeerChange
|
|
||||||
// If a node is getting a full update from a change
|
|
||||||
// node update, then the patch can be dropped.
|
|
||||||
for nodeID, patch := range b.patches {
|
|
||||||
if b.changedNodeIDs.Contains(nodeID) {
|
|
||||||
delete(b.patches, nodeID)
|
|
||||||
} else {
|
|
||||||
patches = append(patches, &patch)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
changedNodes := b.changedNodeIDs.Slice().AsSlice()
|
|
||||||
sort.Slice(changedNodes, func(i, j int) bool {
|
|
||||||
return changedNodes[i] < changedNodes[j]
|
|
||||||
})
|
|
||||||
|
|
||||||
if b.changedNodeIDs.Slice().Len() > 0 {
|
|
||||||
update := types.UpdatePeerChanged(changedNodes...)
|
|
||||||
|
|
||||||
b.n.sendAll(update)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(patches) > 0 {
|
|
||||||
patchUpdate := types.UpdatePeerPatch(patches...)
|
|
||||||
|
|
||||||
b.n.sendAll(patchUpdate)
|
|
||||||
}
|
|
||||||
|
|
||||||
b.changedNodeIDs = set.Slice[types.NodeID]{}
|
|
||||||
notifierBatcherChanges.WithLabelValues().Set(0)
|
|
||||||
b.nodesChanged = false
|
|
||||||
b.patches = make(map[types.NodeID]tailcfg.PeerChange, len(b.patches))
|
|
||||||
notifierBatcherPatches.WithLabelValues().Set(0)
|
|
||||||
b.patchesChanged = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *batcher) doWork() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-b.cancelCh:
|
|
||||||
return
|
|
||||||
case <-b.tick.C:
|
|
||||||
b.flush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// overwritePatch takes the current patch and a newer patch
|
|
||||||
// and override any field that has changed.
|
|
||||||
func overwritePatch(currPatch, newPatch *tailcfg.PeerChange) {
|
|
||||||
if newPatch.DERPRegion != 0 {
|
|
||||||
currPatch.DERPRegion = newPatch.DERPRegion
|
|
||||||
}
|
|
||||||
|
|
||||||
if newPatch.Cap != 0 {
|
|
||||||
currPatch.Cap = newPatch.Cap
|
|
||||||
}
|
|
||||||
|
|
||||||
if newPatch.CapMap != nil {
|
|
||||||
currPatch.CapMap = newPatch.CapMap
|
|
||||||
}
|
|
||||||
|
|
||||||
if newPatch.Endpoints != nil {
|
|
||||||
currPatch.Endpoints = newPatch.Endpoints
|
|
||||||
}
|
|
||||||
|
|
||||||
if newPatch.Key != nil {
|
|
||||||
currPatch.Key = newPatch.Key
|
|
||||||
}
|
|
||||||
|
|
||||||
if newPatch.KeySignature != nil {
|
|
||||||
currPatch.KeySignature = newPatch.KeySignature
|
|
||||||
}
|
|
||||||
|
|
||||||
if newPatch.DiscoKey != nil {
|
|
||||||
currPatch.DiscoKey = newPatch.DiscoKey
|
|
||||||
}
|
|
||||||
|
|
||||||
if newPatch.Online != nil {
|
|
||||||
currPatch.Online = newPatch.Online
|
|
||||||
}
|
|
||||||
|
|
||||||
if newPatch.LastSeen != nil {
|
|
||||||
currPatch.LastSeen = newPatch.LastSeen
|
|
||||||
}
|
|
||||||
|
|
||||||
if newPatch.KeyExpiry != nil {
|
|
||||||
currPatch.KeyExpiry = newPatch.KeyExpiry
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,342 +0,0 @@
|
|||||||
package notifier
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"math/rand"
|
|
||||||
"net/netip"
|
|
||||||
"slices"
|
|
||||||
"sort"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
|
||||||
"tailscale.com/tailcfg"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestBatcher(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
updates []types.StateUpdate
|
|
||||||
want []types.StateUpdate
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "full-passthrough",
|
|
||||||
updates: []types.StateUpdate{
|
|
||||||
{
|
|
||||||
Type: types.StateFullUpdate,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
want: []types.StateUpdate{
|
|
||||||
{
|
|
||||||
Type: types.StateFullUpdate,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "derp-passthrough",
|
|
||||||
updates: []types.StateUpdate{
|
|
||||||
{
|
|
||||||
Type: types.StateDERPUpdated,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
want: []types.StateUpdate{
|
|
||||||
{
|
|
||||||
Type: types.StateDERPUpdated,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "single-node-update",
|
|
||||||
updates: []types.StateUpdate{
|
|
||||||
{
|
|
||||||
Type: types.StatePeerChanged,
|
|
||||||
ChangeNodes: []types.NodeID{
|
|
||||||
2,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
want: []types.StateUpdate{
|
|
||||||
{
|
|
||||||
Type: types.StatePeerChanged,
|
|
||||||
ChangeNodes: []types.NodeID{
|
|
||||||
2,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "merge-node-update",
|
|
||||||
updates: []types.StateUpdate{
|
|
||||||
{
|
|
||||||
Type: types.StatePeerChanged,
|
|
||||||
ChangeNodes: []types.NodeID{
|
|
||||||
2, 4,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Type: types.StatePeerChanged,
|
|
||||||
ChangeNodes: []types.NodeID{
|
|
||||||
2, 3,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
want: []types.StateUpdate{
|
|
||||||
{
|
|
||||||
Type: types.StatePeerChanged,
|
|
||||||
ChangeNodes: []types.NodeID{
|
|
||||||
2, 3, 4,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "single-patch-update",
|
|
||||||
updates: []types.StateUpdate{
|
|
||||||
{
|
|
||||||
Type: types.StatePeerChangedPatch,
|
|
||||||
ChangePatches: []*tailcfg.PeerChange{
|
|
||||||
{
|
|
||||||
NodeID: 2,
|
|
||||||
DERPRegion: 5,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
want: []types.StateUpdate{
|
|
||||||
{
|
|
||||||
Type: types.StatePeerChangedPatch,
|
|
||||||
ChangePatches: []*tailcfg.PeerChange{
|
|
||||||
{
|
|
||||||
NodeID: 2,
|
|
||||||
DERPRegion: 5,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "merge-patch-to-same-node-update",
|
|
||||||
updates: []types.StateUpdate{
|
|
||||||
{
|
|
||||||
Type: types.StatePeerChangedPatch,
|
|
||||||
ChangePatches: []*tailcfg.PeerChange{
|
|
||||||
{
|
|
||||||
NodeID: 2,
|
|
||||||
DERPRegion: 5,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Type: types.StatePeerChangedPatch,
|
|
||||||
ChangePatches: []*tailcfg.PeerChange{
|
|
||||||
{
|
|
||||||
NodeID: 2,
|
|
||||||
DERPRegion: 6,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
want: []types.StateUpdate{
|
|
||||||
{
|
|
||||||
Type: types.StatePeerChangedPatch,
|
|
||||||
ChangePatches: []*tailcfg.PeerChange{
|
|
||||||
{
|
|
||||||
NodeID: 2,
|
|
||||||
DERPRegion: 6,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "merge-patch-to-multiple-node-update",
|
|
||||||
updates: []types.StateUpdate{
|
|
||||||
{
|
|
||||||
Type: types.StatePeerChangedPatch,
|
|
||||||
ChangePatches: []*tailcfg.PeerChange{
|
|
||||||
{
|
|
||||||
NodeID: 3,
|
|
||||||
Endpoints: []netip.AddrPort{
|
|
||||||
netip.MustParseAddrPort("1.1.1.1:9090"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Type: types.StatePeerChangedPatch,
|
|
||||||
ChangePatches: []*tailcfg.PeerChange{
|
|
||||||
{
|
|
||||||
NodeID: 3,
|
|
||||||
Endpoints: []netip.AddrPort{
|
|
||||||
netip.MustParseAddrPort("1.1.1.1:9090"),
|
|
||||||
netip.MustParseAddrPort("2.2.2.2:8080"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Type: types.StatePeerChangedPatch,
|
|
||||||
ChangePatches: []*tailcfg.PeerChange{
|
|
||||||
{
|
|
||||||
NodeID: 4,
|
|
||||||
DERPRegion: 6,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Type: types.StatePeerChangedPatch,
|
|
||||||
ChangePatches: []*tailcfg.PeerChange{
|
|
||||||
{
|
|
||||||
NodeID: 4,
|
|
||||||
Cap: tailcfg.CapabilityVersion(54),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
want: []types.StateUpdate{
|
|
||||||
{
|
|
||||||
Type: types.StatePeerChangedPatch,
|
|
||||||
ChangePatches: []*tailcfg.PeerChange{
|
|
||||||
{
|
|
||||||
NodeID: 3,
|
|
||||||
Endpoints: []netip.AddrPort{
|
|
||||||
netip.MustParseAddrPort("1.1.1.1:9090"),
|
|
||||||
netip.MustParseAddrPort("2.2.2.2:8080"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
NodeID: 4,
|
|
||||||
DERPRegion: 6,
|
|
||||||
Cap: tailcfg.CapabilityVersion(54),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
n := NewNotifier(&types.Config{
|
|
||||||
Tuning: types.Tuning{
|
|
||||||
// We will call flush manually for the tests,
|
|
||||||
// so do not run the worker.
|
|
||||||
BatchChangeDelay: time.Hour,
|
|
||||||
|
|
||||||
// Since we do not load the config, we won't get the
|
|
||||||
// default, so set it manually so we dont time out
|
|
||||||
// and have flakes.
|
|
||||||
NotifierSendTimeout: time.Second,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
ch := make(chan types.StateUpdate, 30)
|
|
||||||
defer close(ch)
|
|
||||||
n.AddNode(1, ch)
|
|
||||||
defer n.RemoveNode(1, ch)
|
|
||||||
|
|
||||||
for _, u := range tt.updates {
|
|
||||||
n.NotifyAll(t.Context(), u)
|
|
||||||
}
|
|
||||||
|
|
||||||
n.b.flush()
|
|
||||||
|
|
||||||
var got []types.StateUpdate
|
|
||||||
for len(ch) > 0 {
|
|
||||||
out := <-ch
|
|
||||||
got = append(got, out)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make the inner order stable for comparison.
|
|
||||||
for _, u := range got {
|
|
||||||
slices.Sort(u.ChangeNodes)
|
|
||||||
sort.Slice(u.ChangePatches, func(i, j int) bool {
|
|
||||||
return u.ChangePatches[i].NodeID < u.ChangePatches[j].NodeID
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
|
|
||||||
t.Errorf("batcher() unexpected result (-want +got):\n%s", diff)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestIsLikelyConnectedRaceCondition tests for a race condition in IsLikelyConnected
|
|
||||||
// Multiple goroutines calling AddNode and RemoveNode cause panics when trying to
|
|
||||||
// close a channel that was already closed, which can happen when a node changes
|
|
||||||
// network transport quickly (eg mobile->wifi) and reconnects whilst also disconnecting.
|
|
||||||
func TestIsLikelyConnectedRaceCondition(t *testing.T) {
|
|
||||||
// mock config for the notifier
|
|
||||||
cfg := &types.Config{
|
|
||||||
Tuning: types.Tuning{
|
|
||||||
NotifierSendTimeout: 1 * time.Second,
|
|
||||||
BatchChangeDelay: 1 * time.Second,
|
|
||||||
NodeMapSessionBufferedChanSize: 30,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
notifier := NewNotifier(cfg)
|
|
||||||
defer notifier.Close()
|
|
||||||
|
|
||||||
nodeID := types.NodeID(1)
|
|
||||||
updateChan := make(chan types.StateUpdate, 10)
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
|
|
||||||
// Number of goroutines to spawn for concurrent access
|
|
||||||
concurrentAccessors := 100
|
|
||||||
iterations := 100
|
|
||||||
|
|
||||||
// Add node to notifier
|
|
||||||
notifier.AddNode(nodeID, updateChan)
|
|
||||||
|
|
||||||
// Track errors
|
|
||||||
errChan := make(chan string, concurrentAccessors*iterations)
|
|
||||||
|
|
||||||
// Start goroutines to cause a race
|
|
||||||
wg.Add(concurrentAccessors)
|
|
||||||
for i := range concurrentAccessors {
|
|
||||||
go func(routineID int) {
|
|
||||||
defer wg.Done()
|
|
||||||
|
|
||||||
for range iterations {
|
|
||||||
// Simulate race by having some goroutines check IsLikelyConnected
|
|
||||||
// while others add/remove the node
|
|
||||||
switch routineID % 3 {
|
|
||||||
case 0:
|
|
||||||
// This goroutine checks connection status
|
|
||||||
isConnected := notifier.IsLikelyConnected(nodeID)
|
|
||||||
if isConnected != true && isConnected != false {
|
|
||||||
errChan <- fmt.Sprintf("Invalid connection status: %v", isConnected)
|
|
||||||
}
|
|
||||||
case 1:
|
|
||||||
// This goroutine removes the node
|
|
||||||
notifier.RemoveNode(nodeID, updateChan)
|
|
||||||
default:
|
|
||||||
// This goroutine adds the node back
|
|
||||||
notifier.AddNode(nodeID, updateChan)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Small random delay to increase chance of races
|
|
||||||
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
|
|
||||||
}
|
|
||||||
}(i)
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Wait()
|
|
||||||
close(errChan)
|
|
||||||
|
|
||||||
// Collate errors
|
|
||||||
var errors []string
|
|
||||||
for err := range errChan {
|
|
||||||
errors = append(errors, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(errors) > 0 {
|
|
||||||
t.Errorf("Detected %d race condition errors: %v", len(errors), errors)
|
|
||||||
}
|
|
||||||
}
|
|
@ -16,9 +16,8 @@ import (
|
|||||||
"github.com/coreos/go-oidc/v3/oidc"
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/juanfont/headscale/hscontrol/db"
|
"github.com/juanfont/headscale/hscontrol/db"
|
||||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
|
||||||
"github.com/juanfont/headscale/hscontrol/state"
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
@ -56,11 +55,10 @@ type RegistrationInfo struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type AuthProviderOIDC struct {
|
type AuthProviderOIDC struct {
|
||||||
|
h *Headscale
|
||||||
serverURL string
|
serverURL string
|
||||||
cfg *types.OIDCConfig
|
cfg *types.OIDCConfig
|
||||||
state *state.State
|
|
||||||
registrationCache *zcache.Cache[string, RegistrationInfo]
|
registrationCache *zcache.Cache[string, RegistrationInfo]
|
||||||
notifier *notifier.Notifier
|
|
||||||
|
|
||||||
oidcProvider *oidc.Provider
|
oidcProvider *oidc.Provider
|
||||||
oauth2Config *oauth2.Config
|
oauth2Config *oauth2.Config
|
||||||
@ -68,10 +66,9 @@ type AuthProviderOIDC struct {
|
|||||||
|
|
||||||
func NewAuthProviderOIDC(
|
func NewAuthProviderOIDC(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
h *Headscale,
|
||||||
serverURL string,
|
serverURL string,
|
||||||
cfg *types.OIDCConfig,
|
cfg *types.OIDCConfig,
|
||||||
state *state.State,
|
|
||||||
notif *notifier.Notifier,
|
|
||||||
) (*AuthProviderOIDC, error) {
|
) (*AuthProviderOIDC, error) {
|
||||||
var err error
|
var err error
|
||||||
// grab oidc config if it hasn't been already
|
// grab oidc config if it hasn't been already
|
||||||
@ -94,11 +91,10 @@ func NewAuthProviderOIDC(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return &AuthProviderOIDC{
|
return &AuthProviderOIDC{
|
||||||
|
h: h,
|
||||||
serverURL: serverURL,
|
serverURL: serverURL,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
state: state,
|
|
||||||
registrationCache: registrationCache,
|
registrationCache: registrationCache,
|
||||||
notifier: notif,
|
|
||||||
|
|
||||||
oidcProvider: oidcProvider,
|
oidcProvider: oidcProvider,
|
||||||
oauth2Config: oauth2Config,
|
oauth2Config: oauth2Config,
|
||||||
@ -318,8 +314,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||||||
|
|
||||||
// Send policy update notifications if needed
|
// Send policy update notifications if needed
|
||||||
if policyChanged {
|
if policyChanged {
|
||||||
ctx := types.NotifyCtx(context.Background(), "oidc-user-created", user.Name)
|
a.h.Change(change.PolicyChange())
|
||||||
a.notifier.NotifyAll(ctx, types.UpdateFull())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(kradalby): Is this comment right?
|
// TODO(kradalby): Is this comment right?
|
||||||
@ -360,8 +355,6 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||||||
// Neither node nor machine key was found in the state cache meaning
|
// Neither node nor machine key was found in the state cache meaning
|
||||||
// that we could not reauth nor register the node.
|
// that we could not reauth nor register the node.
|
||||||
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
|
httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil))
|
||||||
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractCodeAndStateParamFromRequest(
|
func extractCodeAndStateParamFromRequest(
|
||||||
@ -490,12 +483,14 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
|||||||
var err error
|
var err error
|
||||||
var newUser bool
|
var newUser bool
|
||||||
var policyChanged bool
|
var policyChanged bool
|
||||||
user, err = a.state.GetUserByOIDCIdentifier(claims.Identifier())
|
user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier())
|
||||||
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
|
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
|
||||||
return nil, false, fmt.Errorf("creating or updating user: %w", err)
|
return nil, false, fmt.Errorf("creating or updating user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// if the user is still not found, create a new empty user.
|
// if the user is still not found, create a new empty user.
|
||||||
|
// TODO(kradalby): This might cause us to not have an ID below which
|
||||||
|
// is a problem.
|
||||||
if user == nil {
|
if user == nil {
|
||||||
newUser = true
|
newUser = true
|
||||||
user = &types.User{}
|
user = &types.User{}
|
||||||
@ -504,12 +499,12 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
|||||||
user.FromClaim(claims)
|
user.FromClaim(claims)
|
||||||
|
|
||||||
if newUser {
|
if newUser {
|
||||||
user, policyChanged, err = a.state.CreateUser(*user)
|
user, policyChanged, err = a.h.state.CreateUser(*user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, false, fmt.Errorf("creating user: %w", err)
|
return nil, false, fmt.Errorf("creating user: %w", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
_, policyChanged, err = a.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error {
|
_, policyChanged, err = a.h.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error {
|
||||||
*u = *user
|
*u = *user
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
@ -526,7 +521,7 @@ func (a *AuthProviderOIDC) handleRegistration(
|
|||||||
registrationID types.RegistrationID,
|
registrationID types.RegistrationID,
|
||||||
expiry time.Time,
|
expiry time.Time,
|
||||||
) (bool, error) {
|
) (bool, error) {
|
||||||
node, newNode, err := a.state.HandleNodeFromAuthPath(
|
node, nodeChange, err := a.h.state.HandleNodeFromAuthPath(
|
||||||
registrationID,
|
registrationID,
|
||||||
types.UserID(user.ID),
|
types.UserID(user.ID),
|
||||||
&expiry,
|
&expiry,
|
||||||
@ -547,31 +542,20 @@ func (a *AuthProviderOIDC) handleRegistration(
|
|||||||
// ensure we send an update.
|
// ensure we send an update.
|
||||||
// This works, but might be another good candidate for doing some sort of
|
// This works, but might be another good candidate for doing some sort of
|
||||||
// eventbus.
|
// eventbus.
|
||||||
routesChanged := a.state.AutoApproveRoutes(node)
|
_ = a.h.state.AutoApproveRoutes(node)
|
||||||
_, policyChanged, err := a.state.SaveNode(node)
|
_, policyChange, err := a.h.state.SaveNode(node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("saving auto approved routes to node: %w", err)
|
return false, fmt.Errorf("saving auto approved routes to node: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send policy update notifications if needed (from SaveNode or route changes)
|
// Policy updates are full and take precedence over node changes.
|
||||||
if policyChanged {
|
if !policyChange.Empty() {
|
||||||
ctx := types.NotifyCtx(context.Background(), "oidc-nodes-change", "all")
|
a.h.Change(policyChange)
|
||||||
a.notifier.NotifyAll(ctx, types.UpdateFull())
|
} else {
|
||||||
|
a.h.Change(nodeChange)
|
||||||
}
|
}
|
||||||
|
|
||||||
if routesChanged {
|
return !nodeChange.Empty(), nil
|
||||||
ctx := types.NotifyCtx(context.Background(), "oidc-expiry-self", node.Hostname)
|
|
||||||
a.notifier.NotifyByNodeID(
|
|
||||||
ctx,
|
|
||||||
types.UpdateSelf(node.ID),
|
|
||||||
node.ID,
|
|
||||||
)
|
|
||||||
|
|
||||||
ctx = types.NotifyCtx(context.Background(), "oidc-expiry-peers", node.Hostname)
|
|
||||||
a.notifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
return newNode, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(kradalby):
|
// TODO(kradalby):
|
||||||
|
@ -113,6 +113,17 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Also check approved subnet routes - nodes should have access
|
||||||
|
// to subnets they're approved to route traffic for.
|
||||||
|
subnetRoutes := node.SubnetRoutes()
|
||||||
|
|
||||||
|
for _, subnetRoute := range subnetRoutes {
|
||||||
|
if expanded.OverlapsPrefix(subnetRoute) {
|
||||||
|
dests = append(dests, dest)
|
||||||
|
continue DEST_LOOP
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(dests) > 0 {
|
if len(dests) > 0 {
|
||||||
@ -142,16 +153,23 @@ func AutoApproveRoutes(pm PolicyManager, node *types.Node) bool {
|
|||||||
newApproved = append(newApproved, route)
|
newApproved = append(newApproved, route)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if newApproved != nil {
|
|
||||||
newApproved = append(newApproved, node.ApprovedRoutes...)
|
// Only modify ApprovedRoutes if we have new routes to approve.
|
||||||
tsaddr.SortPrefixes(newApproved)
|
// This prevents clearing existing approved routes when nodes
|
||||||
newApproved = slices.Compact(newApproved)
|
// temporarily don't have announced routes during policy changes.
|
||||||
newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool {
|
if len(newApproved) > 0 {
|
||||||
|
combined := append(newApproved, node.ApprovedRoutes...)
|
||||||
|
tsaddr.SortPrefixes(combined)
|
||||||
|
combined = slices.Compact(combined)
|
||||||
|
combined = lo.Filter(combined, func(route netip.Prefix, index int) bool {
|
||||||
return route.IsValid()
|
return route.IsValid()
|
||||||
})
|
})
|
||||||
node.ApprovedRoutes = newApproved
|
|
||||||
|
|
||||||
return true
|
// Only update if the routes actually changed
|
||||||
|
if !slices.Equal(node.ApprovedRoutes, combined) {
|
||||||
|
node.ApprovedRoutes = combined
|
||||||
|
return true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
|
@ -56,10 +56,13 @@ func (pol *Policy) compileFilterRules(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if ips == nil {
|
if ips == nil {
|
||||||
|
log.Debug().Msgf("destination resolved to nil ips: %v", dest)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, pref := range ips.Prefixes() {
|
prefixes := ips.Prefixes()
|
||||||
|
|
||||||
|
for _, pref := range prefixes {
|
||||||
for _, port := range dest.Ports {
|
for _, port := range dest.Ports {
|
||||||
pr := tailcfg.NetPortRange{
|
pr := tailcfg.NetPortRange{
|
||||||
IP: pref.String(),
|
IP: pref.String(),
|
||||||
@ -103,6 +106,8 @@ func (pol *Policy) compileSSHPolicy(
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Trace().Msgf("compiling SSH policy for node %q", node.Hostname())
|
||||||
|
|
||||||
var rules []*tailcfg.SSHRule
|
var rules []*tailcfg.SSHRule
|
||||||
|
|
||||||
for index, rule := range pol.SSHs {
|
for index, rule := range pol.SSHs {
|
||||||
@ -137,7 +142,8 @@ func (pol *Policy) compileSSHPolicy(
|
|||||||
var principals []*tailcfg.SSHPrincipal
|
var principals []*tailcfg.SSHPrincipal
|
||||||
srcIPs, err := rule.Sources.Resolve(pol, users, nodes)
|
srcIPs, err := rule.Sources.Resolve(pol, users, nodes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Trace().Err(err).Msgf("resolving source ips")
|
log.Trace().Err(err).Msgf("SSH policy compilation failed resolving source ips for rule %+v", rule)
|
||||||
|
continue // Skip this rule if we can't resolve sources
|
||||||
}
|
}
|
||||||
|
|
||||||
for addr := range util.IPSetAddrIter(srcIPs) {
|
for addr := range util.IPSetAddrIter(srcIPs) {
|
||||||
|
@ -70,7 +70,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
|||||||
// TODO(kradalby): This could potentially be optimized by only clearing the
|
// TODO(kradalby): This could potentially be optimized by only clearing the
|
||||||
// policies for nodes that have changed. Particularly if the only difference is
|
// policies for nodes that have changed. Particularly if the only difference is
|
||||||
// that nodes has been added or removed.
|
// that nodes has been added or removed.
|
||||||
defer clear(pm.sshPolicyMap)
|
clear(pm.sshPolicyMap)
|
||||||
|
|
||||||
filter, err := pm.pol.compileFilterRules(pm.users, pm.nodes)
|
filter, err := pm.pol.compileFilterRules(pm.users, pm.nodes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1730,7 +1730,7 @@ func (u SSHUser) MarshalJSON() ([]byte, error) {
|
|||||||
// In addition to unmarshalling, it will also validate the policy.
|
// In addition to unmarshalling, it will also validate the policy.
|
||||||
// This is the only entrypoint of reading a policy from a file or other source.
|
// This is the only entrypoint of reading a policy from a file or other source.
|
||||||
func unmarshalPolicy(b []byte) (*Policy, error) {
|
func unmarshalPolicy(b []byte) (*Policy, error) {
|
||||||
if b == nil || len(b) == 0 {
|
if len(b) == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,20 +2,20 @@ package hscontrol
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand/v2"
|
"math/rand/v2"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
|
||||||
"slices"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/mapper"
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/sasha-s/go-deadlock"
|
"github.com/sasha-s/go-deadlock"
|
||||||
xslices "golang.org/x/exp/slices"
|
|
||||||
"tailscale.com/net/tsaddr"
|
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
|
"tailscale.com/util/zstdframe"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -31,18 +31,17 @@ type mapSession struct {
|
|||||||
req tailcfg.MapRequest
|
req tailcfg.MapRequest
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
capVer tailcfg.CapabilityVersion
|
capVer tailcfg.CapabilityVersion
|
||||||
mapper *mapper.Mapper
|
|
||||||
|
|
||||||
cancelChMu deadlock.Mutex
|
cancelChMu deadlock.Mutex
|
||||||
|
|
||||||
ch chan types.StateUpdate
|
ch chan *tailcfg.MapResponse
|
||||||
cancelCh chan struct{}
|
cancelCh chan struct{}
|
||||||
cancelChOpen bool
|
cancelChOpen bool
|
||||||
|
|
||||||
keepAlive time.Duration
|
keepAlive time.Duration
|
||||||
keepAliveTicker *time.Ticker
|
keepAliveTicker *time.Ticker
|
||||||
|
|
||||||
node types.NodeView
|
node *types.Node
|
||||||
w http.ResponseWriter
|
w http.ResponseWriter
|
||||||
|
|
||||||
warnf func(string, ...any)
|
warnf func(string, ...any)
|
||||||
@ -55,18 +54,9 @@ func (h *Headscale) newMapSession(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
req tailcfg.MapRequest,
|
req tailcfg.MapRequest,
|
||||||
w http.ResponseWriter,
|
w http.ResponseWriter,
|
||||||
nv types.NodeView,
|
node *types.Node,
|
||||||
) *mapSession {
|
) *mapSession {
|
||||||
warnf, infof, tracef, errf := logPollFuncView(req, nv)
|
warnf, infof, tracef, errf := logPollFunc(req, node)
|
||||||
|
|
||||||
var updateChan chan types.StateUpdate
|
|
||||||
if req.Stream {
|
|
||||||
// Use a buffered channel in case a node is not fully ready
|
|
||||||
// to receive a message to make sure we dont block the entire
|
|
||||||
// notifier.
|
|
||||||
updateChan = make(chan types.StateUpdate, h.cfg.Tuning.NodeMapSessionBufferedChanSize)
|
|
||||||
updateChan <- types.UpdateFull()
|
|
||||||
}
|
|
||||||
|
|
||||||
ka := keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond)
|
ka := keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond)
|
||||||
|
|
||||||
@ -75,11 +65,10 @@ func (h *Headscale) newMapSession(
|
|||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
req: req,
|
req: req,
|
||||||
w: w,
|
w: w,
|
||||||
node: nv,
|
node: node,
|
||||||
capVer: req.Version,
|
capVer: req.Version,
|
||||||
mapper: h.mapper,
|
|
||||||
|
|
||||||
ch: updateChan,
|
ch: make(chan *tailcfg.MapResponse, h.cfg.Tuning.NodeMapSessionBufferedChanSize),
|
||||||
cancelCh: make(chan struct{}),
|
cancelCh: make(chan struct{}),
|
||||||
cancelChOpen: true,
|
cancelChOpen: true,
|
||||||
|
|
||||||
@ -95,15 +84,11 @@ func (h *Headscale) newMapSession(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *mapSession) isStreaming() bool {
|
func (m *mapSession) isStreaming() bool {
|
||||||
return m.req.Stream && !m.req.ReadOnly
|
return m.req.Stream
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mapSession) isEndpointUpdate() bool {
|
func (m *mapSession) isEndpointUpdate() bool {
|
||||||
return !m.req.Stream && !m.req.ReadOnly && m.req.OmitPeers
|
return !m.req.Stream && m.req.OmitPeers
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mapSession) isReadOnlyUpdate() bool {
|
|
||||||
return !m.req.Stream && m.req.OmitPeers && m.req.ReadOnly
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mapSession) resetKeepAlive() {
|
func (m *mapSession) resetKeepAlive() {
|
||||||
@ -112,25 +97,22 @@ func (m *mapSession) resetKeepAlive() {
|
|||||||
|
|
||||||
func (m *mapSession) beforeServeLongPoll() {
|
func (m *mapSession) beforeServeLongPoll() {
|
||||||
if m.node.IsEphemeral() {
|
if m.node.IsEphemeral() {
|
||||||
m.h.ephemeralGC.Cancel(m.node.ID())
|
m.h.ephemeralGC.Cancel(m.node.ID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mapSession) afterServeLongPoll() {
|
func (m *mapSession) afterServeLongPoll() {
|
||||||
if m.node.IsEphemeral() {
|
if m.node.IsEphemeral() {
|
||||||
m.h.ephemeralGC.Schedule(m.node.ID(), m.h.cfg.EphemeralNodeInactivityTimeout)
|
m.h.ephemeralGC.Schedule(m.node.ID, m.h.cfg.EphemeralNodeInactivityTimeout)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// serve handles non-streaming requests.
|
// serve handles non-streaming requests.
|
||||||
func (m *mapSession) serve() {
|
func (m *mapSession) serve() {
|
||||||
// TODO(kradalby): A set todos to harden:
|
|
||||||
// - func to tell the stream to die, readonly -> false, !stream && omitpeers -> false, true
|
|
||||||
|
|
||||||
// This is the mechanism where the node gives us information about its
|
// This is the mechanism where the node gives us information about its
|
||||||
// current configuration.
|
// current configuration.
|
||||||
//
|
//
|
||||||
// If OmitPeers is true, Stream is false, and ReadOnly is false,
|
// If OmitPeers is true and Stream is false
|
||||||
// then the server will let clients update their endpoints without
|
// then the server will let clients update their endpoints without
|
||||||
// breaking existing long-polling (Stream == true) connections.
|
// breaking existing long-polling (Stream == true) connections.
|
||||||
// In this case, the server can omit the entire response; the client
|
// In this case, the server can omit the entire response; the client
|
||||||
@ -138,26 +120,18 @@ func (m *mapSession) serve() {
|
|||||||
//
|
//
|
||||||
// This is what Tailscale calls a Lite update, the client ignores
|
// This is what Tailscale calls a Lite update, the client ignores
|
||||||
// the response and just wants a 200.
|
// the response and just wants a 200.
|
||||||
// !req.stream && !req.ReadOnly && req.OmitPeers
|
// !req.stream && req.OmitPeers
|
||||||
//
|
|
||||||
// TODO(kradalby): remove ReadOnly when we only support capVer 68+
|
|
||||||
if m.isEndpointUpdate() {
|
if m.isEndpointUpdate() {
|
||||||
m.handleEndpointUpdate()
|
c, err := m.h.state.UpdateNodeFromMapRequest(m.node, m.req)
|
||||||
|
if err != nil {
|
||||||
|
httpError(m.w, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
return
|
m.h.Change(c)
|
||||||
}
|
|
||||||
|
|
||||||
// ReadOnly is whether the client just wants to fetch the
|
m.w.WriteHeader(http.StatusOK)
|
||||||
// MapResponse, without updating their Endpoints. The
|
mapResponseEndpointUpdates.WithLabelValues("ok").Inc()
|
||||||
// Endpoints field will be ignored and LastSeen will not be
|
|
||||||
// updated and peers will not be notified of changes.
|
|
||||||
//
|
|
||||||
// The intended use is for clients to discover the DERP map at
|
|
||||||
// start-up before their first real endpoint update.
|
|
||||||
if m.isReadOnlyUpdate() {
|
|
||||||
m.handleReadOnlyRequest()
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -175,23 +149,15 @@ func (m *mapSession) serveLongPoll() {
|
|||||||
close(m.cancelCh)
|
close(m.cancelCh)
|
||||||
m.cancelChMu.Unlock()
|
m.cancelChMu.Unlock()
|
||||||
|
|
||||||
// only update node status if the node channel was removed.
|
// TODO(kradalby): This can likely be made more effective, but likely most
|
||||||
// in principal, it will be removed, but the client rapidly
|
// nodes has access to the same routes, so it might not be a big deal.
|
||||||
// reconnects, the channel might be of another connection.
|
disconnectChange, err := m.h.state.Disconnect(m.node)
|
||||||
// In that case, it is not closed and the node is still online.
|
if err != nil {
|
||||||
if m.h.nodeNotifier.RemoveNode(m.node.ID(), m.ch) {
|
m.errf(err, "Failed to disconnect node %s", m.node.Hostname)
|
||||||
// TODO(kradalby): This can likely be made more effective, but likely most
|
|
||||||
// nodes has access to the same routes, so it might not be a big deal.
|
|
||||||
change, err := m.h.state.Disconnect(m.node.ID())
|
|
||||||
if err != nil {
|
|
||||||
m.errf(err, "Failed to disconnect node %s", m.node.Hostname())
|
|
||||||
}
|
|
||||||
|
|
||||||
if change {
|
|
||||||
ctx := types.NotifyCtx(context.Background(), "poll-primary-change", m.node.Hostname())
|
|
||||||
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
m.h.Change(disconnectChange)
|
||||||
|
|
||||||
|
m.h.mapBatcher.RemoveNode(m.node.ID, m.ch, m.node.IsSubnetRouter())
|
||||||
|
|
||||||
m.afterServeLongPoll()
|
m.afterServeLongPoll()
|
||||||
m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch)
|
m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch)
|
||||||
@ -201,21 +167,30 @@ func (m *mapSession) serveLongPoll() {
|
|||||||
m.h.pollNetMapStreamWG.Add(1)
|
m.h.pollNetMapStreamWG.Add(1)
|
||||||
defer m.h.pollNetMapStreamWG.Done()
|
defer m.h.pollNetMapStreamWG.Done()
|
||||||
|
|
||||||
m.h.state.Connect(m.node.ID())
|
ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname))
|
||||||
|
|
||||||
// Upgrade the writer to a ResponseController
|
|
||||||
rc := http.NewResponseController(m.w)
|
|
||||||
|
|
||||||
// Longpolling will break if there is a write timeout,
|
|
||||||
// so it needs to be disabled.
|
|
||||||
rc.SetWriteDeadline(time.Time{})
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname()))
|
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
m.keepAliveTicker = time.NewTicker(m.keepAlive)
|
m.keepAliveTicker = time.NewTicker(m.keepAlive)
|
||||||
|
|
||||||
m.h.nodeNotifier.AddNode(m.node.ID(), m.ch)
|
// Add node to batcher BEFORE sending Connect change to prevent race condition
|
||||||
|
// where the change is sent before the node is in the batcher's node map
|
||||||
|
if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.node.IsSubnetRouter(), m.capVer); err != nil {
|
||||||
|
m.errf(err, "failed to add node to batcher")
|
||||||
|
// Send empty response to client to fail fast for invalid/non-existent nodes
|
||||||
|
select {
|
||||||
|
case m.ch <- &tailcfg.MapResponse{}:
|
||||||
|
default:
|
||||||
|
// Channel might be closed
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now send the Connect change - the batcher handles NodeCameOnline internally
|
||||||
|
// but we still need to update routes and other state-level changes
|
||||||
|
connectChange := m.h.state.Connect(m.node)
|
||||||
|
if !connectChange.Empty() && connectChange.Change != change.NodeCameOnline {
|
||||||
|
m.h.Change(connectChange)
|
||||||
|
}
|
||||||
|
|
||||||
m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch)
|
m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch)
|
||||||
|
|
||||||
@ -236,290 +211,94 @@ func (m *mapSession) serveLongPoll() {
|
|||||||
|
|
||||||
// Consume updates sent to node
|
// Consume updates sent to node
|
||||||
case update, ok := <-m.ch:
|
case update, ok := <-m.ch:
|
||||||
|
m.tracef("received update from channel, ok: %t", ok)
|
||||||
if !ok {
|
if !ok {
|
||||||
m.tracef("update channel closed, streaming session is likely being replaced")
|
m.tracef("update channel closed, streaming session is likely being replaced")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the node has been removed from headscale, close the stream
|
if err := m.writeMap(update); err != nil {
|
||||||
if slices.Contains(update.Removed, m.node.ID()) {
|
m.errf(err, "cannot write update to client")
|
||||||
m.tracef("node removed, closing stream")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
m.tracef("received stream update: %s %s", update.Type.String(), update.Message)
|
m.tracef("update sent")
|
||||||
mapResponseUpdateReceived.WithLabelValues(update.Type.String()).Inc()
|
m.resetKeepAlive()
|
||||||
|
|
||||||
var data []byte
|
|
||||||
var err error
|
|
||||||
var lastMessage string
|
|
||||||
|
|
||||||
// Ensure the node view is updated, for example, there
|
|
||||||
// might have been a hostinfo update in a sidechannel
|
|
||||||
// which contains data needed to generate a map response.
|
|
||||||
m.node, err = m.h.state.GetNodeViewByID(m.node.ID())
|
|
||||||
if err != nil {
|
|
||||||
m.errf(err, "Could not get machine from db")
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
updateType := "full"
|
|
||||||
switch update.Type {
|
|
||||||
case types.StateFullUpdate:
|
|
||||||
m.tracef("Sending Full MapResponse")
|
|
||||||
data, err = m.mapper.FullMapResponse(m.req, m.node, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming()))
|
|
||||||
case types.StatePeerChanged:
|
|
||||||
changed := make(map[types.NodeID]bool, len(update.ChangeNodes))
|
|
||||||
|
|
||||||
for _, nodeID := range update.ChangeNodes {
|
|
||||||
changed[nodeID] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
lastMessage = update.Message
|
|
||||||
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
|
|
||||||
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage)
|
|
||||||
updateType = "change"
|
|
||||||
|
|
||||||
case types.StatePeerChangedPatch:
|
|
||||||
m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage))
|
|
||||||
data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches)
|
|
||||||
updateType = "patch"
|
|
||||||
case types.StatePeerRemoved:
|
|
||||||
changed := make(map[types.NodeID]bool, len(update.Removed))
|
|
||||||
|
|
||||||
for _, nodeID := range update.Removed {
|
|
||||||
changed[nodeID] = false
|
|
||||||
}
|
|
||||||
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
|
|
||||||
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage)
|
|
||||||
updateType = "remove"
|
|
||||||
case types.StateSelfUpdate:
|
|
||||||
lastMessage = update.Message
|
|
||||||
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
|
|
||||||
// create the map so an empty (self) update is sent
|
|
||||||
data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, lastMessage)
|
|
||||||
updateType = "remove"
|
|
||||||
case types.StateDERPUpdated:
|
|
||||||
m.tracef("Sending DERPUpdate MapResponse")
|
|
||||||
data, err = m.mapper.DERPMapResponse(m.req, m.node, m.h.state.DERPMap())
|
|
||||||
updateType = "derp"
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
m.errf(err, "Could not get the create map update")
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only send update if there is change
|
|
||||||
if data != nil {
|
|
||||||
startWrite := time.Now()
|
|
||||||
_, err = m.w.Write(data)
|
|
||||||
if err != nil {
|
|
||||||
mapResponseSent.WithLabelValues("error", updateType).Inc()
|
|
||||||
m.errf(err, "could not write the map response(%s), for mapSession: %p", update.Type.String(), m)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = rc.Flush()
|
|
||||||
if err != nil {
|
|
||||||
mapResponseSent.WithLabelValues("error", updateType).Inc()
|
|
||||||
m.errf(err, "flushing the map response to client, for mapSession: %p", m)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Trace().Str("node", m.node.Hostname()).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey().String()).Msg("finished writing mapresp to node")
|
|
||||||
|
|
||||||
if debugHighCardinalityMetrics {
|
|
||||||
mapResponseLastSentSeconds.WithLabelValues(updateType, m.node.ID().String()).Set(float64(time.Now().Unix()))
|
|
||||||
}
|
|
||||||
mapResponseSent.WithLabelValues("ok", updateType).Inc()
|
|
||||||
m.tracef("update sent")
|
|
||||||
m.resetKeepAlive()
|
|
||||||
}
|
|
||||||
|
|
||||||
case <-m.keepAliveTicker.C:
|
case <-m.keepAliveTicker.C:
|
||||||
data, err := m.mapper.KeepAliveResponse(m.req, m.node)
|
if err := m.writeMap(&keepAlive); err != nil {
|
||||||
if err != nil {
|
m.errf(err, "cannot write keep alive")
|
||||||
m.errf(err, "Error generating the keep alive msg")
|
|
||||||
mapResponseSent.WithLabelValues("error", "keepalive").Inc()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_, err = m.w.Write(data)
|
|
||||||
if err != nil {
|
|
||||||
m.errf(err, "Cannot write keep alive message")
|
|
||||||
mapResponseSent.WithLabelValues("error", "keepalive").Inc()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = rc.Flush()
|
|
||||||
if err != nil {
|
|
||||||
m.errf(err, "flushing keep alive to client, for mapSession: %p", m)
|
|
||||||
mapResponseSent.WithLabelValues("error", "keepalive").Inc()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if debugHighCardinalityMetrics {
|
if debugHighCardinalityMetrics {
|
||||||
mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID().String()).Set(float64(time.Now().Unix()))
|
mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID.String()).Set(float64(time.Now().Unix()))
|
||||||
}
|
}
|
||||||
mapResponseSent.WithLabelValues("ok", "keepalive").Inc()
|
mapResponseSent.WithLabelValues("ok", "keepalive").Inc()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mapSession) handleEndpointUpdate() {
|
// writeMap writes the map response to the client.
|
||||||
m.tracef("received endpoint update")
|
// It handles compression if requested and any headers that need to be set.
|
||||||
|
// It also handles flushing the response if the ResponseWriter
|
||||||
// Get fresh node state from database for accurate route calculations
|
// implements http.Flusher.
|
||||||
node, err := m.h.state.GetNodeByID(m.node.ID())
|
func (m *mapSession) writeMap(msg *tailcfg.MapResponse) error {
|
||||||
|
jsonBody, err := json.Marshal(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.errf(err, "Failed to get fresh node from database for endpoint update")
|
return fmt.Errorf("marshalling map response: %w", err)
|
||||||
http.Error(m.w, "", http.StatusInternalServerError)
|
|
||||||
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
change := m.node.PeerChangeFromMapRequest(m.req)
|
if m.req.Compress == util.ZstdCompression {
|
||||||
|
jsonBody = zstdframe.AppendEncode(nil, jsonBody, zstdframe.FastestCompression)
|
||||||
online := m.h.nodeNotifier.IsLikelyConnected(m.node.ID())
|
|
||||||
change.Online = &online
|
|
||||||
|
|
||||||
node.ApplyPeerChange(&change)
|
|
||||||
|
|
||||||
sendUpdate, routesChanged := hostInfoChanged(node.Hostinfo, m.req.Hostinfo)
|
|
||||||
|
|
||||||
// The node might not set NetInfo if it has not changed and if
|
|
||||||
// the full HostInfo object is overwritten, the information is lost.
|
|
||||||
// If there is no NetInfo, keep the previous one.
|
|
||||||
// From 1.66 the client only sends it if changed:
|
|
||||||
// https://github.com/tailscale/tailscale/commit/e1011f138737286ecf5123ff887a7a5800d129a2
|
|
||||||
// TODO(kradalby): evaluate if we need better comparing of hostinfo
|
|
||||||
// before we take the changes.
|
|
||||||
if m.req.Hostinfo.NetInfo == nil && node.Hostinfo != nil {
|
|
||||||
m.req.Hostinfo.NetInfo = node.Hostinfo.NetInfo
|
|
||||||
}
|
|
||||||
node.Hostinfo = m.req.Hostinfo
|
|
||||||
|
|
||||||
logTracePeerChange(node.Hostname, sendUpdate, &change)
|
|
||||||
|
|
||||||
// If there is no changes and nothing to save,
|
|
||||||
// return early.
|
|
||||||
if peerChangeEmpty(change) && !sendUpdate {
|
|
||||||
mapResponseEndpointUpdates.WithLabelValues("noop").Inc()
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Auto approve any routes that have been defined in policy as
|
data := make([]byte, reservedResponseHeaderSize)
|
||||||
// auto approved. Check if this actually changed the node.
|
binary.LittleEndian.PutUint32(data, uint32(len(jsonBody)))
|
||||||
routesAutoApproved := m.h.state.AutoApproveRoutes(node)
|
data = append(data, jsonBody...)
|
||||||
|
|
||||||
// Always update routes for connected nodes to handle reconnection scenarios
|
startWrite := time.Now()
|
||||||
// where routes need to be restored to the primary routes system
|
|
||||||
routesToSet := node.SubnetRoutes()
|
|
||||||
|
|
||||||
if m.h.state.SetNodeRoutes(node.ID, routesToSet...) {
|
_, err = m.w.Write(data)
|
||||||
ctx := types.NotifyCtx(m.ctx, "poll-primary-change", node.Hostname)
|
if err != nil {
|
||||||
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
return err
|
||||||
} else if routesChanged {
|
|
||||||
// Only send peer changed notification if routes actually changed
|
|
||||||
ctx := types.NotifyCtx(m.ctx, "cli-approveroutes", node.Hostname)
|
|
||||||
m.h.nodeNotifier.NotifyWithIgnore(ctx, types.UpdatePeerChanged(node.ID), node.ID)
|
|
||||||
|
|
||||||
// TODO(kradalby): I am not sure if we need this?
|
|
||||||
// Send an update to the node itself with to ensure it
|
|
||||||
// has an updated packetfilter allowing the new route
|
|
||||||
// if it is defined in the ACL.
|
|
||||||
ctx = types.NotifyCtx(m.ctx, "poll-nodeupdate-self-hostinfochange", node.Hostname)
|
|
||||||
m.h.nodeNotifier.NotifyByNodeID(
|
|
||||||
ctx,
|
|
||||||
types.UpdateSelf(node.ID),
|
|
||||||
node.ID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// If routes were auto-approved, we need to save the node to persist the changes
|
if m.isStreaming() {
|
||||||
if routesAutoApproved {
|
if f, ok := m.w.(http.Flusher); ok {
|
||||||
if _, _, err := m.h.state.SaveNode(node); err != nil {
|
f.Flush()
|
||||||
m.errf(err, "Failed to save auto-approved routes to node")
|
} else {
|
||||||
http.Error(m.w, "", http.StatusInternalServerError)
|
m.errf(nil, "ResponseWriter does not implement http.Flusher, cannot flush")
|
||||||
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if there has been a change to Hostname and update them
|
log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey.String()).Msg("finished writing mapresp to node")
|
||||||
// in the database. Then send a Changed update
|
|
||||||
// (containing the whole node object) to peers to inform about
|
|
||||||
// the hostname change.
|
|
||||||
node.ApplyHostnameFromHostInfo(m.req.Hostinfo)
|
|
||||||
|
|
||||||
_, policyChanged, err := m.h.state.SaveNode(node)
|
return nil
|
||||||
if err != nil {
|
|
||||||
m.errf(err, "Failed to persist/update node in the database")
|
|
||||||
http.Error(m.w, "", http.StatusInternalServerError)
|
|
||||||
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send policy update notifications if needed
|
|
||||||
if policyChanged {
|
|
||||||
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-policy", node.Hostname)
|
|
||||||
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-peers-patch", node.Hostname)
|
|
||||||
m.h.nodeNotifier.NotifyWithIgnore(
|
|
||||||
ctx,
|
|
||||||
types.UpdatePeerChanged(node.ID),
|
|
||||||
node.ID,
|
|
||||||
)
|
|
||||||
|
|
||||||
m.w.WriteHeader(http.StatusOK)
|
|
||||||
mapResponseEndpointUpdates.WithLabelValues("ok").Inc()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mapSession) handleReadOnlyRequest() {
|
var keepAlive = tailcfg.MapResponse{
|
||||||
m.tracef("Client asked for a lite update, responding without peers")
|
KeepAlive: true,
|
||||||
|
|
||||||
mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node)
|
|
||||||
if err != nil {
|
|
||||||
m.errf(err, "Failed to create MapResponse")
|
|
||||||
http.Error(m.w, "", http.StatusInternalServerError)
|
|
||||||
mapResponseReadOnly.WithLabelValues("error").Inc()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
m.w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
|
||||||
m.w.WriteHeader(http.StatusOK)
|
|
||||||
_, err = m.w.Write(mapResp)
|
|
||||||
if err != nil {
|
|
||||||
m.errf(err, "Failed to write response")
|
|
||||||
mapResponseReadOnly.WithLabelValues("error").Inc()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
m.w.WriteHeader(http.StatusOK)
|
|
||||||
mapResponseReadOnly.WithLabelValues("ok").Inc()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func logTracePeerChange(hostname string, hostinfoChange bool, change *tailcfg.PeerChange) {
|
func logTracePeerChange(hostname string, hostinfoChange bool, peerChange *tailcfg.PeerChange) {
|
||||||
trace := log.Trace().Uint64("node.id", uint64(change.NodeID)).Str("hostname", hostname)
|
trace := log.Trace().Uint64("node.id", uint64(peerChange.NodeID)).Str("hostname", hostname)
|
||||||
|
|
||||||
if change.Key != nil {
|
if peerChange.Key != nil {
|
||||||
trace = trace.Str("node_key", change.Key.ShortString())
|
trace = trace.Str("node_key", peerChange.Key.ShortString())
|
||||||
}
|
}
|
||||||
|
|
||||||
if change.DiscoKey != nil {
|
if peerChange.DiscoKey != nil {
|
||||||
trace = trace.Str("disco_key", change.DiscoKey.ShortString())
|
trace = trace.Str("disco_key", peerChange.DiscoKey.ShortString())
|
||||||
}
|
}
|
||||||
|
|
||||||
if change.Online != nil {
|
if peerChange.Online != nil {
|
||||||
trace = trace.Bool("online", *change.Online)
|
trace = trace.Bool("online", *peerChange.Online)
|
||||||
}
|
}
|
||||||
|
|
||||||
if change.Endpoints != nil {
|
if peerChange.Endpoints != nil {
|
||||||
eps := make([]string, len(change.Endpoints))
|
eps := make([]string, len(peerChange.Endpoints))
|
||||||
for idx, ep := range change.Endpoints {
|
for idx, ep := range peerChange.Endpoints {
|
||||||
eps[idx] = ep.String()
|
eps[idx] = ep.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -530,21 +309,11 @@ func logTracePeerChange(hostname string, hostinfoChange bool, change *tailcfg.Pe
|
|||||||
trace = trace.Bool("hostinfo_changed", hostinfoChange)
|
trace = trace.Bool("hostinfo_changed", hostinfoChange)
|
||||||
}
|
}
|
||||||
|
|
||||||
if change.DERPRegion != 0 {
|
if peerChange.DERPRegion != 0 {
|
||||||
trace = trace.Int("derp_region", change.DERPRegion)
|
trace = trace.Int("derp_region", peerChange.DERPRegion)
|
||||||
}
|
}
|
||||||
|
|
||||||
trace.Time("last_seen", *change.LastSeen).Msg("PeerChange received")
|
trace.Time("last_seen", *peerChange.LastSeen).Msg("PeerChange received")
|
||||||
}
|
|
||||||
|
|
||||||
func peerChangeEmpty(chng tailcfg.PeerChange) bool {
|
|
||||||
return chng.Key == nil &&
|
|
||||||
chng.DiscoKey == nil &&
|
|
||||||
chng.Online == nil &&
|
|
||||||
chng.Endpoints == nil &&
|
|
||||||
chng.DERPRegion == 0 &&
|
|
||||||
chng.LastSeen == nil &&
|
|
||||||
chng.KeyExpiry == nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func logPollFunc(
|
func logPollFunc(
|
||||||
@ -554,7 +323,6 @@ func logPollFunc(
|
|||||||
return func(msg string, a ...any) {
|
return func(msg string, a ...any) {
|
||||||
log.Warn().
|
log.Warn().
|
||||||
Caller().
|
Caller().
|
||||||
Bool("readOnly", mapRequest.ReadOnly).
|
|
||||||
Bool("omitPeers", mapRequest.OmitPeers).
|
Bool("omitPeers", mapRequest.OmitPeers).
|
||||||
Bool("stream", mapRequest.Stream).
|
Bool("stream", mapRequest.Stream).
|
||||||
Uint64("node.id", node.ID.Uint64()).
|
Uint64("node.id", node.ID.Uint64()).
|
||||||
@ -564,7 +332,6 @@ func logPollFunc(
|
|||||||
func(msg string, a ...any) {
|
func(msg string, a ...any) {
|
||||||
log.Info().
|
log.Info().
|
||||||
Caller().
|
Caller().
|
||||||
Bool("readOnly", mapRequest.ReadOnly).
|
|
||||||
Bool("omitPeers", mapRequest.OmitPeers).
|
Bool("omitPeers", mapRequest.OmitPeers).
|
||||||
Bool("stream", mapRequest.Stream).
|
Bool("stream", mapRequest.Stream).
|
||||||
Uint64("node.id", node.ID.Uint64()).
|
Uint64("node.id", node.ID.Uint64()).
|
||||||
@ -574,7 +341,6 @@ func logPollFunc(
|
|||||||
func(msg string, a ...any) {
|
func(msg string, a ...any) {
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Caller().
|
Caller().
|
||||||
Bool("readOnly", mapRequest.ReadOnly).
|
|
||||||
Bool("omitPeers", mapRequest.OmitPeers).
|
Bool("omitPeers", mapRequest.OmitPeers).
|
||||||
Bool("stream", mapRequest.Stream).
|
Bool("stream", mapRequest.Stream).
|
||||||
Uint64("node.id", node.ID.Uint64()).
|
Uint64("node.id", node.ID.Uint64()).
|
||||||
@ -584,7 +350,6 @@ func logPollFunc(
|
|||||||
func(err error, msg string, a ...any) {
|
func(err error, msg string, a ...any) {
|
||||||
log.Error().
|
log.Error().
|
||||||
Caller().
|
Caller().
|
||||||
Bool("readOnly", mapRequest.ReadOnly).
|
|
||||||
Bool("omitPeers", mapRequest.OmitPeers).
|
Bool("omitPeers", mapRequest.OmitPeers).
|
||||||
Bool("stream", mapRequest.Stream).
|
Bool("stream", mapRequest.Stream).
|
||||||
Uint64("node.id", node.ID.Uint64()).
|
Uint64("node.id", node.ID.Uint64()).
|
||||||
@ -593,91 +358,3 @@ func logPollFunc(
|
|||||||
Msgf(msg, a...)
|
Msgf(msg, a...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func logPollFuncView(
|
|
||||||
mapRequest tailcfg.MapRequest,
|
|
||||||
nodeView types.NodeView,
|
|
||||||
) (func(string, ...any), func(string, ...any), func(string, ...any), func(error, string, ...any)) {
|
|
||||||
return func(msg string, a ...any) {
|
|
||||||
log.Warn().
|
|
||||||
Caller().
|
|
||||||
Bool("readOnly", mapRequest.ReadOnly).
|
|
||||||
Bool("omitPeers", mapRequest.OmitPeers).
|
|
||||||
Bool("stream", mapRequest.Stream).
|
|
||||||
Uint64("node.id", nodeView.ID().Uint64()).
|
|
||||||
Str("node", nodeView.Hostname()).
|
|
||||||
Msgf(msg, a...)
|
|
||||||
},
|
|
||||||
func(msg string, a ...any) {
|
|
||||||
log.Info().
|
|
||||||
Caller().
|
|
||||||
Bool("readOnly", mapRequest.ReadOnly).
|
|
||||||
Bool("omitPeers", mapRequest.OmitPeers).
|
|
||||||
Bool("stream", mapRequest.Stream).
|
|
||||||
Uint64("node.id", nodeView.ID().Uint64()).
|
|
||||||
Str("node", nodeView.Hostname()).
|
|
||||||
Msgf(msg, a...)
|
|
||||||
},
|
|
||||||
func(msg string, a ...any) {
|
|
||||||
log.Trace().
|
|
||||||
Caller().
|
|
||||||
Bool("readOnly", mapRequest.ReadOnly).
|
|
||||||
Bool("omitPeers", mapRequest.OmitPeers).
|
|
||||||
Bool("stream", mapRequest.Stream).
|
|
||||||
Uint64("node.id", nodeView.ID().Uint64()).
|
|
||||||
Str("node", nodeView.Hostname()).
|
|
||||||
Msgf(msg, a...)
|
|
||||||
},
|
|
||||||
func(err error, msg string, a ...any) {
|
|
||||||
log.Error().
|
|
||||||
Caller().
|
|
||||||
Bool("readOnly", mapRequest.ReadOnly).
|
|
||||||
Bool("omitPeers", mapRequest.OmitPeers).
|
|
||||||
Bool("stream", mapRequest.Stream).
|
|
||||||
Uint64("node.id", nodeView.ID().Uint64()).
|
|
||||||
Str("node", nodeView.Hostname()).
|
|
||||||
Err(err).
|
|
||||||
Msgf(msg, a...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// hostInfoChanged reports if hostInfo has changed in two ways,
|
|
||||||
// - first bool reports if an update needs to be sent to nodes
|
|
||||||
// - second reports if there has been changes to routes
|
|
||||||
// the caller can then use this info to save and update nodes
|
|
||||||
// and routes as needed.
|
|
||||||
func hostInfoChanged(old, new *tailcfg.Hostinfo) (bool, bool) {
|
|
||||||
if old.Equal(new) {
|
|
||||||
return false, false
|
|
||||||
}
|
|
||||||
|
|
||||||
if old == nil && new != nil {
|
|
||||||
return true, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Routes
|
|
||||||
oldRoutes := make([]netip.Prefix, 0)
|
|
||||||
if old != nil {
|
|
||||||
oldRoutes = old.RoutableIPs
|
|
||||||
}
|
|
||||||
newRoutes := new.RoutableIPs
|
|
||||||
|
|
||||||
tsaddr.SortPrefixes(oldRoutes)
|
|
||||||
tsaddr.SortPrefixes(newRoutes)
|
|
||||||
|
|
||||||
if !xslices.Equal(oldRoutes, newRoutes) {
|
|
||||||
return true, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Services is mostly useful for discovery and not critical,
|
|
||||||
// except for peerapi, which is how nodes talk to each other.
|
|
||||||
// If peerapi was not part of the initial mapresponse, we
|
|
||||||
// need to make sure its sent out later as it is needed for
|
|
||||||
// Taildrop.
|
|
||||||
// TODO(kradalby): Length comparison is a bit naive, replace.
|
|
||||||
if len(old.Services) != len(new.Services) {
|
|
||||||
return true, false
|
|
||||||
}
|
|
||||||
|
|
||||||
return false, false
|
|
||||||
}
|
|
||||||
|
@ -17,10 +17,13 @@ import (
|
|||||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||||
"github.com/juanfont/headscale/hscontrol/routes"
|
"github.com/juanfont/headscale/hscontrol/routes"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types/change"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
"github.com/sasha-s/go-deadlock"
|
"github.com/sasha-s/go-deadlock"
|
||||||
|
xslices "golang.org/x/exp/slices"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
"tailscale.com/net/tsaddr"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
"tailscale.com/types/ptr"
|
"tailscale.com/types/ptr"
|
||||||
@ -46,12 +49,6 @@ type State struct {
|
|||||||
// cfg holds the current Headscale configuration
|
// cfg holds the current Headscale configuration
|
||||||
cfg *types.Config
|
cfg *types.Config
|
||||||
|
|
||||||
// in-memory data, protected by mu
|
|
||||||
// nodes contains the current set of registered nodes
|
|
||||||
nodes types.Nodes
|
|
||||||
// users contains the current set of users/namespaces
|
|
||||||
users types.Users
|
|
||||||
|
|
||||||
// subsystem keeping state
|
// subsystem keeping state
|
||||||
// db provides persistent storage and database operations
|
// db provides persistent storage and database operations
|
||||||
db *hsdb.HSDatabase
|
db *hsdb.HSDatabase
|
||||||
@ -113,9 +110,6 @@ func NewState(cfg *types.Config) (*State, error) {
|
|||||||
return &State{
|
return &State{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
|
|
||||||
nodes: nodes,
|
|
||||||
users: users,
|
|
||||||
|
|
||||||
db: db,
|
db: db,
|
||||||
ipAlloc: ipAlloc,
|
ipAlloc: ipAlloc,
|
||||||
// TODO(kradalby): Update DERPMap
|
// TODO(kradalby): Update DERPMap
|
||||||
@ -215,6 +209,7 @@ func (s *State) CreateUser(user types.User) (*types.User, bool, error) {
|
|||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
|
||||||
if err := s.db.DB.Save(&user).Error; err != nil {
|
if err := s.db.DB.Save(&user).Error; err != nil {
|
||||||
return nil, false, fmt.Errorf("creating user: %w", err)
|
return nil, false, fmt.Errorf("creating user: %w", err)
|
||||||
}
|
}
|
||||||
@ -226,6 +221,18 @@ func (s *State) CreateUser(user types.User) (*types.User, bool, error) {
|
|||||||
return &user, false, fmt.Errorf("failed to update policy manager after user creation: %w", err)
|
return &user, false, fmt.Errorf("failed to update policy manager after user creation: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Even if the policy manager doesn't detect a filter change, SSH policies
|
||||||
|
// might now be resolvable when they weren't before. If there are existing
|
||||||
|
// nodes, we should send a policy change to ensure they get updated SSH policies.
|
||||||
|
if !policyChanged {
|
||||||
|
nodes, err := s.ListNodes()
|
||||||
|
if err == nil && len(nodes) > 0 {
|
||||||
|
policyChanged = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info().Str("user", user.Name).Bool("policyChanged", policyChanged).Msg("User created, policy manager updated")
|
||||||
|
|
||||||
// TODO(kradalby): implement the user in-memory cache
|
// TODO(kradalby): implement the user in-memory cache
|
||||||
|
|
||||||
return &user, policyChanged, nil
|
return &user, policyChanged, nil
|
||||||
@ -329,7 +336,7 @@ func (s *State) CreateNode(node *types.Node) (*types.Node, bool, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// updateNodeTx performs a database transaction to update a node and refresh the policy manager.
|
// updateNodeTx performs a database transaction to update a node and refresh the policy manager.
|
||||||
func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) error) (*types.Node, bool, error) {
|
func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) error) (*types.Node, change.ChangeSet, error) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
@ -350,72 +357,100 @@ func (s *State) updateNodeTx(nodeID types.NodeID, updateFn func(tx *gorm.DB) err
|
|||||||
return node, nil
|
return node, nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, false, err
|
return nil, change.EmptySet, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if policy manager needs updating
|
// Check if policy manager needs updating
|
||||||
policyChanged, err := s.updatePolicyManagerNodes()
|
policyChanged, err := s.updatePolicyManagerNodes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return node, false, fmt.Errorf("failed to update policy manager after node update: %w", err)
|
return node, change.EmptySet, fmt.Errorf("failed to update policy manager after node update: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(kradalby): implement the node in-memory cache
|
// TODO(kradalby): implement the node in-memory cache
|
||||||
|
|
||||||
return node, policyChanged, nil
|
var c change.ChangeSet
|
||||||
|
if policyChanged {
|
||||||
|
c = change.PolicyChange()
|
||||||
|
} else {
|
||||||
|
// Basic node change without specific details since this is a generic update
|
||||||
|
c = change.NodeAdded(node.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return node, c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveNode persists an existing node to the database and updates the policy manager.
|
// SaveNode persists an existing node to the database and updates the policy manager.
|
||||||
func (s *State) SaveNode(node *types.Node) (*types.Node, bool, error) {
|
func (s *State) SaveNode(node *types.Node) (*types.Node, change.ChangeSet, error) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
if err := s.db.DB.Save(node).Error; err != nil {
|
if err := s.db.DB.Save(node).Error; err != nil {
|
||||||
return nil, false, fmt.Errorf("saving node: %w", err)
|
return nil, change.EmptySet, fmt.Errorf("saving node: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if policy manager needs updating
|
// Check if policy manager needs updating
|
||||||
policyChanged, err := s.updatePolicyManagerNodes()
|
policyChanged, err := s.updatePolicyManagerNodes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return node, false, fmt.Errorf("failed to update policy manager after node save: %w", err)
|
return node, change.EmptySet, fmt.Errorf("failed to update policy manager after node save: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(kradalby): implement the node in-memory cache
|
// TODO(kradalby): implement the node in-memory cache
|
||||||
|
|
||||||
return node, policyChanged, nil
|
if policyChanged {
|
||||||
|
return node, change.PolicyChange(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return node, change.EmptySet, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteNode permanently removes a node and cleans up associated resources.
|
// DeleteNode permanently removes a node and cleans up associated resources.
|
||||||
// Returns whether policies changed and any error. This operation is irreversible.
|
// Returns whether policies changed and any error. This operation is irreversible.
|
||||||
func (s *State) DeleteNode(node *types.Node) (bool, error) {
|
func (s *State) DeleteNode(node *types.Node) (change.ChangeSet, error) {
|
||||||
err := s.db.DeleteNode(node)
|
err := s.db.DeleteNode(node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return change.EmptySet, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c := change.NodeRemoved(node.ID)
|
||||||
|
|
||||||
// Check if policy manager needs updating after node deletion
|
// Check if policy manager needs updating after node deletion
|
||||||
policyChanged, err := s.updatePolicyManagerNodes()
|
policyChanged, err := s.updatePolicyManagerNodes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("failed to update policy manager after node deletion: %w", err)
|
return change.EmptySet, fmt.Errorf("failed to update policy manager after node deletion: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return policyChanged, nil
|
if policyChanged {
|
||||||
|
c = change.PolicyChange()
|
||||||
|
}
|
||||||
|
|
||||||
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *State) Connect(id types.NodeID) {
|
func (s *State) Connect(node *types.Node) change.ChangeSet {
|
||||||
|
c := change.NodeOnline(node.ID)
|
||||||
|
routeChange := s.primaryRoutes.SetRoutes(node.ID, node.SubnetRoutes()...)
|
||||||
|
|
||||||
|
if routeChange {
|
||||||
|
c = change.NodeAdded(node.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *State) Disconnect(id types.NodeID) (bool, error) {
|
func (s *State) Disconnect(node *types.Node) (change.ChangeSet, error) {
|
||||||
// TODO(kradalby): This node should update the in memory state
|
c := change.NodeOffline(node.ID)
|
||||||
_, polChanged, err := s.SetLastSeen(id, time.Now())
|
|
||||||
|
_, _, err := s.SetLastSeen(node.ID, time.Now())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("disconnecting node: %w", err)
|
return c, fmt.Errorf("disconnecting node: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
changed := s.primaryRoutes.SetRoutes(id)
|
if routeChange := s.primaryRoutes.SetRoutes(node.ID); routeChange {
|
||||||
|
c = change.PolicyChange()
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(kradalby): the returned change should be more nuanced allowing us to
|
// TODO(kradalby): This node should update the in memory state
|
||||||
// send more directed updates.
|
return c, nil
|
||||||
return changed || polChanged, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNodeByID retrieves a node by ID.
|
// GetNodeByID retrieves a node by ID.
|
||||||
@ -475,45 +510,93 @@ func (s *State) ListEphemeralNodes() (types.Nodes, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SetNodeExpiry updates the expiration time for a node.
|
// SetNodeExpiry updates the expiration time for a node.
|
||||||
func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (*types.Node, bool, error) {
|
func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (*types.Node, change.ChangeSet, error) {
|
||||||
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||||
return hsdb.NodeSetExpiry(tx, nodeID, expiry)
|
return hsdb.NodeSetExpiry(tx, nodeID, expiry)
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, change.EmptySet, fmt.Errorf("setting node expiry: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !c.IsFull() {
|
||||||
|
c = change.KeyExpiry(nodeID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetNodeTags assigns tags to a node for use in access control policies.
|
// SetNodeTags assigns tags to a node for use in access control policies.
|
||||||
func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (*types.Node, bool, error) {
|
func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (*types.Node, change.ChangeSet, error) {
|
||||||
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||||
return hsdb.SetTags(tx, nodeID, tags)
|
return hsdb.SetTags(tx, nodeID, tags)
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, change.EmptySet, fmt.Errorf("setting node tags: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !c.IsFull() {
|
||||||
|
c = change.NodeAdded(nodeID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetApprovedRoutes sets the network routes that a node is approved to advertise.
|
// SetApprovedRoutes sets the network routes that a node is approved to advertise.
|
||||||
func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (*types.Node, bool, error) {
|
func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (*types.Node, change.ChangeSet, error) {
|
||||||
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||||
return hsdb.SetApprovedRoutes(tx, nodeID, routes)
|
return hsdb.SetApprovedRoutes(tx, nodeID, routes)
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, change.EmptySet, fmt.Errorf("setting approved routes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update primary routes after changing approved routes
|
||||||
|
routeChange := s.primaryRoutes.SetRoutes(nodeID, n.SubnetRoutes()...)
|
||||||
|
|
||||||
|
if routeChange || !c.IsFull() {
|
||||||
|
c = change.PolicyChange()
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RenameNode changes the display name of a node.
|
// RenameNode changes the display name of a node.
|
||||||
func (s *State) RenameNode(nodeID types.NodeID, newName string) (*types.Node, bool, error) {
|
func (s *State) RenameNode(nodeID types.NodeID, newName string) (*types.Node, change.ChangeSet, error) {
|
||||||
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||||
return hsdb.RenameNode(tx, nodeID, newName)
|
return hsdb.RenameNode(tx, nodeID, newName)
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, change.EmptySet, fmt.Errorf("renaming node: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !c.IsFull() {
|
||||||
|
c = change.NodeAdded(nodeID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetLastSeen updates when a node was last seen, used for connectivity monitoring.
|
// SetLastSeen updates when a node was last seen, used for connectivity monitoring.
|
||||||
func (s *State) SetLastSeen(nodeID types.NodeID, lastSeen time.Time) (*types.Node, bool, error) {
|
func (s *State) SetLastSeen(nodeID types.NodeID, lastSeen time.Time) (*types.Node, change.ChangeSet, error) {
|
||||||
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||||
return hsdb.SetLastSeen(tx, nodeID, lastSeen)
|
return hsdb.SetLastSeen(tx, nodeID, lastSeen)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// AssignNodeToUser transfers a node to a different user.
|
// AssignNodeToUser transfers a node to a different user.
|
||||||
func (s *State) AssignNodeToUser(nodeID types.NodeID, userID types.UserID) (*types.Node, bool, error) {
|
func (s *State) AssignNodeToUser(nodeID types.NodeID, userID types.UserID) (*types.Node, change.ChangeSet, error) {
|
||||||
return s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
n, c, err := s.updateNodeTx(nodeID, func(tx *gorm.DB) error {
|
||||||
return hsdb.AssignNodeToUser(tx, nodeID, userID)
|
return hsdb.AssignNodeToUser(tx, nodeID, userID)
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, change.EmptySet, fmt.Errorf("assigning node to user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !c.IsFull() {
|
||||||
|
c = change.NodeAdded(nodeID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// BackfillNodeIPs assigns IP addresses to nodes that don't have them.
|
// BackfillNodeIPs assigns IP addresses to nodes that don't have them.
|
||||||
@ -523,7 +606,7 @@ func (s *State) BackfillNodeIPs() ([]string, error) {
|
|||||||
|
|
||||||
// ExpireExpiredNodes finds and processes expired nodes since the last check.
|
// ExpireExpiredNodes finds and processes expired nodes since the last check.
|
||||||
// Returns next check time, state update with expired nodes, and whether any were found.
|
// Returns next check time, state update with expired nodes, and whether any were found.
|
||||||
func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, types.StateUpdate, bool) {
|
func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.ChangeSet, bool) {
|
||||||
return hsdb.ExpireExpiredNodes(s.db.DB, lastCheck)
|
return hsdb.ExpireExpiredNodes(s.db.DB, lastCheck)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -568,8 +651,14 @@ func (s *State) SetPolicyInDB(data string) (*types.Policy, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SetNodeRoutes sets the primary routes for a node.
|
// SetNodeRoutes sets the primary routes for a node.
|
||||||
func (s *State) SetNodeRoutes(nodeID types.NodeID, routes ...netip.Prefix) bool {
|
func (s *State) SetNodeRoutes(nodeID types.NodeID, routes ...netip.Prefix) change.ChangeSet {
|
||||||
return s.primaryRoutes.SetRoutes(nodeID, routes...)
|
if s.primaryRoutes.SetRoutes(nodeID, routes...) {
|
||||||
|
// Route changes affect packet filters for all nodes, so trigger a policy change
|
||||||
|
// to ensure filters are regenerated across the entire network
|
||||||
|
return change.PolicyChange()
|
||||||
|
}
|
||||||
|
|
||||||
|
return change.EmptySet
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNodePrimaryRoutes returns the primary routes for a node.
|
// GetNodePrimaryRoutes returns the primary routes for a node.
|
||||||
@ -653,10 +742,10 @@ func (s *State) HandleNodeFromAuthPath(
|
|||||||
userID types.UserID,
|
userID types.UserID,
|
||||||
expiry *time.Time,
|
expiry *time.Time,
|
||||||
registrationMethod string,
|
registrationMethod string,
|
||||||
) (*types.Node, bool, error) {
|
) (*types.Node, change.ChangeSet, error) {
|
||||||
ipv4, ipv6, err := s.ipAlloc.Next()
|
ipv4, ipv6, err := s.ipAlloc.Next()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, false, err
|
return nil, change.EmptySet, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.db.HandleNodeFromAuthPath(
|
return s.db.HandleNodeFromAuthPath(
|
||||||
@ -672,12 +761,15 @@ func (s *State) HandleNodeFromAuthPath(
|
|||||||
func (s *State) HandleNodeFromPreAuthKey(
|
func (s *State) HandleNodeFromPreAuthKey(
|
||||||
regReq tailcfg.RegisterRequest,
|
regReq tailcfg.RegisterRequest,
|
||||||
machineKey key.MachinePublic,
|
machineKey key.MachinePublic,
|
||||||
) (*types.Node, bool, error) {
|
) (*types.Node, change.ChangeSet, bool, error) {
|
||||||
pak, err := s.GetPreAuthKey(regReq.Auth.AuthKey)
|
pak, err := s.GetPreAuthKey(regReq.Auth.AuthKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, change.EmptySet, false, err
|
||||||
|
}
|
||||||
|
|
||||||
err = pak.Validate()
|
err = pak.Validate()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, false, err
|
return nil, change.EmptySet, false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeToRegister := types.Node{
|
nodeToRegister := types.Node{
|
||||||
@ -698,22 +790,13 @@ func (s *State) HandleNodeFromPreAuthKey(
|
|||||||
AuthKeyID: &pak.ID,
|
AuthKeyID: &pak.ID,
|
||||||
}
|
}
|
||||||
|
|
||||||
// For auth key registration, ensure we don't keep an expired node
|
if !regReq.Expiry.IsZero() {
|
||||||
// This is especially important for re-registration after logout
|
|
||||||
if !regReq.Expiry.IsZero() && regReq.Expiry.After(time.Now()) {
|
|
||||||
nodeToRegister.Expiry = ®Req.Expiry
|
nodeToRegister.Expiry = ®Req.Expiry
|
||||||
} else if !regReq.Expiry.IsZero() {
|
|
||||||
// If client is sending an expired time (e.g., after logout),
|
|
||||||
// don't set expiry so the node won't be considered expired
|
|
||||||
log.Debug().
|
|
||||||
Time("requested_expiry", regReq.Expiry).
|
|
||||||
Str("node", regReq.Hostinfo.Hostname).
|
|
||||||
Msg("Ignoring expired expiry time from auth key registration")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ipv4, ipv6, err := s.ipAlloc.Next()
|
ipv4, ipv6, err := s.ipAlloc.Next()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, false, fmt.Errorf("allocating IPs: %w", err)
|
return nil, change.EmptySet, false, fmt.Errorf("allocating IPs: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
node, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
node, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||||
@ -735,18 +818,38 @@ func (s *State) HandleNodeFromPreAuthKey(
|
|||||||
return node, nil
|
return node, nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, false, fmt.Errorf("writing node to database: %w", err)
|
return nil, change.EmptySet, false, fmt.Errorf("writing node to database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this is a logout request for an ephemeral node
|
||||||
|
if !regReq.Expiry.IsZero() && regReq.Expiry.Before(time.Now()) && pak.Ephemeral {
|
||||||
|
// This is a logout request for an ephemeral node, delete it immediately
|
||||||
|
c, err := s.DeleteNode(node)
|
||||||
|
if err != nil {
|
||||||
|
return nil, change.EmptySet, false, fmt.Errorf("deleting ephemeral node during logout: %w", err)
|
||||||
|
}
|
||||||
|
return nil, c, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if policy manager needs updating
|
// Check if policy manager needs updating
|
||||||
// This is necessary because we just created a new node.
|
// This is necessary because we just created a new node.
|
||||||
// We need to ensure that the policy manager is aware of this new node.
|
// We need to ensure that the policy manager is aware of this new node.
|
||||||
policyChanged, err := s.updatePolicyManagerNodes()
|
// Also update users to ensure all users are known when evaluating policies.
|
||||||
|
usersChanged, err := s.updatePolicyManagerUsers()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, false, fmt.Errorf("failed to update policy manager after node registration: %w", err)
|
return nil, change.EmptySet, false, fmt.Errorf("failed to update policy manager users after node registration: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return node, policyChanged, nil
|
nodesChanged, err := s.updatePolicyManagerNodes()
|
||||||
|
if err != nil {
|
||||||
|
return nil, change.EmptySet, false, fmt.Errorf("failed to update policy manager nodes after node registration: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
policyChanged := usersChanged || nodesChanged
|
||||||
|
|
||||||
|
c := change.NodeAdded(node.ID)
|
||||||
|
|
||||||
|
return node, c, policyChanged, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AllocateNextIPs allocates the next available IPv4 and IPv6 addresses.
|
// AllocateNextIPs allocates the next available IPv4 and IPv6 addresses.
|
||||||
@ -766,11 +869,15 @@ func (s *State) updatePolicyManagerUsers() (bool, error) {
|
|||||||
return false, fmt.Errorf("listing users for policy update: %w", err)
|
return false, fmt.Errorf("listing users for policy update: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Debug().Int("userCount", len(users)).Msg("Updating policy manager with users")
|
||||||
|
|
||||||
changed, err := s.polMan.SetUsers(users)
|
changed, err := s.polMan.SetUsers(users)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("updating policy manager users: %w", err)
|
return false, fmt.Errorf("updating policy manager users: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Debug().Bool("changed", changed).Msg("Policy manager users updated")
|
||||||
|
|
||||||
return changed, nil
|
return changed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -835,3 +942,125 @@ func (s *State) autoApproveNodes() error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(kradalby): This should just take the node ID?
|
||||||
|
func (s *State) UpdateNodeFromMapRequest(node *types.Node, req tailcfg.MapRequest) (change.ChangeSet, error) {
|
||||||
|
// TODO(kradalby): This is essentially a patch update that could be sent directly to nodes,
|
||||||
|
// which means we could shortcut the whole change thing if there are no other important updates.
|
||||||
|
peerChange := node.PeerChangeFromMapRequest(req)
|
||||||
|
|
||||||
|
node.ApplyPeerChange(&peerChange)
|
||||||
|
|
||||||
|
sendUpdate, routesChanged := hostInfoChanged(node.Hostinfo, req.Hostinfo)
|
||||||
|
|
||||||
|
// The node might not set NetInfo if it has not changed and if
|
||||||
|
// the full HostInfo object is overwritten, the information is lost.
|
||||||
|
// If there is no NetInfo, keep the previous one.
|
||||||
|
// From 1.66 the client only sends it if changed:
|
||||||
|
// https://github.com/tailscale/tailscale/commit/e1011f138737286ecf5123ff887a7a5800d129a2
|
||||||
|
// TODO(kradalby): evaluate if we need better comparing of hostinfo
|
||||||
|
// before we take the changes.
|
||||||
|
if req.Hostinfo.NetInfo == nil && node.Hostinfo != nil {
|
||||||
|
req.Hostinfo.NetInfo = node.Hostinfo.NetInfo
|
||||||
|
}
|
||||||
|
node.Hostinfo = req.Hostinfo
|
||||||
|
|
||||||
|
// If there is no changes and nothing to save,
|
||||||
|
// return early.
|
||||||
|
if peerChangeEmpty(peerChange) && !sendUpdate {
|
||||||
|
// mapResponseEndpointUpdates.WithLabelValues("noop").Inc()
|
||||||
|
return change.EmptySet, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c := change.EmptySet
|
||||||
|
|
||||||
|
// Check if the Hostinfo of the node has changed.
|
||||||
|
// If it has changed, check if there has been a change to
|
||||||
|
// the routable IPs of the host and update them in
|
||||||
|
// the database. Then send a Changed update
|
||||||
|
// (containing the whole node object) to peers to inform about
|
||||||
|
// the route change.
|
||||||
|
// If the hostinfo has changed, but not the routes, just update
|
||||||
|
// hostinfo and let the function continue.
|
||||||
|
if routesChanged {
|
||||||
|
// Auto approve any routes that have been defined in policy as
|
||||||
|
// auto approved. Check if this actually changed the node.
|
||||||
|
_ = s.AutoApproveRoutes(node)
|
||||||
|
|
||||||
|
// Update the routes of the given node in the route manager to
|
||||||
|
// see if an update needs to be sent.
|
||||||
|
c = s.SetNodeRoutes(node.ID, node.SubnetRoutes()...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if there has been a change to Hostname and update them
|
||||||
|
// in the database. Then send a Changed update
|
||||||
|
// (containing the whole node object) to peers to inform about
|
||||||
|
// the hostname change.
|
||||||
|
node.ApplyHostnameFromHostInfo(req.Hostinfo)
|
||||||
|
|
||||||
|
_, policyChange, err := s.SaveNode(node)
|
||||||
|
if err != nil {
|
||||||
|
return change.EmptySet, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if policyChange.IsFull() {
|
||||||
|
c = policyChange
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.Empty() {
|
||||||
|
c = change.NodeAdded(node.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// hostInfoChanged reports if hostInfo has changed in two ways,
|
||||||
|
// - first bool reports if an update needs to be sent to nodes
|
||||||
|
// - second reports if there has been changes to routes
|
||||||
|
// the caller can then use this info to save and update nodes
|
||||||
|
// and routes as needed.
|
||||||
|
func hostInfoChanged(old, new *tailcfg.Hostinfo) (bool, bool) {
|
||||||
|
if old.Equal(new) {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if old == nil && new != nil {
|
||||||
|
return true, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Routes
|
||||||
|
oldRoutes := make([]netip.Prefix, 0)
|
||||||
|
if old != nil {
|
||||||
|
oldRoutes = old.RoutableIPs
|
||||||
|
}
|
||||||
|
newRoutes := new.RoutableIPs
|
||||||
|
|
||||||
|
tsaddr.SortPrefixes(oldRoutes)
|
||||||
|
tsaddr.SortPrefixes(newRoutes)
|
||||||
|
|
||||||
|
if !xslices.Equal(oldRoutes, newRoutes) {
|
||||||
|
return true, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Services is mostly useful for discovery and not critical,
|
||||||
|
// except for peerapi, which is how nodes talk to each other.
|
||||||
|
// If peerapi was not part of the initial mapresponse, we
|
||||||
|
// need to make sure its sent out later as it is needed for
|
||||||
|
// Taildrop.
|
||||||
|
// TODO(kradalby): Length comparison is a bit naive, replace.
|
||||||
|
if len(old.Services) != len(new.Services) {
|
||||||
|
return true, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func peerChangeEmpty(peerChange tailcfg.PeerChange) bool {
|
||||||
|
return peerChange.Key == nil &&
|
||||||
|
peerChange.DiscoKey == nil &&
|
||||||
|
peerChange.Online == nil &&
|
||||||
|
peerChange.Endpoints == nil &&
|
||||||
|
peerChange.DERPRegion == 0 &&
|
||||||
|
peerChange.LastSeen == nil &&
|
||||||
|
peerChange.KeyExpiry == nil
|
||||||
|
}
|
||||||
|
183
hscontrol/types/change/change.go
Normal file
183
hscontrol/types/change/change.go
Normal file
@ -0,0 +1,183 @@
|
|||||||
|
//go:generate go tool stringer -type=Change
|
||||||
|
package change
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
NodeID = types.NodeID
|
||||||
|
UserID = types.UserID
|
||||||
|
)
|
||||||
|
|
||||||
|
type Change int
|
||||||
|
|
||||||
|
const (
|
||||||
|
ChangeUnknown Change = 0
|
||||||
|
|
||||||
|
// Deprecated: Use specific change instead
|
||||||
|
// Full is a legacy change to ensure places where we
|
||||||
|
// have not yet determined the specific update, can send.
|
||||||
|
Full Change = 9
|
||||||
|
|
||||||
|
// Server changes.
|
||||||
|
Policy Change = 11
|
||||||
|
DERP Change = 12
|
||||||
|
ExtraRecords Change = 13
|
||||||
|
|
||||||
|
// Node changes.
|
||||||
|
NodeCameOnline Change = 21
|
||||||
|
NodeWentOffline Change = 22
|
||||||
|
NodeRemove Change = 23
|
||||||
|
NodeKeyExpiry Change = 24
|
||||||
|
NodeNewOrUpdate Change = 25
|
||||||
|
|
||||||
|
// User changes.
|
||||||
|
UserNewOrUpdate Change = 51
|
||||||
|
UserRemove Change = 52
|
||||||
|
)
|
||||||
|
|
||||||
|
// AlsoSelf reports whether this change should also be sent to the node itself.
|
||||||
|
func (c Change) AlsoSelf() bool {
|
||||||
|
switch c {
|
||||||
|
case NodeRemove, NodeKeyExpiry, NodeNewOrUpdate:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChangeSet struct {
|
||||||
|
Change Change
|
||||||
|
|
||||||
|
// SelfUpdateOnly indicates that this change should only be sent
|
||||||
|
// to the node itself, and not to other nodes.
|
||||||
|
// This is used for changes that are not relevant to other nodes.
|
||||||
|
// NodeID must be set if this is true.
|
||||||
|
SelfUpdateOnly bool
|
||||||
|
|
||||||
|
// NodeID if set, is the ID of the node that is being changed.
|
||||||
|
// It must be set if this is a node change.
|
||||||
|
NodeID types.NodeID
|
||||||
|
|
||||||
|
// UserID if set, is the ID of the user that is being changed.
|
||||||
|
// It must be set if this is a user change.
|
||||||
|
UserID types.UserID
|
||||||
|
|
||||||
|
// IsSubnetRouter indicates whether the node is a subnet router.
|
||||||
|
IsSubnetRouter bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChangeSet) Validate() error {
|
||||||
|
if c.Change >= NodeCameOnline || c.Change <= NodeNewOrUpdate {
|
||||||
|
if c.NodeID == 0 {
|
||||||
|
return errors.New("ChangeSet.NodeID must be set for node updates")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.Change >= UserNewOrUpdate || c.Change <= UserRemove {
|
||||||
|
if c.UserID == 0 {
|
||||||
|
return errors.New("ChangeSet.UserID must be set for user updates")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Empty reports whether the ChangeSet is empty, meaning it does not
|
||||||
|
// represent any change.
|
||||||
|
func (c ChangeSet) Empty() bool {
|
||||||
|
return c.Change == ChangeUnknown && c.NodeID == 0 && c.UserID == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsFull reports whether the ChangeSet represents a full update.
|
||||||
|
func (c ChangeSet) IsFull() bool {
|
||||||
|
return c.Change == Full || c.Change == Policy
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c ChangeSet) AlsoSelf() bool {
|
||||||
|
// If NodeID is 0, it means this ChangeSet is not related to a specific node,
|
||||||
|
// so we consider it as a change that should be sent to all nodes.
|
||||||
|
if c.NodeID == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return c.Change.AlsoSelf() || c.SelfUpdateOnly
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
EmptySet = ChangeSet{Change: ChangeUnknown}
|
||||||
|
FullSet = ChangeSet{Change: Full}
|
||||||
|
DERPSet = ChangeSet{Change: DERP}
|
||||||
|
PolicySet = ChangeSet{Change: Policy}
|
||||||
|
ExtraRecordsSet = ChangeSet{Change: ExtraRecords}
|
||||||
|
)
|
||||||
|
|
||||||
|
func FullSelf(id types.NodeID) ChangeSet {
|
||||||
|
return ChangeSet{
|
||||||
|
Change: Full,
|
||||||
|
SelfUpdateOnly: true,
|
||||||
|
NodeID: id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NodeAdded(id types.NodeID) ChangeSet {
|
||||||
|
return ChangeSet{
|
||||||
|
Change: NodeNewOrUpdate,
|
||||||
|
NodeID: id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NodeRemoved(id types.NodeID) ChangeSet {
|
||||||
|
return ChangeSet{
|
||||||
|
Change: NodeRemove,
|
||||||
|
NodeID: id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NodeOnline(id types.NodeID) ChangeSet {
|
||||||
|
return ChangeSet{
|
||||||
|
Change: NodeCameOnline,
|
||||||
|
NodeID: id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NodeOffline(id types.NodeID) ChangeSet {
|
||||||
|
return ChangeSet{
|
||||||
|
Change: NodeWentOffline,
|
||||||
|
NodeID: id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func KeyExpiry(id types.NodeID) ChangeSet {
|
||||||
|
return ChangeSet{
|
||||||
|
Change: NodeKeyExpiry,
|
||||||
|
NodeID: id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func UserAdded(id types.UserID) ChangeSet {
|
||||||
|
return ChangeSet{
|
||||||
|
Change: UserNewOrUpdate,
|
||||||
|
UserID: id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func UserRemoved(id types.UserID) ChangeSet {
|
||||||
|
return ChangeSet{
|
||||||
|
Change: UserRemove,
|
||||||
|
UserID: id,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func PolicyChange() ChangeSet {
|
||||||
|
return ChangeSet{
|
||||||
|
Change: Policy,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func DERPChange() ChangeSet {
|
||||||
|
return ChangeSet{
|
||||||
|
Change: DERP,
|
||||||
|
}
|
||||||
|
}
|
57
hscontrol/types/change/change_string.go
Normal file
57
hscontrol/types/change/change_string.go
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
// Code generated by "stringer -type=Change"; DO NOT EDIT.
|
||||||
|
|
||||||
|
package change
|
||||||
|
|
||||||
|
import "strconv"
|
||||||
|
|
||||||
|
func _() {
|
||||||
|
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||||
|
// Re-run the stringer command to generate them again.
|
||||||
|
var x [1]struct{}
|
||||||
|
_ = x[ChangeUnknown-0]
|
||||||
|
_ = x[Full-9]
|
||||||
|
_ = x[Policy-11]
|
||||||
|
_ = x[DERP-12]
|
||||||
|
_ = x[ExtraRecords-13]
|
||||||
|
_ = x[NodeCameOnline-21]
|
||||||
|
_ = x[NodeWentOffline-22]
|
||||||
|
_ = x[NodeRemove-23]
|
||||||
|
_ = x[NodeKeyExpiry-24]
|
||||||
|
_ = x[NodeNewOrUpdate-25]
|
||||||
|
_ = x[UserNewOrUpdate-51]
|
||||||
|
_ = x[UserRemove-52]
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
_Change_name_0 = "ChangeUnknown"
|
||||||
|
_Change_name_1 = "Full"
|
||||||
|
_Change_name_2 = "PolicyDERPExtraRecords"
|
||||||
|
_Change_name_3 = "NodeCameOnlineNodeWentOfflineNodeRemoveNodeKeyExpiryNodeNewOrUpdate"
|
||||||
|
_Change_name_4 = "UserNewOrUpdateUserRemove"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
_Change_index_2 = [...]uint8{0, 6, 10, 22}
|
||||||
|
_Change_index_3 = [...]uint8{0, 14, 29, 39, 52, 67}
|
||||||
|
_Change_index_4 = [...]uint8{0, 15, 25}
|
||||||
|
)
|
||||||
|
|
||||||
|
func (i Change) String() string {
|
||||||
|
switch {
|
||||||
|
case i == 0:
|
||||||
|
return _Change_name_0
|
||||||
|
case i == 9:
|
||||||
|
return _Change_name_1
|
||||||
|
case 11 <= i && i <= 13:
|
||||||
|
i -= 11
|
||||||
|
return _Change_name_2[_Change_index_2[i]:_Change_index_2[i+1]]
|
||||||
|
case 21 <= i && i <= 25:
|
||||||
|
i -= 21
|
||||||
|
return _Change_name_3[_Change_index_3[i]:_Change_index_3[i+1]]
|
||||||
|
case 51 <= i && i <= 52:
|
||||||
|
i -= 51
|
||||||
|
return _Change_name_4[_Change_index_4[i]:_Change_index_4[i+1]]
|
||||||
|
default:
|
||||||
|
return "Change(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||||
|
}
|
||||||
|
}
|
@ -1,16 +1,16 @@
|
|||||||
//go:generate go run tailscale.com/cmd/viewer --type=User,Node,PreAuthKey
|
//go:generate go tool viewer --type=User,Node,PreAuthKey
|
||||||
|
|
||||||
package types
|
package types
|
||||||
|
|
||||||
|
//go:generate go run tailscale.com/cmd/viewer --type=User,Node,PreAuthKey
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"runtime"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"tailscale.com/tailcfg"
|
"tailscale.com/tailcfg"
|
||||||
"tailscale.com/util/ctxkey"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -150,18 +150,6 @@ func UpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
NotifyOriginKey = ctxkey.New("notify.origin", "")
|
|
||||||
NotifyHostnameKey = ctxkey.New("notify.hostname", "")
|
|
||||||
)
|
|
||||||
|
|
||||||
func NotifyCtx(ctx context.Context, origin, hostname string) context.Context {
|
|
||||||
ctx2, _ := context.WithTimeout(ctx, 3*time.Second)
|
|
||||||
ctx2 = NotifyOriginKey.WithValue(ctx2, origin)
|
|
||||||
ctx2 = NotifyHostnameKey.WithValue(ctx2, hostname)
|
|
||||||
return ctx2
|
|
||||||
}
|
|
||||||
|
|
||||||
const RegistrationIDLength = 24
|
const RegistrationIDLength = 24
|
||||||
|
|
||||||
type RegistrationID string
|
type RegistrationID string
|
||||||
@ -199,3 +187,20 @@ type RegisterNode struct {
|
|||||||
Node Node
|
Node Node
|
||||||
Registered chan *Node
|
Registered chan *Node
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DefaultBatcherWorkers returns the default number of batcher workers.
|
||||||
|
// Default to 3/4 of CPU cores, minimum 1, no maximum.
|
||||||
|
func DefaultBatcherWorkers() int {
|
||||||
|
return DefaultBatcherWorkersFor(runtime.NumCPU())
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultBatcherWorkersFor returns the default number of batcher workers for a given CPU count.
|
||||||
|
// Default to 3/4 of CPU cores, minimum 1, no maximum.
|
||||||
|
func DefaultBatcherWorkersFor(cpuCount int) int {
|
||||||
|
defaultWorkers := (cpuCount * 3) / 4
|
||||||
|
if defaultWorkers < 1 {
|
||||||
|
defaultWorkers = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
return defaultWorkers
|
||||||
|
}
|
||||||
|
36
hscontrol/types/common_test.go
Normal file
36
hscontrol/types/common_test.go
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDefaultBatcherWorkersFor(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
cpuCount int
|
||||||
|
expected int
|
||||||
|
}{
|
||||||
|
{1, 1}, // (1*3)/4 = 0, should be minimum 1
|
||||||
|
{2, 1}, // (2*3)/4 = 1
|
||||||
|
{4, 3}, // (4*3)/4 = 3
|
||||||
|
{8, 6}, // (8*3)/4 = 6
|
||||||
|
{12, 9}, // (12*3)/4 = 9
|
||||||
|
{16, 12}, // (16*3)/4 = 12
|
||||||
|
{20, 15}, // (20*3)/4 = 15
|
||||||
|
{24, 18}, // (24*3)/4 = 18
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
result := DefaultBatcherWorkersFor(test.cpuCount)
|
||||||
|
if result != test.expected {
|
||||||
|
t.Errorf("DefaultBatcherWorkersFor(%d) = %d, expected %d", test.cpuCount, result, test.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultBatcherWorkers(t *testing.T) {
|
||||||
|
// Just verify it returns a valid value (>= 1)
|
||||||
|
result := DefaultBatcherWorkers()
|
||||||
|
if result < 1 {
|
||||||
|
t.Errorf("DefaultBatcherWorkers() = %d, expected value >= 1", result)
|
||||||
|
}
|
||||||
|
}
|
@ -234,6 +234,7 @@ type Tuning struct {
|
|||||||
NotifierSendTimeout time.Duration
|
NotifierSendTimeout time.Duration
|
||||||
BatchChangeDelay time.Duration
|
BatchChangeDelay time.Duration
|
||||||
NodeMapSessionBufferedChanSize int
|
NodeMapSessionBufferedChanSize int
|
||||||
|
BatcherWorkers int
|
||||||
}
|
}
|
||||||
|
|
||||||
func validatePKCEMethod(method string) error {
|
func validatePKCEMethod(method string) error {
|
||||||
@ -991,6 +992,12 @@ func LoadServerConfig() (*Config, error) {
|
|||||||
NodeMapSessionBufferedChanSize: viper.GetInt(
|
NodeMapSessionBufferedChanSize: viper.GetInt(
|
||||||
"tuning.node_mapsession_buffered_chan_size",
|
"tuning.node_mapsession_buffered_chan_size",
|
||||||
),
|
),
|
||||||
|
BatcherWorkers: func() int {
|
||||||
|
if workers := viper.GetInt("tuning.batcher_workers"); workers > 0 {
|
||||||
|
return workers
|
||||||
|
}
|
||||||
|
return DefaultBatcherWorkers()
|
||||||
|
}(),
|
||||||
},
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
@ -431,6 +431,11 @@ func (node *Node) SubnetRoutes() []netip.Prefix {
|
|||||||
return routes
|
return routes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsSubnetRouter reports if the node has any subnet routes.
|
||||||
|
func (node *Node) IsSubnetRouter() bool {
|
||||||
|
return len(node.SubnetRoutes()) > 0
|
||||||
|
}
|
||||||
|
|
||||||
func (node *Node) String() string {
|
func (node *Node) String() string {
|
||||||
return node.Hostname
|
return node.Hostname
|
||||||
}
|
}
|
||||||
@ -669,6 +674,13 @@ func (v NodeView) SubnetRoutes() []netip.Prefix {
|
|||||||
return v.ж.SubnetRoutes()
|
return v.ж.SubnetRoutes()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (v NodeView) IsSubnetRouter() bool {
|
||||||
|
if !v.Valid() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return v.ж.IsSubnetRouter()
|
||||||
|
}
|
||||||
|
|
||||||
func (v NodeView) AppendToIPSet(build *netipx.IPSetBuilder) {
|
func (v NodeView) AppendToIPSet(build *netipx.IPSetBuilder) {
|
||||||
if !v.Valid() {
|
if !v.Valid() {
|
||||||
return
|
return
|
||||||
|
@ -1,17 +1,16 @@
|
|||||||
package types
|
package types
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
)
|
)
|
||||||
|
|
||||||
type PAKError string
|
type PAKError string
|
||||||
|
|
||||||
func (e PAKError) Error() string { return string(e) }
|
func (e PAKError) Error() string { return string(e) }
|
||||||
func (e PAKError) Unwrap() error { return fmt.Errorf("preauth key error: %w", e) }
|
|
||||||
|
|
||||||
// PreAuthKey describes a pre-authorization key usable in a particular user.
|
// PreAuthKey describes a pre-authorization key usable in a particular user.
|
||||||
type PreAuthKey struct {
|
type PreAuthKey struct {
|
||||||
@ -60,6 +59,21 @@ func (pak *PreAuthKey) Validate() error {
|
|||||||
if pak == nil {
|
if pak == nil {
|
||||||
return PAKError("invalid authkey")
|
return PAKError("invalid authkey")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Debug().
|
||||||
|
Str("key", pak.Key).
|
||||||
|
Bool("hasExpiration", pak.Expiration != nil).
|
||||||
|
Time("expiration", func() time.Time {
|
||||||
|
if pak.Expiration != nil {
|
||||||
|
return *pak.Expiration
|
||||||
|
}
|
||||||
|
return time.Time{}
|
||||||
|
}()).
|
||||||
|
Time("now", time.Now()).
|
||||||
|
Bool("reusable", pak.Reusable).
|
||||||
|
Bool("used", pak.Used).
|
||||||
|
Msg("PreAuthKey.Validate: checking key")
|
||||||
|
|
||||||
if pak.Expiration != nil && pak.Expiration.Before(time.Now()) {
|
if pak.Expiration != nil && pak.Expiration.Before(time.Now()) {
|
||||||
return PAKError("authkey expired")
|
return PAKError("authkey expired")
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,8 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"tailscale.com/util/dnsname"
|
||||||
|
"tailscale.com/util/must"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCheckForFQDNRules(t *testing.T) {
|
func TestCheckForFQDNRules(t *testing.T) {
|
||||||
@ -102,59 +104,16 @@ func TestConvertWithFQDNRules(t *testing.T) {
|
|||||||
func TestMagicDNSRootDomains100(t *testing.T) {
|
func TestMagicDNSRootDomains100(t *testing.T) {
|
||||||
domains := GenerateIPv4DNSRootDomain(netip.MustParsePrefix("100.64.0.0/10"))
|
domains := GenerateIPv4DNSRootDomain(netip.MustParsePrefix("100.64.0.0/10"))
|
||||||
|
|
||||||
found := false
|
assert.Contains(t, domains, must.Get(dnsname.ToFQDN("64.100.in-addr.arpa.")))
|
||||||
for _, domain := range domains {
|
assert.Contains(t, domains, must.Get(dnsname.ToFQDN("100.100.in-addr.arpa.")))
|
||||||
if domain == "64.100.in-addr.arpa." {
|
assert.Contains(t, domains, must.Get(dnsname.ToFQDN("127.100.in-addr.arpa.")))
|
||||||
found = true
|
|
||||||
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assert.True(t, found)
|
|
||||||
|
|
||||||
found = false
|
|
||||||
for _, domain := range domains {
|
|
||||||
if domain == "100.100.in-addr.arpa." {
|
|
||||||
found = true
|
|
||||||
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assert.True(t, found)
|
|
||||||
|
|
||||||
found = false
|
|
||||||
for _, domain := range domains {
|
|
||||||
if domain == "127.100.in-addr.arpa." {
|
|
||||||
found = true
|
|
||||||
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assert.True(t, found)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestMagicDNSRootDomains172(t *testing.T) {
|
func TestMagicDNSRootDomains172(t *testing.T) {
|
||||||
domains := GenerateIPv4DNSRootDomain(netip.MustParsePrefix("172.16.0.0/16"))
|
domains := GenerateIPv4DNSRootDomain(netip.MustParsePrefix("172.16.0.0/16"))
|
||||||
|
|
||||||
found := false
|
assert.Contains(t, domains, must.Get(dnsname.ToFQDN("0.16.172.in-addr.arpa.")))
|
||||||
for _, domain := range domains {
|
assert.Contains(t, domains, must.Get(dnsname.ToFQDN("255.16.172.in-addr.arpa.")))
|
||||||
if domain == "0.16.172.in-addr.arpa." {
|
|
||||||
found = true
|
|
||||||
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assert.True(t, found)
|
|
||||||
|
|
||||||
found = false
|
|
||||||
for _, domain := range domains {
|
|
||||||
if domain == "255.16.172.in-addr.arpa." {
|
|
||||||
found = true
|
|
||||||
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assert.True(t, found)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Happens when netmask is a multiple of 4 bits (sounds likely).
|
// Happens when netmask is a multiple of 4 bits (sounds likely).
|
||||||
|
@ -143,7 +143,7 @@ func ParseTraceroute(output string) (Traceroute, error) {
|
|||||||
|
|
||||||
// Parse latencies
|
// Parse latencies
|
||||||
for j := 5; j <= 7; j++ {
|
for j := 5; j <= 7; j++ {
|
||||||
if matches[j] != "" {
|
if j < len(matches) && matches[j] != "" {
|
||||||
ms, err := strconv.ParseFloat(matches[j], 64)
|
ms, err := strconv.ParseFloat(matches[j], 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Traceroute{}, fmt.Errorf("parsing latency: %w", err)
|
return Traceroute{}, fmt.Errorf("parsing latency: %w", err)
|
||||||
|
@ -88,7 +88,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
|
|||||||
var err error
|
var err error
|
||||||
listNodes, err = headscale.ListNodes()
|
listNodes, err = headscale.ListNodes()
|
||||||
assert.NoError(ct, err)
|
assert.NoError(ct, err)
|
||||||
assert.Equal(ct, nodeCountBeforeLogout, len(listNodes), "Node count should match before logout count")
|
assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match before logout count")
|
||||||
}, 20*time.Second, 1*time.Second)
|
}, 20*time.Second, 1*time.Second)
|
||||||
|
|
||||||
for _, node := range listNodes {
|
for _, node := range listNodes {
|
||||||
@ -123,7 +123,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
|
|||||||
var err error
|
var err error
|
||||||
listNodes, err = headscale.ListNodes()
|
listNodes, err = headscale.ListNodes()
|
||||||
assert.NoError(ct, err)
|
assert.NoError(ct, err)
|
||||||
assert.Equal(ct, nodeCountBeforeLogout, len(listNodes), "Node count should match after HTTPS reconnection")
|
assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match after HTTPS reconnection")
|
||||||
}, 30*time.Second, 2*time.Second)
|
}, 30*time.Second, 2*time.Second)
|
||||||
|
|
||||||
for _, node := range listNodes {
|
for _, node := range listNodes {
|
||||||
@ -161,7 +161,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
listNodes, err = headscale.ListNodes()
|
listNodes, err = headscale.ListNodes()
|
||||||
require.Equal(t, nodeCountBeforeLogout, len(listNodes))
|
require.Len(t, listNodes, nodeCountBeforeLogout)
|
||||||
for _, node := range listNodes {
|
for _, node := range listNodes {
|
||||||
assertLastSeenSet(t, node)
|
assertLastSeenSet(t, node)
|
||||||
}
|
}
|
||||||
@ -355,7 +355,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
|
|||||||
"--user",
|
"--user",
|
||||||
strconv.FormatUint(userMap[userName].GetId(), 10),
|
strconv.FormatUint(userMap[userName].GetId(), 10),
|
||||||
"expire",
|
"expire",
|
||||||
key.Key,
|
key.GetKey(),
|
||||||
})
|
})
|
||||||
assertNoErr(t, err)
|
assertNoErr(t, err)
|
||||||
|
|
||||||
|
@ -147,3 +147,9 @@ func DockerAllowNetworkAdministration(config *docker.HostConfig) {
|
|||||||
config.CapAdd = append(config.CapAdd, "NET_ADMIN")
|
config.CapAdd = append(config.CapAdd, "NET_ADMIN")
|
||||||
config.Privileged = true
|
config.Privileged = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DockerMemoryLimit sets memory limit and disables OOM kill for containers.
|
||||||
|
func DockerMemoryLimit(config *docker.HostConfig) {
|
||||||
|
config.Memory = 2 * 1024 * 1024 * 1024 // 2GB in bytes
|
||||||
|
config.OOMKillDisable = true
|
||||||
|
}
|
||||||
|
@ -883,6 +883,10 @@ func TestNodeOnlineStatus(t *testing.T) {
|
|||||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||||
status, err := client.Status()
|
status, err := client.Status()
|
||||||
assert.NoError(ct, err)
|
assert.NoError(ct, err)
|
||||||
|
if status == nil {
|
||||||
|
assert.Fail(ct, "status is nil")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
for _, peerKey := range status.Peers() {
|
for _, peerKey := range status.Peers() {
|
||||||
peerStatus := status.Peer[peerKey]
|
peerStatus := status.Peer[peerKey]
|
||||||
@ -984,16 +988,11 @@ func TestPingAllByIPManyUpDown(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Wait for sync and successful pings after nodes come back up
|
// Wait for sync and successful pings after nodes come back up
|
||||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
err = scenario.WaitForTailscaleSync()
|
||||||
err = scenario.WaitForTailscaleSync()
|
assert.NoError(t, err)
|
||||||
assert.NoError(ct, err)
|
|
||||||
|
|
||||||
success := pingAllHelper(t, allClients, allAddrs)
|
|
||||||
assert.Greater(ct, success, 0, "Nodes should be able to ping after coming back up")
|
|
||||||
}, 30*time.Second, 2*time.Second)
|
|
||||||
|
|
||||||
success := pingAllHelper(t, allClients, allAddrs)
|
success := pingAllHelper(t, allClients, allAddrs)
|
||||||
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
assert.Equalf(t, len(allClients)*len(allIps), success, "%d successful pings out of %d", success, len(allClients)*len(allIps))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -260,7 +260,9 @@ func WithDERPConfig(derpMap tailcfg.DERPMap) Option {
|
|||||||
func WithTuning(batchTimeout time.Duration, mapSessionChanSize int) Option {
|
func WithTuning(batchTimeout time.Duration, mapSessionChanSize int) Option {
|
||||||
return func(hsic *HeadscaleInContainer) {
|
return func(hsic *HeadscaleInContainer) {
|
||||||
hsic.env["HEADSCALE_TUNING_BATCH_CHANGE_DELAY"] = batchTimeout.String()
|
hsic.env["HEADSCALE_TUNING_BATCH_CHANGE_DELAY"] = batchTimeout.String()
|
||||||
hsic.env["HEADSCALE_TUNING_NODE_MAPSESSION_BUFFERED_CHAN_SIZE"] = strconv.Itoa(mapSessionChanSize)
|
hsic.env["HEADSCALE_TUNING_NODE_MAPSESSION_BUFFERED_CHAN_SIZE"] = strconv.Itoa(
|
||||||
|
mapSessionChanSize,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -279,9 +281,15 @@ func WithDebugPort(port int) Option {
|
|||||||
|
|
||||||
// buildEntrypoint builds the container entrypoint command based on configuration.
|
// buildEntrypoint builds the container entrypoint command based on configuration.
|
||||||
func (hsic *HeadscaleInContainer) buildEntrypoint() []string {
|
func (hsic *HeadscaleInContainer) buildEntrypoint() []string {
|
||||||
debugCmd := fmt.Sprintf("/go/bin/dlv --listen=0.0.0.0:%d --headless=true --api-version=2 --accept-multiclient --allow-non-terminal-interactive=true exec /go/bin/headscale --continue -- serve", hsic.debugPort)
|
debugCmd := fmt.Sprintf(
|
||||||
|
"/go/bin/dlv --listen=0.0.0.0:%d --headless=true --api-version=2 --accept-multiclient --allow-non-terminal-interactive=true exec /go/bin/headscale --continue -- serve",
|
||||||
|
hsic.debugPort,
|
||||||
|
)
|
||||||
|
|
||||||
entrypoint := fmt.Sprintf("/bin/sleep 3 ; update-ca-certificates ; %s ; /bin/sleep 30", debugCmd)
|
entrypoint := fmt.Sprintf(
|
||||||
|
"/bin/sleep 3 ; update-ca-certificates ; %s ; /bin/sleep 30",
|
||||||
|
debugCmd,
|
||||||
|
)
|
||||||
|
|
||||||
return []string{"/bin/bash", "-c", entrypoint}
|
return []string{"/bin/bash", "-c", entrypoint}
|
||||||
}
|
}
|
||||||
@ -448,7 +456,11 @@ func New(
|
|||||||
|
|
||||||
hsic.container = container
|
hsic.container = container
|
||||||
|
|
||||||
log.Printf("Debug ports for %s: delve=%s, metrics/pprof=49090\n", hsic.hostname, hsic.GetHostDebugPort())
|
log.Printf(
|
||||||
|
"Debug ports for %s: delve=%s, metrics/pprof=49090\n",
|
||||||
|
hsic.hostname,
|
||||||
|
hsic.GetHostDebugPort(),
|
||||||
|
)
|
||||||
|
|
||||||
// Write the CA certificates to the container
|
// Write the CA certificates to the container
|
||||||
for i, cert := range hsic.caCerts {
|
for i, cert := range hsic.caCerts {
|
||||||
@ -684,14 +696,6 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// First, let's see what files are actually in /tmp
|
|
||||||
tmpListing, err := t.Execute([]string{"ls", "-la", "/tmp/"})
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Warning: could not list /tmp directory: %v", err)
|
|
||||||
} else {
|
|
||||||
log.Printf("Contents of /tmp in container %s:\n%s", t.hostname, tmpListing)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Also check for any .sqlite files
|
// Also check for any .sqlite files
|
||||||
sqliteFiles, err := t.Execute([]string{"find", "/tmp", "-name", "*.sqlite*", "-type", "f"})
|
sqliteFiles, err := t.Execute([]string{"find", "/tmp", "-name", "*.sqlite*", "-type", "f"})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -718,12 +722,6 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
|
|||||||
return errors.New("database file exists but has no schema (empty database)")
|
return errors.New("database file exists but has no schema (empty database)")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Show a preview of the schema (first 500 chars)
|
|
||||||
schemaPreview := schemaCheck
|
|
||||||
if len(schemaPreview) > 500 {
|
|
||||||
schemaPreview = schemaPreview[:500] + "..."
|
|
||||||
}
|
|
||||||
|
|
||||||
tarFile, err := t.FetchPath("/tmp/integration_test_db.sqlite3")
|
tarFile, err := t.FetchPath("/tmp/integration_test_db.sqlite3")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to fetch database file: %w", err)
|
return fmt.Errorf("failed to fetch database file: %w", err)
|
||||||
@ -740,7 +738,12 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
|
|||||||
return fmt.Errorf("failed to read tar header: %w", err)
|
return fmt.Errorf("failed to read tar header: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Found file in tar: %s (type: %d, size: %d)", header.Name, header.Typeflag, header.Size)
|
log.Printf(
|
||||||
|
"Found file in tar: %s (type: %d, size: %d)",
|
||||||
|
header.Name,
|
||||||
|
header.Typeflag,
|
||||||
|
header.Size,
|
||||||
|
)
|
||||||
|
|
||||||
// Extract the first regular file we find
|
// Extract the first regular file we find
|
||||||
if header.Typeflag == tar.TypeReg {
|
if header.Typeflag == tar.TypeReg {
|
||||||
@ -756,11 +759,20 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
|
|||||||
return fmt.Errorf("failed to copy database file: %w", err)
|
return fmt.Errorf("failed to copy database file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Extracted database file: %s (%d bytes written, header claimed %d bytes)", dbPath, written, header.Size)
|
log.Printf(
|
||||||
|
"Extracted database file: %s (%d bytes written, header claimed %d bytes)",
|
||||||
|
dbPath,
|
||||||
|
written,
|
||||||
|
header.Size,
|
||||||
|
)
|
||||||
|
|
||||||
// Check if we actually wrote something
|
// Check if we actually wrote something
|
||||||
if written == 0 {
|
if written == 0 {
|
||||||
return fmt.Errorf("database file is empty (size: %d, header size: %d)", written, header.Size)
|
return fmt.Errorf(
|
||||||
|
"database file is empty (size: %d, header size: %d)",
|
||||||
|
written,
|
||||||
|
header.Size,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -871,7 +883,15 @@ func (t *HeadscaleInContainer) WaitForRunning() error {
|
|||||||
func (t *HeadscaleInContainer) CreateUser(
|
func (t *HeadscaleInContainer) CreateUser(
|
||||||
user string,
|
user string,
|
||||||
) (*v1.User, error) {
|
) (*v1.User, error) {
|
||||||
command := []string{"headscale", "users", "create", user, fmt.Sprintf("--email=%s@test.no", user), "--output", "json"}
|
command := []string{
|
||||||
|
"headscale",
|
||||||
|
"users",
|
||||||
|
"create",
|
||||||
|
user,
|
||||||
|
fmt.Sprintf("--email=%s@test.no", user),
|
||||||
|
"--output",
|
||||||
|
"json",
|
||||||
|
}
|
||||||
|
|
||||||
result, _, err := dockertestutil.ExecuteCommand(
|
result, _, err := dockertestutil.ExecuteCommand(
|
||||||
t.container,
|
t.container,
|
||||||
@ -1182,13 +1202,18 @@ func (t *HeadscaleInContainer) ApproveRoutes(id uint64, routes []netip.Prefix) (
|
|||||||
[]string{},
|
[]string{},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to execute list node command: %w", err)
|
return nil, fmt.Errorf(
|
||||||
|
"failed to execute approve routes command (node %d, routes %v): %w",
|
||||||
|
id,
|
||||||
|
routes,
|
||||||
|
err,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
var node *v1.Node
|
var node *v1.Node
|
||||||
err = json.Unmarshal([]byte(result), &node)
|
err = json.Unmarshal([]byte(result), &node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to unmarshal nodes: %w", err)
|
return nil, fmt.Errorf("failed to unmarshal node response: %q, error: %w", result, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return node, nil
|
return node, nil
|
||||||
|
@ -310,7 +310,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||||||
// Enable route on node 1
|
// Enable route on node 1
|
||||||
t.Logf("Enabling route on subnet router 1, no HA")
|
t.Logf("Enabling route on subnet router 1, no HA")
|
||||||
_, err = headscale.ApproveRoutes(
|
_, err = headscale.ApproveRoutes(
|
||||||
1,
|
MustFindNode(subRouter1.Hostname(), nodes).GetId(),
|
||||||
[]netip.Prefix{pref},
|
[]netip.Prefix{pref},
|
||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -366,7 +366,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||||||
// Enable route on node 2, now we will have a HA subnet router
|
// Enable route on node 2, now we will have a HA subnet router
|
||||||
t.Logf("Enabling route on subnet router 2, now HA, subnetrouter 1 is primary, 2 is standby")
|
t.Logf("Enabling route on subnet router 2, now HA, subnetrouter 1 is primary, 2 is standby")
|
||||||
_, err = headscale.ApproveRoutes(
|
_, err = headscale.ApproveRoutes(
|
||||||
2,
|
MustFindNode(subRouter2.Hostname(), nodes).GetId(),
|
||||||
[]netip.Prefix{pref},
|
[]netip.Prefix{pref},
|
||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -422,7 +422,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||||||
// be enabled.
|
// be enabled.
|
||||||
t.Logf("Enabling route on subnet router 3, now HA, subnetrouter 1 is primary, 2 and 3 is standby")
|
t.Logf("Enabling route on subnet router 3, now HA, subnetrouter 1 is primary, 2 and 3 is standby")
|
||||||
_, err = headscale.ApproveRoutes(
|
_, err = headscale.ApproveRoutes(
|
||||||
3,
|
MustFindNode(subRouter3.Hostname(), nodes).GetId(),
|
||||||
[]netip.Prefix{pref},
|
[]netip.Prefix{pref},
|
||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -639,7 +639,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||||||
|
|
||||||
t.Logf("disabling route in subnet router r3 (%s)", subRouter3.Hostname())
|
t.Logf("disabling route in subnet router r3 (%s)", subRouter3.Hostname())
|
||||||
t.Logf("expecting route to failover to r1 (%s), which is still available with r2", subRouter1.Hostname())
|
t.Logf("expecting route to failover to r1 (%s), which is still available with r2", subRouter1.Hostname())
|
||||||
_, err = headscale.ApproveRoutes(nodes[2].GetId(), []netip.Prefix{})
|
_, err = headscale.ApproveRoutes(MustFindNode(subRouter3.Hostname(), nodes).GetId(), []netip.Prefix{})
|
||||||
|
|
||||||
time.Sleep(5 * time.Second)
|
time.Sleep(5 * time.Second)
|
||||||
|
|
||||||
@ -647,9 +647,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, nodes, 6)
|
assert.Len(t, nodes, 6)
|
||||||
|
|
||||||
requireNodeRouteCount(t, nodes[0], 1, 1, 1)
|
requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 1)
|
||||||
requireNodeRouteCount(t, nodes[1], 1, 1, 0)
|
requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 0)
|
||||||
requireNodeRouteCount(t, nodes[2], 1, 0, 0)
|
requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
|
||||||
|
|
||||||
// Verify that the route is announced from subnet router 1
|
// Verify that the route is announced from subnet router 1
|
||||||
clientStatus, err = client.Status()
|
clientStatus, err = client.Status()
|
||||||
@ -684,7 +684,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||||||
// Disable the route of subnet router 1, making it failover to 2
|
// Disable the route of subnet router 1, making it failover to 2
|
||||||
t.Logf("disabling route in subnet router r1 (%s)", subRouter1.Hostname())
|
t.Logf("disabling route in subnet router r1 (%s)", subRouter1.Hostname())
|
||||||
t.Logf("expecting route to failover to r2 (%s)", subRouter2.Hostname())
|
t.Logf("expecting route to failover to r2 (%s)", subRouter2.Hostname())
|
||||||
_, err = headscale.ApproveRoutes(nodes[0].GetId(), []netip.Prefix{})
|
_, err = headscale.ApproveRoutes(MustFindNode(subRouter1.Hostname(), nodes).GetId(), []netip.Prefix{})
|
||||||
|
|
||||||
time.Sleep(5 * time.Second)
|
time.Sleep(5 * time.Second)
|
||||||
|
|
||||||
@ -692,9 +692,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, nodes, 6)
|
assert.Len(t, nodes, 6)
|
||||||
|
|
||||||
requireNodeRouteCount(t, nodes[0], 1, 0, 0)
|
requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 0, 0)
|
||||||
requireNodeRouteCount(t, nodes[1], 1, 1, 1)
|
requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1)
|
||||||
requireNodeRouteCount(t, nodes[2], 1, 0, 0)
|
requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
|
||||||
|
|
||||||
// Verify that the route is announced from subnet router 1
|
// Verify that the route is announced from subnet router 1
|
||||||
clientStatus, err = client.Status()
|
clientStatus, err = client.Status()
|
||||||
@ -729,9 +729,10 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||||||
// enable the route of subnet router 1, no change expected
|
// enable the route of subnet router 1, no change expected
|
||||||
t.Logf("enabling route in subnet router 1 (%s)", subRouter1.Hostname())
|
t.Logf("enabling route in subnet router 1 (%s)", subRouter1.Hostname())
|
||||||
t.Logf("both online, expecting r2 (%s) to still be primary (no flapping)", subRouter2.Hostname())
|
t.Logf("both online, expecting r2 (%s) to still be primary (no flapping)", subRouter2.Hostname())
|
||||||
|
r1Node := MustFindNode(subRouter1.Hostname(), nodes)
|
||||||
_, err = headscale.ApproveRoutes(
|
_, err = headscale.ApproveRoutes(
|
||||||
nodes[0].GetId(),
|
r1Node.GetId(),
|
||||||
util.MustStringsToPrefixes(nodes[0].GetAvailableRoutes()),
|
util.MustStringsToPrefixes(r1Node.GetAvailableRoutes()),
|
||||||
)
|
)
|
||||||
|
|
||||||
time.Sleep(5 * time.Second)
|
time.Sleep(5 * time.Second)
|
||||||
@ -740,9 +741,9 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, nodes, 6)
|
assert.Len(t, nodes, 6)
|
||||||
|
|
||||||
requireNodeRouteCount(t, nodes[0], 1, 1, 0)
|
requireNodeRouteCount(t, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 0)
|
||||||
requireNodeRouteCount(t, nodes[1], 1, 1, 1)
|
requireNodeRouteCount(t, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1)
|
||||||
requireNodeRouteCount(t, nodes[2], 1, 0, 0)
|
requireNodeRouteCount(t, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0)
|
||||||
|
|
||||||
// Verify that the route is announced from subnet router 1
|
// Verify that the route is announced from subnet router 1
|
||||||
clientStatus, err = client.Status()
|
clientStatus, err = client.Status()
|
||||||
|
@ -223,7 +223,7 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) {
|
|||||||
|
|
||||||
s.userToNetwork = userToNetwork
|
s.userToNetwork = userToNetwork
|
||||||
|
|
||||||
if spec.OIDCUsers != nil && len(spec.OIDCUsers) != 0 {
|
if len(spec.OIDCUsers) != 0 {
|
||||||
ttl := defaultAccessTTL
|
ttl := defaultAccessTTL
|
||||||
if spec.OIDCAccessTTL != 0 {
|
if spec.OIDCAccessTTL != 0 {
|
||||||
ttl = spec.OIDCAccessTTL
|
ttl = spec.OIDCAccessTTL
|
||||||
|
@ -370,10 +370,12 @@ func TestSSHUserOnlyIsolation(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func doSSH(t *testing.T, client TailscaleClient, peer TailscaleClient) (string, string, error) {
|
func doSSH(t *testing.T, client TailscaleClient, peer TailscaleClient) (string, string, error) {
|
||||||
|
t.Helper()
|
||||||
return doSSHWithRetry(t, client, peer, true)
|
return doSSHWithRetry(t, client, peer, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func doSSHWithoutRetry(t *testing.T, client TailscaleClient, peer TailscaleClient) (string, string, error) {
|
func doSSHWithoutRetry(t *testing.T, client TailscaleClient, peer TailscaleClient) (string, string, error) {
|
||||||
|
t.Helper()
|
||||||
return doSSHWithRetry(t, client, peer, false)
|
return doSSHWithRetry(t, client, peer, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -319,6 +319,7 @@ func New(
|
|||||||
dockertestutil.DockerRestartPolicy,
|
dockertestutil.DockerRestartPolicy,
|
||||||
dockertestutil.DockerAllowLocalIPv6,
|
dockertestutil.DockerAllowLocalIPv6,
|
||||||
dockertestutil.DockerAllowNetworkAdministration,
|
dockertestutil.DockerAllowNetworkAdministration,
|
||||||
|
dockertestutil.DockerMemoryLimit,
|
||||||
)
|
)
|
||||||
case "unstable":
|
case "unstable":
|
||||||
tailscaleOptions.Repository = "tailscale/tailscale"
|
tailscaleOptions.Repository = "tailscale/tailscale"
|
||||||
@ -329,6 +330,7 @@ func New(
|
|||||||
dockertestutil.DockerRestartPolicy,
|
dockertestutil.DockerRestartPolicy,
|
||||||
dockertestutil.DockerAllowLocalIPv6,
|
dockertestutil.DockerAllowLocalIPv6,
|
||||||
dockertestutil.DockerAllowNetworkAdministration,
|
dockertestutil.DockerAllowNetworkAdministration,
|
||||||
|
dockertestutil.DockerMemoryLimit,
|
||||||
)
|
)
|
||||||
default:
|
default:
|
||||||
tailscaleOptions.Repository = "tailscale/tailscale"
|
tailscaleOptions.Repository = "tailscale/tailscale"
|
||||||
@ -339,6 +341,7 @@ func New(
|
|||||||
dockertestutil.DockerRestartPolicy,
|
dockertestutil.DockerRestartPolicy,
|
||||||
dockertestutil.DockerAllowLocalIPv6,
|
dockertestutil.DockerAllowLocalIPv6,
|
||||||
dockertestutil.DockerAllowNetworkAdministration,
|
dockertestutil.DockerAllowNetworkAdministration,
|
||||||
|
dockertestutil.DockerMemoryLimit,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -22,11 +22,11 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
// derpPingTimeout defines the timeout for individual DERP ping operations
|
// derpPingTimeout defines the timeout for individual DERP ping operations
|
||||||
// Used in DERP connectivity tests to verify relay server communication
|
// Used in DERP connectivity tests to verify relay server communication.
|
||||||
derpPingTimeout = 2 * time.Second
|
derpPingTimeout = 2 * time.Second
|
||||||
|
|
||||||
// derpPingCount defines the number of ping attempts for DERP connectivity tests
|
// derpPingCount defines the number of ping attempts for DERP connectivity tests
|
||||||
// Higher count provides better reliability assessment of DERP connectivity
|
// Higher count provides better reliability assessment of DERP connectivity.
|
||||||
derpPingCount = 10
|
derpPingCount = 10
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -321,7 +321,7 @@ func assertValidNetcheck(t *testing.T, client TailscaleClient) {
|
|||||||
// before executing commands that depend on network state propagation.
|
// before executing commands that depend on network state propagation.
|
||||||
//
|
//
|
||||||
// Timeout: 10 seconds with exponential backoff
|
// Timeout: 10 seconds with exponential backoff
|
||||||
// Use cases: DNS resolution, route propagation, policy updates
|
// Use cases: DNS resolution, route propagation, policy updates.
|
||||||
func assertCommandOutputContains(t *testing.T, c TailscaleClient, command []string, contains string) {
|
func assertCommandOutputContains(t *testing.T, c TailscaleClient, command []string, contains string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
@ -361,10 +361,10 @@ func isSelfClient(client TailscaleClient, addr string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func dockertestMaxWait() time.Duration {
|
func dockertestMaxWait() time.Duration {
|
||||||
wait := 120 * time.Second //nolint
|
wait := 300 * time.Second //nolint
|
||||||
|
|
||||||
if util.IsCI() {
|
if util.IsCI() {
|
||||||
wait = 300 * time.Second //nolint
|
wait = 600 * time.Second //nolint
|
||||||
}
|
}
|
||||||
|
|
||||||
return wait
|
return wait
|
||||||
|
@ -6,7 +6,6 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"regexp"
|
"regexp"
|
||||||
@ -21,7 +20,7 @@ import (
|
|||||||
const (
|
const (
|
||||||
releasesURL = "https://api.github.com/repos/tailscale/tailscale/releases"
|
releasesURL = "https://api.github.com/repos/tailscale/tailscale/releases"
|
||||||
rawFileURL = "https://github.com/tailscale/tailscale/raw/refs/tags/%s/tailcfg/tailcfg.go"
|
rawFileURL = "https://github.com/tailscale/tailscale/raw/refs/tags/%s/tailcfg/tailcfg.go"
|
||||||
outputFile = "../capver_generated.go"
|
outputFile = "../../hscontrol/capver/capver_generated.go"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Release struct {
|
type Release struct {
|
||||||
@ -105,7 +104,7 @@ func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion
|
|||||||
sortedVersions := xmaps.Keys(versions)
|
sortedVersions := xmaps.Keys(versions)
|
||||||
sort.Strings(sortedVersions)
|
sort.Strings(sortedVersions)
|
||||||
for _, version := range sortedVersions {
|
for _, version := range sortedVersions {
|
||||||
file.WriteString(fmt.Sprintf("\t\"%s\": %d,\n", version, versions[version]))
|
fmt.Fprintf(file, "\t\"%s\": %d,\n", version, versions[version])
|
||||||
}
|
}
|
||||||
file.WriteString("}\n")
|
file.WriteString("}\n")
|
||||||
|
|
||||||
@ -115,16 +114,13 @@ func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion
|
|||||||
capVarToTailscaleVer := make(map[tailcfg.CapabilityVersion]string)
|
capVarToTailscaleVer := make(map[tailcfg.CapabilityVersion]string)
|
||||||
for _, v := range sortedVersions {
|
for _, v := range sortedVersions {
|
||||||
cap := versions[v]
|
cap := versions[v]
|
||||||
log.Printf("cap for v: %d, %s", cap, v)
|
|
||||||
|
|
||||||
// If it is already set, skip and continue,
|
// If it is already set, skip and continue,
|
||||||
// we only want the first tailscale vsion per
|
// we only want the first tailscale vsion per
|
||||||
// capability vsion.
|
// capability vsion.
|
||||||
if _, ok := capVarToTailscaleVer[cap]; ok {
|
if _, ok := capVarToTailscaleVer[cap]; ok {
|
||||||
log.Printf("Skipping %d, %s", cap, v)
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
log.Printf("Storing %d, %s", cap, v)
|
|
||||||
capVarToTailscaleVer[cap] = v
|
capVarToTailscaleVer[cap] = v
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -133,7 +129,7 @@ func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion
|
|||||||
return capsSorted[i] < capsSorted[j]
|
return capsSorted[i] < capsSorted[j]
|
||||||
})
|
})
|
||||||
for _, capVer := range capsSorted {
|
for _, capVer := range capsSorted {
|
||||||
file.WriteString(fmt.Sprintf("\t%d:\t\t\"%s\",\n", capVer, capVarToTailscaleVer[capVer]))
|
fmt.Fprintf(file, "\t%d:\t\t\"%s\",\n", capVer, capVarToTailscaleVer[capVer])
|
||||||
}
|
}
|
||||||
file.WriteString("}\n")
|
file.WriteString("}\n")
|
||||||
|
|
Loading…
x
Reference in New Issue
Block a user