This commit is contained in:
Kristoffer Dalby 2025-07-14 20:43:57 +00:00
parent 9d2cfb1e7e
commit 67f2c20052
17 changed files with 973 additions and 3132 deletions

View File

@ -1,63 +0,0 @@
# Column Filtering for Table Output
## Overview
All CLI commands that output tables now support a `--columns` flag to customize which columns are displayed.
## Usage
```bash
# Show all default columns
headscale users list
# Show only name and email
headscale users list --columns="name,email"
# Show only ID and username
headscale users list --columns="id,username"
# Show columns in custom order
headscale users list --columns="email,name,id"
```
## Available Columns
### Users List
- `id` - User ID
- `name` - Display name
- `username` - Username
- `email` - Email address
- `created` - Creation date
### Implementation Pattern
For developers adding this to other commands:
```go
// 1. Add columns flag with default columns
AddColumnsFlag(cmd, "id,name,hostname,ip,status")
// 2. Use ListOutput with TableRenderer
ListOutput(cmd, items, func(tr *TableRenderer) {
tr.AddColumn("id", "ID", func(item interface{}) string {
node := item.(*v1.Node)
return strconv.FormatUint(node.GetId(), 10)
}).
AddColumn("name", "Name", func(item interface{}) string {
node := item.(*v1.Node)
return node.GetName()
}).
AddColumn("hostname", "Hostname", func(item interface{}) string {
node := item.(*v1.Node)
return node.GetHostname()
})
// ... add more columns
})
```
## Notes
- Column filtering only applies to table output, not JSON/YAML output
- Invalid column names are silently ignored
- Columns appear in the order specified in the --columns flag
- Default columns are defined per command based on most useful information

View File

@ -1,321 +0,0 @@
# Headscale CLI Infrastructure Refactoring - Completed
## Overview
Successfully completed a comprehensive refactoring of the Headscale CLI infrastructure following the CLI_IMPROVEMENT_PLAN.md. The refactoring created a robust, type-safe, and maintainable CLI framework that significantly reduces code duplication while improving consistency and testability.
## ✅ Completed Infrastructure Components
### 1. **CLI Unit Testing Infrastructure**
- **Files**: `testing.go`, `testing_test.go`
- **Features**: Mock gRPC client, command execution helpers, test data creation utilities
- **Impact**: Enables comprehensive unit testing of all CLI commands
- **Lines of Code**: ~750 lines of testing infrastructure
### 2. **Common Flag Infrastructure**
- **Files**: `flags.go`, `flags_test.go`
- **Features**: Standardized flag helpers, consistent shortcuts, validation helpers
- **Impact**: Consistent flag handling across all commands
- **Lines of Code**: ~200 lines of flag utilities
### 3. **gRPC Client Infrastructure**
- **Files**: `client.go`, `client_test.go`
- **Features**: ClientWrapper with automatic connection management, error handling
- **Impact**: Simplified gRPC client usage with consistent error handling
- **Lines of Code**: ~400 lines of client infrastructure
### 4. **Output Infrastructure**
- **Files**: `output.go`, `output_test.go`
- **Features**: OutputManager, TableRenderer, consistent formatting utilities
- **Impact**: Standardized output across all formats (JSON, YAML, tables)
- **Lines of Code**: ~350 lines of output utilities
### 5. **Command Patterns Infrastructure**
- **Files**: `patterns.go`, `patterns_test.go`
- **Features**: Reusable CRUD patterns, argument validation, resource resolution
- **Impact**: Dramatically reduces code per command (~50% reduction)
- **Lines of Code**: ~200 lines of pattern utilities
### 6. **Validation Infrastructure**
- **Files**: `validation.go`, `validation_test.go`
- **Features**: Input validation, business logic validation, error formatting
- **Impact**: Consistent validation with meaningful error messages
- **Lines of Code**: ~500 lines of validation functions + 400+ test cases
## ✅ Example Refactored Commands
### 7. **Refactored User Commands**
- **Files**: `users_refactored.go`, `users_refactored_test.go`
- **Features**: Complete user command suite using new infrastructure
- **Impact**: Demonstrates 50% code reduction while maintaining functionality
- **Lines of Code**: ~250 lines (vs ~500 lines original)
### 8. **Comprehensive Test Coverage**
- **Files**: Multiple test files for each component
- **Features**: 500+ unit tests, integration tests, performance benchmarks
- **Impact**: High confidence in infrastructure reliability
- **Test Coverage**: All new infrastructure components
## 📊 Key Metrics and Improvements
### **Code Reduction**
- **User Commands**: 50% less code per command
- **Flag Setup**: 70% less repetitive flag code
- **Error Handling**: 60% less error handling boilerplate
- **Output Formatting**: 80% less output formatting code
### **Type Safety Improvements**
- **Zero `interface{}` usage**: All functions use concrete types
- **No `any` types**: Proper type safety throughout
- **Compile-time validation**: Type checking catches errors early
- **Mock client type safety**: Testing infrastructure is fully typed
### **Consistency Improvements**
- **Standardized error messages**: All validation errors follow same format
- **Consistent flag shortcuts**: All common flags use same shortcuts
- **Uniform output**: All commands support JSON/YAML/table formats
- **Common patterns**: All CRUD operations follow same structure
### **Testing Improvements**
- **400+ validation tests**: Every validation function extensively tested
- **Mock infrastructure**: Complete mock gRPC client for testing
- **Integration tests**: End-to-end testing of command patterns
- **Performance benchmarks**: Ensures CLI remains responsive
## 🔧 Technical Implementation Details
### **Type-Safe Architecture**
```go
// Example: Type-safe command function
func createUserLogic(client *ClientWrapper, cmd *cobra.Command, args []string) (interface{}, error) {
// Validate input using validation infrastructure
if err := ValidateUserName(args[0]); err != nil {
return nil, err
}
// Use standardized client wrapper
response, err := client.CreateUser(cmd, request)
if err != nil {
return nil, err
}
return response.GetUser(), nil
}
```
### **Reusable Command Patterns**
```go
// Example: Standard command creation
func createUserRefactored() *cobra.Command {
return &cobra.Command{
Use: "create NAME",
Args: ValidateExactArgs(1, "create <username>"),
Run: StandardCreateCommand(createUserLogic, "User created successfully"),
}
}
```
### **Comprehensive Validation**
```go
// Example: Validation with clear error messages
if err := ValidateEmail(email); err != nil {
return nil, fmt.Errorf("invalid email: %w", err)
}
```
### **Consistent Output Handling**
```go
// Example: Automatic output formatting
ListOutput(cmd, users, setupUsersTable) // Handles JSON/YAML/table automatically
```
## 🎯 Benefits Achieved
### **For Developers**
- **50% less code** to write for new commands
- **Consistent patterns** reduce learning curve
- **Type safety** catches errors at compile time
- **Comprehensive testing** infrastructure ready to use
- **Better error messages** improve debugging experience
### **For Users**
- **Consistent interface** across all commands
- **Better error messages** with helpful suggestions
- **Reliable validation** catches issues early
- **Multiple output formats** (JSON, YAML, human-readable)
- **Improved help text** and usage examples
### **For Maintainers**
- **Easier code review** with standardized patterns
- **Better test coverage** with testing infrastructure
- **Consistent behavior** across commands reduces bugs
- **Simpler onboarding** for new contributors
- **Future extensibility** with modular design
## 📁 File Structure Overview
```
cmd/headscale/cli/
├── infrastructure/
│ ├── testing.go # Mock client infrastructure
│ ├── testing_test.go # Testing infrastructure tests
│ ├── flags.go # Flag registration helpers
│ ├── client.go # gRPC client wrapper
│ ├── output.go # Output formatting utilities
│ ├── patterns.go # Command execution patterns
│ └── validation.go # Input validation utilities
├── examples/
│ ├── users_refactored.go # Refactored user commands
│ └── users_refactored_example.go # Original examples
├── tests/
│ ├── *_test.go # Unit tests for each component
│ ├── infrastructure_integration_test.go # Integration tests
│ ├── validation_test.go # Comprehensive validation tests
│ └── dump_config_test.go # Additional command tests
└── original/
├── users.go # Original user commands (unchanged)
├── nodes.go # Original node commands (unchanged)
└── *.go # Other original commands (unchanged)
```
## 🚀 Usage Examples
### **Creating a New Command (Before vs After)**
**Before (Original Pattern)**:
```go
var createUserCmd = &cobra.Command{
Use: "create NAME",
Short: "Creates a new user",
Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 1 {
return errMissingParameter
}
return nil
},
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
userName := args[0]
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
request := &v1.CreateUserRequest{Name: userName}
if displayName, _ := cmd.Flags().GetString("display-name"); displayName != "" {
request.DisplayName = displayName
}
// ... more validation and setup (30+ lines)
response, err := client.CreateUser(ctx, request)
if err != nil {
ErrorOutput(err, "Cannot create user: "+status.Convert(err).Message(), output)
}
SuccessOutput(response.GetUser(), "User created", output)
},
}
```
**After (Refactored Pattern)**:
```go
func createUserRefactored() *cobra.Command {
cmd := &cobra.Command{
Use: "create NAME",
Short: "Creates a new user",
Args: ValidateExactArgs(1, "create <username>"),
Run: StandardCreateCommand(createUserLogic, "User created successfully"),
}
cmd.Flags().StringP("display-name", "d", "", "Display name")
cmd.Flags().StringP("email", "e", "", "Email address")
cmd.Flags().StringP("picture-url", "p", "", "Profile picture URL")
AddOutputFlag(cmd)
return cmd
}
func createUserLogic(client *ClientWrapper, cmd *cobra.Command, args []string) (interface{}, error) {
userName := args[0]
if err := ValidateUserName(userName); err != nil {
return nil, err
}
request := &v1.CreateUserRequest{Name: userName}
if displayName, _ := cmd.Flags().GetString("display-name"); displayName != "" {
request.DisplayName = displayName
}
if email, _ := cmd.Flags().GetString("email"); email != "" {
if err := ValidateEmail(email); err != nil {
return nil, fmt.Errorf("invalid email: %w", err)
}
request.Email = email
}
if pictureURL, _ := cmd.Flags().GetString("picture-url"); pictureURL != "" {
if err := ValidateURL(pictureURL); err != nil {
return nil, fmt.Errorf("invalid picture URL: %w", err)
}
request.PictureUrl = pictureURL
}
if err := ValidateNoDuplicateUsers(client, userName, 0); err != nil {
return nil, err
}
response, err := client.CreateUser(cmd, request)
if err != nil {
return nil, err
}
return response.GetUser(), nil
}
```
**Result**: ~50% less code, better validation, consistent error handling, automatic output formatting.
## 🔍 Quality Assurance
### **Test Coverage**
- **Unit Tests**: 500+ test cases covering all components
- **Integration Tests**: End-to-end command pattern testing
- **Performance Tests**: Benchmarks for command execution
- **Mock Testing**: Complete mock infrastructure for reliable testing
### **Type Safety**
- **Zero `interface{}`**: All functions use concrete types
- **Compile-time validation**: Type system catches errors early
- **Mock type safety**: Testing infrastructure is fully typed
### **Documentation**
- **Comprehensive comments**: All functions well-documented
- **Usage examples**: Clear examples for each pattern
- **Error message quality**: Helpful error messages with suggestions
## 🎉 Conclusion
The Headscale CLI infrastructure refactoring has been successfully completed, delivering:
**Complete infrastructure** for type-safe CLI development
**50% code reduction** for new commands
**Comprehensive testing** infrastructure
**Consistent user experience** across all commands
**Better error handling** and validation
**Future-proof architecture** for extensibility
The new infrastructure provides a solid foundation for CLI development at Headscale, making it easier to add new commands, maintain existing ones, and provide a consistent experience for users. All components are thoroughly tested, type-safe, and ready for production use.
### **Next Steps**
1. **Gradual Migration**: Existing commands can be migrated to use the new infrastructure incrementally
2. **Documentation Updates**: User-facing documentation can be updated to reflect new consistent behavior
3. **New Command Development**: All new commands should use the refactored patterns from day one
The refactoring work demonstrates the power of well-designed infrastructure in reducing complexity while improving quality and maintainability.

View File

@ -0,0 +1,82 @@
# CLI Simplification - WithClient Pattern
## Problem
Every CLI command has repetitive gRPC client setup boilerplate:
```go
// This pattern appears 25+ times across all commands
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
// ... command logic ...
```
## Solution
Simple closure that handles client lifecycle:
```go
// client.go - 16 lines total
func WithClient(fn func(context.Context, v1.HeadscaleServiceClient) error) error {
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
return fn(ctx, client)
}
```
## Usage Example
### Before (users.go listUsersCmd):
```go
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
ctx, client, conn, cancel := newHeadscaleCLIWithConfig() // 4 lines
defer cancel()
defer conn.Close()
request := &v1.ListUsersRequest{}
// ... build request ...
response, err := client.ListUsers(ctx, request)
if err != nil {
ErrorOutput(err, "Cannot get users: "+status.Convert(err).Message(), output)
}
// ... handle response ...
}
```
### After:
```go
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
err := WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
request := &v1.ListUsersRequest{}
// ... build request ...
response, err := client.ListUsers(ctx, request)
if err != nil {
ErrorOutput(err, "Cannot get users: "+status.Convert(err).Message(), output)
return err
}
// ... handle response ...
return nil
})
if err != nil {
return // Error already handled
}
}
```
## Benefits
- **Removes 4 lines of boilerplate** from every command
- **Ensures proper cleanup** - no forgetting defer statements
- **Simpler error handling** - return from closure, handled centrally
- **Easy to apply** - minimal changes to existing commands
## Rollout
This pattern can be applied to all 25+ commands systematically, removing ~100 lines of repetitive boilerplate.

View File

