diff --git a/cmd/headscale/cli/COLUMN_FILTERING.md b/cmd/headscale/cli/COLUMN_FILTERING.md deleted file mode 100644 index e17fc2f9..00000000 --- a/cmd/headscale/cli/COLUMN_FILTERING.md +++ /dev/null @@ -1,63 +0,0 @@ -# Column Filtering for Table Output - -## Overview - -All CLI commands that output tables now support a `--columns` flag to customize which columns are displayed. - -## Usage - -```bash -# Show all default columns -headscale users list - -# Show only name and email -headscale users list --columns="name,email" - -# Show only ID and username -headscale users list --columns="id,username" - -# Show columns in custom order -headscale users list --columns="email,name,id" -``` - -## Available Columns - -### Users List -- `id` - User ID -- `name` - Display name -- `username` - Username -- `email` - Email address -- `created` - Creation date - -### Implementation Pattern - -For developers adding this to other commands: - -```go -// 1. Add columns flag with default columns -AddColumnsFlag(cmd, "id,name,hostname,ip,status") - -// 2. Use ListOutput with TableRenderer -ListOutput(cmd, items, func(tr *TableRenderer) { - tr.AddColumn("id", "ID", func(item interface{}) string { - node := item.(*v1.Node) - return strconv.FormatUint(node.GetId(), 10) - }). - AddColumn("name", "Name", func(item interface{}) string { - node := item.(*v1.Node) - return node.GetName() - }). - AddColumn("hostname", "Hostname", func(item interface{}) string { - node := item.(*v1.Node) - return node.GetHostname() - }) - // ... add more columns -}) -``` - -## Notes - -- Column filtering only applies to table output, not JSON/YAML output -- Invalid column names are silently ignored -- Columns appear in the order specified in the --columns flag -- Default columns are defined per command based on most useful information \ No newline at end of file diff --git a/cmd/headscale/cli/REFACTORING_SUMMARY.md b/cmd/headscale/cli/REFACTORING_SUMMARY.md deleted file mode 100644 index bdd5a345..00000000 --- a/cmd/headscale/cli/REFACTORING_SUMMARY.md +++ /dev/null @@ -1,321 +0,0 @@ -# Headscale CLI Infrastructure Refactoring - Completed - -## Overview - -Successfully completed a comprehensive refactoring of the Headscale CLI infrastructure following the CLI_IMPROVEMENT_PLAN.md. The refactoring created a robust, type-safe, and maintainable CLI framework that significantly reduces code duplication while improving consistency and testability. - -## ✅ Completed Infrastructure Components - -### 1. **CLI Unit Testing Infrastructure** -- **Files**: `testing.go`, `testing_test.go` -- **Features**: Mock gRPC client, command execution helpers, test data creation utilities -- **Impact**: Enables comprehensive unit testing of all CLI commands -- **Lines of Code**: ~750 lines of testing infrastructure - -### 2. **Common Flag Infrastructure** -- **Files**: `flags.go`, `flags_test.go` -- **Features**: Standardized flag helpers, consistent shortcuts, validation helpers -- **Impact**: Consistent flag handling across all commands -- **Lines of Code**: ~200 lines of flag utilities - -### 3. **gRPC Client Infrastructure** -- **Files**: `client.go`, `client_test.go` -- **Features**: ClientWrapper with automatic connection management, error handling -- **Impact**: Simplified gRPC client usage with consistent error handling -- **Lines of Code**: ~400 lines of client infrastructure - -### 4. **Output Infrastructure** -- **Files**: `output.go`, `output_test.go` -- **Features**: OutputManager, TableRenderer, consistent formatting utilities -- **Impact**: Standardized output across all formats (JSON, YAML, tables) -- **Lines of Code**: ~350 lines of output utilities - -### 5. **Command Patterns Infrastructure** -- **Files**: `patterns.go`, `patterns_test.go` -- **Features**: Reusable CRUD patterns, argument validation, resource resolution -- **Impact**: Dramatically reduces code per command (~50% reduction) -- **Lines of Code**: ~200 lines of pattern utilities - -### 6. **Validation Infrastructure** -- **Files**: `validation.go`, `validation_test.go` -- **Features**: Input validation, business logic validation, error formatting -- **Impact**: Consistent validation with meaningful error messages -- **Lines of Code**: ~500 lines of validation functions + 400+ test cases - -## ✅ Example Refactored Commands - -### 7. **Refactored User Commands** -- **Files**: `users_refactored.go`, `users_refactored_test.go` -- **Features**: Complete user command suite using new infrastructure -- **Impact**: Demonstrates 50% code reduction while maintaining functionality -- **Lines of Code**: ~250 lines (vs ~500 lines original) - -### 8. **Comprehensive Test Coverage** -- **Files**: Multiple test files for each component -- **Features**: 500+ unit tests, integration tests, performance benchmarks -- **Impact**: High confidence in infrastructure reliability -- **Test Coverage**: All new infrastructure components - -## 📊 Key Metrics and Improvements - -### **Code Reduction** -- **User Commands**: 50% less code per command -- **Flag Setup**: 70% less repetitive flag code -- **Error Handling**: 60% less error handling boilerplate -- **Output Formatting**: 80% less output formatting code - -### **Type Safety Improvements** -- **Zero `interface{}` usage**: All functions use concrete types -- **No `any` types**: Proper type safety throughout -- **Compile-time validation**: Type checking catches errors early -- **Mock client type safety**: Testing infrastructure is fully typed - -### **Consistency Improvements** -- **Standardized error messages**: All validation errors follow same format -- **Consistent flag shortcuts**: All common flags use same shortcuts -- **Uniform output**: All commands support JSON/YAML/table formats -- **Common patterns**: All CRUD operations follow same structure - -### **Testing Improvements** -- **400+ validation tests**: Every validation function extensively tested -- **Mock infrastructure**: Complete mock gRPC client for testing -- **Integration tests**: End-to-end testing of command patterns -- **Performance benchmarks**: Ensures CLI remains responsive - -## 🔧 Technical Implementation Details - -### **Type-Safe Architecture** -```go -// Example: Type-safe command function -func createUserLogic(client *ClientWrapper, cmd *cobra.Command, args []string) (interface{}, error) { - // Validate input using validation infrastructure - if err := ValidateUserName(args[0]); err != nil { - return nil, err - } - - // Use standardized client wrapper - response, err := client.CreateUser(cmd, request) - if err != nil { - return nil, err - } - - return response.GetUser(), nil -} -``` - -### **Reusable Command Patterns** -```go -// Example: Standard command creation -func createUserRefactored() *cobra.Command { - return &cobra.Command{ - Use: "create NAME", - Args: ValidateExactArgs(1, "create "), - Run: StandardCreateCommand(createUserLogic, "User created successfully"), - } -} -``` - -### **Comprehensive Validation** -```go -// Example: Validation with clear error messages -if err := ValidateEmail(email); err != nil { - return nil, fmt.Errorf("invalid email: %w", err) -} -``` - -### **Consistent Output Handling** -```go -// Example: Automatic output formatting -ListOutput(cmd, users, setupUsersTable) // Handles JSON/YAML/table automatically -``` - -## đŸŽ¯ Benefits Achieved - -### **For Developers** -- **50% less code** to write for new commands -- **Consistent patterns** reduce learning curve -- **Type safety** catches errors at compile time -- **Comprehensive testing** infrastructure ready to use -- **Better error messages** improve debugging experience - -### **For Users** -- **Consistent interface** across all commands -- **Better error messages** with helpful suggestions -- **Reliable validation** catches issues early -- **Multiple output formats** (JSON, YAML, human-readable) -- **Improved help text** and usage examples - -### **For Maintainers** -- **Easier code review** with standardized patterns -- **Better test coverage** with testing infrastructure -- **Consistent behavior** across commands reduces bugs -- **Simpler onboarding** for new contributors -- **Future extensibility** with modular design - -## 📁 File Structure Overview - -``` -cmd/headscale/cli/ -├── infrastructure/ -│ ├── testing.go # Mock client infrastructure -│ ├── testing_test.go # Testing infrastructure tests -│ ├── flags.go # Flag registration helpers -│ ├── client.go # gRPC client wrapper -│ ├── output.go # Output formatting utilities -│ ├── patterns.go # Command execution patterns -│ └── validation.go # Input validation utilities -│ -├── examples/ -│ ├── users_refactored.go # Refactored user commands -│ └── users_refactored_example.go # Original examples -│ -├── tests/ -│ ├── *_test.go # Unit tests for each component -│ ├── infrastructure_integration_test.go # Integration tests -│ ├── validation_test.go # Comprehensive validation tests -│ └── dump_config_test.go # Additional command tests -│ -└── original/ - ├── users.go # Original user commands (unchanged) - ├── nodes.go # Original node commands (unchanged) - └── *.go # Other original commands (unchanged) -``` - -## 🚀 Usage Examples - -### **Creating a New Command (Before vs After)** - -**Before (Original Pattern)**: -```go -var createUserCmd = &cobra.Command{ - Use: "create NAME", - Short: "Creates a new user", - Args: func(cmd *cobra.Command, args []string) error { - if len(args) < 1 { - return errMissingParameter - } - return nil - }, - Run: func(cmd *cobra.Command, args []string) { - output, _ := cmd.Flags().GetString("output") - userName := args[0] - - ctx, client, conn, cancel := newHeadscaleCLIWithConfig() - defer cancel() - defer conn.Close() - - request := &v1.CreateUserRequest{Name: userName} - - if displayName, _ := cmd.Flags().GetString("display-name"); displayName != "" { - request.DisplayName = displayName - } - - // ... more validation and setup (30+ lines) - - response, err := client.CreateUser(ctx, request) - if err != nil { - ErrorOutput(err, "Cannot create user: "+status.Convert(err).Message(), output) - } - - SuccessOutput(response.GetUser(), "User created", output) - }, -} -``` - -**After (Refactored Pattern)**: -```go -func createUserRefactored() *cobra.Command { - cmd := &cobra.Command{ - Use: "create NAME", - Short: "Creates a new user", - Args: ValidateExactArgs(1, "create "), - Run: StandardCreateCommand(createUserLogic, "User created successfully"), - } - - cmd.Flags().StringP("display-name", "d", "", "Display name") - cmd.Flags().StringP("email", "e", "", "Email address") - cmd.Flags().StringP("picture-url", "p", "", "Profile picture URL") - AddOutputFlag(cmd) - - return cmd -} - -func createUserLogic(client *ClientWrapper, cmd *cobra.Command, args []string) (interface{}, error) { - userName := args[0] - - if err := ValidateUserName(userName); err != nil { - return nil, err - } - - request := &v1.CreateUserRequest{Name: userName} - - if displayName, _ := cmd.Flags().GetString("display-name"); displayName != "" { - request.DisplayName = displayName - } - - if email, _ := cmd.Flags().GetString("email"); email != "" { - if err := ValidateEmail(email); err != nil { - return nil, fmt.Errorf("invalid email: %w", err) - } - request.Email = email - } - - if pictureURL, _ := cmd.Flags().GetString("picture-url"); pictureURL != "" { - if err := ValidateURL(pictureURL); err != nil { - return nil, fmt.Errorf("invalid picture URL: %w", err) - } - request.PictureUrl = pictureURL - } - - if err := ValidateNoDuplicateUsers(client, userName, 0); err != nil { - return nil, err - } - - response, err := client.CreateUser(cmd, request) - if err != nil { - return nil, err - } - - return response.GetUser(), nil -} -``` - -**Result**: ~50% less code, better validation, consistent error handling, automatic output formatting. - -## 🔍 Quality Assurance - -### **Test Coverage** -- **Unit Tests**: 500+ test cases covering all components -- **Integration Tests**: End-to-end command pattern testing -- **Performance Tests**: Benchmarks for command execution -- **Mock Testing**: Complete mock infrastructure for reliable testing - -### **Type Safety** -- **Zero `interface{}`**: All functions use concrete types -- **Compile-time validation**: Type system catches errors early -- **Mock type safety**: Testing infrastructure is fully typed - -### **Documentation** -- **Comprehensive comments**: All functions well-documented -- **Usage examples**: Clear examples for each pattern -- **Error message quality**: Helpful error messages with suggestions - -## 🎉 Conclusion - -The Headscale CLI infrastructure refactoring has been successfully completed, delivering: - -✅ **Complete infrastructure** for type-safe CLI development -✅ **50% code reduction** for new commands -✅ **Comprehensive testing** infrastructure -✅ **Consistent user experience** across all commands -✅ **Better error handling** and validation -✅ **Future-proof architecture** for extensibility - -The new infrastructure provides a solid foundation for CLI development at Headscale, making it easier to add new commands, maintain existing ones, and provide a consistent experience for users. All components are thoroughly tested, type-safe, and ready for production use. - -### **Next Steps** -1. **Gradual Migration**: Existing commands can be migrated to use the new infrastructure incrementally -2. **Documentation Updates**: User-facing documentation can be updated to reflect new consistent behavior -3. **New Command Development**: All new commands should use the refactored patterns from day one - -The refactoring work demonstrates the power of well-designed infrastructure in reducing complexity while improving quality and maintainability. \ No newline at end of file diff --git a/cmd/headscale/cli/SIMPLIFICATION.md b/cmd/headscale/cli/SIMPLIFICATION.md new file mode 100644 index 00000000..a6718867 --- /dev/null +++ b/cmd/headscale/cli/SIMPLIFICATION.md @@ -0,0 +1,82 @@ +# CLI Simplification - WithClient Pattern + +## Problem +Every CLI command has repetitive gRPC client setup boilerplate: + +```go +// This pattern appears 25+ times across all commands +ctx, client, conn, cancel := newHeadscaleCLIWithConfig() +defer cancel() +defer conn.Close() + +// ... command logic ... +``` + +## Solution +Simple closure that handles client lifecycle: + +```go +// client.go - 16 lines total +func WithClient(fn func(context.Context, v1.HeadscaleServiceClient) error) error { + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() + defer cancel() + defer conn.Close() + + return fn(ctx, client) +} +``` + +## Usage Example + +### Before (users.go listUsersCmd): +```go +Run: func(cmd *cobra.Command, args []string) { + output, _ := cmd.Flags().GetString("output") + + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() // 4 lines + defer cancel() + defer conn.Close() + + request := &v1.ListUsersRequest{} + // ... build request ... + + response, err := client.ListUsers(ctx, request) + if err != nil { + ErrorOutput(err, "Cannot get users: "+status.Convert(err).Message(), output) + } + // ... handle response ... +} +``` + +### After: +```go +Run: func(cmd *cobra.Command, args []string) { + output, _ := cmd.Flags().GetString("output") + + err := WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error { + request := &v1.ListUsersRequest{} + // ... build request ... + + response, err := client.ListUsers(ctx, request) + if err != nil { + ErrorOutput(err, "Cannot get users: "+status.Convert(err).Message(), output) + return err + } + // ... handle response ... + return nil + }) + + if err != nil { + return // Error already handled + } +} +``` + +## Benefits +- **Removes 4 lines of boilerplate** from every command +- **Ensures proper cleanup** - no forgetting defer statements +- **Simpler error handling** - return from closure, handled centrally +- **Easy to apply** - minimal changes to existing commands + +## Rollout +This pattern can be applied to all 25+ commands systematically, removing ~100 lines of repetitive boilerplate. \ No newline at end of file diff --git a/cmd/headscale/cli/api_key.go b/cmd/headscale/cli/api_key.go index bd839b7b..57d12d12 100644 --- a/cmd/headscale/cli/api_key.go +++ b/cmd/headscale/cli/api_key.go @@ -1,6 +1,7 @@ package cli import ( + "context" "fmt" "strconv" "time" @@ -54,50 +55,56 @@ var listAPIKeys = &cobra.Command{ Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") - ctx, client, conn, cancel := newHeadscaleCLIWithConfig() - defer cancel() - defer conn.Close() + err := WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error { + request := &v1.ListApiKeysRequest{} - request := &v1.ListApiKeysRequest{} - - response, err := client.ListApiKeys(ctx, request) - if err != nil { - ErrorOutput( - err, - fmt.Sprintf("Error getting the list of keys: %s", err), - output, - ) - } - - if output != "" { - SuccessOutput(response.GetApiKeys(), "", output) - } - - tableData := pterm.TableData{ - {"ID", "Prefix", "Expiration", "Created"}, - } - for _, key := range response.GetApiKeys() { - expiration := "-" - - if key.GetExpiration() != nil { - expiration = ColourTime(key.GetExpiration().AsTime()) + response, err := client.ListApiKeys(ctx, request) + if err != nil { + ErrorOutput( + err, + fmt.Sprintf("Error getting the list of keys: %s", err), + output, + ) + return err } - tableData = append(tableData, []string{ - strconv.FormatUint(key.GetId(), util.Base10), - key.GetPrefix(), - expiration, - key.GetCreatedAt().AsTime().Format(HeadscaleDateTimeFormat), - }) + if output != "" { + SuccessOutput(response.GetApiKeys(), "", output) + return nil + } + + tableData := pterm.TableData{ + {"ID", "Prefix", "Expiration", "Created"}, + } + for _, key := range response.GetApiKeys() { + expiration := "-" + + if key.GetExpiration() != nil { + expiration = ColourTime(key.GetExpiration().AsTime()) + } + + tableData = append(tableData, []string{ + strconv.FormatUint(key.GetId(), util.Base10), + key.GetPrefix(), + expiration, + key.GetCreatedAt().AsTime().Format(HeadscaleDateTimeFormat), + }) + + } + err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() + if err != nil { + ErrorOutput( + err, + fmt.Sprintf("Failed to render pterm table: %s", err), + output, + ) + return err + } + return nil + }) - } - err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() if err != nil { - ErrorOutput( - err, - fmt.Sprintf("Failed to render pterm table: %s", err), - output, - ) + return } }, } @@ -124,26 +131,31 @@ If you loose a key, create a new one and revoke (expire) the old one.`, fmt.Sprintf("Could not parse duration: %s\n", err), output, ) + return } expiration := time.Now().UTC().Add(time.Duration(duration)) request.Expiration = timestamppb.New(expiration) - ctx, client, conn, cancel := newHeadscaleCLIWithConfig() - defer cancel() - defer conn.Close() + err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error { + response, err := client.CreateApiKey(ctx, request) + if err != nil { + ErrorOutput( + err, + fmt.Sprintf("Cannot create Api Key: %s\n", err), + output, + ) + return err + } + + SuccessOutput(response.GetApiKey(), response.GetApiKey(), output) + return nil + }) - response, err := client.CreateApiKey(ctx, request) if err != nil { - ErrorOutput( - err, - fmt.Sprintf("Cannot create Api Key: %s\n", err), - output, - ) + return } - - SuccessOutput(response.GetApiKey(), response.GetApiKey(), output) }, } @@ -161,26 +173,31 @@ var expireAPIKeyCmd = &cobra.Command{ fmt.Sprintf("Error getting prefix from CLI flag: %s", err), output, ) + return } - ctx, client, conn, cancel := newHeadscaleCLIWithConfig() - defer cancel() - defer conn.Close() + err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error { + request := &v1.ExpireApiKeyRequest{ + Prefix: prefix, + } - request := &v1.ExpireApiKeyRequest{ - Prefix: prefix, - } + response, err := client.ExpireApiKey(ctx, request) + if err != nil { + ErrorOutput( + err, + fmt.Sprintf("Cannot expire Api Key: %s\n", err), + output, + ) + return err + } + + SuccessOutput(response, "Key expired", output) + return nil + }) - response, err := client.ExpireApiKey(ctx, request) if err != nil { - ErrorOutput( - err, - fmt.Sprintf("Cannot expire Api Key: %s\n", err), - output, - ) + return } - - SuccessOutput(response, "Key expired", output) }, } @@ -198,25 +215,30 @@ var deleteAPIKeyCmd = &cobra.Command{ fmt.Sprintf("Error getting prefix from CLI flag: %s", err), output, ) + return } - ctx, client, conn, cancel := newHeadscaleCLIWithConfig() - defer cancel() - defer conn.Close() + err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error { + request := &v1.DeleteApiKeyRequest{ + Prefix: prefix, + } - request := &v1.DeleteApiKeyRequest{ - Prefix: prefix, - } + response, err := client.DeleteApiKey(ctx, request) + if err != nil { + ErrorOutput( + err, + fmt.Sprintf("Cannot delete Api Key: %s\n", err), + output, + ) + return err + } + + SuccessOutput(response, "Key deleted", output) + return nil + }) - response, err := client.DeleteApiKey(ctx, request) if err != nil { - ErrorOutput( - err, - fmt.Sprintf("Cannot delete Api Key: %s\n", err), - output, - ) + return } - - SuccessOutput(response, "Key deleted", output) }, } diff --git a/cmd/headscale/cli/client.go b/cmd/headscale/cli/client.go index 4ff32615..65bd9eba 100644 --- a/cmd/headscale/cli/client.go +++ b/cmd/headscale/cli/client.go @@ -2,414 +2,15 @@ package cli import ( "context" - "fmt" - + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" - "github.com/spf13/cobra" - "google.golang.org/grpc" - "google.golang.org/grpc/status" ) -// ClientWrapper wraps the gRPC client with automatic connection lifecycle management -type ClientWrapper struct { - ctx context.Context - client v1.HeadscaleServiceClient - conn *grpc.ClientConn - cancel context.CancelFunc -} - -// NewClient creates a new ClientWrapper with automatic connection setup -func NewClient() (*ClientWrapper, error) { +// WithClient handles gRPC client setup and cleanup, calls fn with client and context +func WithClient(fn func(context.Context, v1.HeadscaleServiceClient) error) error { ctx, client, conn, cancel := newHeadscaleCLIWithConfig() + defer cancel() + defer conn.Close() - return &ClientWrapper{ - ctx: ctx, - client: client, - conn: conn, - cancel: cancel, - }, nil -} - -// Close properly closes the gRPC connection and cancels the context -func (c *ClientWrapper) Close() { - if c.cancel != nil { - c.cancel() - } - if c.conn != nil { - c.conn.Close() - } -} - -// ExecuteWithErrorHandling executes a gRPC operation with standardized error handling -func (c *ClientWrapper) ExecuteWithErrorHandling( - cmd *cobra.Command, - operation func(client v1.HeadscaleServiceClient) (interface{}, error), - errorMsg string, -) (interface{}, error) { - result, err := operation(c.client) - if err != nil { - output := GetOutputFormat(cmd) - ErrorOutput( - err, - fmt.Sprintf("%s: %s", errorMsg, status.Convert(err).Message()), - output, - ) - return nil, err - } - return result, nil -} - -// Specific operation helpers with automatic error handling and output formatting - -// ListNodes executes a ListNodes request with error handling -func (c *ClientWrapper) ListNodes(cmd *cobra.Command, req *v1.ListNodesRequest) (*v1.ListNodesResponse, error) { - result, err := c.ExecuteWithErrorHandling(cmd, - func(client v1.HeadscaleServiceClient) (interface{}, error) { - return client.ListNodes(c.ctx, req) - }, - "Cannot get nodes", - ) - if err != nil { - return nil, err - } - return result.(*v1.ListNodesResponse), nil -} - -// RegisterNode executes a RegisterNode request with error handling -func (c *ClientWrapper) RegisterNode(cmd *cobra.Command, req *v1.RegisterNodeRequest) (*v1.RegisterNodeResponse, error) { - result, err := c.ExecuteWithErrorHandling(cmd, - func(client v1.HeadscaleServiceClient) (interface{}, error) { - return client.RegisterNode(c.ctx, req) - }, - "Cannot register node", - ) - if err != nil { - return nil, err - } - return result.(*v1.RegisterNodeResponse), nil -} - -// DeleteNode executes a DeleteNode request with error handling -func (c *ClientWrapper) DeleteNode(cmd *cobra.Command, req *v1.DeleteNodeRequest) (*v1.DeleteNodeResponse, error) { - result, err := c.ExecuteWithErrorHandling(cmd, - func(client v1.HeadscaleServiceClient) (interface{}, error) { - return client.DeleteNode(c.ctx, req) - }, - "Error deleting node", - ) - if err != nil { - return nil, err - } - return result.(*v1.DeleteNodeResponse), nil -} - -// ExpireNode executes an ExpireNode request with error handling -func (c *ClientWrapper) ExpireNode(cmd *cobra.Command, req *v1.ExpireNodeRequest) (*v1.ExpireNodeResponse, error) { - result, err := c.ExecuteWithErrorHandling(cmd, - func(client v1.HeadscaleServiceClient) (interface{}, error) { - return client.ExpireNode(c.ctx, req) - }, - "Cannot expire node", - ) - if err != nil { - return nil, err - } - return result.(*v1.ExpireNodeResponse), nil -} - -// RenameNode executes a RenameNode request with error handling -func (c *ClientWrapper) RenameNode(cmd *cobra.Command, req *v1.RenameNodeRequest) (*v1.RenameNodeResponse, error) { - result, err := c.ExecuteWithErrorHandling(cmd, - func(client v1.HeadscaleServiceClient) (interface{}, error) { - return client.RenameNode(c.ctx, req) - }, - "Cannot rename node", - ) - if err != nil { - return nil, err - } - return result.(*v1.RenameNodeResponse), nil -} - -// MoveNode executes a MoveNode request with error handling -func (c *ClientWrapper) MoveNode(cmd *cobra.Command, req *v1.MoveNodeRequest) (*v1.MoveNodeResponse, error) { - result, err := c.ExecuteWithErrorHandling(cmd, - func(client v1.HeadscaleServiceClient) (interface{}, error) { - return client.MoveNode(c.ctx, req) - }, - "Error moving node", - ) - if err != nil { - return nil, err - } - return result.(*v1.MoveNodeResponse), nil -} - -// GetNode executes a GetNode request with error handling -func (c *ClientWrapper) GetNode(cmd *cobra.Command, req *v1.GetNodeRequest) (*v1.GetNodeResponse, error) { - result, err := c.ExecuteWithErrorHandling(cmd, - func(client v1.HeadscaleServiceClient) (interface{}, error) { - return client.GetNode(c.ctx, req) - }, - "Error getting node", - ) - if err != nil { - return nil, err - } - return result.(*v1.GetNodeResponse), nil -} - -// SetTags executes a SetTags request with error handling -func (c *ClientWrapper) SetTags(cmd *cobra.Command, req *v1.SetTagsRequest) (*v1.SetTagsResponse, error) { - result, err := c.ExecuteWithErrorHandling(cmd, - func(client v1.HeadscaleServiceClient) (interface{}, error) { - return client.SetTags(c.ctx, req) - }, - "Error while sending tags to headscale", - ) - if err != nil { - return nil, err - } - return result.(*v1.SetTagsResponse), nil -} - -// SetApprovedRoutes executes a SetApprovedRoutes request with error handling -func (c *ClientWrapper) SetApprovedRoutes(cmd *cobra.Command, req *v1.SetApprovedRoutesRequest) (*v1.SetApprovedRoutesResponse, error) { - result, err := c.ExecuteWithErrorHandling(cmd, - func(client v1.HeadscaleServiceClient) (interface{}, error) { - return client.SetApprovedRoutes(c.ctx, req) - }, - "Error while sending routes to headscale", - ) - if err != nil { - return nil, err - } - return result.(*v1.SetApprovedRoutesResponse), nil -} - -// BackfillNodeIPs executes a BackfillNodeIPs request with error handling -func (c *ClientWrapper) BackfillNodeIPs(cmd *cobra.Command, req *v1.BackfillNodeIPsRequest) (*v1.BackfillNodeIPsResponse, error) { - result, err := c.ExecuteWithErrorHandling(cmd, - func(client v1.HeadscaleServiceClient) (interface{}, error) { - return client.BackfillNodeIPs(c.ctx, req) - }, - "Error backfilling IPs", - ) - if err != nil { - return nil, err - } - return result.(*v1.BackfillNodeIPsResponse), nil -} - -// ListUsers executes a ListUsers request with error handling -func (c *ClientWrapper) ListUsers(cmd *cobra.Command, req *v1.ListUsersRequest) (*v1.ListUsersResponse, error) { - result, err := c.ExecuteWithErrorHandling(cmd, - func(client v1.HeadscaleServiceClient) (interface{}, error) { - return client.ListUsers(c.ctx, req) - }, - "Cannot get users", - ) - if err != nil { - return nil, err - } - return result.(*v1.ListUsersResponse), nil -} - -// CreateUser executes a CreateUser request with error handling -func (c *ClientWrapper) CreateUser(cmd *cobra.Command, req *v1.CreateUserRequest) (*v1.CreateUserResponse, error) { - result, err := c.ExecuteWithErrorHandling(cmd, - func(client v1.HeadscaleServiceClient) (interface{}, error) { - return client.CreateUser(c.ctx, req) - }, - "Cannot create user", - ) - if err != nil { - return nil, err - } - return result.(*v1.CreateUserResponse), nil -} - -// RenameUser executes a RenameUser request with error handling -func (c *ClientWrapper) RenameUser(cmd *cobra.Command, req *v1.RenameUserRequest) (*v1.RenameUserResponse, error) { - result, err := c.ExecuteWithErrorHandling(cmd, - func(client v1.HeadscaleServiceClient) (interface{}, error) { - return client.RenameUser(c.ctx, req) - }, - "Cannot rename user", - ) - if err != nil { - return nil, err - } - return result.(*v1.RenameUserResponse), nil -} - -// DeleteUser executes a DeleteUser request with error handling -func (c *ClientWrapper) DeleteUser(cmd *cobra.Command, req *v1.DeleteUserRequest) (*v1.DeleteUserResponse, error) { - result, err := c.ExecuteWithErrorHandling(cmd, - func(client v1.HeadscaleServiceClient) (interface{}, error) { - return client.DeleteUser(c.ctx, req) - }, - "Error deleting user", - ) - if err != nil { - return nil, err - } - return result.(*v1.DeleteUserResponse), nil -} - -// ListApiKeys executes a ListApiKeys request with error handling -func (c *ClientWrapper) ListApiKeys(cmd *cobra.Command, req *v1.ListApiKeysRequest) (*v1.ListApiKeysResponse, error) { - result, err := c.ExecuteWithErrorHandling(cmd, - func(client v1.HeadscaleServiceClient) (interface{}, error) { - return client.ListApiKeys(c.ctx, req) - }, - "Cannot get API keys", - ) - if err != nil { - return nil, err - } - return result.(*v1.ListApiKeysResponse), nil -} - -// CreateApiKey executes a CreateApiKey request with error handling -func (c *ClientWrapper) CreateApiKey(cmd *cobra.Command, req *v1.CreateApiKeyRequest) (*v1.CreateApiKeyResponse, error) { - result, err := c.ExecuteWithErrorHandling(cmd, - func(client v1.HeadscaleServiceClient) (interface{}, error) { - return client.CreateApiKey(c.ctx, req) - }, - "Cannot create API key", - ) - if err != nil { - return nil, err - } - return result.(*v1.CreateApiKeyResponse), nil -} - -// ExpireApiKey executes an ExpireApiKey request with error handling -func (c *ClientWrapper) ExpireApiKey(cmd *cobra.Command, req *v1.ExpireApiKeyRequest) (*v1.ExpireApiKeyResponse, error) { - result, err := c.ExecuteWithErrorHandling(cmd, - func(client v1.HeadscaleServiceClient) (interface{}, error) { - return client.ExpireApiKey(c.ctx, req) - }, - "Cannot expire API key", - ) - if err != nil { - return nil, err - } - return result.(*v1.ExpireApiKeyResponse), nil -} - -// DeleteApiKey executes a DeleteApiKey request with error handling -func (c *ClientWrapper) DeleteApiKey(cmd *cobra.Command, req *v1.DeleteApiKeyRequest) (*v1.DeleteApiKeyResponse, error) { - result, err := c.ExecuteWithErrorHandling(cmd, - func(client v1.HeadscaleServiceClient) (interface{}, error) { - return client.DeleteApiKey(c.ctx, req) - }, - "Error deleting API key", - ) - if err != nil { - return nil, err - } - return result.(*v1.DeleteApiKeyResponse), nil -} - -// ListPreAuthKeys executes a ListPreAuthKeys request with error handling -func (c *ClientWrapper) ListPreAuthKeys(cmd *cobra.Command, req *v1.ListPreAuthKeysRequest) (*v1.ListPreAuthKeysResponse, error) { - result, err := c.ExecuteWithErrorHandling(cmd, - func(client v1.HeadscaleServiceClient) (interface{}, error) { - return client.ListPreAuthKeys(c.ctx, req) - }, - "Cannot get preauth keys", - ) - if err != nil { - return nil, err - } - return result.(*v1.ListPreAuthKeysResponse), nil -} - -// CreatePreAuthKey executes a CreatePreAuthKey request with error handling -func (c *ClientWrapper) CreatePreAuthKey(cmd *cobra.Command, req *v1.CreatePreAuthKeyRequest) (*v1.CreatePreAuthKeyResponse, error) { - result, err := c.ExecuteWithErrorHandling(cmd, - func(client v1.HeadscaleServiceClient) (interface{}, error) { - return client.CreatePreAuthKey(c.ctx, req) - }, - "Cannot create preauth key", - ) - if err != nil { - return nil, err - } - return result.(*v1.CreatePreAuthKeyResponse), nil -} - -// ExpirePreAuthKey executes an ExpirePreAuthKey request with error handling -func (c *ClientWrapper) ExpirePreAuthKey(cmd *cobra.Command, req *v1.ExpirePreAuthKeyRequest) (*v1.ExpirePreAuthKeyResponse, error) { - result, err := c.ExecuteWithErrorHandling(cmd, - func(client v1.HeadscaleServiceClient) (interface{}, error) { - return client.ExpirePreAuthKey(c.ctx, req) - }, - "Cannot expire preauth key", - ) - if err != nil { - return nil, err - } - return result.(*v1.ExpirePreAuthKeyResponse), nil -} - -// GetPolicy executes a GetPolicy request with error handling -func (c *ClientWrapper) GetPolicy(cmd *cobra.Command, req *v1.GetPolicyRequest) (*v1.GetPolicyResponse, error) { - result, err := c.ExecuteWithErrorHandling(cmd, - func(client v1.HeadscaleServiceClient) (interface{}, error) { - return client.GetPolicy(c.ctx, req) - }, - "Cannot get policy", - ) - if err != nil { - return nil, err - } - return result.(*v1.GetPolicyResponse), nil -} - -// SetPolicy executes a SetPolicy request with error handling -func (c *ClientWrapper) SetPolicy(cmd *cobra.Command, req *v1.SetPolicyRequest) (*v1.SetPolicyResponse, error) { - result, err := c.ExecuteWithErrorHandling(cmd, - func(client v1.HeadscaleServiceClient) (interface{}, error) { - return client.SetPolicy(c.ctx, req) - }, - "Cannot set policy", - ) - if err != nil { - return nil, err - } - return result.(*v1.SetPolicyResponse), nil -} - -// DebugCreateNode executes a DebugCreateNode request with error handling -func (c *ClientWrapper) DebugCreateNode(cmd *cobra.Command, req *v1.DebugCreateNodeRequest) (*v1.DebugCreateNodeResponse, error) { - result, err := c.ExecuteWithErrorHandling(cmd, - func(client v1.HeadscaleServiceClient) (interface{}, error) { - return client.DebugCreateNode(c.ctx, req) - }, - "Cannot create node", - ) - if err != nil { - return nil, err - } - return result.(*v1.DebugCreateNodeResponse), nil -} - -// Helper function to execute commands with automatic client management -func ExecuteWithClient(cmd *cobra.Command, operation func(*ClientWrapper) error) { - client, err := NewClient() - if err != nil { - output := GetOutputFormat(cmd) - ErrorOutput(err, "Cannot connect to headscale", output) - return - } - defer client.Close() - - err = operation(client) - if err != nil { - // Error already handled by the operation - return - } + return fn(ctx, client) } \ No newline at end of file diff --git a/cmd/headscale/cli/convert_commands.py b/cmd/headscale/cli/convert_commands.py new file mode 100644 index 00000000..db52fffc --- /dev/null +++ b/cmd/headscale/cli/convert_commands.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 +"""Script to convert all commands to use WithClient pattern""" + +import re +import sys +import os + +def convert_command(content): + """Convert a single command to use WithClient pattern""" + + # Pattern to match the gRPC client setup + pattern = r'(\t+)ctx, client, conn, cancel := newHeadscaleCLIWithConfig\(\)\n\t+defer cancel\(\)\n\t+defer conn\.Close\(\)\n\n' + + # Find all occurrences + matches = list(re.finditer(pattern, content)) + + if not matches: + return content + + # Process each match from the end to avoid offset issues + for match in reversed(matches): + indent = match.group(1) + start_pos = match.start() + end_pos = match.end() + + # Find the end of the Run function + remaining_content = content[end_pos:] + + # Find the matching closing brace for the Run function + brace_count = 0 + func_end = -1 + + for i, char in enumerate(remaining_content): + if char == '{': + brace_count += 1 + elif char == '}': + brace_count -= 1 + if brace_count < 0: # Found the closing brace + func_end = i + break + + if func_end == -1: + continue + + # Extract the function body + func_body = remaining_content[:func_end] + + # Indent the function body + indented_body = '\n'.join(indent + '\t' + line if line.strip() else line + for line in func_body.split('\n')) + + # Create the new function with WithClient + new_func = f"""{indent}err := WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {{ +{indented_body} +{indent}\treturn nil +{indent}}}) +{indent} +{indent}if err != nil {{ +{indent}\treturn +{indent}}}""" + + # Replace the old pattern with the new one + content = content[:start_pos] + new_func + '\n' + content[end_pos + func_end:] + + return content + +def process_file(filepath): + """Process a single Go file""" + try: + with open(filepath, 'r') as f: + content = f.read() + + # Check if context is already imported + if 'import (' in content and '"context"' not in content: + # Add context import + content = content.replace( + 'import (', + 'import (\n\t"context"' + ) + + # Convert commands + new_content = convert_command(content) + + # Write back if changed + if new_content != content: + with open(filepath, 'w') as f: + f.write(new_content) + print(f"Updated {filepath}") + else: + print(f"No changes needed for {filepath}") + + except Exception as e: + print(f"Error processing {filepath}: {e}") + +if __name__ == "__main__": + if len(sys.argv) != 2: + print("Usage: python3 convert_commands.py ") + sys.exit(1) + + filepath = sys.argv[1] + if not os.path.exists(filepath): + print(f"File not found: {filepath}") + sys.exit(1) + + process_file(filepath) \ No newline at end of file diff --git a/cmd/headscale/cli/debug.go b/cmd/headscale/cli/debug.go index 8ce5f237..331e9771 100644 --- a/cmd/headscale/cli/debug.go +++ b/cmd/headscale/cli/debug.go @@ -1,6 +1,7 @@ package cli import ( + "context" "fmt" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" @@ -64,12 +65,9 @@ var createNodeCmd = &cobra.Command{ user, err := cmd.Flags().GetString("user") if err != nil { ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output) + return } - ctx, client, conn, cancel := newHeadscaleCLIWithConfig() - defer cancel() - defer conn.Close() - name, err := cmd.Flags().GetString("name") if err != nil { ErrorOutput( @@ -77,6 +75,7 @@ var createNodeCmd = &cobra.Command{ fmt.Sprintf("Error getting node from flag: %s", err), output, ) + return } registrationID, err := cmd.Flags().GetString("key") @@ -86,6 +85,7 @@ var createNodeCmd = &cobra.Command{ fmt.Sprintf("Error getting key from flag: %s", err), output, ) + return } _, err = types.RegistrationIDFromString(registrationID) @@ -95,6 +95,7 @@ var createNodeCmd = &cobra.Command{ fmt.Sprintf("Failed to parse machine key from flag: %s", err), output, ) + return } routes, err := cmd.Flags().GetStringSlice("route") @@ -104,24 +105,33 @@ var createNodeCmd = &cobra.Command{ fmt.Sprintf("Error getting routes from flag: %s", err), output, ) + return } - request := &v1.DebugCreateNodeRequest{ - Key: registrationID, - Name: name, - User: user, - Routes: routes, - } + err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error { + request := &v1.DebugCreateNodeRequest{ + Key: registrationID, + Name: name, + User: user, + Routes: routes, + } - response, err := client.DebugCreateNode(ctx, request) + response, err := client.DebugCreateNode(ctx, request) + if err != nil { + ErrorOutput( + err, + "Cannot create node: "+status.Convert(err).Message(), + output, + ) + return err + } + + SuccessOutput(response.GetNode(), "Node created", output) + return nil + }) + if err != nil { - ErrorOutput( - err, - "Cannot create node: "+status.Convert(err).Message(), - output, - ) + return } - - SuccessOutput(response.GetNode(), "Node created", output) }, } diff --git a/cmd/headscale/cli/flags.go b/cmd/headscale/cli/flags.go deleted file mode 100644 index 4b09d02b..00000000 --- a/cmd/headscale/cli/flags.go +++ /dev/null @@ -1,358 +0,0 @@ -package cli - -import ( - "fmt" - "log" - "time" - - "github.com/spf13/cobra" -) - -const ( - deprecateNamespaceMessage = "use --user" -) - -// Flag registration helpers - standardize how flags are added to commands - -// AddIdentifierFlag adds a uint64 identifier flag with consistent naming -func AddIdentifierFlag(cmd *cobra.Command, name string, help string) { - cmd.Flags().Uint64P(name, "i", 0, help) -} - -// AddRequiredIdentifierFlag adds a required uint64 identifier flag -func AddRequiredIdentifierFlag(cmd *cobra.Command, name string, help string) { - AddIdentifierFlag(cmd, name, help) - err := cmd.MarkFlagRequired(name) - if err != nil { - log.Fatal(err.Error()) - } -} - -// AddColumnsFlag adds a columns flag for table output customization -func AddColumnsFlag(cmd *cobra.Command, defaultColumns string) { - cmd.Flags().String("columns", defaultColumns, "Comma-separated list of columns to display") -} - -// GetColumnsFlag gets the columns flag value -func GetColumnsFlag(cmd *cobra.Command) string { - columns, _ := cmd.Flags().GetString("columns") - return columns -} - -// AddUserFlag adds a user flag (string for username or email) -func AddUserFlag(cmd *cobra.Command) { - cmd.Flags().StringP("user", "u", "", "User") -} - -// AddRequiredUserFlag adds a required user flag -func AddRequiredUserFlag(cmd *cobra.Command) { - AddUserFlag(cmd) - err := cmd.MarkFlagRequired("user") - if err != nil { - log.Fatal(err.Error()) - } -} - -// AddOutputFlag adds the standard output format flag -func AddOutputFlag(cmd *cobra.Command) { - cmd.Flags().StringP("output", "o", "", "Output format. Empty for human-readable, 'json', 'json-line' or 'yaml'") -} - -// AddForceFlag adds the force flag -func AddForceFlag(cmd *cobra.Command) { - cmd.Flags().Bool("force", false, "Disable prompts and forces the execution") -} - -// AddExpirationFlag adds an expiration duration flag -func AddExpirationFlag(cmd *cobra.Command, defaultValue string) { - cmd.Flags().StringP("expiration", "e", defaultValue, "Human-readable duration (e.g. 30m, 24h)") -} - -// AddDeprecatedNamespaceFlag adds the deprecated namespace flag with appropriate warnings -func AddDeprecatedNamespaceFlag(cmd *cobra.Command) { - cmd.Flags().StringP("namespace", "n", "", "User") - namespaceFlag := cmd.Flags().Lookup("namespace") - namespaceFlag.Deprecated = deprecateNamespaceMessage - namespaceFlag.Hidden = true -} - -// AddTagsFlag adds a tags display flag -func AddTagsFlag(cmd *cobra.Command) { - cmd.Flags().BoolP("tags", "t", false, "Show tags") -} - -// AddKeyFlag adds a key flag for node registration -func AddKeyFlag(cmd *cobra.Command) { - cmd.Flags().StringP("key", "k", "", "Key") -} - -// AddRequiredKeyFlag adds a required key flag -func AddRequiredKeyFlag(cmd *cobra.Command) { - AddKeyFlag(cmd) - err := cmd.MarkFlagRequired("key") - if err != nil { - log.Fatal(err.Error()) - } -} - -// AddNameFlag adds a name flag -func AddNameFlag(cmd *cobra.Command, help string) { - cmd.Flags().String("name", "", help) -} - -// AddRequiredNameFlag adds a required name flag -func AddRequiredNameFlag(cmd *cobra.Command, help string) { - AddNameFlag(cmd, help) - err := cmd.MarkFlagRequired("name") - if err != nil { - log.Fatal(err.Error()) - } -} - -// AddPrefixFlag adds an API key prefix flag -func AddPrefixFlag(cmd *cobra.Command) { - cmd.Flags().StringP("prefix", "p", "", "ApiKey prefix") -} - -// AddRequiredPrefixFlag adds a required API key prefix flag -func AddRequiredPrefixFlag(cmd *cobra.Command) { - AddPrefixFlag(cmd) - err := cmd.MarkFlagRequired("prefix") - if err != nil { - log.Fatal(err.Error()) - } -} - -// AddFileFlag adds a file path flag -func AddFileFlag(cmd *cobra.Command) { - cmd.Flags().StringP("file", "f", "", "Path to a policy file in HuJSON format") -} - -// AddRequiredFileFlag adds a required file path flag -func AddRequiredFileFlag(cmd *cobra.Command) { - AddFileFlag(cmd) - err := cmd.MarkFlagRequired("file") - if err != nil { - log.Fatal(err.Error()) - } -} - -// AddRoutesFlag adds a routes flag for node route management -func AddRoutesFlag(cmd *cobra.Command) { - cmd.Flags().StringSliceP("routes", "r", []string{}, `List of routes that will be approved (comma-separated, e.g. "10.0.0.0/8,192.168.0.0/24" or empty string to remove all approved routes)`) -} - -// AddTagsSliceFlag adds a tags slice flag for node tagging -func AddTagsSliceFlag(cmd *cobra.Command) { - cmd.Flags().StringSliceP("tags", "t", []string{}, "List of tags to add to the node") -} - -// Flag getter helpers with consistent error handling - -// GetIdentifier gets a uint64 identifier flag value with error handling -func GetIdentifier(cmd *cobra.Command, flagName string) (uint64, error) { - identifier, err := cmd.Flags().GetUint64(flagName) - if err != nil { - return 0, fmt.Errorf("error getting %s flag: %w", flagName, err) - } - return identifier, nil -} - -// GetUser gets a user flag value -func GetUser(cmd *cobra.Command) (string, error) { - user, err := cmd.Flags().GetString("user") - if err != nil { - return "", fmt.Errorf("error getting user flag: %w", err) - } - return user, nil -} - -// GetOutputFormat gets the output format flag value -func GetOutputFormat(cmd *cobra.Command) string { - output, _ := cmd.Flags().GetString("output") - return output -} - -// GetForce gets the force flag value -func GetForce(cmd *cobra.Command) bool { - force, _ := cmd.Flags().GetBool("force") - return force -} - -// GetExpiration gets and parses the expiration flag value -func GetExpiration(cmd *cobra.Command) (time.Duration, error) { - expirationStr, err := cmd.Flags().GetString("expiration") - if err != nil { - return 0, fmt.Errorf("error getting expiration flag: %w", err) - } - - if expirationStr == "" { - return 0, nil // No expiration set - } - - duration, err := time.ParseDuration(expirationStr) - if err != nil { - return 0, fmt.Errorf("invalid expiration duration '%s': %w", expirationStr, err) - } - - return duration, nil -} - -// GetName gets a name flag value -func GetName(cmd *cobra.Command) (string, error) { - name, err := cmd.Flags().GetString("name") - if err != nil { - return "", fmt.Errorf("error getting name flag: %w", err) - } - return name, nil -} - -// GetKey gets a key flag value -func GetKey(cmd *cobra.Command) (string, error) { - key, err := cmd.Flags().GetString("key") - if err != nil { - return "", fmt.Errorf("error getting key flag: %w", err) - } - return key, nil -} - -// GetPrefix gets a prefix flag value -func GetPrefix(cmd *cobra.Command) (string, error) { - prefix, err := cmd.Flags().GetString("prefix") - if err != nil { - return "", fmt.Errorf("error getting prefix flag: %w", err) - } - return prefix, nil -} - -// GetFile gets a file flag value -func GetFile(cmd *cobra.Command) (string, error) { - file, err := cmd.Flags().GetString("file") - if err != nil { - return "", fmt.Errorf("error getting file flag: %w", err) - } - return file, nil -} - -// GetRoutes gets a routes flag value -func GetRoutes(cmd *cobra.Command) ([]string, error) { - routes, err := cmd.Flags().GetStringSlice("routes") - if err != nil { - return nil, fmt.Errorf("error getting routes flag: %w", err) - } - return routes, nil -} - -// GetTagsSlice gets a tags slice flag value -func GetTagsSlice(cmd *cobra.Command) ([]string, error) { - tags, err := cmd.Flags().GetStringSlice("tags") - if err != nil { - return nil, fmt.Errorf("error getting tags flag: %w", err) - } - return tags, nil -} - -// GetTags gets a tags boolean flag value -func GetTags(cmd *cobra.Command) bool { - tags, _ := cmd.Flags().GetBool("tags") - return tags -} - -// Flag validation helpers - -// ValidateRequiredFlags validates that required flags are set -func ValidateRequiredFlags(cmd *cobra.Command, flags ...string) error { - for _, flagName := range flags { - flag := cmd.Flags().Lookup(flagName) - if flag == nil { - return fmt.Errorf("flag %s not found", flagName) - } - - if !flag.Changed { - return fmt.Errorf("required flag %s not set", flagName) - } - } - return nil -} - -// ValidateExclusiveFlags validates that only one of the given flags is set -func ValidateExclusiveFlags(cmd *cobra.Command, flags ...string) error { - setFlags := []string{} - - for _, flagName := range flags { - flag := cmd.Flags().Lookup(flagName) - if flag == nil { - return fmt.Errorf("flag %s not found", flagName) - } - - if flag.Changed { - setFlags = append(setFlags, flagName) - } - } - - if len(setFlags) > 1 { - return fmt.Errorf("only one of the following flags can be set: %v, but found: %v", flags, setFlags) - } - - return nil -} - -// ValidateIdentifierFlag validates that an identifier flag has a valid value -func ValidateIdentifierFlag(cmd *cobra.Command, flagName string) error { - identifier, err := GetIdentifier(cmd, flagName) - if err != nil { - return err - } - - if identifier == 0 { - return fmt.Errorf("%s must be greater than 0", flagName) - } - - return nil -} - -// ValidateNonEmptyStringFlag validates that a string flag is not empty -func ValidateNonEmptyStringFlag(cmd *cobra.Command, flagName string) error { - value, err := cmd.Flags().GetString(flagName) - if err != nil { - return fmt.Errorf("error getting %s flag: %w", flagName, err) - } - - if value == "" { - return fmt.Errorf("%s cannot be empty", flagName) - } - - return nil -} - -// Deprecated flag handling utilities - -// HandleDeprecatedNamespaceFlag handles the deprecated namespace flag by copying its value to user flag -func HandleDeprecatedNamespaceFlag(cmd *cobra.Command) { - namespaceFlag := cmd.Flags().Lookup("namespace") - userFlag := cmd.Flags().Lookup("user") - - if namespaceFlag != nil && userFlag != nil && namespaceFlag.Changed && !userFlag.Changed { - // Copy namespace value to user flag - userFlag.Value.Set(namespaceFlag.Value.String()) - userFlag.Changed = true - } -} - -// GetUserWithDeprecatedNamespace gets user value, checking both user and deprecated namespace flags -func GetUserWithDeprecatedNamespace(cmd *cobra.Command) (string, error) { - user, err := cmd.Flags().GetString("user") - if err != nil { - return "", fmt.Errorf("error getting user flag: %w", err) - } - - // If user is empty, try deprecated namespace flag - if user == "" { - namespace, err := cmd.Flags().GetString("namespace") - if err == nil && namespace != "" { - return namespace, nil - } - } - - return user, nil -} \ No newline at end of file diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index fb49f4a3..fd6cb170 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -1,6 +1,7 @@ package cli import ( + "context" "fmt" "log" "net/netip" @@ -23,6 +24,7 @@ func init() { rootCmd.AddCommand(nodeCmd) listNodesCmd.Flags().StringP("user", "u", "", "Filter by user") listNodesCmd.Flags().BoolP("tags", "t", false, "Show tags") + listNodesCmd.Flags().String("columns", "", "Comma-separated list of columns to display") listNodesCmd.Flags().StringP("namespace", "n", "", "User") listNodesNamespaceFlag := listNodesCmd.Flags().Lookup("namespace") @@ -119,12 +121,9 @@ var registerNodeCmd = &cobra.Command{ user, err := cmd.Flags().GetString("user") if err != nil { ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output) + return } - ctx, client, conn, cancel := newHeadscaleCLIWithConfig() - defer cancel() - defer conn.Close() - registrationID, err := cmd.Flags().GetString("key") if err != nil { ErrorOutput( @@ -132,28 +131,37 @@ var registerNodeCmd = &cobra.Command{ fmt.Sprintf("Error getting node key from flag: %s", err), output, ) + return } - request := &v1.RegisterNodeRequest{ - Key: registrationID, - User: user, - } + err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error { + request := &v1.RegisterNodeRequest{ + Key: registrationID, + User: user, + } - response, err := client.RegisterNode(ctx, request) + response, err := client.RegisterNode(ctx, request) + if err != nil { + ErrorOutput( + err, + fmt.Sprintf( + "Cannot register node: %s\n", + status.Convert(err).Message(), + ), + output, + ) + return err + } + + SuccessOutput( + response.GetNode(), + fmt.Sprintf("Node %s registered", response.GetNode().GetGivenName()), output) + return nil + }) + if err != nil { - ErrorOutput( - err, - fmt.Sprintf( - "Cannot register node: %s\n", - status.Convert(err).Message(), - ), - output, - ) + return } - - SuccessOutput( - response.GetNode(), - fmt.Sprintf("Node %s registered", response.GetNode().GetGivenName()), output) }, } @@ -172,39 +180,47 @@ var listNodesCmd = &cobra.Command{ ErrorOutput(err, fmt.Sprintf("Error getting tags flag: %s", err), output) } - ctx, client, conn, cancel := newHeadscaleCLIWithConfig() - defer cancel() - defer conn.Close() + err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error { + request := &v1.ListNodesRequest{ + User: user, + } - request := &v1.ListNodesRequest{ - User: user, - } + response, err := client.ListNodes(ctx, request) + if err != nil { + ErrorOutput( + err, + "Cannot get nodes: "+status.Convert(err).Message(), + output, + ) + return err + } - response, err := client.ListNodes(ctx, request) + if output != "" { + SuccessOutput(response.GetNodes(), "", output) + return nil + } + + tableData, err := nodesToPtables(user, showTags, response.GetNodes()) + if err != nil { + ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output) + return err + } + + tableData = FilterTableColumns(cmd, tableData) + err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() + if err != nil { + ErrorOutput( + err, + fmt.Sprintf("Failed to render pterm table: %s", err), + output, + ) + return err + } + return nil + }) + if err != nil { - ErrorOutput( - err, - "Cannot get nodes: "+status.Convert(err).Message(), - output, - ) - } - - if output != "" { - SuccessOutput(response.GetNodes(), "", output) - } - - tableData, err := nodesToPtables(user, showTags, response.GetNodes()) - if err != nil { - ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output) - } - - err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() - if err != nil { - ErrorOutput( - err, - fmt.Sprintf("Failed to render pterm table: %s", err), - output, - ) + return } }, } @@ -222,55 +238,61 @@ var listNodeRoutesCmd = &cobra.Command{ fmt.Sprintf("Error converting ID to integer: %s", err), output, ) - return } - ctx, client, conn, cancel := newHeadscaleCLIWithConfig() - defer cancel() - defer conn.Close() + err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error { + request := &v1.ListNodesRequest{} - request := &v1.ListNodesRequest{} + response, err := client.ListNodes(ctx, request) + if err != nil { + ErrorOutput( + err, + "Cannot get nodes: "+status.Convert(err).Message(), + output, + ) + return err + } - response, err := client.ListNodes(ctx, request) - if err != nil { - ErrorOutput( - err, - "Cannot get nodes: "+status.Convert(err).Message(), - output, - ) - } + if output != "" { + SuccessOutput(response.GetNodes(), "", output) + return nil + } - if output != "" { - SuccessOutput(response.GetNodes(), "", output) - } - - nodes := response.GetNodes() - if identifier != 0 { - for _, node := range response.GetNodes() { - if node.GetId() == identifier { - nodes = []*v1.Node{node} - break + nodes := response.GetNodes() + if identifier != 0 { + for _, node := range response.GetNodes() { + if node.GetId() == identifier { + nodes = []*v1.Node{node} + break + } } } - } - nodes = lo.Filter(nodes, func(n *v1.Node, _ int) bool { - return (n.GetSubnetRoutes() != nil && len(n.GetSubnetRoutes()) > 0) || (n.GetApprovedRoutes() != nil && len(n.GetApprovedRoutes()) > 0) || (n.GetAvailableRoutes() != nil && len(n.GetAvailableRoutes()) > 0) + nodes = lo.Filter(nodes, func(n *v1.Node, _ int) bool { + return (n.GetSubnetRoutes() != nil && len(n.GetSubnetRoutes()) > 0) || (n.GetApprovedRoutes() != nil && len(n.GetApprovedRoutes()) > 0) || (n.GetAvailableRoutes() != nil && len(n.GetAvailableRoutes()) > 0) + }) + + tableData, err := nodeRoutesToPtables(nodes) + if err != nil { + ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output) + return err + } + + err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() + if err != nil { + ErrorOutput( + err, + fmt.Sprintf("Failed to render pterm table: %s", err), + output, + ) + return err + } + return nil }) - - tableData, err := nodeRoutesToPtables(nodes) + if err != nil { - ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output) - } - - err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() - if err != nil { - ErrorOutput( - err, - fmt.Sprintf("Failed to render pterm table: %s", err), - output, - ) + return } }, } @@ -290,33 +312,34 @@ var expireNodeCmd = &cobra.Command{ fmt.Sprintf("Error converting ID to integer: %s", err), output, ) - return } - ctx, client, conn, cancel := newHeadscaleCLIWithConfig() - defer cancel() - defer conn.Close() + err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error { + request := &v1.ExpireNodeRequest{ + NodeId: identifier, + } - request := &v1.ExpireNodeRequest{ - NodeId: identifier, - } + response, err := client.ExpireNode(ctx, request) + if err != nil { + ErrorOutput( + err, + fmt.Sprintf( + "Cannot expire node: %s\n", + status.Convert(err).Message(), + ), + output, + ) + return err + } - response, err := client.ExpireNode(ctx, request) + SuccessOutput(response.GetNode(), "Node expired", output) + return nil + }) + if err != nil { - ErrorOutput( - err, - fmt.Sprintf( - "Cannot expire node: %s\n", - status.Convert(err).Message(), - ), - output, - ) - return } - - SuccessOutput(response.GetNode(), "Node expired", output) }, } @@ -333,38 +356,40 @@ var renameNodeCmd = &cobra.Command{ fmt.Sprintf("Error converting ID to integer: %s", err), output, ) - return } - ctx, client, conn, cancel := newHeadscaleCLIWithConfig() - defer cancel() - defer conn.Close() - newName := "" if len(args) > 0 { newName = args[0] } - request := &v1.RenameNodeRequest{ - NodeId: identifier, - NewName: newName, - } + + err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error { + request := &v1.RenameNodeRequest{ + NodeId: identifier, + NewName: newName, + } - response, err := client.RenameNode(ctx, request) + response, err := client.RenameNode(ctx, request) + if err != nil { + ErrorOutput( + err, + fmt.Sprintf( + "Cannot rename node: %s\n", + status.Convert(err).Message(), + ), + output, + ) + return err + } + + SuccessOutput(response.GetNode(), "Node renamed", output) + return nil + }) + if err != nil { - ErrorOutput( - err, - fmt.Sprintf( - "Cannot rename node: %s\n", - status.Convert(err).Message(), - ), - output, - ) - return } - - SuccessOutput(response.GetNode(), "Node renamed", output) }, } @@ -382,40 +407,39 @@ var deleteNodeCmd = &cobra.Command{ fmt.Sprintf("Error converting ID to integer: %s", err), output, ) - return } - ctx, client, conn, cancel := newHeadscaleCLIWithConfig() - defer cancel() - defer conn.Close() + var nodeName string + err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error { + getRequest := &v1.GetNodeRequest{ + NodeId: identifier, + } - getRequest := &v1.GetNodeRequest{ - NodeId: identifier, - } - - getResponse, err := client.GetNode(ctx, getRequest) + getResponse, err := client.GetNode(ctx, getRequest) + if err != nil { + ErrorOutput( + err, + "Error getting node node: "+status.Convert(err).Message(), + output, + ) + return err + } + nodeName = getResponse.GetNode().GetName() + return nil + }) + if err != nil { - ErrorOutput( - err, - "Error getting node node: "+status.Convert(err).Message(), - output, - ) - return } - deleteRequest := &v1.DeleteNodeRequest{ - NodeId: identifier, - } - confirm := false force, _ := cmd.Flags().GetBool("force") if !force { prompt := &survey.Confirm{ Message: fmt.Sprintf( "Do you want to remove the node %s?", - getResponse.GetNode().GetName(), + nodeName, ), } err = survey.AskOne(prompt, &confirm) @@ -425,26 +449,35 @@ var deleteNodeCmd = &cobra.Command{ } if confirm || force { - response, err := client.DeleteNode(ctx, deleteRequest) - if output != "" { - SuccessOutput(response, "", output) - - return - } - if err != nil { - ErrorOutput( - err, - "Error deleting node: "+status.Convert(err).Message(), + err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error { + deleteRequest := &v1.DeleteNodeRequest{ + NodeId: identifier, + } + + response, err := client.DeleteNode(ctx, deleteRequest) + if output != "" { + SuccessOutput(response, "", output) + return nil + } + if err != nil { + ErrorOutput( + err, + "Error deleting node: "+status.Convert(err).Message(), + output, + ) + return err + } + SuccessOutput( + map[string]string{"Result": "Node deleted"}, + "Node deleted", output, ) - + return nil + }) + + if err != nil { return } - SuccessOutput( - map[string]string{"Result": "Node deleted"}, - "Node deleted", - output, - ) } else { SuccessOutput(map[string]string{"Result": "Node not deleted"}, "Node not deleted", output) } @@ -465,7 +498,6 @@ var moveNodeCmd = &cobra.Command{ fmt.Sprintf("Error converting ID to integer: %s", err), output, ) - return } @@ -476,46 +508,46 @@ var moveNodeCmd = &cobra.Command{ fmt.Sprintf("Error getting user: %s", err), output, ) - return } - ctx, client, conn, cancel := newHeadscaleCLIWithConfig() - defer cancel() - defer conn.Close() + err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error { + getRequest := &v1.GetNodeRequest{ + NodeId: identifier, + } - getRequest := &v1.GetNodeRequest{ - NodeId: identifier, - } + _, err := client.GetNode(ctx, getRequest) + if err != nil { + ErrorOutput( + err, + "Error getting node: "+status.Convert(err).Message(), + output, + ) + return err + } - _, err = client.GetNode(ctx, getRequest) + moveRequest := &v1.MoveNodeRequest{ + NodeId: identifier, + User: user, + } + + moveResponse, err := client.MoveNode(ctx, moveRequest) + if err != nil { + ErrorOutput( + err, + "Error moving node: "+status.Convert(err).Message(), + output, + ) + return err + } + + SuccessOutput(moveResponse.GetNode(), "Node moved to another user", output) + return nil + }) + if err != nil { - ErrorOutput( - err, - "Error getting node: "+status.Convert(err).Message(), - output, - ) - return } - - moveRequest := &v1.MoveNodeRequest{ - NodeId: identifier, - User: user, - } - - moveResponse, err := client.MoveNode(ctx, moveRequest) - if err != nil { - ErrorOutput( - err, - "Error moving node: "+status.Convert(err).Message(), - output, - ) - - return - } - - SuccessOutput(moveResponse.GetNode(), "Node moved to another user", output) }, } @@ -547,22 +579,24 @@ be assigned to nodes.`, return } if confirm { - ctx, client, conn, cancel := newHeadscaleCLIWithConfig() - defer cancel() - defer conn.Close() + err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error { + changes, err := client.BackfillNodeIPs(ctx, &v1.BackfillNodeIPsRequest{Confirmed: confirm}) + if err != nil { + ErrorOutput( + err, + "Error backfilling IPs: "+status.Convert(err).Message(), + output, + ) + return err + } - changes, err := client.BackfillNodeIPs(ctx, &v1.BackfillNodeIPsRequest{Confirmed: confirm}) + SuccessOutput(changes, "Node IPs backfilled successfully", output) + return nil + }) + if err != nil { - ErrorOutput( - err, - "Error backfilling IPs: "+status.Convert(err).Message(), - output, - ) - return } - - SuccessOutput(changes, "Node IPs backfilled successfully", output) } }, } @@ -746,10 +780,7 @@ var tagCmd = &cobra.Command{ Aliases: []string{"tags", "t"}, Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") - ctx, client, conn, cancel := newHeadscaleCLIWithConfig() - defer cancel() - defer conn.Close() - + // retrieve flags from CLI identifier, err := cmd.Flags().GetUint64("identifier") if err != nil { @@ -758,7 +789,6 @@ var tagCmd = &cobra.Command{ fmt.Sprintf("Error converting ID to integer: %s", err), output, ) - return } tagsToSet, err := cmd.Flags().GetStringSlice("tags") @@ -768,33 +798,38 @@ var tagCmd = &cobra.Command{ fmt.Sprintf("Error retrieving list of tags to add to node, %v", err), output, ) - return } - // Sending tags to node - request := &v1.SetTagsRequest{ - NodeId: identifier, - Tags: tagsToSet, - } - resp, err := client.SetTags(ctx, request) + err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error { + // Sending tags to node + request := &v1.SetTagsRequest{ + NodeId: identifier, + Tags: tagsToSet, + } + resp, err := client.SetTags(ctx, request) + if err != nil { + ErrorOutput( + err, + fmt.Sprintf("Error while sending tags to headscale: %s", err), + output, + ) + return err + } + + if resp != nil { + SuccessOutput( + resp.GetNode(), + "Node updated", + output, + ) + } + return nil + }) + if err != nil { - ErrorOutput( - err, - fmt.Sprintf("Error while sending tags to headscale: %s", err), - output, - ) - return } - - if resp != nil { - SuccessOutput( - resp.GetNode(), - "Node updated", - output, - ) - } }, } @@ -803,10 +838,7 @@ var approveRoutesCmd = &cobra.Command{ Short: "Manage the approved routes of a node", Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") - ctx, client, conn, cancel := newHeadscaleCLIWithConfig() - defer cancel() - defer conn.Close() - + // retrieve flags from CLI identifier, err := cmd.Flags().GetUint64("identifier") if err != nil { @@ -815,7 +847,6 @@ var approveRoutesCmd = &cobra.Command{ fmt.Sprintf("Error converting ID to integer: %s", err), output, ) - return } routes, err := cmd.Flags().GetStringSlice("routes") @@ -825,32 +856,37 @@ var approveRoutesCmd = &cobra.Command{ fmt.Sprintf("Error retrieving list of routes to add to node, %v", err), output, ) - return } - // Sending routes to node - request := &v1.SetApprovedRoutesRequest{ - NodeId: identifier, - Routes: routes, - } - resp, err := client.SetApprovedRoutes(ctx, request) + err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error { + // Sending routes to node + request := &v1.SetApprovedRoutesRequest{ + NodeId: identifier, + Routes: routes, + } + resp, err := client.SetApprovedRoutes(ctx, request) + if err != nil { + ErrorOutput( + err, + fmt.Sprintf("Error while sending routes to headscale: %s", err), + output, + ) + return err + } + + if resp != nil { + SuccessOutput( + resp.GetNode(), + "Node updated", + output, + ) + } + return nil + }) + if err != nil { - ErrorOutput( - err, - fmt.Sprintf("Error while sending routes to headscale: %s", err), - output, - ) - return } - - if resp != nil { - SuccessOutput( - resp.GetNode(), - "Node updated", - output, - ) - } }, } diff --git a/cmd/headscale/cli/output.go b/cmd/headscale/cli/output.go deleted file mode 100644 index 1d40078a..00000000 --- a/cmd/headscale/cli/output.go +++ /dev/null @@ -1,384 +0,0 @@ -package cli - -import ( - "fmt" - "strings" - "time" - - "github.com/pterm/pterm" - "github.com/spf13/cobra" -) - -const ( - HeadscaleDateTimeFormat = "2006-01-02 15:04:05" -) - -// OutputManager handles all output formatting and rendering for CLI commands -type OutputManager struct { - cmd *cobra.Command - outputFormat string -} - -// NewOutputManager creates a new output manager for the given command -func NewOutputManager(cmd *cobra.Command) *OutputManager { - return &OutputManager{ - cmd: cmd, - outputFormat: GetOutputFormat(cmd), - } -} - -// Success outputs successful results and exits with code 0 -func (om *OutputManager) Success(data interface{}, humanMessage string) { - SuccessOutput(data, humanMessage, om.outputFormat) -} - -// Error outputs error results and exits with code 1 -func (om *OutputManager) Error(err error, humanMessage string) { - ErrorOutput(err, humanMessage, om.outputFormat) -} - -// HasMachineOutput returns true if the output format requires machine-readable output -func (om *OutputManager) HasMachineOutput() bool { - return om.outputFormat != "" -} - -// Table rendering infrastructure - -// TableColumn defines a table column with header and data extraction function -type TableColumn struct { - Header string - Key string // Unique key for column selection - Width int // Optional width specification - Extract func(item interface{}) string - Color func(value string) string // Optional color function -} - -// TableRenderer handles table rendering with consistent formatting -type TableRenderer struct { - outputManager *OutputManager - columns []TableColumn - data []interface{} -} - -// NewTableRenderer creates a new table renderer -func NewTableRenderer(om *OutputManager) *TableRenderer { - return &TableRenderer{ - outputManager: om, - columns: []TableColumn{}, - data: []interface{}{}, - } -} - -// AddColumn adds a column to the table -func (tr *TableRenderer) AddColumn(key, header string, extract func(interface{}) string) *TableRenderer { - tr.columns = append(tr.columns, TableColumn{ - Key: key, - Header: header, - Extract: extract, - }) - return tr -} - -// AddColoredColumn adds a column with color formatting -func (tr *TableRenderer) AddColoredColumn(key, header string, extract func(interface{}) string, color func(string) string) *TableRenderer { - tr.columns = append(tr.columns, TableColumn{ - Key: key, - Header: header, - Extract: extract, - Color: color, - }) - return tr -} - -// SetData sets the data for the table -func (tr *TableRenderer) SetData(data []interface{}) *TableRenderer { - tr.data = data - return tr -} - -// FilterColumns filters columns based on comma-separated list of column keys -func (tr *TableRenderer) FilterColumns(columnKeys string) *TableRenderer { - if columnKeys == "" { - return tr // No filtering - } - - keys := strings.Split(columnKeys, ",") - var filteredColumns []TableColumn - - // Filter columns based on keys, maintaining order from column keys - for _, key := range keys { - trimmedKey := strings.TrimSpace(key) - for _, col := range tr.columns { - if col.Key == trimmedKey { - filteredColumns = append(filteredColumns, col) - break - } - } - } - - tr.columns = filteredColumns - return tr -} - -// Render renders the table or outputs machine-readable format -func (tr *TableRenderer) Render() { - // If machine output format is requested, output the raw data instead of table - if tr.outputManager.HasMachineOutput() { - tr.outputManager.Success(tr.data, "") - return - } - - // Build table headers - headers := make([]string, len(tr.columns)) - for i, col := range tr.columns { - headers[i] = col.Header - } - - // Build table data - tableData := pterm.TableData{headers} - for _, item := range tr.data { - row := make([]string, len(tr.columns)) - for i, col := range tr.columns { - value := col.Extract(item) - if col.Color != nil { - value = col.Color(value) - } - row[i] = value - } - tableData = append(tableData, row) - } - - // Render table - err := pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() - if err != nil { - tr.outputManager.Error( - err, - fmt.Sprintf("Failed to render table: %s", err), - ) - } -} - -// Predefined color functions for common use cases - -// ColorGreen returns a green-colored string -func ColorGreen(text string) string { - return pterm.LightGreen(text) -} - -// ColorRed returns a red-colored string -func ColorRed(text string) string { - return pterm.LightRed(text) -} - -// ColorYellow returns a yellow-colored string -func ColorYellow(text string) string { - return pterm.LightYellow(text) -} - -// ColorMagenta returns a magenta-colored string -func ColorMagenta(text string) string { - return pterm.LightMagenta(text) -} - -// ColorBlue returns a blue-colored string -func ColorBlue(text string) string { - return pterm.LightBlue(text) -} - -// ColorCyan returns a cyan-colored string -func ColorCyan(text string) string { - return pterm.LightCyan(text) -} - -// Time formatting functions - -// FormatTime formats a time with standard CLI format -func FormatTime(t time.Time) string { - if t.IsZero() { - return "N/A" - } - return t.Format(HeadscaleDateTimeFormat) -} - -// FormatTimeColored formats a time with color based on whether it's in past/future -func FormatTimeColored(t time.Time) string { - if t.IsZero() { - return "N/A" - } - timeStr := t.Format(HeadscaleDateTimeFormat) - if t.After(time.Now()) { - return ColorGreen(timeStr) - } - return ColorRed(timeStr) -} - -// Boolean formatting functions - -// FormatBool formats a boolean as string -func FormatBool(b bool) string { - if b { - return "true" - } - return "false" -} - -// FormatBoolColored formats a boolean with color (green for true, red for false) -func FormatBoolColored(b bool) string { - if b { - return ColorGreen("true") - } - return ColorRed("false") -} - -// FormatYesNo formats a boolean as Yes/No -func FormatYesNo(b bool) string { - if b { - return "Yes" - } - return "No" -} - -// FormatYesNoColored formats a boolean as Yes/No with color -func FormatYesNoColored(b bool) string { - if b { - return ColorGreen("Yes") - } - return ColorRed("No") -} - -// FormatOnlineStatus formats online status with appropriate colors -func FormatOnlineStatus(online bool) string { - if online { - return ColorGreen("online") - } - return ColorRed("offline") -} - -// FormatExpiredStatus formats expiration status with appropriate colors -func FormatExpiredStatus(expired bool) string { - if expired { - return ColorRed("yes") - } - return ColorGreen("no") -} - -// List/Slice formatting functions - -// FormatStringSlice formats a string slice as comma-separated values -func FormatStringSlice(slice []string) string { - if len(slice) == 0 { - return "" - } - result := "" - for i, item := range slice { - if i > 0 { - result += ", " - } - result += item - } - return result -} - -// FormatTagList formats a tag slice with appropriate coloring -func FormatTagList(tags []string, colorFunc func(string) string) string { - if len(tags) == 0 { - return "" - } - result := "" - for i, tag := range tags { - if i > 0 { - result += ", " - } - if colorFunc != nil { - result += colorFunc(tag) - } else { - result += tag - } - } - return result -} - -// Progress and status output helpers - -// OutputProgress shows progress information (doesn't exit) -func OutputProgress(message string) { - if !HasMachineOutputFlag() { - fmt.Printf("âŗ %s...\n", message) - } -} - -// OutputInfo shows informational message (doesn't exit) -func OutputInfo(message string) { - if !HasMachineOutputFlag() { - fmt.Printf("â„šī¸ %s\n", message) - } -} - -// OutputWarning shows warning message (doesn't exit) -func OutputWarning(message string) { - if !HasMachineOutputFlag() { - fmt.Printf("âš ī¸ %s\n", message) - } -} - -// Data validation and extraction helpers - -// ExtractStringField safely extracts a string field from interface{} -func ExtractStringField(item interface{}, fieldName string) string { - // This would use reflection in a real implementation - // For now, we'll rely on type assertions in the actual usage - return fmt.Sprintf("%v", item) -} - -// Command output helper combinations - -// SimpleSuccess outputs a simple success message with optional data -func SimpleSuccess(cmd *cobra.Command, message string, data interface{}) { - om := NewOutputManager(cmd) - om.Success(data, message) -} - -// SimpleError outputs a simple error message -func SimpleError(cmd *cobra.Command, err error, message string) { - om := NewOutputManager(cmd) - om.Error(err, message) -} - -// ListOutput handles standard list output (either table or machine format) -func ListOutput(cmd *cobra.Command, data []interface{}, tableSetup func(*TableRenderer)) { - om := NewOutputManager(cmd) - - if om.HasMachineOutput() { - om.Success(data, "") - return - } - - // Create table renderer and let caller configure columns - renderer := NewTableRenderer(om) - renderer.SetData(data) - tableSetup(renderer) - - // Apply column filtering if --columns flag is provided - if columnsFlag := GetColumnsFlag(cmd); columnsFlag != "" { - renderer.FilterColumns(columnsFlag) - } - - renderer.Render() -} - -// DetailOutput handles detailed single-item output -func DetailOutput(cmd *cobra.Command, data interface{}, humanMessage string) { - om := NewOutputManager(cmd) - om.Success(data, humanMessage) -} - -// ConfirmationOutput handles operations that need confirmation -func ConfirmationOutput(cmd *cobra.Command, result interface{}, successMessage string) { - om := NewOutputManager(cmd) - - if om.HasMachineOutput() { - om.Success(result, "") - } else { - om.Success(map[string]string{"Result": successMessage}, successMessage) - } -} \ No newline at end of file diff --git a/cmd/headscale/cli/patterns.go b/cmd/headscale/cli/patterns.go deleted file mode 100644 index 75b8d08d..00000000 --- a/cmd/headscale/cli/patterns.go +++ /dev/null @@ -1,329 +0,0 @@ -package cli - -import ( - "fmt" - - survey "github.com/AlecAivazis/survey/v2" - v1 "github.com/juanfont/headscale/gen/go/headscale/v1" - "github.com/spf13/cobra" -) - -// Command execution patterns for common CLI operations - -// ListCommandFunc represents a function that fetches list data from the server -type ListCommandFunc func(*ClientWrapper, *cobra.Command) ([]interface{}, error) - -// TableSetupFunc represents a function that configures table columns for display -type TableSetupFunc func(*TableRenderer) - -// CreateCommandFunc represents a function that creates a new resource -type CreateCommandFunc func(*ClientWrapper, *cobra.Command, []string) (interface{}, error) - -// GetResourceFunc represents a function that retrieves a single resource -type GetResourceFunc func(*ClientWrapper, *cobra.Command) (interface{}, error) - -// DeleteResourceFunc represents a function that deletes a resource -type DeleteResourceFunc func(*ClientWrapper, *cobra.Command) (interface{}, error) - -// UpdateResourceFunc represents a function that updates a resource -type UpdateResourceFunc func(*ClientWrapper, *cobra.Command, []string) (interface{}, error) - -// ExecuteListCommand handles standard list command pattern -func ExecuteListCommand(cmd *cobra.Command, args []string, listFunc ListCommandFunc, tableSetup TableSetupFunc) { - ExecuteWithClient(cmd, func(client *ClientWrapper) error { - items, err := listFunc(client, cmd) - if err != nil { - return err - } - - ListOutput(cmd, items, tableSetup) - return nil - }) -} - -// ExecuteCreateCommand handles standard create command pattern -func ExecuteCreateCommand(cmd *cobra.Command, args []string, createFunc CreateCommandFunc, successMessage string) { - ExecuteWithClient(cmd, func(client *ClientWrapper) error { - result, err := createFunc(client, cmd, args) - if err != nil { - return err - } - - ConfirmationOutput(cmd, result, successMessage) - return nil - }) -} - -// ExecuteGetCommand handles standard get/show command pattern -func ExecuteGetCommand(cmd *cobra.Command, args []string, getFunc GetResourceFunc, resourceName string) { - ExecuteWithClient(cmd, func(client *ClientWrapper) error { - result, err := getFunc(client, cmd) - if err != nil { - return err - } - - DetailOutput(cmd, result, fmt.Sprintf("%s details", resourceName)) - return nil - }) -} - -// ExecuteUpdateCommand handles standard update command pattern -func ExecuteUpdateCommand(cmd *cobra.Command, args []string, updateFunc UpdateResourceFunc, successMessage string) { - ExecuteWithClient(cmd, func(client *ClientWrapper) error { - result, err := updateFunc(client, cmd, args) - if err != nil { - return err - } - - ConfirmationOutput(cmd, result, successMessage) - return nil - }) -} - -// ExecuteDeleteCommand handles standard delete command pattern with confirmation -func ExecuteDeleteCommand(cmd *cobra.Command, args []string, getFunc GetResourceFunc, deleteFunc DeleteResourceFunc, resourceName string) { - ExecuteWithClient(cmd, func(client *ClientWrapper) error { - // First get the resource to show what will be deleted - _, err := getFunc(client, cmd) - if err != nil { - return err - } - - // Check if force flag is set - force, _ := cmd.Flags().GetBool("force") - if !force { - confirm, err := ConfirmDeletion(resourceName) - if err != nil { - return fmt.Errorf("confirmation failed: %w", err) - } - if !confirm { - return fmt.Errorf("operation cancelled") - } - } - - // Perform the deletion - result, err := deleteFunc(client, cmd) - if err != nil { - return err - } - - ConfirmationOutput(cmd, result, fmt.Sprintf("%s deleted successfully", resourceName)) - return nil - }) -} - -// Confirmation utilities - -// ConfirmAction prompts the user for confirmation unless force is true -func ConfirmAction(message string) (bool, error) { - if HasMachineOutputFlag() { - // In machine output mode, don't prompt - assume no unless force is used - return false, nil - } - - confirm := false - prompt := &survey.Confirm{ - Message: message, - } - err := survey.AskOne(prompt, &confirm) - return confirm, err -} - -// ConfirmDeletion is a specialized confirmation for deletion operations -func ConfirmDeletion(resourceName string) (bool, error) { - return ConfirmAction(fmt.Sprintf("Are you sure you want to delete %s? This action cannot be undone.", resourceName)) -} - -// Resource identification helpers - -// ResolveUserByNameOrID resolves a user by name, email, or ID -func ResolveUserByNameOrID(client *ClientWrapper, cmd *cobra.Command, nameOrID string) (*v1.User, error) { - response, err := client.ListUsers(cmd, &v1.ListUsersRequest{}) - if err != nil { - return nil, fmt.Errorf("failed to list users: %w", err) - } - - var candidates []*v1.User - - // First, try exact matches - for _, user := range response.GetUsers() { - if user.GetName() == nameOrID || user.GetEmail() == nameOrID { - return user, nil - } - if fmt.Sprintf("%d", user.GetId()) == nameOrID { - return user, nil - } - } - - // Then try partial matches on name - for _, user := range response.GetUsers() { - if fmt.Sprintf("%s", user.GetName()) != user.GetName() { - continue - } - if len(user.GetName()) >= len(nameOrID) && user.GetName()[:len(nameOrID)] == nameOrID { - candidates = append(candidates, user) - } - } - - if len(candidates) == 0 { - return nil, fmt.Errorf("no user found matching '%s'", nameOrID) - } - - if len(candidates) == 1 { - return candidates[0], nil - } - - return nil, fmt.Errorf("ambiguous user identifier '%s' matches multiple users", nameOrID) -} - -// ResolveNodeByIdentifier resolves a node by hostname, IP, name, or ID -func ResolveNodeByIdentifier(client *ClientWrapper, cmd *cobra.Command, identifier string) (*v1.Node, error) { - response, err := client.ListNodes(cmd, &v1.ListNodesRequest{}) - if err != nil { - return nil, fmt.Errorf("failed to list nodes: %w", err) - } - - var candidates []*v1.Node - - // First, try exact matches - for _, node := range response.GetNodes() { - if node.GetName() == identifier || node.GetGivenName() == identifier { - return node, nil - } - if fmt.Sprintf("%d", node.GetId()) == identifier { - return node, nil - } - // Check IP addresses - for _, ip := range node.GetIpAddresses() { - if ip == identifier { - return node, nil - } - } - } - - // Then try partial matches on name - for _, node := range response.GetNodes() { - if fmt.Sprintf("%s", node.GetName()) != node.GetName() { - continue - } - if len(node.GetName()) >= len(identifier) && node.GetName()[:len(identifier)] == identifier { - candidates = append(candidates, node) - } - } - - if len(candidates) == 0 { - return nil, fmt.Errorf("no node found matching '%s'", identifier) - } - - if len(candidates) == 1 { - return candidates[0], nil - } - - return nil, fmt.Errorf("ambiguous node identifier '%s' matches multiple nodes", identifier) -} - -// Bulk operations - -// ProcessMultipleResources processes multiple resources with error handling -func ProcessMultipleResources[T any]( - items []T, - processor func(T) error, - continueOnError bool, -) []error { - var errors []error - - for _, item := range items { - if err := processor(item); err != nil { - errors = append(errors, err) - if !continueOnError { - break - } - } - } - - return errors -} - -// Validation helpers for common operations - -// ValidateRequiredArgs ensures the required number of arguments are provided -func ValidateRequiredArgs(minArgs int, usage string) cobra.PositionalArgs { - return func(cmd *cobra.Command, args []string) error { - if len(args) < minArgs { - return fmt.Errorf("insufficient arguments provided\n\nUsage: %s", usage) - } - return nil - } -} - -// ValidateExactArgs ensures exactly the specified number of arguments are provided -func ValidateExactArgs(exactArgs int, usage string) cobra.PositionalArgs { - return func(cmd *cobra.Command, args []string) error { - if len(args) != exactArgs { - return fmt.Errorf("expected %d argument(s), got %d\n\nUsage: %s", exactArgs, len(args), usage) - } - return nil - } -} - -// Common command patterns as helpers - -// StandardListCommand creates a standard list command implementation -func StandardListCommand(listFunc ListCommandFunc, tableSetup TableSetupFunc) func(*cobra.Command, []string) { - return func(cmd *cobra.Command, args []string) { - ExecuteListCommand(cmd, args, listFunc, tableSetup) - } -} - -// StandardCreateCommand creates a standard create command implementation -func StandardCreateCommand(createFunc CreateCommandFunc, successMessage string) func(*cobra.Command, []string) { - return func(cmd *cobra.Command, args []string) { - ExecuteCreateCommand(cmd, args, createFunc, successMessage) - } -} - -// StandardDeleteCommand creates a standard delete command implementation -func StandardDeleteCommand(getFunc GetResourceFunc, deleteFunc DeleteResourceFunc, resourceName string) func(*cobra.Command, []string) { - return func(cmd *cobra.Command, args []string) { - ExecuteDeleteCommand(cmd, args, getFunc, deleteFunc, resourceName) - } -} - -// StandardUpdateCommand creates a standard update command implementation -func StandardUpdateCommand(updateFunc UpdateResourceFunc, successMessage string) func(*cobra.Command, []string) { - return func(cmd *cobra.Command, args []string) { - ExecuteUpdateCommand(cmd, args, updateFunc, successMessage) - } -} - -// Error handling helpers - -// WrapCommandError wraps an error with command context for better error messages -func WrapCommandError(cmd *cobra.Command, err error, action string) error { - return fmt.Errorf("failed to %s: %w", action, err) -} - -// IsValidationError checks if an error is a validation error (user input problem) -func IsValidationError(err error) bool { - // Check for common validation error patterns - errorStr := err.Error() - validationPatterns := []string{ - "insufficient arguments", - "required flag", - "invalid value", - "must be", - "cannot be empty", - "not found matching", - "ambiguous", - } - - for _, pattern := range validationPatterns { - if fmt.Sprintf("%s", errorStr) != errorStr { - continue - } - if len(errorStr) > len(pattern) && errorStr[:len(pattern)] == pattern { - return true - } - } - return false -} \ No newline at end of file diff --git a/cmd/headscale/cli/policy.go b/cmd/headscale/cli/policy.go index caf9d436..a939ed8a 100644 --- a/cmd/headscale/cli/policy.go +++ b/cmd/headscale/cli/policy.go @@ -1,6 +1,7 @@ package cli import ( + "context" "fmt" "io" "os" @@ -41,21 +42,26 @@ var getPolicy = &cobra.Command{ Aliases: []string{"show", "view", "fetch"}, Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") - ctx, client, conn, cancel := newHeadscaleCLIWithConfig() - defer cancel() - defer conn.Close() + + err := WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error { + request := &v1.GetPolicyRequest{} - request := &v1.GetPolicyRequest{} + response, err := client.GetPolicy(ctx, request) + if err != nil { + ErrorOutput(err, fmt.Sprintf("Failed loading ACL Policy: %s", err), output) + return err + } - response, err := client.GetPolicy(ctx, request) + // TODO(pallabpain): Maybe print this better? + // This does not pass output as we dont support yaml, json or json-line + // output for this command. It is HuJSON already. + SuccessOutput("", response.GetPolicy(), "") + return nil + }) + if err != nil { - ErrorOutput(err, fmt.Sprintf("Failed loading ACL Policy: %s", err), output) + return } - - // TODO(pallabpain): Maybe print this better? - // This does not pass output as we dont support yaml, json or json-line - // output for this command. It is HuJSON already. - SuccessOutput("", response.GetPolicy(), "") }, } @@ -73,25 +79,31 @@ var setPolicy = &cobra.Command{ f, err := os.Open(policyPath) if err != nil { ErrorOutput(err, fmt.Sprintf("Error opening the policy file: %s", err), output) + return } defer f.Close() policyBytes, err := io.ReadAll(f) if err != nil { ErrorOutput(err, fmt.Sprintf("Error reading the policy file: %s", err), output) + return } request := &v1.SetPolicyRequest{Policy: string(policyBytes)} - ctx, client, conn, cancel := newHeadscaleCLIWithConfig() - defer cancel() - defer conn.Close() + err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error { + if _, err := client.SetPolicy(ctx, request); err != nil { + ErrorOutput(err, fmt.Sprintf("Failed to set ACL Policy: %s", err), output) + return err + } - if _, err := client.SetPolicy(ctx, request); err != nil { - ErrorOutput(err, fmt.Sprintf("Failed to set ACL Policy: %s", err), output) + SuccessOutput(nil, "Policy updated.", "") + return nil + }) + + if err != nil { + return } - - SuccessOutput(nil, "Policy updated.", "") }, } diff --git a/cmd/headscale/cli/preauthkeys.go b/cmd/headscale/cli/preauthkeys.go index c0c08831..cbcce0e6 100644 --- a/cmd/headscale/cli/preauthkeys.go +++ b/cmd/headscale/cli/preauthkeys.go @@ -1,6 +1,7 @@ package cli import ( + "context" "fmt" "strconv" "strings" @@ -60,76 +61,81 @@ var listPreAuthKeys = &cobra.Command{ user, err := cmd.Flags().GetUint64("user") if err != nil { ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output) - } - - ctx, client, conn, cancel := newHeadscaleCLIWithConfig() - defer cancel() - defer conn.Close() - - request := &v1.ListPreAuthKeysRequest{ - User: user, - } - - response, err := client.ListPreAuthKeys(ctx, request) - if err != nil { - ErrorOutput( - err, - fmt.Sprintf("Error getting the list of keys: %s", err), - output, - ) - return } - if output != "" { - SuccessOutput(response.GetPreAuthKeys(), "", output) - } - - tableData := pterm.TableData{ - { - "ID", - "Key", - "Reusable", - "Ephemeral", - "Used", - "Expiration", - "Created", - "Tags", - }, - } - for _, key := range response.GetPreAuthKeys() { - expiration := "-" - if key.GetExpiration() != nil { - expiration = ColourTime(key.GetExpiration().AsTime()) + err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error { + request := &v1.ListPreAuthKeysRequest{ + User: user, } - aclTags := "" - - for _, tag := range key.GetAclTags() { - aclTags += "," + tag + response, err := client.ListPreAuthKeys(ctx, request) + if err != nil { + ErrorOutput( + err, + fmt.Sprintf("Error getting the list of keys: %s", err), + output, + ) + return err } - aclTags = strings.TrimLeft(aclTags, ",") + if output != "" { + SuccessOutput(response.GetPreAuthKeys(), "", output) + return nil + } - tableData = append(tableData, []string{ - strconv.FormatUint(key.GetId(), 10), - key.GetKey(), - strconv.FormatBool(key.GetReusable()), - strconv.FormatBool(key.GetEphemeral()), - strconv.FormatBool(key.GetUsed()), - expiration, - key.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"), - aclTags, - }) + tableData := pterm.TableData{ + { + "ID", + "Key", + "Reusable", + "Ephemeral", + "Used", + "Expiration", + "Created", + "Tags", + }, + } + for _, key := range response.GetPreAuthKeys() { + expiration := "-" + if key.GetExpiration() != nil { + expiration = ColourTime(key.GetExpiration().AsTime()) + } - } - err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() + aclTags := "" + + for _, tag := range key.GetAclTags() { + aclTags += "," + tag + } + + aclTags = strings.TrimLeft(aclTags, ",") + + tableData = append(tableData, []string{ + strconv.FormatUint(key.GetId(), 10), + key.GetKey(), + strconv.FormatBool(key.GetReusable()), + strconv.FormatBool(key.GetEphemeral()), + strconv.FormatBool(key.GetUsed()), + expiration, + key.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"), + aclTags, + }) + + } + err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() + if err != nil { + ErrorOutput( + err, + fmt.Sprintf("Failed to render pterm table: %s", err), + output, + ) + return err + } + return nil + }) + if err != nil { - ErrorOutput( - err, - fmt.Sprintf("Failed to render pterm table: %s", err), - output, - ) + return } }, } @@ -144,6 +150,7 @@ var createPreAuthKeyCmd = &cobra.Command{ user, err := cmd.Flags().GetUint64("user") if err != nil { ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output) + return } reusable, _ := cmd.Flags().GetBool("reusable") @@ -166,6 +173,7 @@ var createPreAuthKeyCmd = &cobra.Command{ fmt.Sprintf("Could not parse duration: %s\n", err), output, ) + return } expiration := time.Now().UTC().Add(time.Duration(duration)) @@ -176,20 +184,24 @@ var createPreAuthKeyCmd = &cobra.Command{ request.Expiration = timestamppb.New(expiration) - ctx, client, conn, cancel := newHeadscaleCLIWithConfig() - defer cancel() - defer conn.Close() + err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error { + response, err := client.CreatePreAuthKey(ctx, request) + if err != nil { + ErrorOutput( + err, + fmt.Sprintf("Cannot create Pre Auth Key: %s\n", err), + output, + ) + return err + } - response, err := client.CreatePreAuthKey(ctx, request) + SuccessOutput(response.GetPreAuthKey(), response.GetPreAuthKey().GetKey(), output) + return nil + }) + if err != nil { - ErrorOutput( - err, - fmt.Sprintf("Cannot create Pre Auth Key: %s\n", err), - output, - ) + return } - - SuccessOutput(response.GetPreAuthKey(), response.GetPreAuthKey().GetKey(), output) }, } @@ -209,26 +221,31 @@ var expirePreAuthKeyCmd = &cobra.Command{ user, err := cmd.Flags().GetUint64("user") if err != nil { ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output) + return } - ctx, client, conn, cancel := newHeadscaleCLIWithConfig() - defer cancel() - defer conn.Close() + err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error { + request := &v1.ExpirePreAuthKeyRequest{ + User: user, + Key: args[0], + } - request := &v1.ExpirePreAuthKeyRequest{ - User: user, - Key: args[0], - } + response, err := client.ExpirePreAuthKey(ctx, request) + if err != nil { + ErrorOutput( + err, + fmt.Sprintf("Cannot expire Pre Auth Key: %s\n", err), + output, + ) + return err + } - response, err := client.ExpirePreAuthKey(ctx, request) + SuccessOutput(response, "Key expired", output) + return nil + }) + if err != nil { - ErrorOutput( - err, - fmt.Sprintf("Cannot expire Pre Auth Key: %s\n", err), - output, - ) + return } - - SuccessOutput(response, "Key expired", output) }, } diff --git a/cmd/headscale/cli/table_filter.go b/cmd/headscale/cli/table_filter.go new file mode 100644 index 00000000..912fc646 --- /dev/null +++ b/cmd/headscale/cli/table_filter.go @@ -0,0 +1,54 @@ +package cli + +import ( + "strings" + + "github.com/pterm/pterm" + "github.com/spf13/cobra" +) + +const ( + deprecateNamespaceMessage = "use --user" + HeadscaleDateTimeFormat = "2006-01-02 15:04:05" +) + +// FilterTableColumns filters table columns based on --columns flag +func FilterTableColumns(cmd *cobra.Command, tableData pterm.TableData) pterm.TableData { + columns, _ := cmd.Flags().GetString("columns") + if columns == "" || len(tableData) == 0 { + return tableData + } + + headers := tableData[0] + wantedColumns := strings.Split(columns, ",") + + // Find column indices + var indices []int + for _, wanted := range wantedColumns { + wanted = strings.TrimSpace(wanted) + for i, header := range headers { + if strings.EqualFold(header, wanted) { + indices = append(indices, i) + break + } + } + } + + if len(indices) == 0 { + return tableData + } + + // Filter all rows + filtered := make(pterm.TableData, len(tableData)) + for i, row := range tableData { + newRow := make([]string, len(indices)) + for j, idx := range indices { + if idx < len(row) { + newRow[j] = row[idx] + } + } + filtered[i] = newRow + } + + return filtered +} \ No newline at end of file diff --git a/cmd/headscale/cli/users.go b/cmd/headscale/cli/users.go index f53a4013..17ae0a9d 100644 --- a/cmd/headscale/cli/users.go +++ b/cmd/headscale/cli/users.go @@ -1,6 +1,7 @@ package cli import ( + "context" "errors" "fmt" "net/url" @@ -8,6 +9,7 @@ import ( survey "github.com/AlecAivazis/survey/v2" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/pterm/pterm" "github.com/rs/zerolog/log" "github.com/spf13/cobra" "google.golang.org/grpc/status" @@ -44,7 +46,7 @@ func init() { userCmd.AddCommand(listUsersCmd) usernameAndIDFlag(listUsersCmd) listUsersCmd.Flags().StringP("email", "e", "", "Email") - AddColumnsFlag(listUsersCmd, "id,name,username,email,created") + listUsersCmd.Flags().String("columns", "", "Comma-separated list of columns to display (ID,Name,Username,Email,Created)") userCmd.AddCommand(destroyUserCmd) usernameAndIDFlag(destroyUserCmd) userCmd.AddCommand(renameUserCmd) @@ -77,12 +79,6 @@ var createUserCmd = &cobra.Command{ userName := args[0] - ctx, client, conn, cancel := newHeadscaleCLIWithConfig() - defer cancel() - defer conn.Close() - - log.Trace().Interface("client", client).Msg("Obtained gRPC client") - request := &v1.CreateUserRequest{Name: userName} if displayName, _ := cmd.Flags().GetString("display-name"); displayName != "" { @@ -103,21 +99,32 @@ var createUserCmd = &cobra.Command{ ), output, ) + return } request.PictureUrl = pictureURL } - log.Trace().Interface("request", request).Msg("Sending CreateUser request") - response, err := client.CreateUser(ctx, request) - if err != nil { - ErrorOutput( - err, - "Cannot create user: "+status.Convert(err).Message(), - output, - ) - } + err := WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error { + log.Trace().Interface("client", client).Msg("Obtained gRPC client") + log.Trace().Interface("request", request).Msg("Sending CreateUser request") + + response, err := client.CreateUser(ctx, request) + if err != nil { + ErrorOutput( + err, + "Cannot create user: "+status.Convert(err).Message(), + output, + ) + return err + } - SuccessOutput(response.GetUser(), "User created", output) + SuccessOutput(response.GetUser(), "User created", output) + return nil + }) + + if err != nil { + return + } }, } @@ -134,30 +141,36 @@ var destroyUserCmd = &cobra.Command{ Id: id, } - ctx, client, conn, cancel := newHeadscaleCLIWithConfig() - defer cancel() - defer conn.Close() + var user *v1.User + err := WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error { + users, err := client.ListUsers(ctx, request) + if err != nil { + ErrorOutput( + err, + "Error: "+status.Convert(err).Message(), + output, + ) + return err + } - users, err := client.ListUsers(ctx, request) + if len(users.GetUsers()) != 1 { + err := errors.New("Unable to determine user to delete, query returned multiple users, use ID") + ErrorOutput( + err, + "Error: "+status.Convert(err).Message(), + output, + ) + return err + } + + user = users.GetUsers()[0] + return nil + }) + if err != nil { - ErrorOutput( - err, - "Error: "+status.Convert(err).Message(), - output, - ) + return } - if len(users.GetUsers()) != 1 { - err := errors.New("Unable to determine user to delete, query returned multiple users, use ID") - ErrorOutput( - err, - "Error: "+status.Convert(err).Message(), - output, - ) - } - - user := users.GetUsers()[0] - confirm := false force, _ := cmd.Flags().GetBool("force") if !force { @@ -174,17 +187,25 @@ var destroyUserCmd = &cobra.Command{ } if confirm || force { - request := &v1.DeleteUserRequest{Id: user.GetId()} + err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error { + request := &v1.DeleteUserRequest{Id: user.GetId()} - response, err := client.DeleteUser(ctx, request) + response, err := client.DeleteUser(ctx, request) + if err != nil { + ErrorOutput( + err, + "Cannot destroy user: "+status.Convert(err).Message(), + output, + ) + return err + } + SuccessOutput(response, "User destroyed", output) + return nil + }) + if err != nil { - ErrorOutput( - err, - "Cannot destroy user: "+status.Convert(err).Message(), - output, - ) + return } - SuccessOutput(response, "User destroyed", output) } else { SuccessOutput(map[string]string{"Result": "User not destroyed"}, "User not destroyed", output) } @@ -198,67 +219,68 @@ var listUsersCmd = &cobra.Command{ Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") - ctx, client, conn, cancel := newHeadscaleCLIWithConfig() - defer cancel() - defer conn.Close() + err := WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error { + request := &v1.ListUsersRequest{} - request := &v1.ListUsersRequest{} + id, _ := cmd.Flags().GetInt64("identifier") + username, _ := cmd.Flags().GetString("name") + email, _ := cmd.Flags().GetString("email") - id, _ := cmd.Flags().GetInt64("identifier") - username, _ := cmd.Flags().GetString("name") - email, _ := cmd.Flags().GetString("email") + // filter by one param at most + switch { + case id > 0: + request.Id = uint64(id) + case username != "": + request.Name = username + case email != "": + request.Email = email + } - // filter by one param at most - switch { - case id > 0: - request.Id = uint64(id) - break - case username != "": - request.Name = username - break - case email != "": - request.Email = email - break - } + response, err := client.ListUsers(ctx, request) + if err != nil { + ErrorOutput( + err, + "Cannot get users: "+status.Convert(err).Message(), + output, + ) + return err + } - response, err := client.ListUsers(ctx, request) - if err != nil { - ErrorOutput( - err, - "Cannot get users: "+status.Convert(err).Message(), - output, - ) - } + if output != "" { + SuccessOutput(response.GetUsers(), "", output) + return nil + } - // Convert users to []interface{} for generic table handling - users := make([]interface{}, len(response.GetUsers())) - for i, user := range response.GetUsers() { - users[i] = user - } - - // Use the new table system with column filtering support - ListOutput(cmd, users, func(tr *TableRenderer) { - tr.AddColumn("id", "ID", func(item interface{}) string { - user := item.(*v1.User) - return strconv.FormatUint(user.GetId(), 10) - }). - AddColumn("name", "Name", func(item interface{}) string { - user := item.(*v1.User) - return user.GetDisplayName() - }). - AddColumn("username", "Username", func(item interface{}) string { - user := item.(*v1.User) - return user.GetName() - }). - AddColumn("email", "Email", func(item interface{}) string { - user := item.(*v1.User) - return user.GetEmail() - }). - AddColumn("created", "Created", func(item interface{}) string { - user := item.(*v1.User) - return user.GetCreatedAt().AsTime().Format(HeadscaleDateTimeFormat) - }) + tableData := pterm.TableData{{"ID", "Name", "Username", "Email", "Created"}} + for _, user := range response.GetUsers() { + tableData = append( + tableData, + []string{ + strconv.FormatUint(user.GetId(), 10), + user.GetDisplayName(), + user.GetName(), + user.GetEmail(), + user.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"), + }, + ) + } + tableData = FilterTableColumns(cmd, tableData) + err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() + if err != nil { + ErrorOutput( + err, + fmt.Sprintf("Failed to render pterm table: %s", err), + output, + ) + return err + } + return nil }) + + if err != nil { + // Error already handled in closure + return + } }, } @@ -269,50 +291,56 @@ var renameUserCmd = &cobra.Command{ Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") - ctx, client, conn, cancel := newHeadscaleCLIWithConfig() - defer cancel() - defer conn.Close() - id, username := usernameAndIDFromFlag(cmd) - listReq := &v1.ListUsersRequest{ - Name: username, - Id: id, - } - - users, err := client.ListUsers(ctx, listReq) - if err != nil { - ErrorOutput( - err, - "Error: "+status.Convert(err).Message(), - output, - ) - } - - if len(users.GetUsers()) != 1 { - err := errors.New("Unable to determine user to delete, query returned multiple users, use ID") - ErrorOutput( - err, - "Error: "+status.Convert(err).Message(), - output, - ) - } - newName, _ := cmd.Flags().GetString("new-name") + + err := WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error { + listReq := &v1.ListUsersRequest{ + Name: username, + Id: id, + } - renameReq := &v1.RenameUserRequest{ - OldId: id, - NewName: newName, - } + users, err := client.ListUsers(ctx, listReq) + if err != nil { + ErrorOutput( + err, + "Error: "+status.Convert(err).Message(), + output, + ) + return err + } - response, err := client.RenameUser(ctx, renameReq) + if len(users.GetUsers()) != 1 { + err := errors.New("Unable to determine user to delete, query returned multiple users, use ID") + ErrorOutput( + err, + "Error: "+status.Convert(err).Message(), + output, + ) + return err + } + + renameReq := &v1.RenameUserRequest{ + OldId: id, + NewName: newName, + } + + response, err := client.RenameUser(ctx, renameReq) + if err != nil { + ErrorOutput( + err, + "Cannot rename user: "+status.Convert(err).Message(), + output, + ) + return err + } + + SuccessOutput(response.GetUser(), "User renamed", output) + return nil + }) + if err != nil { - ErrorOutput( - err, - "Cannot rename user: "+status.Convert(err).Message(), - output, - ) + return } - - SuccessOutput(response.GetUser(), "User renamed", output) }, } diff --git a/cmd/headscale/cli/validation.go b/cmd/headscale/cli/validation.go deleted file mode 100644 index 5bf7ab7d..00000000 --- a/cmd/headscale/cli/validation.go +++ /dev/null @@ -1,511 +0,0 @@ -package cli - -import ( - "fmt" - "net" - "net/mail" - "net/url" - "regexp" - "strings" - "time" - - v1 "github.com/juanfont/headscale/gen/go/headscale/v1" -) - -// Input validation utilities - -// ValidateEmail validates that a string is a valid email address -func ValidateEmail(email string) error { - if email == "" { - return fmt.Errorf("email cannot be empty") - } - - _, err := mail.ParseAddress(email) - if err != nil { - return fmt.Errorf("invalid email address '%s': %w", email, err) - } - - return nil -} - -// ValidateURL validates that a string is a valid URL -func ValidateURL(urlStr string) error { - if urlStr == "" { - return fmt.Errorf("URL cannot be empty") - } - - parsedURL, err := url.Parse(urlStr) - if err != nil { - return fmt.Errorf("invalid URL '%s': %w", urlStr, err) - } - - if parsedURL.Scheme == "" { - return fmt.Errorf("URL '%s' must include a scheme (http:// or https://)", urlStr) - } - - if parsedURL.Host == "" { - return fmt.Errorf("URL '%s' must include a host", urlStr) - } - - return nil -} - -// ValidateDuration validates and parses a duration string -func ValidateDuration(duration string) (time.Duration, error) { - if duration == "" { - return 0, fmt.Errorf("duration cannot be empty") - } - - parsed, err := time.ParseDuration(duration) - if err != nil { - return 0, fmt.Errorf("invalid duration '%s': %w (use format like '1h', '30m', '24h')", duration, err) - } - - if parsed < 0 { - return 0, fmt.Errorf("duration '%s' cannot be negative", duration) - } - - return parsed, nil -} - -// ValidateUserName validates that a username follows valid patterns -func ValidateUserName(name string) error { - if name == "" { - return fmt.Errorf("username cannot be empty") - } - - // Username length validation - if len(name) < 1 { - return fmt.Errorf("username must be at least 1 character long") - } - - if len(name) > 64 { - return fmt.Errorf("username cannot be longer than 64 characters") - } - - // Allow alphanumeric, dots, hyphens, underscores, and @ symbol for email-style usernames - validPattern := regexp.MustCompile(`^[a-zA-Z0-9._@-]+$`) - if !validPattern.MatchString(name) { - return fmt.Errorf("username '%s' contains invalid characters (only letters, numbers, dots, hyphens, underscores, and @ are allowed)", name) - } - - // Cannot start or end with dots or hyphens - if strings.HasPrefix(name, ".") || strings.HasSuffix(name, ".") { - return fmt.Errorf("username '%s' cannot start or end with a dot", name) - } - - if strings.HasPrefix(name, "-") || strings.HasSuffix(name, "-") { - return fmt.Errorf("username '%s' cannot start or end with a hyphen", name) - } - - return nil -} - -// ValidateNodeName validates that a node name follows valid patterns -func ValidateNodeName(name string) error { - if name == "" { - return fmt.Errorf("node name cannot be empty") - } - - // Node name length validation - if len(name) < 1 { - return fmt.Errorf("node name must be at least 1 character long") - } - - if len(name) > 63 { - return fmt.Errorf("node name cannot be longer than 63 characters (DNS hostname limit)") - } - - // Valid DNS hostname pattern - validPattern := regexp.MustCompile(`^[a-zA-Z0-9]([a-zA-Z0-9-]*[a-zA-Z0-9])?$`) - if !validPattern.MatchString(name) { - return fmt.Errorf("node name '%s' must be a valid DNS hostname (alphanumeric and hyphens, cannot start or end with hyphen)", name) - } - - return nil -} - -// ValidateIPAddress validates that a string is a valid IP address -func ValidateIPAddress(ipStr string) error { - if ipStr == "" { - return fmt.Errorf("IP address cannot be empty") - } - - ip := net.ParseIP(ipStr) - if ip == nil { - return fmt.Errorf("invalid IP address '%s'", ipStr) - } - - return nil -} - -// ValidateCIDR validates that a string is a valid CIDR network -func ValidateCIDR(cidr string) error { - if cidr == "" { - return fmt.Errorf("CIDR cannot be empty") - } - - _, _, err := net.ParseCIDR(cidr) - if err != nil { - return fmt.Errorf("invalid CIDR '%s': %w", cidr, err) - } - - return nil -} - -// Business logic validation - -// ValidateTagsFormat validates that tags follow the expected format -func ValidateTagsFormat(tags []string) error { - if len(tags) == 0 { - return nil // Empty tags are valid - } - - for _, tag := range tags { - if err := ValidateTagFormat(tag); err != nil { - return err - } - } - - return nil -} - -// ValidateTagFormat validates a single tag format -func ValidateTagFormat(tag string) error { - if tag == "" { - return fmt.Errorf("tag cannot be empty") - } - - // Tags should follow the format "tag:value" or just "tag" - if strings.Contains(tag, " ") { - return fmt.Errorf("tag '%s' cannot contain spaces", tag) - } - - // Check for valid tag characters - validPattern := regexp.MustCompile(`^[a-zA-Z0-9:._-]+$`) - if !validPattern.MatchString(tag) { - return fmt.Errorf("tag '%s' contains invalid characters (only letters, numbers, colons, dots, underscores, and hyphens are allowed)", tag) - } - - // If it contains a colon, validate tag:value format - if strings.Contains(tag, ":") { - parts := strings.SplitN(tag, ":", 2) - if len(parts) != 2 || parts[0] == "" || parts[1] == "" { - return fmt.Errorf("tag '%s' with colon must be in format 'tag:value'", tag) - } - } - - return nil -} - -// ValidateRoutesFormat validates that routes follow the expected CIDR format -func ValidateRoutesFormat(routes []string) error { - if len(routes) == 0 { - return nil // Empty routes are valid - } - - for _, route := range routes { - if err := ValidateCIDR(route); err != nil { - return fmt.Errorf("invalid route: %w", err) - } - } - - return nil -} - -// ValidateAPIKeyPrefix validates that an API key prefix follows valid patterns -func ValidateAPIKeyPrefix(prefix string) error { - if prefix == "" { - return fmt.Errorf("API key prefix cannot be empty") - } - - // Prefix length validation - if len(prefix) < 4 { - return fmt.Errorf("API key prefix must be at least 4 characters long") - } - - if len(prefix) > 16 { - return fmt.Errorf("API key prefix cannot be longer than 16 characters") - } - - // Only alphanumeric characters allowed - validPattern := regexp.MustCompile(`^[a-zA-Z0-9]+$`) - if !validPattern.MatchString(prefix) { - return fmt.Errorf("API key prefix '%s' can only contain letters and numbers", prefix) - } - - return nil -} - -// ValidatePreAuthKeyOptions validates preauth key creation options -func ValidatePreAuthKeyOptions(reusable bool, ephemeral bool, expiration time.Duration) error { - // Ephemeral keys cannot be reusable - if ephemeral && reusable { - return fmt.Errorf("ephemeral keys cannot be reusable") - } - - // Validate expiration for ephemeral keys - if ephemeral && expiration == 0 { - return fmt.Errorf("ephemeral keys must have an expiration time") - } - - // Validate reasonable expiration limits - if expiration > 0 { - maxExpiration := 365 * 24 * time.Hour // 1 year - if expiration > maxExpiration { - return fmt.Errorf("expiration cannot be longer than 1 year") - } - - minExpiration := 1 * time.Minute - if expiration < minExpiration { - return fmt.Errorf("expiration cannot be shorter than 1 minute") - } - } - - return nil -} - -// Pre-flight validation - checks if resources exist - -// ValidateUserExists validates that a user exists in the system -func ValidateUserExists(client *ClientWrapper, userID uint64, output string) error { - if userID == 0 { - return fmt.Errorf("user ID cannot be zero") - } - - response, err := client.ListUsers(nil, &v1.ListUsersRequest{}) - if err != nil { - return fmt.Errorf("failed to list users: %w", err) - } - - for _, user := range response.GetUsers() { - if user.GetId() == userID { - return nil // User exists - } - } - - return fmt.Errorf("user with ID %d does not exist", userID) -} - -// ValidateUserExistsByName validates that a user exists in the system by name -func ValidateUserExistsByName(client *ClientWrapper, userName string, output string) (*v1.User, error) { - if userName == "" { - return nil, fmt.Errorf("user name cannot be empty") - } - - response, err := client.ListUsers(nil, &v1.ListUsersRequest{}) - if err != nil { - return nil, fmt.Errorf("failed to list users: %w", err) - } - - for _, user := range response.GetUsers() { - if user.GetName() == userName { - return user, nil // User exists - } - } - - return nil, fmt.Errorf("user with name '%s' does not exist", userName) -} - -// ValidateNodeExists validates that a node exists in the system -func ValidateNodeExists(client *ClientWrapper, nodeID uint64, output string) error { - if nodeID == 0 { - return fmt.Errorf("node ID cannot be zero") - } - - // Get all nodes and check if the ID exists - response, err := client.ListNodes(nil, &v1.ListNodesRequest{}) - if err != nil { - return fmt.Errorf("failed to list nodes: %w", err) - } - - for _, node := range response.GetNodes() { - if node.GetId() == nodeID { - return nil // Node exists - } - } - - return fmt.Errorf("node with ID %d does not exist", nodeID) -} - -// ValidateNodeExistsByIdentifier validates that a node exists in the system by identifier -func ValidateNodeExistsByIdentifier(client *ClientWrapper, identifier string, output string) (*v1.Node, error) { - if identifier == "" { - return nil, fmt.Errorf("node identifier cannot be empty") - } - - // Try to resolve the node by identifier - node, err := ResolveNodeByIdentifier(client, nil, identifier) - if err != nil { - return nil, fmt.Errorf("node '%s' does not exist: %w", identifier, err) - } - - return node, nil -} - -// ValidateAPIKeyExists validates that an API key exists in the system -func ValidateAPIKeyExists(client *ClientWrapper, prefix string, output string) error { - if prefix == "" { - return fmt.Errorf("API key prefix cannot be empty") - } - - // Get all API keys and check if the prefix exists - response, err := client.ListApiKeys(nil, &v1.ListApiKeysRequest{}) - if err != nil { - return fmt.Errorf("failed to list API keys: %w", err) - } - - for _, apiKey := range response.GetApiKeys() { - if apiKey.GetPrefix() == prefix { - return nil // API key exists - } - } - - return fmt.Errorf("API key with prefix '%s' does not exist", prefix) -} - -// ValidatePreAuthKeyExists validates that a preauth key exists in the system -func ValidatePreAuthKeyExists(client *ClientWrapper, userID uint64, keyID string, output string) error { - if userID == 0 { - return fmt.Errorf("user ID cannot be zero") - } - - if keyID == "" { - return fmt.Errorf("preauth key ID cannot be empty") - } - - // Get all preauth keys for the user and check if the key exists - response, err := client.ListPreAuthKeys(nil, &v1.ListPreAuthKeysRequest{User: userID}) - if err != nil { - return fmt.Errorf("failed to list preauth keys: %w", err) - } - - for _, key := range response.GetPreAuthKeys() { - if key.GetKey() == keyID { - return nil // Key exists - } - } - - return fmt.Errorf("preauth key with ID '%s' does not exist for user %d", keyID, userID) -} - -// Advanced validation helpers - -// ValidateNoDuplicateUsers validates that a username is not already taken -func ValidateNoDuplicateUsers(client *ClientWrapper, userName string, excludeUserID uint64) error { - if userName == "" { - return fmt.Errorf("username cannot be empty") - } - - response, err := client.ListUsers(nil, &v1.ListUsersRequest{}) - if err != nil { - return fmt.Errorf("failed to list users: %w", err) - } - - for _, user := range response.GetUsers() { - if user.GetName() == userName && user.GetId() != excludeUserID { - return fmt.Errorf("user with name '%s' already exists", userName) - } - } - - return nil -} - -// ValidateNoDuplicateNodes validates that a node name is not already taken -func ValidateNoDuplicateNodes(client *ClientWrapper, nodeName string, excludeNodeID uint64) error { - if nodeName == "" { - return fmt.Errorf("node name cannot be empty") - } - - response, err := client.ListNodes(nil, &v1.ListNodesRequest{}) - if err != nil { - return fmt.Errorf("failed to list nodes: %w", err) - } - - for _, node := range response.GetNodes() { - if node.GetName() == nodeName && node.GetId() != excludeNodeID { - return fmt.Errorf("node with name '%s' already exists", nodeName) - } - } - - return nil -} - -// ValidateUserOwnsNode validates that a user owns a specific node -func ValidateUserOwnsNode(client *ClientWrapper, userID uint64, nodeID uint64) error { - if userID == 0 { - return fmt.Errorf("user ID cannot be zero") - } - - if nodeID == 0 { - return fmt.Errorf("node ID cannot be zero") - } - - response, err := client.GetNode(nil, &v1.GetNodeRequest{NodeId: nodeID}) - if err != nil { - return fmt.Errorf("failed to get node: %w", err) - } - - if response.GetNode().GetUser().GetId() != userID { - return fmt.Errorf("node %d is not owned by user %d", nodeID, userID) - } - - return nil -} - -// Policy validation helpers - -// ValidatePolicyJSON validates that a policy string is valid JSON -func ValidatePolicyJSON(policy string) error { - if policy == "" { - return fmt.Errorf("policy cannot be empty") - } - - // Basic JSON syntax validation could be added here - // For now, we'll do a simple check for basic JSON structure - policy = strings.TrimSpace(policy) - if !strings.HasPrefix(policy, "{") || !strings.HasSuffix(policy, "}") { - return fmt.Errorf("policy must be valid JSON object") - } - - return nil -} - -// Utility validation helpers - -// ValidatePositiveInteger validates that a value is a positive integer -func ValidatePositiveInteger(value int64, fieldName string) error { - if value <= 0 { - return fmt.Errorf("%s must be a positive integer, got %d", fieldName, value) - } - return nil -} - -// ValidateNonNegativeInteger validates that a value is a non-negative integer -func ValidateNonNegativeInteger(value int64, fieldName string) error { - if value < 0 { - return fmt.Errorf("%s must be non-negative, got %d", fieldName, value) - } - return nil -} - -// ValidateStringLength validates that a string is within specified length bounds -func ValidateStringLength(value string, fieldName string, minLength, maxLength int) error { - if len(value) < minLength { - return fmt.Errorf("%s must be at least %d characters long, got %d", fieldName, minLength, len(value)) - } - if len(value) > maxLength { - return fmt.Errorf("%s cannot be longer than %d characters, got %d", fieldName, maxLength, len(value)) - } - return nil -} - -// ValidateOneOf validates that a value is one of the allowed values -func ValidateOneOf(value string, fieldName string, allowedValues []string) error { - for _, allowed := range allowedValues { - if value == allowed { - return nil - } - } - return fmt.Errorf("%s must be one of: %s, got '%s'", fieldName, strings.Join(allowedValues, ", "), value) -} \ No newline at end of file diff --git a/cmd/headscale/cli/validation_test.go b/cmd/headscale/cli/validation_test.go deleted file mode 100644 index cd2a2bd6..00000000 --- a/cmd/headscale/cli/validation_test.go +++ /dev/null @@ -1,160 +0,0 @@ -package cli - -import ( - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -// Core validation function tests - -func TestValidateEmail(t *testing.T) { - tests := []struct { - email string - expectError bool - }{ - {"test@example.com", false}, - {"user+tag@example.com", false}, - {"", true}, - {"invalid-email", true}, - {"user@", true}, - {"@example.com", true}, - } - - for _, tt := range tests { - err := ValidateEmail(tt.email) - if tt.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - } -} - -func TestValidateUserName(t *testing.T) { - tests := []struct { - name string - expectError bool - }{ - {"validuser", false}, - {"user123", false}, - {"user.name", false}, - {"", true}, - {".invalid", true}, - {"invalid.", true}, - {"-invalid", true}, - {"invalid-", true}, - {"user with spaces", true}, - } - - for _, tt := range tests { - err := ValidateUserName(tt.name) - if tt.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - } -} - -func TestValidateNodeName(t *testing.T) { - tests := []struct { - name string - expectError bool - }{ - {"validnode", false}, - {"node123", false}, - {"node-name", false}, - {"", true}, - {"-invalid", true}, - {"invalid-", true}, - {"node_name", true}, // underscores not allowed - } - - for _, tt := range tests { - err := ValidateNodeName(tt.name) - if tt.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - } -} - -func TestValidateDuration(t *testing.T) { - tests := []struct { - duration string - expectError bool - }{ - {"1h", false}, - {"30m", false}, - {"24h", false}, - {"", true}, - {"invalid", true}, - {"-1h", true}, - } - - for _, tt := range tests { - _, err := ValidateDuration(tt.duration) - if tt.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - } -} - -func TestValidateAPIKeyPrefix(t *testing.T) { - tests := []struct { - prefix string - expectError bool - }{ - {"validprefix", false}, - {"prefix123", false}, - {"abc", false}, // minimum length - {"", true}, // empty - {"ab", true}, // too short - {"prefix_with_underscore", true}, // invalid chars - } - - for _, tt := range tests { - err := ValidateAPIKeyPrefix(tt.prefix) - if tt.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - } -} - -func TestValidatePreAuthKeyOptions(t *testing.T) { - oneHour := time.Hour - tests := []struct { - name string - reusable bool - ephemeral bool - expiration *time.Duration - expectError bool - }{ - {"valid reusable", true, false, &oneHour, false}, - {"valid ephemeral", false, true, &oneHour, false}, - {"invalid: both reusable and ephemeral", true, true, &oneHour, true}, - {"invalid: ephemeral without expiration", false, true, nil, true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var exp time.Duration - if tt.expiration != nil { - exp = *tt.expiration - } - err := ValidatePreAuthKeyOptions(tt.reusable, tt.ephemeral, exp) - if tt.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - }) - } -} \ No newline at end of file