@ -1,6 +1,7 @@
package cli
import (
"context"
"fmt"
"strconv"
"time"
@ -54,50 +55,56 @@ var listAPIKeys = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
err := WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
request := &v1.ListApiKeysRequest{}
request := &v1.ListApiKeysRequest{}
response, err := client.ListApiKeys(ctx, request)
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Error getting the list of keys: %s", err),
output,
)
}
if output != "" {
SuccessOutput(response.GetApiKeys(), "", output)
}
tableData := pterm.TableData{
{"ID", "Prefix", "Expiration", "Created"},
}
for _, key := range response.GetApiKeys() {
expiration := "-"
if key.GetExpiration() != nil {
expiration = ColourTime(key.GetExpiration().AsTime())
response, err := client.ListApiKeys(ctx, request)
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Error getting the list of keys: %s", err),
output,
)
return err
}
tableData = append(tableData, []string{
strconv.FormatUint(key.GetId(), util.Base10),
key.GetPrefix(),
expiration,
key.GetCreatedAt().AsTime().Format(HeadscaleDateTimeFormat),
})
if output != "" {
SuccessOutput(response.GetApiKeys(), "", output)
return nil
}
tableData := pterm.TableData{
{"ID", "Prefix", "Expiration", "Created"},
}
for _, key := range response.GetApiKeys() {
expiration := "-"
if key.GetExpiration() != nil {
expiration = ColourTime(key.GetExpiration().AsTime())
}
tableData = append(tableData, []string{
strconv.FormatUint(key.GetId(), util.Base10),
key.GetPrefix(),
expiration,
key.GetCreatedAt().AsTime().Format(HeadscaleDateTimeFormat),
})
}
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
return err
}
return nil
})
}
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
return
}
},
}
@ -124,26 +131,31 @@ If you loose a key, create a new one and revoke (expire) the old one.`,
fmt.Sprintf("Could not parse duration: %s\n", err),
output,
)
return
}
expiration := time.Now().UTC().Add(time.Duration(duration))
request.Expiration = timestamppb.New(expiration)
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
response, err := client.CreateApiKey(ctx, request)
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Cannot create Api Key: %s\n", err),
output,
)
return err
}
SuccessOutput(response.GetApiKey(), response.GetApiKey(), output)
return nil
})
response, err := client.CreateApiKey(ctx, request)
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Cannot create Api Key: %s\n", err),
output,
)
return
}
SuccessOutput(response.GetApiKey(), response.GetApiKey(), output)
},
}
@ -161,26 +173,31 @@ var expireAPIKeyCmd = &cobra.Command{
fmt.Sprintf("Error getting prefix from CLI flag: %s", err),
output,
)
return
}
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
request := &v1.ExpireApiKeyRequest{
Prefix: prefix,
}
request := &v1.ExpireApiKeyRequest{
Prefix: prefix,
}
response, err := client.ExpireApiKey(ctx, request)
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Cannot expire Api Key: %s\n", err),
output,
)
return err
}
SuccessOutput(response, "Key expired", output)
return nil
})
response, err := client.ExpireApiKey(ctx, request)
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Cannot expire Api Key: %s\n", err),
output,
)
return
}
SuccessOutput(response, "Key expired", output)
},
}
@ -198,25 +215,30 @@ var deleteAPIKeyCmd = &cobra.Command{
fmt.Sprintf("Error getting prefix from CLI flag: %s", err),
output,
)
return
}
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
request := &v1.DeleteApiKeyRequest{
Prefix: prefix,
}
request := &v1.DeleteApiKeyRequest{
Prefix: prefix,
}
response, err := client.DeleteApiKey(ctx, request)
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Cannot delete Api Key: %s\n", err),
output,
)
return err
}
SuccessOutput(response, "Key deleted", output)
return nil
})
response, err := client.DeleteApiKey(ctx, request)
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Cannot delete Api Key: %s\n", err),
output,
)
return
}
SuccessOutput(response, "Key deleted", output)
},
}

View File

@ -2,414 +2,15 @@ package cli
import (
"context"
"fmt"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/spf13/cobra"
"google.golang.org/grpc"
"google.golang.org/grpc/status"
)
// ClientWrapper wraps the gRPC client with automatic connection lifecycle management
type ClientWrapper struct {
ctx context.Context
client v1.HeadscaleServiceClient
conn *grpc.ClientConn
cancel context.CancelFunc
}
// NewClient creates a new ClientWrapper with automatic connection setup
func NewClient() (*ClientWrapper, error) {
// WithClient handles gRPC client setup and cleanup, calls fn with client and context
func WithClient(fn func(context.Context, v1.HeadscaleServiceClient) error) error {
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
return &ClientWrapper{
ctx: ctx,
client: client,
conn: conn,
cancel: cancel,
}, nil
}
// Close properly closes the gRPC connection and cancels the context
func (c *ClientWrapper) Close() {
if c.cancel != nil {
c.cancel()
}
if c.conn != nil {
c.conn.Close()
}
}
// ExecuteWithErrorHandling executes a gRPC operation with standardized error handling
func (c *ClientWrapper) ExecuteWithErrorHandling(
cmd *cobra.Command,
operation func(client v1.HeadscaleServiceClient) (interface{}, error),
errorMsg string,
) (interface{}, error) {
result, err := operation(c.client)
if err != nil {
output := GetOutputFormat(cmd)
ErrorOutput(
err,
fmt.Sprintf("%s: %s", errorMsg, status.Convert(err).Message()),
output,
)
return nil, err
}
return result, nil
}
// Specific operation helpers with automatic error handling and output formatting
// ListNodes executes a ListNodes request with error handling
func (c *ClientWrapper) ListNodes(cmd *cobra.Command, req *v1.ListNodesRequest) (*v1.ListNodesResponse, error) {
result, err := c.ExecuteWithErrorHandling(cmd,
func(client v1.HeadscaleServiceClient) (interface{}, error) {
return client.ListNodes(c.ctx, req)
},
"Cannot get nodes",
)
if err != nil {
return nil, err
}
return result.(*v1.ListNodesResponse), nil
}
// RegisterNode executes a RegisterNode request with error handling
func (c *ClientWrapper) RegisterNode(cmd *cobra.Command, req *v1.RegisterNodeRequest) (*v1.RegisterNodeResponse, error) {
result, err := c.ExecuteWithErrorHandling(cmd,
func(client v1.HeadscaleServiceClient) (interface{}, error) {
return client.RegisterNode(c.ctx, req)
},
"Cannot register node",
)
if err != nil {
return nil, err
}
return result.(*v1.RegisterNodeResponse), nil
}
// DeleteNode executes a DeleteNode request with error handling
func (c *ClientWrapper) DeleteNode(cmd *cobra.Command, req *v1.DeleteNodeRequest) (*v1.DeleteNodeResponse, error) {
result, err := c.ExecuteWithErrorHandling(cmd,
func(client v1.HeadscaleServiceClient) (interface{}, error) {
return client.DeleteNode(c.ctx, req)
},
"Error deleting node",
)
if err != nil {
return nil, err
}
return result.(*v1.DeleteNodeResponse), nil
}
// ExpireNode executes an ExpireNode request with error handling
func (c *ClientWrapper) ExpireNode(cmd *cobra.Command, req *v1.ExpireNodeRequest) (*v1.ExpireNodeResponse, error) {
result, err := c.ExecuteWithErrorHandling(cmd,
func(client v1.HeadscaleServiceClient) (interface{}, error) {
return client.ExpireNode(c.ctx, req)
},
"Cannot expire node",
)
if err != nil {
return nil, err
}
return result.(*v1.ExpireNodeResponse), nil
}
// RenameNode executes a RenameNode request with error handling
func (c *ClientWrapper) RenameNode(cmd *cobra.Command, req *v1.RenameNodeRequest) (*v1.RenameNodeResponse, error) {
result, err := c.ExecuteWithErrorHandling(cmd,
func(client v1.HeadscaleServiceClient) (interface{}, error) {
return client.RenameNode(c.ctx, req)
},
"Cannot rename node",
)
if err != nil {
return nil, err
}
return result.(*v1.RenameNodeResponse), nil
}
// MoveNode executes a MoveNode request with error handling
func (c *ClientWrapper) MoveNode(cmd *cobra.Command, req *v1.MoveNodeRequest) (*v1.MoveNodeResponse, error) {
result, err := c.ExecuteWithErrorHandling(cmd,
func(client v1.HeadscaleServiceClient) (interface{}, error) {
return client.MoveNode(c.ctx, req)
},
"Error moving node",
)
if err != nil {
return nil, err
}
return result.(*v1.MoveNodeResponse), nil
}
// GetNode executes a GetNode request with error handling
func (c *ClientWrapper) GetNode(cmd *cobra.Command, req *v1.GetNodeRequest) (*v1.GetNodeResponse, error) {
result, err := c.ExecuteWithErrorHandling(cmd,
func(client v1.HeadscaleServiceClient) (interface{}, error) {
return client.GetNode(c.ctx, req)
},
"Error getting node",
)
if err != nil {
return nil, err
}
return result.(*v1.GetNodeResponse), nil
}
// SetTags executes a SetTags request with error handling
func (c *ClientWrapper) SetTags(cmd *cobra.Command, req *v1.SetTagsRequest) (*v1.SetTagsResponse, error) {
result, err := c.ExecuteWithErrorHandling(cmd,
func(client v1.HeadscaleServiceClient) (interface{}, error) {
return client.SetTags(c.ctx, req)
},
"Error while sending tags to headscale",
)
if err != nil {
return nil, err
}
return result.(*v1.SetTagsResponse), nil
}
// SetApprovedRoutes executes a SetApprovedRoutes request with error handling
func (c *ClientWrapper) SetApprovedRoutes(cmd *cobra.Command, req *v1.SetApprovedRoutesRequest) (*v1.SetApprovedRoutesResponse, error) {
result, err := c.ExecuteWithErrorHandling(cmd,
func(client v1.HeadscaleServiceClient) (interface{}, error) {
return client.SetApprovedRoutes(c.ctx, req)
},
"Error while sending routes to headscale",
)
if err != nil {
return nil, err
}
return result.(*v1.SetApprovedRoutesResponse), nil
}
// BackfillNodeIPs executes a BackfillNodeIPs request with error handling
func (c *ClientWrapper) BackfillNodeIPs(cmd *cobra.Command, req *v1.BackfillNodeIPsRequest) (*v1.BackfillNodeIPsResponse, error) {
result, err := c.ExecuteWithErrorHandling(cmd,
func(client v1.HeadscaleServiceClient) (interface{}, error) {
return client.BackfillNodeIPs(c.ctx, req)
},
"Error backfilling IPs",
)
if err != nil {
return nil, err
}
return result.(*v1.BackfillNodeIPsResponse), nil
}
// ListUsers executes a ListUsers request with error handling
func (c *ClientWrapper) ListUsers(cmd *cobra.Command, req *v1.ListUsersRequest) (*v1.ListUsersResponse, error) {
result, err := c.ExecuteWithErrorHandling(cmd,
func(client v1.HeadscaleServiceClient) (interface{}, error) {
return client.ListUsers(c.ctx, req)
},
"Cannot get users",
)
if err != nil {
return nil, err
}
return result.(*v1.ListUsersResponse), nil
}
// CreateUser executes a CreateUser request with error handling
func (c *ClientWrapper) CreateUser(cmd *cobra.Command, req *v1.CreateUserRequest) (*v1.CreateUserResponse, error) {
result, err := c.ExecuteWithErrorHandling(cmd,
func(client v1.HeadscaleServiceClient) (interface{}, error) {
return client.CreateUser(c.ctx, req)
},
"Cannot create user",
)
if err != nil {
return nil, err
}
return result.(*v1.CreateUserResponse), nil
}
// RenameUser executes a RenameUser request with error handling
func (c *ClientWrapper) RenameUser(cmd *cobra.Command, req *v1.RenameUserRequest) (*v1.RenameUserResponse, error) {
result, err := c.ExecuteWithErrorHandling(cmd,
func(client v1.HeadscaleServiceClient) (interface{}, error) {
return client.RenameUser(c.ctx, req)
},
"Cannot rename user",
)
if err != nil {
return nil, err
}
return result.(*v1.RenameUserResponse), nil
}
// DeleteUser executes a DeleteUser request with error handling
func (c *ClientWrapper) DeleteUser(cmd *cobra.Command, req *v1.DeleteUserRequest) (*v1.DeleteUserResponse, error) {
result, err := c.ExecuteWithErrorHandling(cmd,
func(client v1.HeadscaleServiceClient) (interface{}, error) {
return client.DeleteUser(c.ctx, req)
},
"Error deleting user",
)
if err != nil {
return nil, err
}
return result.(*v1.DeleteUserResponse), nil
}
// ListApiKeys executes a ListApiKeys request with error handling
func (c *ClientWrapper) ListApiKeys(cmd *cobra.Command, req *v1.ListApiKeysRequest) (*v1.ListApiKeysResponse, error) {
result, err := c.ExecuteWithErrorHandling(cmd,
func(client v1.HeadscaleServiceClient) (interface{}, error) {
return client.ListApiKeys(c.ctx, req)
},
"Cannot get API keys",
)
if err != nil {
return nil, err
}
return result.(*v1.ListApiKeysResponse), nil
}
// CreateApiKey executes a CreateApiKey request with error handling
func (c *ClientWrapper) CreateApiKey(cmd *cobra.Command, req *v1.CreateApiKeyRequest) (*v1.CreateApiKeyResponse, error) {
result, err := c.ExecuteWithErrorHandling(cmd,
func(client v1.HeadscaleServiceClient) (interface{}, error) {
return client.CreateApiKey(c.ctx, req)
},
"Cannot create API key",
)
if err != nil {
return nil, err
}
return result.(*v1.CreateApiKeyResponse), nil
}
// ExpireApiKey executes an ExpireApiKey request with error handling
func (c *ClientWrapper) ExpireApiKey(cmd *cobra.Command, req *v1.ExpireApiKeyRequest) (*v1.ExpireApiKeyResponse, error) {
result, err := c.ExecuteWithErrorHandling(cmd,
func(client v1.HeadscaleServiceClient) (interface{}, error) {
return client.ExpireApiKey(c.ctx, req)
},
"Cannot expire API key",
)
if err != nil {
return nil, err
}
return result.(*v1.ExpireApiKeyResponse), nil
}
// DeleteApiKey executes a DeleteApiKey request with error handling
func (c *ClientWrapper) DeleteApiKey(cmd *cobra.Command, req *v1.DeleteApiKeyRequest) (*v1.DeleteApiKeyResponse, error) {
result, err := c.ExecuteWithErrorHandling(cmd,
func(client v1.HeadscaleServiceClient) (interface{}, error) {
return client.DeleteApiKey(c.ctx, req)
},
"Error deleting API key",
)
if err != nil {
return nil, err
}
return result.(*v1.DeleteApiKeyResponse), nil
}
// ListPreAuthKeys executes a ListPreAuthKeys request with error handling
func (c *ClientWrapper) ListPreAuthKeys(cmd *cobra.Command, req *v1.ListPreAuthKeysRequest) (*v1.ListPreAuthKeysResponse, error) {
result, err := c.ExecuteWithErrorHandling(cmd,
func(client v1.HeadscaleServiceClient) (interface{}, error) {
return client.ListPreAuthKeys(c.ctx, req)
},
"Cannot get preauth keys",
)
if err != nil {
return nil, err
}
return result.(*v1.ListPreAuthKeysResponse), nil
}
// CreatePreAuthKey executes a CreatePreAuthKey request with error handling
func (c *ClientWrapper) CreatePreAuthKey(cmd *cobra.Command, req *v1.CreatePreAuthKeyRequest) (*v1.CreatePreAuthKeyResponse, error) {
result, err := c.ExecuteWithErrorHandling(cmd,
func(client v1.HeadscaleServiceClient) (interface{}, error) {
return client.CreatePreAuthKey(c.ctx, req)
},
"Cannot create preauth key",
)
if err != nil {
return nil, err
}
return result.(*v1.CreatePreAuthKeyResponse), nil
}
// ExpirePreAuthKey executes an ExpirePreAuthKey request with error handling
func (c *ClientWrapper) ExpirePreAuthKey(cmd *cobra.Command, req *v1.ExpirePreAuthKeyRequest) (*v1.ExpirePreAuthKeyResponse, error) {
result, err := c.ExecuteWithErrorHandling(cmd,
func(client v1.HeadscaleServiceClient) (interface{}, error) {
return client.ExpirePreAuthKey(c.ctx, req)
},
"Cannot expire preauth key",
)
if err != nil {
return nil, err
}
return result.(*v1.ExpirePreAuthKeyResponse), nil
}
// GetPolicy executes a GetPolicy request with error handling
func (c *ClientWrapper) GetPolicy(cmd *cobra.Command, req *v1.GetPolicyRequest) (*v1.GetPolicyResponse, error) {
result, err := c.ExecuteWithErrorHandling(cmd,
func(client v1.HeadscaleServiceClient) (interface{}, error) {
return client.GetPolicy(c.ctx, req)
},
"Cannot get policy",
)
if err != nil {
return nil, err
}
return result.(*v1.GetPolicyResponse), nil
}
// SetPolicy executes a SetPolicy request with error handling
func (c *ClientWrapper) SetPolicy(cmd *cobra.Command, req *v1.SetPolicyRequest) (*v1.SetPolicyResponse, error) {
result, err := c.ExecuteWithErrorHandling(cmd,
func(client v1.HeadscaleServiceClient) (interface{}, error) {
return client.SetPolicy(c.ctx, req)
},
"Cannot set policy",
)
if err != nil {
return nil, err
}
return result.(*v1.SetPolicyResponse), nil
}
// DebugCreateNode executes a DebugCreateNode request with error handling
func (c *ClientWrapper) DebugCreateNode(cmd *cobra.Command, req *v1.DebugCreateNodeRequest) (*v1.DebugCreateNodeResponse, error) {
result, err := c.ExecuteWithErrorHandling(cmd,
func(client v1.HeadscaleServiceClient) (interface{}, error) {
return client.DebugCreateNode(c.ctx, req)
},
"Cannot create node",
)
if err != nil {
return nil, err
}
return result.(*v1.DebugCreateNodeResponse), nil
}
// Helper function to execute commands with automatic client management
func ExecuteWithClient(cmd *cobra.Command, operation func(*ClientWrapper) error) {
client, err := NewClient()
if err != nil {
output := GetOutputFormat(cmd)
ErrorOutput(err, "Cannot connect to headscale", output)
return
}
defer client.Close()
err = operation(client)
if err != nil {
// Error already handled by the operation
return
}
return fn(ctx, client)
}

View File

@ -0,0 +1,105 @@
#!/usr/bin/env python3
"""Script to convert all commands to use WithClient pattern"""
import re
import sys
import os
def convert_command(content):
"""Convert a single command to use WithClient pattern"""
# Pattern to match the gRPC client setup
pattern = r'(\t+)ctx, client, conn, cancel := newHeadscaleCLIWithConfig\(\)\n\t+defer cancel\(\)\n\t+defer conn\.Close\(\)\n\n'
# Find all occurrences
matches = list(re.finditer(pattern, content))
if not matches:
return content
# Process each match from the end to avoid offset issues
for match in reversed(matches):
indent = match.group(1)
start_pos = match.start()
end_pos = match.end()
# Find the end of the Run function
remaining_content = content[end_pos:]
# Find the matching closing brace for the Run function
brace_count = 0
func_end = -1
for i, char in enumerate(remaining_content):
if char == '{':
brace_count += 1
elif char == '}':
brace_count -= 1
if brace_count < 0: # Found the closing brace
func_end = i
break
if func_end == -1:
continue
# Extract the function body
func_body = remaining_content[:func_end]
# Indent the function body
indented_body = '\n'.join(indent + '\t' + line if line.strip() else line
for line in func_body.split('\n'))
# Create the new function with WithClient
new_func = f"""{indent}err := WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {{
{indented_body}
{indent}\treturn nil
{indent}}})
{indent}
{indent}if err != nil {{
{indent}\treturn
{indent}}}"""
# Replace the old pattern with the new one
content = content[:start_pos] + new_func + '\n' + content[end_pos + func_end:]
return content
def process_file(filepath):
"""Process a single Go file"""
try:
with open(filepath, 'r') as f:
content = f.read()
# Check if context is already imported
if 'import (' in content and '"context"' not in content:
# Add context import
content = content.replace(
'import (',
'import (\n\t"context"'
)
# Convert commands
new_content = convert_command(content)
# Write back if changed
if new_content != content:
with open(filepath, 'w') as f:
f.write(new_content)
print(f"Updated {filepath}")
else:
print(f"No changes needed for {filepath}")
except Exception as e:
print(f"Error processing {filepath}: {e}")
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: python3 convert_commands.py <go_file>")
sys.exit(1)
filepath = sys.argv[1]
if not os.path.exists(filepath):
print(f"File not found: {filepath}")
sys.exit(1)
process_file(filepath)

View File

@ -1,6 +1,7 @@
package cli
import (
"context"
"fmt"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
@ -64,12 +65,9 @@ var createNodeCmd = &cobra.Command{
user, err := cmd.Flags().GetString("user")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
return
}
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
name, err := cmd.Flags().GetString("name")
if err != nil {
ErrorOutput(
@ -77,6 +75,7 @@ var createNodeCmd = &cobra.Command{
fmt.Sprintf("Error getting node from flag: %s", err),
output,
)
return
}
registrationID, err := cmd.Flags().GetString("key")
@ -86,6 +85,7 @@ var createNodeCmd = &cobra.Command{
fmt.Sprintf("Error getting key from flag: %s", err),
output,
)
return
}
_, err = types.RegistrationIDFromString(registrationID)
@ -95,6 +95,7 @@ var createNodeCmd = &cobra.Command{
fmt.Sprintf("Failed to parse machine key from flag: %s", err),
output,
)
return
}
routes, err := cmd.Flags().GetStringSlice("route")
@ -104,24 +105,33 @@ var createNodeCmd = &cobra.Command{
fmt.Sprintf("Error getting routes from flag: %s", err),
output,
)
return
}
request := &v1.DebugCreateNodeRequest{
Key: registrationID,
Name: name,
User: user,
Routes: routes,
}
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
request := &v1.DebugCreateNodeRequest{
Key: registrationID,
Name: name,
User: user,
Routes: routes,
}
response, err := client.DebugCreateNode(ctx, request)
response, err := client.DebugCreateNode(ctx, request)
if err != nil {
ErrorOutput(
err,
"Cannot create node: "+status.Convert(err).Message(),
output,
)
return err
}
SuccessOutput(response.GetNode(), "Node created", output)
return nil
})
if err != nil {
ErrorOutput(
err,
"Cannot create node: "+status.Convert(err).Message(),
output,
)
return
}
SuccessOutput(response.GetNode(), "Node created", output)
},
}

View File

@ -1,358 +0,0 @@
package cli
import (
"fmt"
"log"
"time"
"github.com/spf13/cobra"
)
const (
deprecateNamespaceMessage = "use --user"
)
// Flag registration helpers - standardize how flags are added to commands
// AddIdentifierFlag adds a uint64 identifier flag with consistent naming
func AddIdentifierFlag(cmd *cobra.Command, name string, help string) {
cmd.Flags().Uint64P(name, "i", 0, help)
}
// AddRequiredIdentifierFlag adds a required uint64 identifier flag
func AddRequiredIdentifierFlag(cmd *cobra.Command, name string, help string) {
AddIdentifierFlag(cmd, name, help)
err := cmd.MarkFlagRequired(name)
if err != nil {
log.Fatal(err.Error())
}
}
// AddColumnsFlag adds a columns flag for table output customization
func AddColumnsFlag(cmd *cobra.Command, defaultColumns string) {
cmd.Flags().String("columns", defaultColumns, "Comma-separated list of columns to display")
}
// GetColumnsFlag gets the columns flag value
func GetColumnsFlag(cmd *cobra.Command) string {
columns, _ := cmd.Flags().GetString("columns")
return columns
}
// AddUserFlag adds a user flag (string for username or email)
func AddUserFlag(cmd *cobra.Command) {
cmd.Flags().StringP("user", "u", "", "User")
}
// AddRequiredUserFlag adds a required user flag
func AddRequiredUserFlag(cmd *cobra.Command) {
AddUserFlag(cmd)
err := cmd.MarkFlagRequired("user")
if err != nil {
log.Fatal(err.Error())
}
}
// AddOutputFlag adds the standard output format flag
func AddOutputFlag(cmd *cobra.Command) {
cmd.Flags().StringP("output", "o", "", "Output format. Empty for human-readable, 'json', 'json-line' or 'yaml'")
}
// AddForceFlag adds the force flag
func AddForceFlag(cmd *cobra.Command) {
cmd.Flags().Bool("force", false, "Disable prompts and forces the execution")
}
// AddExpirationFlag adds an expiration duration flag
func AddExpirationFlag(cmd *cobra.Command, defaultValue string) {
cmd.Flags().StringP("expiration", "e", defaultValue, "Human-readable duration (e.g. 30m, 24h)")
}
// AddDeprecatedNamespaceFlag adds the deprecated namespace flag with appropriate warnings
func AddDeprecatedNamespaceFlag(cmd *cobra.Command) {
cmd.Flags().StringP("namespace", "n", "", "User")
namespaceFlag := cmd.Flags().Lookup("namespace")
namespaceFlag.Deprecated = deprecateNamespaceMessage
namespaceFlag.Hidden = true
}
// AddTagsFlag adds a tags display flag
func AddTagsFlag(cmd *cobra.Command) {
cmd.Flags().BoolP("tags", "t", false, "Show tags")
}
// AddKeyFlag adds a key flag for node registration
func AddKeyFlag(cmd *cobra.Command) {
cmd.Flags().StringP("key", "k", "", "Key")
}
// AddRequiredKeyFlag adds a required key flag
func AddRequiredKeyFlag(cmd *cobra.Command) {
AddKeyFlag(cmd)
err := cmd.MarkFlagRequired("key")
if err != nil {
log.Fatal(err.Error())
}
}
// AddNameFlag adds a name flag
func AddNameFlag(cmd *cobra.Command, help string) {
cmd.Flags().String("name", "", help)
}
// AddRequiredNameFlag adds a required name flag
func AddRequiredNameFlag(cmd *cobra.Command, help string) {
AddNameFlag(cmd, help)
err := cmd.MarkFlagRequired("name")
if err != nil {
log.Fatal(err.Error())
}
}
// AddPrefixFlag adds an API key prefix flag
func AddPrefixFlag(cmd *cobra.Command) {
cmd.Flags().StringP("prefix", "p", "", "ApiKey prefix")
}
// AddRequiredPrefixFlag adds a required API key prefix flag
func AddRequiredPrefixFlag(cmd *cobra.Command) {
AddPrefixFlag(cmd)
err := cmd.MarkFlagRequired("prefix")
if err != nil {
log.Fatal(err.Error())
}
}
// AddFileFlag adds a file path flag
func AddFileFlag(cmd *cobra.Command) {
cmd.Flags().StringP("file", "f", "", "Path to a policy file in HuJSON format")
}
// AddRequiredFileFlag adds a required file path flag
func AddRequiredFileFlag(cmd *cobra.Command) {
AddFileFlag(cmd)
err := cmd.MarkFlagRequired("file")
if err != nil {
log.Fatal(err.Error())
}
}
// AddRoutesFlag adds a routes flag for node route management
func AddRoutesFlag(cmd *cobra.Command) {
cmd.Flags().StringSliceP("routes", "r", []string{}, `List of routes that will be approved (comma-separated, e.g. "10.0.0.0/8,192.168.0.0/24" or empty string to remove all approved routes)`)
}
// AddTagsSliceFlag adds a tags slice flag for node tagging
func AddTagsSliceFlag(cmd *cobra.Command) {
cmd.Flags().StringSliceP("tags", "t", []string{}, "List of tags to add to the node")
}
// Flag getter helpers with consistent error handling
// GetIdentifier gets a uint64 identifier flag value with error handling
func GetIdentifier(cmd *cobra.Command, flagName string) (uint64, error) {
identifier, err := cmd.Flags().GetUint64(flagName)
if err != nil {
return 0, fmt.Errorf("error getting %s flag: %w", flagName, err)
}
return identifier, nil
}
// GetUser gets a user flag value
func GetUser(cmd *cobra.Command) (string, error) {
user, err := cmd.Flags().GetString("user")
if err != nil {
return "", fmt.Errorf("error getting user flag: %w", err)
}
return user, nil
}
// GetOutputFormat gets the output format flag value
func GetOutputFormat(cmd *cobra.Command) string {
output, _ := cmd.Flags().GetString("output")
return output
}
// GetForce gets the force flag value
func GetForce(cmd *cobra.Command) bool {
force, _ := cmd.Flags().GetBool("force")
return force
}
// GetExpiration gets and parses the expiration flag value
func GetExpiration(cmd *cobra.Command) (time.Duration, error) {
expirationStr, err := cmd.Flags().GetString("expiration")
if err != nil {
return 0, fmt.Errorf("error getting expiration flag: %w", err)
}
if expirationStr == "" {
return 0, nil // No expiration set
}
duration, err := time.ParseDuration(expirationStr)
if err != nil {
return 0, fmt.Errorf("invalid expiration duration '%s': %w", expirationStr, err)
}
return duration, nil
}
// GetName gets a name flag value
func GetName(cmd *cobra.Command) (string, error) {
name, err := cmd.Flags().GetString("name")
if err != nil {
return "", fmt.Errorf("error getting name flag: %w", err)
}
return name, nil
}
// GetKey gets a key flag value
func GetKey(cmd *cobra.Command) (string, error) {
key, err := cmd.Flags().GetString("key")
if err != nil {
return "", fmt.Errorf("error getting key flag: %w", err)
}
return key, nil
}
// GetPrefix gets a prefix flag value
func GetPrefix(cmd *cobra.Command) (string, error) {
prefix, err := cmd.Flags().GetString("prefix")
if err != nil {
return "", fmt.Errorf("error getting prefix flag: %w", err)
}
return prefix, nil
}
// GetFile gets a file flag value
func GetFile(cmd *cobra.Command) (string, error) {
file, err := cmd.Flags().GetString("file")
if err != nil {
return "", fmt.Errorf("error getting file flag: %w", err)
}
return file, nil
}
// GetRoutes gets a routes flag value
func GetRoutes(cmd *cobra.Command) ([]string, error) {
routes, err := cmd.Flags().GetStringSlice("routes")
if err != nil {
return nil, fmt.Errorf("error getting routes flag: %w", err)
}
return routes, nil
}
// GetTagsSlice gets a tags slice flag value
func GetTagsSlice(cmd *cobra.Command) ([]string, error) {
tags, err := cmd.Flags().GetStringSlice("tags")
if err != nil {
return nil, fmt.Errorf("error getting tags flag: %w", err)
}
return tags, nil
}
// GetTags gets a tags boolean flag value
func GetTags(cmd *cobra.Command) bool {
tags, _ := cmd.Flags().GetBool("tags")
return tags
}
// Flag validation helpers
// ValidateRequiredFlags validates that required flags are set
func ValidateRequiredFlags(cmd *cobra.Command, flags ...string) error {
for _, flagName := range flags {
flag := cmd.Flags().Lookup(flagName)
if flag == nil {
return fmt.Errorf("flag %s not found", flagName)
}
if !flag.Changed {
return fmt.Errorf("required flag %s not set", flagName)
}
}
return nil
}
// ValidateExclusiveFlags validates that only one of the given flags is set
func ValidateExclusiveFlags(cmd *cobra.Command, flags ...string) error {
setFlags := []string{}
for _, flagName := range flags {
flag := cmd.Flags().Lookup(flagName)
if flag == nil {
return fmt.Errorf("flag %s not found", flagName)
}
if flag.Changed {
setFlags = append(setFlags, flagName)
}
}
if len(setFlags) > 1 {
return fmt.Errorf("only one of the following flags can be set: %v, but found: %v", flags, setFlags)
}
return nil
}
// ValidateIdentifierFlag validates that an identifier flag has a valid value
func ValidateIdentifierFlag(cmd *cobra.Command, flagName string) error {
identifier, err := GetIdentifier(cmd, flagName)
if err != nil {
return err
}
if identifier == 0 {
return fmt.Errorf("%s must be greater than 0", flagName)
}
return nil
}
// ValidateNonEmptyStringFlag validates that a string flag is not empty
func ValidateNonEmptyStringFlag(cmd *cobra.Command, flagName string) error {
value, err := cmd.Flags().GetString(flagName)
if err != nil {
return fmt.Errorf("error getting %s flag: %w", flagName, err)
}
if value == "" {
return fmt.Errorf("%s cannot be empty", flagName)
}
return nil
}
// Deprecated flag handling utilities
// HandleDeprecatedNamespaceFlag handles the deprecated namespace flag by copying its value to user flag
func HandleDeprecatedNamespaceFlag(cmd *cobra.Command) {
namespaceFlag := cmd.Flags().Lookup("namespace")
userFlag := cmd.Flags().Lookup("user")
if namespaceFlag != nil && userFlag != nil && namespaceFlag.Changed && !userFlag.Changed {
// Copy namespace value to user flag
userFlag.Value.Set(namespaceFlag.Value.String())
userFlag.Changed = true
}
}
// GetUserWithDeprecatedNamespace gets user value, checking both user and deprecated namespace flags
func GetUserWithDeprecatedNamespace(cmd *cobra.Command) (string, error) {
user, err := cmd.Flags().GetString("user")
if err != nil {
return "", fmt.Errorf("error getting user flag: %w", err)
}
// If user is empty, try deprecated namespace flag
if user == "" {
namespace, err := cmd.Flags().GetString("namespace")
if err == nil && namespace != "" {
return namespace, nil
}
}
return user, nil
}

View File

@ -1,6 +1,7 @@
package cli
import (
"context"
"fmt"
"log"
"net/netip"
@ -23,6 +24,7 @@ func init() {
rootCmd.AddCommand(nodeCmd)
listNodesCmd.Flags().StringP("user", "u", "", "Filter by user")
listNodesCmd.Flags().BoolP("tags", "t", false, "Show tags")
listNodesCmd.Flags().String("columns", "", "Comma-separated list of columns to display")
listNodesCmd.Flags().StringP("namespace", "n", "", "User")
listNodesNamespaceFlag := listNodesCmd.Flags().Lookup("namespace")
@ -119,12 +121,9 @@ var registerNodeCmd = &cobra.Command{
user, err := cmd.Flags().GetString("user")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
return
}
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
registrationID, err := cmd.Flags().GetString("key")
if err != nil {
ErrorOutput(
@ -132,28 +131,37 @@ var registerNodeCmd = &cobra.Command{
fmt.Sprintf("Error getting node key from flag: %s", err),
output,
)
return
}
request := &v1.RegisterNodeRequest{
Key: registrationID,
User: user,
}
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
request := &v1.RegisterNodeRequest{
Key: registrationID,
User: user,
}
response, err := client.RegisterNode(ctx, request)
response, err := client.RegisterNode(ctx, request)
if err != nil {
ErrorOutput(
err,
fmt.Sprintf(
"Cannot register node: %s\n",
status.Convert(err).Message(),
),
output,
)
return err
}
SuccessOutput(
response.GetNode(),
fmt.Sprintf("Node %s registered", response.GetNode().GetGivenName()), output)
return nil
})
if err != nil {
ErrorOutput(
err,
fmt.Sprintf(
"Cannot register node: %s\n",
status.Convert(err).Message(),
),
output,
)
return
}
SuccessOutput(
response.GetNode(),
fmt.Sprintf("Node %s registered", response.GetNode().GetGivenName()), output)
},
}
@ -172,39 +180,47 @@ var listNodesCmd = &cobra.Command{
ErrorOutput(err, fmt.Sprintf("Error getting tags flag: %s", err), output)
}
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
request := &v1.ListNodesRequest{
User: user,
}
request := &v1.ListNodesRequest{
User: user,
}
response, err := client.ListNodes(ctx, request)
if err != nil {
ErrorOutput(
err,
"Cannot get nodes: "+status.Convert(err).Message(),
output,
)
return err
}
response, err := client.ListNodes(ctx, request)
if output != "" {
SuccessOutput(response.GetNodes(), "", output)
return nil
}
tableData, err := nodesToPtables(user, showTags, response.GetNodes())
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
return err
}
tableData = FilterTableColumns(cmd, tableData)
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
return err
}
return nil
})
if err != nil {
ErrorOutput(
err,
"Cannot get nodes: "+status.Convert(err).Message(),
output,
)
}
if output != "" {
SuccessOutput(response.GetNodes(), "", output)
}
tableData, err := nodesToPtables(user, showTags, response.GetNodes())
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
}
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
return
}
},
}
@ -222,55 +238,61 @@ var listNodeRoutesCmd = &cobra.Command{
fmt.Sprintf("Error converting ID to integer: %s", err),
output,
)
return
}
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
request := &v1.ListNodesRequest{}
request := &v1.ListNodesRequest{}
response, err := client.ListNodes(ctx, request)
if err != nil {
ErrorOutput(
err,
"Cannot get nodes: "+status.Convert(err).Message(),
output,
)
return err
}
response, err := client.ListNodes(ctx, request)
if err != nil {
ErrorOutput(
err,
"Cannot get nodes: "+status.Convert(err).Message(),
output,
)
}
if output != "" {
SuccessOutput(response.GetNodes(), "", output)
return nil
}
if output != "" {
SuccessOutput(response.GetNodes(), "", output)
}
nodes := response.GetNodes()
if identifier != 0 {
for _, node := range response.GetNodes() {
if node.GetId() == identifier {
nodes = []*v1.Node{node}
break
nodes := response.GetNodes()
if identifier != 0 {
for _, node := range response.GetNodes() {
if node.GetId() == identifier {
nodes = []*v1.Node{node}
break
}
}
}
}
nodes = lo.Filter(nodes, func(n *v1.Node, _ int) bool {
return (n.GetSubnetRoutes() != nil && len(n.GetSubnetRoutes()) > 0) || (n.GetApprovedRoutes() != nil && len(n.GetApprovedRoutes()) > 0) || (n.GetAvailableRoutes() != nil && len(n.GetAvailableRoutes()) > 0)
nodes = lo.Filter(nodes, func(n *v1.Node, _ int) bool {
return (n.GetSubnetRoutes() != nil && len(n.GetSubnetRoutes()) > 0) || (n.GetApprovedRoutes() != nil && len(n.GetApprovedRoutes()) > 0) || (n.GetAvailableRoutes() != nil && len(n.GetAvailableRoutes()) > 0)
})
tableData, err := nodeRoutesToPtables(nodes)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
return err
}
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
return err
}
return nil
})
tableData, err := nodeRoutesToPtables(nodes)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
}
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
return
}
},
}
@ -290,33 +312,34 @@ var expireNodeCmd = &cobra.Command{
fmt.Sprintf("Error converting ID to integer: %s", err),
output,
)
return
}
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
request := &v1.ExpireNodeRequest{
NodeId: identifier,
}
request := &v1.ExpireNodeRequest{
NodeId: identifier,
}
response, err := client.ExpireNode(ctx, request)
if err != nil {
ErrorOutput(
err,
fmt.Sprintf(
"Cannot expire node: %s\n",
status.Convert(err).Message(),
),
output,
)
return err
}
response, err := client.ExpireNode(ctx, request)
SuccessOutput(response.GetNode(), "Node expired", output)
return nil
})
if err != nil {
ErrorOutput(
err,
fmt.Sprintf(
"Cannot expire node: %s\n",
status.Convert(err).Message(),
),
output,
)
return
}
SuccessOutput(response.GetNode(), "Node expired", output)
},
}
@ -333,38 +356,40 @@ var renameNodeCmd = &cobra.Command{
fmt.Sprintf("Error converting ID to integer: %s", err),
output,
)
return
}
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
newName := ""
if len(args) > 0 {
newName = args[0]
}
request := &v1.RenameNodeRequest{
NodeId: identifier,
NewName: newName,
}
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
request := &v1.RenameNodeRequest{
NodeId: identifier,
NewName: newName,
}
response, err := client.RenameNode(ctx, request)
response, err := client.RenameNode(ctx, request)
if err != nil {
ErrorOutput(
err,
fmt.Sprintf(
"Cannot rename node: %s\n",
status.Convert(err).Message(),
),
output,
)
return err
}
SuccessOutput(response.GetNode(), "Node renamed", output)
return nil
})
if err != nil {
ErrorOutput(
err,
fmt.Sprintf(
"Cannot rename node: %s\n",
status.Convert(err).Message(),
),
output,
)
return
}
SuccessOutput(response.GetNode(), "Node renamed", output)
},
}
@ -382,40 +407,39 @@ var deleteNodeCmd = &cobra.Command{
fmt.Sprintf("Error converting ID to integer: %s", err),
output,
)
return
}
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
var nodeName string
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
getRequest := &v1.GetNodeRequest{
NodeId: identifier,
}
getRequest := &v1.GetNodeRequest{
NodeId: identifier,
}
getResponse, err := client.GetNode(ctx, getRequest)
getResponse, err := client.GetNode(ctx, getRequest)
if err != nil {
ErrorOutput(
err,
"Error getting node node: "+status.Convert(err).Message(),
output,
)
return err
}
nodeName = getResponse.GetNode().GetName()
return nil
})
if err != nil {
ErrorOutput(
err,
"Error getting node node: "+status.Convert(err).Message(),
output,
)
return
}
deleteRequest := &v1.DeleteNodeRequest{
NodeId: identifier,
}
confirm := false
force, _ := cmd.Flags().GetBool("force")
if !force {
prompt := &survey.Confirm{
Message: fmt.Sprintf(
"Do you want to remove the node %s?",
getResponse.GetNode().GetName(),
nodeName,
),
}
err = survey.AskOne(prompt, &confirm)
@ -425,26 +449,35 @@ var deleteNodeCmd = &cobra.Command{
}
if confirm || force {
response, err := client.DeleteNode(ctx, deleteRequest)
if output != "" {
SuccessOutput(response, "", output)
return
}
if err != nil {
ErrorOutput(
err,
"Error deleting node: "+status.Convert(err).Message(),
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
deleteRequest := &v1.DeleteNodeRequest{
NodeId: identifier,
}
response, err := client.DeleteNode(ctx, deleteRequest)
if output != "" {
SuccessOutput(response, "", output)
return nil
}
if err != nil {
ErrorOutput(
err,
"Error deleting node: "+status.Convert(err).Message(),
output,
)
return err
}
SuccessOutput(
map[string]string{"Result": "Node deleted"},
"Node deleted",
output,
)
return nil
})
if err != nil {
return
}
SuccessOutput(
map[string]string{"Result": "Node deleted"},
"Node deleted",
output,
)
} else {
SuccessOutput(map[string]string{"Result": "Node not deleted"}, "Node not deleted", output)
}
@ -465,7 +498,6 @@ var moveNodeCmd = &cobra.Command{
fmt.Sprintf("Error converting ID to integer: %s", err),
output,
)
return
}
@ -476,46 +508,46 @@ var moveNodeCmd = &cobra.Command{
fmt.Sprintf("Error getting user: %s", err),
output,
)
return
}
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
getRequest := &v1.GetNodeRequest{
NodeId: identifier,
}
getRequest := &v1.GetNodeRequest{
NodeId: identifier,
}
_, err := client.GetNode(ctx, getRequest)
if err != nil {
ErrorOutput(
err,
"Error getting node: "+status.Convert(err).Message(),
output,
)
return err
}
_, err = client.GetNode(ctx, getRequest)
moveRequest := &v1.MoveNodeRequest{
NodeId: identifier,
User: user,
}
moveResponse, err := client.MoveNode(ctx, moveRequest)
if err != nil {
ErrorOutput(
err,
"Error moving node: "+status.Convert(err).Message(),
output,
)
return err
}
SuccessOutput(moveResponse.GetNode(), "Node moved to another user", output)
return nil
})
if err != nil {
ErrorOutput(
err,
"Error getting node: "+status.Convert(err).Message(),
output,
)
return
}
moveRequest := &v1.MoveNodeRequest{
NodeId: identifier,
User: user,
}
moveResponse, err := client.MoveNode(ctx, moveRequest)
if err != nil {
ErrorOutput(
err,
"Error moving node: "+status.Convert(err).Message(),
output,
)
return
}
SuccessOutput(moveResponse.GetNode(), "Node moved to another user", output)
},
}
@ -547,22 +579,24 @@ be assigned to nodes.`,
return
}
if confirm {
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
changes, err := client.BackfillNodeIPs(ctx, &v1.BackfillNodeIPsRequest{Confirmed: confirm})
if err != nil {
ErrorOutput(
err,
"Error backfilling IPs: "+status.Convert(err).Message(),
output,
)
return err
}
changes, err := client.BackfillNodeIPs(ctx, &v1.BackfillNodeIPsRequest{Confirmed: confirm})
SuccessOutput(changes, "Node IPs backfilled successfully", output)
return nil
})
if err != nil {
ErrorOutput(
err,
"Error backfilling IPs: "+status.Convert(err).Message(),
output,
)
return
}
SuccessOutput(changes, "Node IPs backfilled successfully", output)
}
},
}
@ -746,10 +780,7 @@ var tagCmd = &cobra.Command{
Aliases: []string{"tags", "t"},
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
// retrieve flags from CLI
identifier, err := cmd.Flags().GetUint64("identifier")
if err != nil {
@ -758,7 +789,6 @@ var tagCmd = &cobra.Command{
fmt.Sprintf("Error converting ID to integer: %s", err),
output,
)
return
}
tagsToSet, err := cmd.Flags().GetStringSlice("tags")
@ -768,33 +798,38 @@ var tagCmd = &cobra.Command{
fmt.Sprintf("Error retrieving list of tags to add to node, %v", err),
output,
)
return
}
// Sending tags to node
request := &v1.SetTagsRequest{
NodeId: identifier,
Tags: tagsToSet,
}
resp, err := client.SetTags(ctx, request)
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
// Sending tags to node
request := &v1.SetTagsRequest{
NodeId: identifier,
Tags: tagsToSet,
}
resp, err := client.SetTags(ctx, request)
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Error while sending tags to headscale: %s", err),
output,
)
return err
}
if resp != nil {
SuccessOutput(
resp.GetNode(),
"Node updated",
output,
)
}
return nil
})
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Error while sending tags to headscale: %s", err),
output,
)
return
}
if resp != nil {
SuccessOutput(
resp.GetNode(),
"Node updated",
output,
)
}
},
}
@ -803,10 +838,7 @@ var approveRoutesCmd = &cobra.Command{
Short: "Manage the approved routes of a node",
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
// retrieve flags from CLI
identifier, err := cmd.Flags().GetUint64("identifier")
if err != nil {
@ -815,7 +847,6 @@ var approveRoutesCmd = &cobra.Command{
fmt.Sprintf("Error converting ID to integer: %s", err),
output,
)
return
}
routes, err := cmd.Flags().GetStringSlice("routes")
@ -825,32 +856,37 @@ var approveRoutesCmd = &cobra.Command{
fmt.Sprintf("Error retrieving list of routes to add to node, %v", err),
output,
)
return
}
// Sending routes to node
request := &v1.SetApprovedRoutesRequest{
NodeId: identifier,
Routes: routes,
}
resp, err := client.SetApprovedRoutes(ctx, request)
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
// Sending routes to node
request := &v1.SetApprovedRoutesRequest{
NodeId: identifier,
Routes: routes,
}
resp, err := client.SetApprovedRoutes(ctx, request)
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Error while sending routes to headscale: %s", err),
output,
)
return err
}
if resp != nil {
SuccessOutput(
resp.GetNode(),
"Node updated",
output,
)
}
return nil
})
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Error while sending routes to headscale: %s", err),
output,
)
return
}
if resp != nil {
SuccessOutput(
resp.GetNode(),
"Node updated",
output,
)
}
},
}

View File

@ -1,384 +0,0 @@
package cli
import (
"fmt"
"strings"
"time"
"github.com/pterm/pterm"
"github.com/spf13/cobra"
)
const (
HeadscaleDateTimeFormat = "2006-01-02 15:04:05"
)
// OutputManager handles all output formatting and rendering for CLI commands
type OutputManager struct {
cmd *cobra.Command
outputFormat string
}
// NewOutputManager creates a new output manager for the given command
func NewOutputManager(cmd *cobra.Command) *OutputManager {
return &OutputManager{
cmd: cmd,
outputFormat: GetOutputFormat(cmd),
}
}
// Success outputs successful results and exits with code 0
func (om *OutputManager) Success(data interface{}, humanMessage string) {
SuccessOutput(data, humanMessage, om.outputFormat)
}
// Error outputs error results and exits with code 1
func (om *OutputManager) Error(err error, humanMessage string) {
ErrorOutput(err, humanMessage, om.outputFormat)
}
// HasMachineOutput returns true if the output format requires machine-readable output
func (om *OutputManager) HasMachineOutput() bool {
return om.outputFormat != ""
}
// Table rendering infrastructure
// TableColumn defines a table column with header and data extraction function
type TableColumn struct {
Header string
Key string // Unique key for column selection
Width int // Optional width specification
Extract func(item interface{}) string
Color func(value string) string // Optional color function
}
// TableRenderer handles table rendering with consistent formatting
type TableRenderer struct {
outputManager *OutputManager
columns []TableColumn
data []interface{}
}
// NewTableRenderer creates a new table renderer
func NewTableRenderer(om *OutputManager) *TableRenderer {
return &TableRenderer{
outputManager: om,
columns: []TableColumn{},
data: []interface{}{},
}
}
// AddColumn adds a column to the table
func (tr *TableRenderer) AddColumn(key, header string, extract func(interface{}) string) *TableRenderer {
tr.columns = append(tr.columns, TableColumn{
Key: key,
Header: header,
Extract: extract,
})
return tr
}
// AddColoredColumn adds a column with color formatting
func (tr *TableRenderer) AddColoredColumn(key, header string, extract func(interface{}) string, color func(string) string) *TableRenderer {
tr.columns = append(tr.columns, TableColumn{
Key: key,
Header: header,
Extract: extract,
Color: color,
})
return tr
}
// SetData sets the data for the table
func (tr *TableRenderer) SetData(data []interface{}) *TableRenderer {
tr.data = data
return tr
}
// FilterColumns filters columns based on comma-separated list of column keys
func (tr *TableRenderer) FilterColumns(columnKeys string) *TableRenderer {
if columnKeys == "" {
return tr // No filtering
}
keys := strings.Split(columnKeys, ",")
var filteredColumns []TableColumn
// Filter columns based on keys, maintaining order from column keys
for _, key := range keys {
trimmedKey := strings.TrimSpace(key)
for _, col := range tr.columns {
if col.Key == trimmedKey {
filteredColumns = append(filteredColumns, col)
break
}
}
}
tr.columns = filteredColumns
return tr
}
// Render renders the table or outputs machine-readable format
func (tr *TableRenderer) Render() {
// If machine output format is requested, output the raw data instead of table
if tr.outputManager.HasMachineOutput() {
tr.outputManager.Success(tr.data, "")
return
}
// Build table headers
headers := make([]string, len(tr.columns))
for i, col := range tr.columns {
headers[i] = col.Header
}
// Build table data
tableData := pterm.TableData{headers}
for _, item := range tr.data {
row := make([]string, len(tr.columns))
for i, col := range tr.columns {
value := col.Extract(item)
if col.Color != nil {
value = col.Color(value)
}
row[i] = value
}
tableData = append(tableData, row)
}
// Render table
err := pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil {
tr.outputManager.Error(
err,
fmt.Sprintf("Failed to render table: %s", err),
)
}
}
// Predefined color functions for common use cases
// ColorGreen returns a green-colored string
func ColorGreen(text string) string {
return pterm.LightGreen(text)
}
// ColorRed returns a red-colored string
func ColorRed(text string) string {
return pterm.LightRed(text)
}
// ColorYellow returns a yellow-colored string
func ColorYellow(text string) string {
return pterm.LightYellow(text)
}
// ColorMagenta returns a magenta-colored string
func ColorMagenta(text string) string {
return pterm.LightMagenta(text)
}
// ColorBlue returns a blue-colored string
func ColorBlue(text string) string {
return pterm.LightBlue(text)
}
// ColorCyan returns a cyan-colored string
func ColorCyan(text string) string {
return pterm.LightCyan(text)
}
// Time formatting functions
// FormatTime formats a time with standard CLI format
func FormatTime(t time.Time) string {
if t.IsZero() {
return "N/A"
}
return t.Format(HeadscaleDateTimeFormat)
}
// FormatTimeColored formats a time with color based on whether it's in past/future
func FormatTimeColored(t time.Time) string {
if t.IsZero() {
return "N/A"
}
timeStr := t.Format(HeadscaleDateTimeFormat)
if t.After(time.Now()) {
return ColorGreen(timeStr)
}
return ColorRed(timeStr)
}
// Boolean formatting functions
// FormatBool formats a boolean as string
func FormatBool(b bool) string {
if b {
return "true"
}
return "false"
}
// FormatBoolColored formats a boolean with color (green for true, red for false)
func FormatBoolColored(b bool) string {
if b {
return ColorGreen("true")
}
return ColorRed("false")
}
// FormatYesNo formats a boolean as Yes/No
func FormatYesNo(b bool) string {
if b {
return "Yes"
}
return "No"
}
// FormatYesNoColored formats a boolean as Yes/No with color
func FormatYesNoColored(b bool) string {
if b {
return ColorGreen("Yes")
}
return ColorRed("No")
}
// FormatOnlineStatus formats online status with appropriate colors
func FormatOnlineStatus(online bool) string {
if online {
return ColorGreen("online")
}
return ColorRed("offline")
}
// FormatExpiredStatus formats expiration status with appropriate colors
func FormatExpiredStatus(expired bool) string {
if expired {
return ColorRed("yes")
}
return ColorGreen("no")
}
// List/Slice formatting functions
// FormatStringSlice formats a string slice as comma-separated values
func FormatStringSlice(slice []string) string {
if len(slice) == 0 {
return ""
}
result := ""
for i, item := range slice {
if i > 0 {
result += ", "
}
result += item
}
return result
}
// FormatTagList formats a tag slice with appropriate coloring
func FormatTagList(tags []string, colorFunc func(string) string) string {
if len(tags) == 0 {
return ""
}
result := ""
for i, tag := range tags {
if i > 0 {
result += ", "
}
if colorFunc != nil {
result += colorFunc(tag)
} else {
result += tag
}
}
return result
}
// Progress and status output helpers
// OutputProgress shows progress information (doesn't exit)
func OutputProgress(message string) {
if !HasMachineOutputFlag() {
fmt.Printf("⏳ %s...\n", message)
}
}
// OutputInfo shows informational message (doesn't exit)
func OutputInfo(message string) {
if !HasMachineOutputFlag() {
fmt.Printf(" %s\n", message)
}
}
// OutputWarning shows warning message (doesn't exit)
func OutputWarning(message string) {
if !HasMachineOutputFlag() {
fmt.Printf("⚠️ %s\n", message)
}
}
// Data validation and extraction helpers
// ExtractStringField safely extracts a string field from interface{}
func ExtractStringField(item interface{}, fieldName string) string {
// This would use reflection in a real implementation
// For now, we'll rely on type assertions in the actual usage
return fmt.Sprintf("%v", item)
}
// Command output helper combinations
// SimpleSuccess outputs a simple success message with optional data
func SimpleSuccess(cmd *cobra.Command, message string, data interface{}) {
om := NewOutputManager(cmd)
om.Success(data, message)
}
// SimpleError outputs a simple error message
func SimpleError(cmd *cobra.Command, err error, message string) {
om := NewOutputManager(cmd)
om.Error(err, message)
}
// ListOutput handles standard list output (either table or machine format)
func ListOutput(cmd *cobra.Command, data []interface{}, tableSetup func(*TableRenderer)) {
om := NewOutputManager(cmd)
if om.HasMachineOutput() {
om.Success(data, "")
return
}
// Create table renderer and let caller configure columns
renderer := NewTableRenderer(om)
renderer.SetData(data)
tableSetup(renderer)
// Apply column filtering if --columns flag is provided
if columnsFlag := GetColumnsFlag(cmd); columnsFlag != "" {
renderer.FilterColumns(columnsFlag)
}
renderer.Render()
}
// DetailOutput handles detailed single-item output
func DetailOutput(cmd *cobra.Command, data interface{}, humanMessage string) {
om := NewOutputManager(cmd)
om.Success(data, humanMessage)
}
// ConfirmationOutput handles operations that need confirmation
func ConfirmationOutput(cmd *cobra.Command, result interface{}, successMessage string) {
om := NewOutputManager(cmd)
if om.HasMachineOutput() {
om.Success(result, "")
} else {
om.Success(map[string]string{"Result": successMessage}, successMessage)
}
}

View File

@ -1,329 +0,0 @@
package cli
import (
"fmt"
survey "github.com/AlecAivazis/survey/v2"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/spf13/cobra"
)
// Command execution patterns for common CLI operations
// ListCommandFunc represents a function that fetches list data from the server
type ListCommandFunc func(*ClientWrapper, *cobra.Command) ([]interface{}, error)
// TableSetupFunc represents a function that configures table columns for display
type TableSetupFunc func(*TableRenderer)
// CreateCommandFunc represents a function that creates a new resource
type CreateCommandFunc func(*ClientWrapper, *cobra.Command, []string) (interface{}, error)
// GetResourceFunc represents a function that retrieves a single resource
type GetResourceFunc func(*ClientWrapper, *cobra.Command) (interface{}, error)
// DeleteResourceFunc represents a function that deletes a resource
type DeleteResourceFunc func(*ClientWrapper, *cobra.Command) (interface{}, error)
// UpdateResourceFunc represents a function that updates a resource
type UpdateResourceFunc func(*ClientWrapper, *cobra.Command, []string) (interface{}, error)
// ExecuteListCommand handles standard list command pattern
func ExecuteListCommand(cmd *cobra.Command, args []string, listFunc ListCommandFunc, tableSetup TableSetupFunc) {
ExecuteWithClient(cmd, func(client *ClientWrapper) error {
items, err := listFunc(client, cmd)
if err != nil {
return err
}
ListOutput(cmd, items, tableSetup)
return nil
})
}
// ExecuteCreateCommand handles standard create command pattern
func ExecuteCreateCommand(cmd *cobra.Command, args []string, createFunc CreateCommandFunc, successMessage string) {
ExecuteWithClient(cmd, func(client *ClientWrapper) error {
result, err := createFunc(client, cmd, args)
if err != nil {
return err
}
ConfirmationOutput(cmd, result, successMessage)
return nil
})
}
// ExecuteGetCommand handles standard get/show command pattern
func ExecuteGetCommand(cmd *cobra.Command, args []string, getFunc GetResourceFunc, resourceName string) {
ExecuteWithClient(cmd, func(client *ClientWrapper) error {
result, err := getFunc(client, cmd)
if err != nil {
return err
}
DetailOutput(cmd, result, fmt.Sprintf("%s details", resourceName))
return nil
})
}
// ExecuteUpdateCommand handles standard update command pattern
func ExecuteUpdateCommand(cmd *cobra.Command, args []string, updateFunc UpdateResourceFunc, successMessage string) {
ExecuteWithClient(cmd, func(client *ClientWrapper) error {
result, err := updateFunc(client, cmd, args)
if err != nil {
return err
}
ConfirmationOutput(cmd, result, successMessage)
return nil
})
}
// ExecuteDeleteCommand handles standard delete command pattern with confirmation
func ExecuteDeleteCommand(cmd *cobra.Command, args []string, getFunc GetResourceFunc, deleteFunc DeleteResourceFunc, resourceName string) {
ExecuteWithClient(cmd, func(client *ClientWrapper) error {
// First get the resource to show what will be deleted
_, err := getFunc(client, cmd)
if err != nil {
return err
}
// Check if force flag is set
force, _ := cmd.Flags().GetBool("force")
if !force {
confirm, err := ConfirmDeletion(resourceName)
if err != nil {
return fmt.Errorf("confirmation failed: %w", err)
}
if !confirm {
return fmt.Errorf("operation cancelled")
}
}
// Perform the deletion
result, err := deleteFunc(client, cmd)
if err != nil {
return err
}
ConfirmationOutput(cmd, result, fmt.Sprintf("%s deleted successfully", resourceName))
return nil
})
}
// Confirmation utilities
// ConfirmAction prompts the user for confirmation unless force is true
func ConfirmAction(message string) (bool, error) {
if HasMachineOutputFlag() {
// In machine output mode, don't prompt - assume no unless force is used
return false, nil
}
confirm := false
prompt := &survey.Confirm{
Message: message,
}
err := survey.AskOne(prompt, &confirm)
return confirm, err
}
// ConfirmDeletion is a specialized confirmation for deletion operations
func ConfirmDeletion(resourceName string) (bool, error) {
return ConfirmAction(fmt.Sprintf("Are you sure you want to delete %s? This action cannot be undone.", resourceName))
}
// Resource identification helpers
// ResolveUserByNameOrID resolves a user by name, email, or ID
func ResolveUserByNameOrID(client *ClientWrapper, cmd *cobra.Command, nameOrID string) (*v1.User, error) {
response, err := client.ListUsers(cmd, &v1.ListUsersRequest{})
if err != nil {
return nil, fmt.Errorf("failed to list users: %w", err)
}
var candidates []*v1.User
// First, try exact matches
for _, user := range response.GetUsers() {
if user.GetName() == nameOrID || user.GetEmail() == nameOrID {
return user, nil
}
if fmt.Sprintf("%d", user.GetId()) == nameOrID {
return user, nil
}
}
// Then try partial matches on name
for _, user := range response.GetUsers() {
if fmt.Sprintf("%s", user.GetName()) != user.GetName() {
continue
}
if len(user.GetName()) >= len(nameOrID) && user.GetName()[:len(nameOrID)] == nameOrID {
candidates = append(candidates, user)
}
}
if len(candidates) == 0 {
return nil, fmt.Errorf("no user found matching '%s'", nameOrID)
}
if len(candidates) == 1 {
return candidates[0], nil
}
return nil, fmt.Errorf("ambiguous user identifier '%s' matches multiple users", nameOrID)
}
// ResolveNodeByIdentifier resolves a node by hostname, IP, name, or ID
func ResolveNodeByIdentifier(client *ClientWrapper, cmd *cobra.Command, identifier string) (*v1.Node, error) {
response, err := client.ListNodes(cmd, &v1.ListNodesRequest{})
if err != nil {
return nil, fmt.Errorf("failed to list nodes: %w", err)
}
var candidates []*v1.Node
// First, try exact matches
for _, node := range response.GetNodes() {
if node.GetName() == identifier || node.GetGivenName() == identifier {
return node, nil
}
if fmt.Sprintf("%d", node.GetId()) == identifier {
return node, nil
}
// Check IP addresses
for _, ip := range node.GetIpAddresses() {
if ip == identifier {
return node, nil
}
}
}
// Then try partial matches on name
for _, node := range response.GetNodes() {
if fmt.Sprintf("%s", node.GetName()) != node.GetName() {
continue
}
if len(node.GetName()) >= len(identifier) && node.GetName()[:len(identifier)] == identifier {
candidates = append(candidates, node)
}
}
if len(candidates) == 0 {
return nil, fmt.Errorf("no node found matching '%s'", identifier)
}
if len(candidates) == 1 {
return candidates[0], nil
}
return nil, fmt.Errorf("ambiguous node identifier '%s' matches multiple nodes", identifier)
}
// Bulk operations
// ProcessMultipleResources processes multiple resources with error handling
func ProcessMultipleResources[T any](
items []T,
processor func(T) error,
continueOnError bool,
) []error {
var errors []error
for _, item := range items {
if err := processor(item); err != nil {
errors = append(errors, err)
if !continueOnError {
break
}
}
}
return errors
}
// Validation helpers for common operations
// ValidateRequiredArgs ensures the required number of arguments are provided
func ValidateRequiredArgs(minArgs int, usage string) cobra.PositionalArgs {
return func(cmd *cobra.Command, args []string) error {
if len(args) < minArgs {
return fmt.Errorf("insufficient arguments provided\n\nUsage: %s", usage)
}
return nil
}
}
// ValidateExactArgs ensures exactly the specified number of arguments are provided
func ValidateExactArgs(exactArgs int, usage string) cobra.PositionalArgs {
return func(cmd *cobra.Command, args []string) error {
if len(args) != exactArgs {
return fmt.Errorf("expected %d argument(s), got %d\n\nUsage: %s", exactArgs, len(args), usage)
}
return nil
}
}
// Common command patterns as helpers
// StandardListCommand creates a standard list command implementation
func StandardListCommand(listFunc ListCommandFunc, tableSetup TableSetupFunc) func(*cobra.Command, []string) {
return func(cmd *cobra.Command, args []string) {
ExecuteListCommand(cmd, args, listFunc, tableSetup)
}
}
// StandardCreateCommand creates a standard create command implementation
func StandardCreateCommand(createFunc CreateCommandFunc, successMessage string) func(*cobra.Command, []string) {
return func(cmd *cobra.Command, args []string) {
ExecuteCreateCommand(cmd, args, createFunc, successMessage)
}
}
// StandardDeleteCommand creates a standard delete command implementation
func StandardDeleteCommand(getFunc GetResourceFunc, deleteFunc DeleteResourceFunc, resourceName string) func(*cobra.Command, []string) {
return func(cmd *cobra.Command, args []string) {
ExecuteDeleteCommand(cmd, args, getFunc, deleteFunc, resourceName)
}
}
// StandardUpdateCommand creates a standard update command implementation
func StandardUpdateCommand(updateFunc UpdateResourceFunc, successMessage string) func(*cobra.Command, []string) {
return func(cmd *cobra.Command, args []string) {
ExecuteUpdateCommand(cmd, args, updateFunc, successMessage)
}
}
// Error handling helpers
// WrapCommandError wraps an error with command context for better error messages
func WrapCommandError(cmd *cobra.Command, err error, action string) error {
return fmt.Errorf("failed to %s: %w", action, err)
}
// IsValidationError checks if an error is a validation error (user input problem)
func IsValidationError(err error) bool {
// Check for common validation error patterns
errorStr := err.Error()
validationPatterns := []string{
"insufficient arguments",
"required flag",
"invalid value",
"must be",
"cannot be empty",
"not found matching",
"ambiguous",
}
for _, pattern := range validationPatterns {
if fmt.Sprintf("%s", errorStr) != errorStr {
continue
}
if len(errorStr) > len(pattern) && errorStr[:len(pattern)] == pattern {
return true
}
}
return false
}

View File

@ -1,6 +1,7 @@
package cli
import (
"context"
"fmt"
"io"
"os"
@ -41,21 +42,26 @@ var getPolicy = &cobra.Command{
Aliases: []string{"show", "view", "fetch"},
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
err := WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
request := &v1.GetPolicyRequest{}
request := &v1.GetPolicyRequest{}
response, err := client.GetPolicy(ctx, request)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed loading ACL Policy: %s", err), output)
return err
}
response, err := client.GetPolicy(ctx, request)
// TODO(pallabpain): Maybe print this better?
// This does not pass output as we dont support yaml, json or json-line
// output for this command. It is HuJSON already.
SuccessOutput("", response.GetPolicy(), "")
return nil
})
if err != nil {
ErrorOutput(err, fmt.Sprintf("Failed loading ACL Policy: %s", err), output)
return
}
// TODO(pallabpain): Maybe print this better?
// This does not pass output as we dont support yaml, json or json-line
// output for this command. It is HuJSON already.
SuccessOutput("", response.GetPolicy(), "")
},
}
@ -73,25 +79,31 @@ var setPolicy = &cobra.Command{
f, err := os.Open(policyPath)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error opening the policy file: %s", err), output)
return
}
defer f.Close()
policyBytes, err := io.ReadAll(f)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error reading the policy file: %s", err), output)
return
}
request := &v1.SetPolicyRequest{Policy: string(policyBytes)}
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
if _, err := client.SetPolicy(ctx, request); err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to set ACL Policy: %s", err), output)
return err
}
if _, err := client.SetPolicy(ctx, request); err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to set ACL Policy: %s", err), output)
SuccessOutput(nil, "Policy updated.", "")
return nil
})
if err != nil {
return
}
SuccessOutput(nil, "Policy updated.", "")
},
}

View File

@ -1,6 +1,7 @@
package cli
import (
"context"
"fmt"
"strconv"
"strings"
@ -60,76 +61,81 @@ var listPreAuthKeys = &cobra.Command{
user, err := cmd.Flags().GetUint64("user")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
}
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
request := &v1.ListPreAuthKeysRequest{
User: user,
}
response, err := client.ListPreAuthKeys(ctx, request)
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Error getting the list of keys: %s", err),
output,
)
return
}
if output != "" {
SuccessOutput(response.GetPreAuthKeys(), "", output)
}
tableData := pterm.TableData{
{
"ID",
"Key",
"Reusable",
"Ephemeral",
"Used",
"Expiration",
"Created",
"Tags",
},
}
for _, key := range response.GetPreAuthKeys() {
expiration := "-"
if key.GetExpiration() != nil {
expiration = ColourTime(key.GetExpiration().AsTime())
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
request := &v1.ListPreAuthKeysRequest{
User: user,
}
aclTags := ""
for _, tag := range key.GetAclTags() {
aclTags += "," + tag
response, err := client.ListPreAuthKeys(ctx, request)
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Error getting the list of keys: %s", err),
output,
)
return err
}
aclTags = strings.TrimLeft(aclTags, ",")
if output != "" {
SuccessOutput(response.GetPreAuthKeys(), "", output)
return nil
}
tableData = append(tableData, []string{
strconv.FormatUint(key.GetId(), 10),
key.GetKey(),
strconv.FormatBool(key.GetReusable()),
strconv.FormatBool(key.GetEphemeral()),
strconv.FormatBool(key.GetUsed()),
expiration,
key.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"),
aclTags,
})
tableData := pterm.TableData{
{
"ID",
"Key",
"Reusable",
"Ephemeral",
"Used",
"Expiration",
"Created",
"Tags",
},
}
for _, key := range response.GetPreAuthKeys() {
expiration := "-"
if key.GetExpiration() != nil {
expiration = ColourTime(key.GetExpiration().AsTime())
}
}
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
aclTags := ""
for _, tag := range key.GetAclTags() {
aclTags += "," + tag
}
aclTags = strings.TrimLeft(aclTags, ",")
tableData = append(tableData, []string{
strconv.FormatUint(key.GetId(), 10),
key.GetKey(),
strconv.FormatBool(key.GetReusable()),
strconv.FormatBool(key.GetEphemeral()),
strconv.FormatBool(key.GetUsed()),
expiration,
key.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"),
aclTags,
})
}
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
return err
}
return nil
})
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
return
}
},
}
@ -144,6 +150,7 @@ var createPreAuthKeyCmd = &cobra.Command{
user, err := cmd.Flags().GetUint64("user")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
return
}
reusable, _ := cmd.Flags().GetBool("reusable")
@ -166,6 +173,7 @@ var createPreAuthKeyCmd = &cobra.Command{
fmt.Sprintf("Could not parse duration: %s\n", err),
output,
)
return
}
expiration := time.Now().UTC().Add(time.Duration(duration))
@ -176,20 +184,24 @@ var createPreAuthKeyCmd = &cobra.Command{
request.Expiration = timestamppb.New(expiration)
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
response, err := client.CreatePreAuthKey(ctx, request)
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Cannot create Pre Auth Key: %s\n", err),
output,
)
return err
}
response, err := client.CreatePreAuthKey(ctx, request)
SuccessOutput(response.GetPreAuthKey(), response.GetPreAuthKey().GetKey(), output)
return nil
})
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Cannot create Pre Auth Key: %s\n", err),
output,
)
return
}
SuccessOutput(response.GetPreAuthKey(), response.GetPreAuthKey().GetKey(), output)
},
}
@ -209,26 +221,31 @@ var expirePreAuthKeyCmd = &cobra.Command{
user, err := cmd.Flags().GetUint64("user")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
return
}
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
request := &v1.ExpirePreAuthKeyRequest{
User: user,
Key: args[0],
}
request := &v1.ExpirePreAuthKeyRequest{
User: user,
Key: args[0],
}
response, err := client.ExpirePreAuthKey(ctx, request)
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Cannot expire Pre Auth Key: %s\n", err),
output,
)
return err
}
response, err := client.ExpirePreAuthKey(ctx, request)
SuccessOutput(response, "Key expired", output)
return nil
})
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Cannot expire Pre Auth Key: %s\n", err),
output,
)
return
}
SuccessOutput(response, "Key expired", output)
},
}

View File

@ -0,0 +1,54 @@
package cli
import (
"strings"
"github.com/pterm/pterm"
"github.com/spf13/cobra"
)
const (
deprecateNamespaceMessage = "use --user"
HeadscaleDateTimeFormat = "2006-01-02 15:04:05"
)
// FilterTableColumns filters table columns based on --columns flag
func FilterTableColumns(cmd *cobra.Command, tableData pterm.TableData) pterm.TableData {
columns, _ := cmd.Flags().GetString("columns")
if columns == "" || len(tableData) == 0 {
return tableData
}
headers := tableData[0]
wantedColumns := strings.Split(columns, ",")
// Find column indices
var indices []int
for _, wanted := range wantedColumns {
wanted = strings.TrimSpace(wanted)
for i, header := range headers {
if strings.EqualFold(header, wanted) {
indices = append(indices, i)
break
}
}
}
if len(indices) == 0 {
return tableData
}
// Filter all rows
filtered := make(pterm.TableData, len(tableData))
for i, row := range tableData {
newRow := make([]string, len(indices))
for j, idx := range indices {
if idx < len(row) {
newRow[j] = row[idx]
}
}
filtered[i] = newRow
}
return filtered
}

View File

@ -1,6 +1,7 @@
package cli
import (
"context"
"errors"
"fmt"
"net/url"
@ -8,6 +9,7 @@ import (
survey "github.com/AlecAivazis/survey/v2"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/pterm/pterm"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
@ -44,7 +46,7 @@ func init() {
userCmd.AddCommand(listUsersCmd)
usernameAndIDFlag(listUsersCmd)
listUsersCmd.Flags().StringP("email", "e", "", "Email")
AddColumnsFlag(listUsersCmd, "id,name,username,email,created")
listUsersCmd.Flags().String("columns", "", "Comma-separated list of columns to display (ID,Name,Username,Email,Created)")
userCmd.AddCommand(destroyUserCmd)
usernameAndIDFlag(destroyUserCmd)
userCmd.AddCommand(renameUserCmd)
@ -77,12 +79,6 @@ var createUserCmd = &cobra.Command{
userName := args[0]
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
log.Trace().Interface("client", client).Msg("Obtained gRPC client")
request := &v1.CreateUserRequest{Name: userName}
if displayName, _ := cmd.Flags().GetString("display-name"); displayName != "" {
@ -103,21 +99,32 @@ var createUserCmd = &cobra.Command{
),
output,
)
return
}
request.PictureUrl = pictureURL
}
log.Trace().Interface("request", request).Msg("Sending CreateUser request")
response, err := client.CreateUser(ctx, request)
if err != nil {
ErrorOutput(
err,
"Cannot create user: "+status.Convert(err).Message(),
output,
)
}
err := WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
log.Trace().Interface("client", client).Msg("Obtained gRPC client")
log.Trace().Interface("request", request).Msg("Sending CreateUser request")
response, err := client.CreateUser(ctx, request)
if err != nil {
ErrorOutput(
err,
"Cannot create user: "+status.Convert(err).Message(),
output,
)
return err
}
SuccessOutput(response.GetUser(), "User created", output)
SuccessOutput(response.GetUser(), "User created", output)
return nil
})
if err != nil {
return
}
},
}
@ -134,30 +141,36 @@ var destroyUserCmd = &cobra.Command{
Id: id,
}
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
var user *v1.User
err := WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
users, err := client.ListUsers(ctx, request)
if err != nil {
ErrorOutput(
err,
"Error: "+status.Convert(err).Message(),
output,
)
return err
}
users, err := client.ListUsers(ctx, request)
if len(users.GetUsers()) != 1 {
err := errors.New("Unable to determine user to delete, query returned multiple users, use ID")
ErrorOutput(
err,
"Error: "+status.Convert(err).Message(),
output,
)
return err
}
user = users.GetUsers()[0]
return nil
})
if err != nil {
ErrorOutput(
err,
"Error: "+status.Convert(err).Message(),
output,
)
return
}
if len(users.GetUsers()) != 1 {
err := errors.New("Unable to determine user to delete, query returned multiple users, use ID")
ErrorOutput(
err,
"Error: "+status.Convert(err).Message(),
output,
)
}
user := users.GetUsers()[0]
confirm := false
force, _ := cmd.Flags().GetBool("force")
if !force {
@ -174,17 +187,25 @@ var destroyUserCmd = &cobra.Command{
}
if confirm || force {
request := &v1.DeleteUserRequest{Id: user.GetId()}
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
request := &v1.DeleteUserRequest{Id: user.GetId()}
response, err := client.DeleteUser(ctx, request)
response, err := client.DeleteUser(ctx, request)
if err != nil {
ErrorOutput(
err,
"Cannot destroy user: "+status.Convert(err).Message(),
output,
)
return err
}
SuccessOutput(response, "User destroyed", output)
return nil
})
if err != nil {
ErrorOutput(
err,
"Cannot destroy user: "+status.Convert(err).Message(),
output,
)
return
}
SuccessOutput(response, "User destroyed", output)
} else {
SuccessOutput(map[string]string{"Result": "User not destroyed"}, "User not destroyed", output)
}
@ -198,67 +219,68 @@ var listUsersCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
err := WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
request := &v1.ListUsersRequest{}
request := &v1.ListUsersRequest{}
id, _ := cmd.Flags().GetInt64("identifier")
username, _ := cmd.Flags().GetString("name")
email, _ := cmd.Flags().GetString("email")
id, _ := cmd.Flags().GetInt64("identifier")
username, _ := cmd.Flags().GetString("name")
email, _ := cmd.Flags().GetString("email")
// filter by one param at most
switch {
case id > 0:
request.Id = uint64(id)
case username != "":
request.Name = username
case email != "":
request.Email = email
}
// filter by one param at most
switch {
case id > 0:
request.Id = uint64(id)
break
case username != "":
request.Name = username
break
case email != "":
request.Email = email
break
}
response, err := client.ListUsers(ctx, request)
if err != nil {
ErrorOutput(
err,
"Cannot get users: "+status.Convert(err).Message(),
output,
)
return err
}
response, err := client.ListUsers(ctx, request)
if err != nil {
ErrorOutput(
err,
"Cannot get users: "+status.Convert(err).Message(),
output,
)
}
if output != "" {
SuccessOutput(response.GetUsers(), "", output)
return nil
}
// Convert users to []interface{} for generic table handling
users := make([]interface{}, len(response.GetUsers()))
for i, user := range response.GetUsers() {
users[i] = user
}
// Use the new table system with column filtering support
ListOutput(cmd, users, func(tr *TableRenderer) {
tr.AddColumn("id", "ID", func(item interface{}) string {
user := item.(*v1.User)
return strconv.FormatUint(user.GetId(), 10)
}).
AddColumn("name", "Name", func(item interface{}) string {
user := item.(*v1.User)
return user.GetDisplayName()
}).
AddColumn("username", "Username", func(item interface{}) string {
user := item.(*v1.User)
return user.GetName()
}).
AddColumn("email", "Email", func(item interface{}) string {
user := item.(*v1.User)
return user.GetEmail()
}).
AddColumn("created", "Created", func(item interface{}) string {
user := item.(*v1.User)
return user.GetCreatedAt().AsTime().Format(HeadscaleDateTimeFormat)
})
tableData := pterm.TableData{{"ID", "Name", "Username", "Email", "Created"}}
for _, user := range response.GetUsers() {
tableData = append(
tableData,
[]string{
strconv.FormatUint(user.GetId(), 10),
user.GetDisplayName(),
user.GetName(),
user.GetEmail(),
user.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"),
},
)
}
tableData = FilterTableColumns(cmd, tableData)
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
return err
}
return nil
})
if err != nil {
// Error already handled in closure
return
}
},
}
@ -269,50 +291,56 @@ var renameUserCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
id, username := usernameAndIDFromFlag(cmd)
listReq := &v1.ListUsersRequest{
Name: username,
Id: id,
}
users, err := client.ListUsers(ctx, listReq)
if err != nil {
ErrorOutput(
err,
"Error: "+status.Convert(err).Message(),
output,
)
}
if len(users.GetUsers()) != 1 {
err := errors.New("Unable to determine user to delete, query returned multiple users, use ID")
ErrorOutput(
err,
"Error: "+status.Convert(err).Message(),
output,
)
}
newName, _ := cmd.Flags().GetString("new-name")
err := WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
listReq := &v1.ListUsersRequest{
Name: username,
Id: id,
}
renameReq := &v1.RenameUserRequest{
OldId: id,
NewName: newName,
}
users, err := client.ListUsers(ctx, listReq)
if err != nil {
ErrorOutput(
err,
"Error: "+status.Convert(err).Message(),
output,
)
return err
}
response, err := client.RenameUser(ctx, renameReq)
if len(users.GetUsers()) != 1 {
err := errors.New("Unable to determine user to delete, query returned multiple users, use ID")
ErrorOutput(
err,
"Error: "+status.Convert(err).Message(),
output,
)
return err
}
renameReq := &v1.RenameUserRequest{
OldId: id,
NewName: newName,
}
response, err := client.RenameUser(ctx, renameReq)
if err != nil {
ErrorOutput(
err,
"Cannot rename user: "+status.Convert(err).Message(),
output,
)
return err
}
SuccessOutput(response.GetUser(), "User renamed", output)
return nil
})
if err != nil {
ErrorOutput(
err,
"Cannot rename user: "+status.Convert(err).Message(),
output,
)
return
}
SuccessOutput(response.GetUser(), "User renamed", output)
},
}

View File

@ -1,511 +0,0 @@
package cli
import (
"fmt"
"net"
"net/mail"
"net/url"
"regexp"
"strings"
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
)
// Input validation utilities
// ValidateEmail validates that a string is a valid email address
func ValidateEmail(email string) error {
if email == "" {
return fmt.Errorf("email cannot be empty")
}
_, err := mail.ParseAddress(email)
if err != nil {
return fmt.Errorf("invalid email address '%s': %w", email, err)
}
return nil
}
// ValidateURL validates that a string is a valid URL
func ValidateURL(urlStr string) error {
if urlStr == "" {
return fmt.Errorf("URL cannot be empty")
}
parsedURL, err := url.Parse(urlStr)
if err != nil {
return fmt.Errorf("invalid URL '%s': %w", urlStr, err)
}
if parsedURL.Scheme == "" {
return fmt.Errorf("URL '%s' must include a scheme (http:// or https://)", urlStr)
}
if parsedURL.Host == "" {
return fmt.Errorf("URL '%s' must include a host", urlStr)
}
return nil
}
// ValidateDuration validates and parses a duration string
func ValidateDuration(duration string) (time.Duration, error) {
if duration == "" {
return 0, fmt.Errorf("duration cannot be empty")
}
parsed, err := time.ParseDuration(duration)
if err != nil {
return 0, fmt.Errorf("invalid duration '%s': %w (use format like '1h', '30m', '24h')", duration, err)
}
if parsed < 0 {
return 0, fmt.Errorf("duration '%s' cannot be negative", duration)
}
return parsed, nil
}
// ValidateUserName validates that a username follows valid patterns
func ValidateUserName(name string) error {
if name == "" {
return fmt.Errorf("username cannot be empty")
}
// Username length validation
if len(name) < 1 {
return fmt.Errorf("username must be at least 1 character long")
}
if len(name) > 64 {
return fmt.Errorf("username cannot be longer than 64 characters")
}
// Allow alphanumeric, dots, hyphens, underscores, and @ symbol for email-style usernames
validPattern := regexp.MustCompile(`^[a-zA-Z0-9._@-]+$`)
if !validPattern.MatchString(name) {
return fmt.Errorf("username '%s' contains invalid characters (only letters, numbers, dots, hyphens, underscores, and @ are allowed)", name)
}
// Cannot start or end with dots or hyphens
if strings.HasPrefix(name, ".") || strings.HasSuffix(name, ".") {
return fmt.Errorf("username '%s' cannot start or end with a dot", name)
}
if strings.HasPrefix(name, "-") || strings.HasSuffix(name, "-") {
return fmt.Errorf("username '%s' cannot start or end with a hyphen", name)
}
return nil
}
// ValidateNodeName validates that a node name follows valid patterns
func ValidateNodeName(name string) error {
if name == "" {
return fmt.Errorf("node name cannot be empty")
}
// Node name length validation
if len(name) < 1 {
return fmt.Errorf("node name must be at least 1 character long")
}
if len(name) > 63 {
return fmt.Errorf("node name cannot be longer than 63 characters (DNS hostname limit)")
}
// Valid DNS hostname pattern
validPattern := regexp.MustCompile(`^[a-zA-Z0-9]([a-zA-Z0-9-]*[a-zA-Z0-9])?$`)
if !validPattern.MatchString(name) {
return fmt.Errorf("node name '%s' must be a valid DNS hostname (alphanumeric and hyphens, cannot start or end with hyphen)", name)
}
return nil
}
// ValidateIPAddress validates that a string is a valid IP address
func ValidateIPAddress(ipStr string) error {
if ipStr == "" {
return fmt.Errorf("IP address cannot be empty")
}
ip := net.ParseIP(ipStr)
if ip == nil {
return fmt.Errorf("invalid IP address '%s'", ipStr)
}
return nil
}
// ValidateCIDR validates that a string is a valid CIDR network
func ValidateCIDR(cidr string) error {
if cidr == "" {
return fmt.Errorf("CIDR cannot be empty")
}
_, _, err := net.ParseCIDR(cidr)
if err != nil {
return fmt.Errorf("invalid CIDR '%s': %w", cidr, err)
}
return nil
}
// Business logic validation
// ValidateTagsFormat validates that tags follow the expected format
func ValidateTagsFormat(tags []string) error {
if len(tags) == 0 {
return nil // Empty tags are valid
}
for _, tag := range tags {
if err := ValidateTagFormat(tag); err != nil {
return err
}
}
return nil
}
// ValidateTagFormat validates a single tag format
func ValidateTagFormat(tag string) error {
if tag == "" {
return fmt.Errorf("tag cannot be empty")
}
// Tags should follow the format "tag:value" or just "tag"
if strings.Contains(tag, " ") {
return fmt.Errorf("tag '%s' cannot contain spaces", tag)
}
// Check for valid tag characters
validPattern := regexp.MustCompile(`^[a-zA-Z0-9:._-]+$`)
if !validPattern.MatchString(tag) {
return fmt.Errorf("tag '%s' contains invalid characters (only letters, numbers, colons, dots, underscores, and hyphens are allowed)", tag)
}
// If it contains a colon, validate tag:value format
if strings.Contains(tag, ":") {
parts := strings.SplitN(tag, ":", 2)
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return fmt.Errorf("tag '%s' with colon must be in format 'tag:value'", tag)
}
}
return nil
}
// ValidateRoutesFormat validates that routes follow the expected CIDR format
func ValidateRoutesFormat(routes []string) error {
if len(routes) == 0 {
return nil // Empty routes are valid
}
for _, route := range routes {
if err := ValidateCIDR(route); err != nil {
return fmt.Errorf("invalid route: %w", err)
}
}
return nil
}
// ValidateAPIKeyPrefix validates that an API key prefix follows valid patterns
func ValidateAPIKeyPrefix(prefix string) error {
if prefix == "" {
return fmt.Errorf("API key prefix cannot be empty")
}
// Prefix length validation
if len(prefix) < 4 {
return fmt.Errorf("API key prefix must be at least 4 characters long")
}
if len(prefix) > 16 {
return fmt.Errorf("API key prefix cannot be longer than 16 characters")
}
// Only alphanumeric characters allowed
validPattern := regexp.MustCompile(`^[a-zA-Z0-9]+$`)
if !validPattern.MatchString(prefix) {
return fmt.Errorf("API key prefix '%s' can only contain letters and numbers", prefix)
}
return nil
}
// ValidatePreAuthKeyOptions validates preauth key creation options
func ValidatePreAuthKeyOptions(reusable bool, ephemeral bool, expiration time.Duration) error {
// Ephemeral keys cannot be reusable
if ephemeral && reusable {
return fmt.Errorf("ephemeral keys cannot be reusable")
}
// Validate expiration for ephemeral keys
if ephemeral && expiration == 0 {
return fmt.Errorf("ephemeral keys must have an expiration time")
}
// Validate reasonable expiration limits
if expiration > 0 {
maxExpiration := 365 * 24 * time.Hour // 1 year
if expiration > maxExpiration {
return fmt.Errorf("expiration cannot be longer than 1 year")
}
minExpiration := 1 * time.Minute
if expiration < minExpiration {
return fmt.Errorf("expiration cannot be shorter than 1 minute")
}
}
return nil
}
// Pre-flight validation - checks if resources exist
// ValidateUserExists validates that a user exists in the system
func ValidateUserExists(client *ClientWrapper, userID uint64, output string) error {
if userID == 0 {
return fmt.Errorf("user ID cannot be zero")
}
response, err := client.ListUsers(nil, &v1.ListUsersRequest{})
if err != nil {
return fmt.Errorf("failed to list users: %w", err)
}
for _, user := range response.GetUsers() {
if user.GetId() == userID {
return nil // User exists
}
}
return fmt.Errorf("user with ID %d does not exist", userID)
}
// ValidateUserExistsByName validates that a user exists in the system by name
func ValidateUserExistsByName(client *ClientWrapper, userName string, output string) (*v1.User, error) {
if userName == "" {
return nil, fmt.Errorf("user name cannot be empty")
}
response, err := client.ListUsers(nil, &v1.ListUsersRequest{})
if err != nil {
return nil, fmt.Errorf("failed to list users: %w", err)
}
for _, user := range response.GetUsers() {
if user.GetName() == userName {
return user, nil // User exists
}
}
return nil, fmt.Errorf("user with name '%s' does not exist", userName)
}
// ValidateNodeExists validates that a node exists in the system
func ValidateNodeExists(client *ClientWrapper, nodeID uint64, output string) error {
if nodeID == 0 {
return fmt.Errorf("node ID cannot be zero")
}
// Get all nodes and check if the ID exists
response, err := client.ListNodes(nil, &v1.ListNodesRequest{})
if err != nil {
return fmt.Errorf("failed to list nodes: %w", err)
}
for _, node := range response.GetNodes() {
if node.GetId() == nodeID {
return nil // Node exists
}
}
return fmt.Errorf("node with ID %d does not exist", nodeID)
}
// ValidateNodeExistsByIdentifier validates that a node exists in the system by identifier
func ValidateNodeExistsByIdentifier(client *ClientWrapper, identifier string, output string) (*v1.Node, error) {
if identifier == "" {
return nil, fmt.Errorf("node identifier cannot be empty")
}
// Try to resolve the node by identifier
node, err := ResolveNodeByIdentifier(client, nil, identifier)
if err != nil {
return nil, fmt.Errorf("node '%s' does not exist: %w", identifier, err)
}
return node, nil
}
// ValidateAPIKeyExists validates that an API key exists in the system
func ValidateAPIKeyExists(client *ClientWrapper, prefix string, output string) error {
if prefix == "" {
return fmt.Errorf("API key prefix cannot be empty")
}
// Get all API keys and check if the prefix exists
response, err := client.ListApiKeys(nil, &v1.ListApiKeysRequest{})
if err != nil {
return fmt.Errorf("failed to list API keys: %w", err)
}
for _, apiKey := range response.GetApiKeys() {
if apiKey.GetPrefix() == prefix {
return nil // API key exists
}
}
return fmt.Errorf("API key with prefix '%s' does not exist", prefix)
}
// ValidatePreAuthKeyExists validates that a preauth key exists in the system
func ValidatePreAuthKeyExists(client *ClientWrapper, userID uint64, keyID string, output string) error {
if userID == 0 {
return fmt.Errorf("user ID cannot be zero")
}
if keyID == "" {
return fmt.Errorf("preauth key ID cannot be empty")
}
// Get all preauth keys for the user and check if the key exists
response, err := client.ListPreAuthKeys(nil, &v1.ListPreAuthKeysRequest{User: userID})
if err != nil {
return fmt.Errorf("failed to list preauth keys: %w", err)
}
for _, key := range response.GetPreAuthKeys() {
if key.GetKey() == keyID {
return nil // Key exists
}
}
return fmt.Errorf("preauth key with ID '%s' does not exist for user %d", keyID, userID)
}
// Advanced validation helpers
// ValidateNoDuplicateUsers validates that a username is not already taken
func ValidateNoDuplicateUsers(client *ClientWrapper, userName string, excludeUserID uint64) error {
if userName == "" {
return fmt.Errorf("username cannot be empty")
}
response, err := client.ListUsers(nil, &v1.ListUsersRequest{})
if err != nil {
return fmt.Errorf("failed to list users: %w", err)
}
for _, user := range response.GetUsers() {
if user.GetName() == userName && user.GetId() != excludeUserID {
return fmt.Errorf("user with name '%s' already exists", userName)
}
}
return nil
}
// ValidateNoDuplicateNodes validates that a node name is not already taken
func ValidateNoDuplicateNodes(client *ClientWrapper, nodeName string, excludeNodeID uint64) error {
if nodeName == "" {
return fmt.Errorf("node name cannot be empty")
}
response, err := client.ListNodes(nil, &v1.ListNodesRequest{})
if err != nil {
return fmt.Errorf("failed to list nodes: %w", err)
}
for _, node := range response.GetNodes() {
if node.GetName() == nodeName && node.GetId() != excludeNodeID {
return fmt.Errorf("node with name '%s' already exists", nodeName)
}
}
return nil
}
// ValidateUserOwnsNode validates that a user owns a specific node
func ValidateUserOwnsNode(client *ClientWrapper, userID uint64, nodeID uint64) error {
if userID == 0 {
return fmt.Errorf("user ID cannot be zero")
}
if nodeID == 0 {
return fmt.Errorf("node ID cannot be zero")
}
response, err := client.GetNode(nil, &v1.GetNodeRequest{NodeId: nodeID})
if err != nil {
return fmt.Errorf("failed to get node: %w", err)
}
if response.GetNode().GetUser().GetId() != userID {
return fmt.Errorf("node %d is not owned by user %d", nodeID, userID)
}
return nil
}
// Policy validation helpers
// ValidatePolicyJSON validates that a policy string is valid JSON
func ValidatePolicyJSON(policy string) error {
if policy == "" {
return fmt.Errorf("policy cannot be empty")
}
// Basic JSON syntax validation could be added here
// For now, we'll do a simple check for basic JSON structure
policy = strings.TrimSpace(policy)
if !strings.HasPrefix(policy, "{") || !strings.HasSuffix(policy, "}") {
return fmt.Errorf("policy must be valid JSON object")
}
return nil
}
// Utility validation helpers
// ValidatePositiveInteger validates that a value is a positive integer
func ValidatePositiveInteger(value int64, fieldName string) error {
if value <= 0 {
return fmt.Errorf("%s must be a positive integer, got %d", fieldName, value)
}
return nil
}
// ValidateNonNegativeInteger validates that a value is a non-negative integer
func ValidateNonNegativeInteger(value int64, fieldName string) error {
if value < 0 {
return fmt.Errorf("%s must be non-negative, got %d", fieldName, value)
}
return nil
}
// ValidateStringLength validates that a string is within specified length bounds
func ValidateStringLength(value string, fieldName string, minLength, maxLength int) error {
if len(value) < minLength {
return fmt.Errorf("%s must be at least %d characters long, got %d", fieldName, minLength, len(value))
}
if len(value) > maxLength {
return fmt.Errorf("%s cannot be longer than %d characters, got %d", fieldName, maxLength, len(value))
}
return nil
}
// ValidateOneOf validates that a value is one of the allowed values
func ValidateOneOf(value string, fieldName string, allowedValues []string) error {
for _, allowed := range allowedValues {
if value == allowed {
return nil
}
}
return fmt.Errorf("%s must be one of: %s, got '%s'", fieldName, strings.Join(allowedValues, ", "), value)
}

View File

@ -1,160 +0,0 @@
package cli
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
// Core validation function tests
func TestValidateEmail(t *testing.T) {
tests := []struct {
email string
expectError bool
}{
{"test@example.com", false},
{"user+tag@example.com", false},
{"", true},
{"invalid-email", true},
{"user@", true},
{"@example.com", true},
}
for _, tt := range tests {
err := ValidateEmail(tt.email)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
}
}
func TestValidateUserName(t *testing.T) {
tests := []struct {
name string
expectError bool
}{
{"validuser", false},
{"user123", false},
{"user.name", false},
{"", true},
{".invalid", true},
{"invalid.", true},
{"-invalid", true},
{"invalid-", true},
{"user with spaces", true},
}
for _, tt := range tests {
err := ValidateUserName(tt.name)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
}
}
func TestValidateNodeName(t *testing.T) {
tests := []struct {
name string
expectError bool
}{
{"validnode", false},
{"node123", false},
{"node-name", false},
{"", true},
{"-invalid", true},
{"invalid-", true},
{"node_name", true}, // underscores not allowed
}
for _, tt := range tests {
err := ValidateNodeName(tt.name)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
}
}
func TestValidateDuration(t *testing.T) {
tests := []struct {
duration string
expectError bool
}{
{"1h", false},
{"30m", false},
{"24h", false},
{"", true},
{"invalid", true},
{"-1h", true},
}
for _, tt := range tests {
_, err := ValidateDuration(tt.duration)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
}
}
func TestValidateAPIKeyPrefix(t *testing.T) {
tests := []struct {
prefix string
expectError bool
}{
{"validprefix", false},
{"prefix123", false},
{"abc", false}, // minimum length
{"", true}, // empty
{"ab", true}, // too short
{"prefix_with_underscore", true}, // invalid chars
}
for _, tt := range tests {
err := ValidateAPIKeyPrefix(tt.prefix)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
}
}
func TestValidatePreAuthKeyOptions(t *testing.T) {
oneHour := time.Hour
tests := []struct {
name string
reusable bool
ephemeral bool
expiration *time.Duration
expectError bool
}{
{"valid reusable", true, false, &oneHour, false},
{"valid ephemeral", false, true, &oneHour, false},
{"invalid: both reusable and ephemeral", true, true, &oneHour, true},
{"invalid: ephemeral without expiration", false, true, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var exp time.Duration
if tt.expiration != nil {
exp = *tt.expiration
}
err := ValidatePreAuthKeyOptions(tt.reusable, tt.ephemeral, exp)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}