This commit is contained in:
Kristoffer Dalby 2025-07-14 15:50:46 +00:00
parent 7d31735bac
commit 9d2cfb1e7e
26 changed files with 216 additions and 8147 deletions

View File

@ -0,0 +1,63 @@
# Column Filtering for Table Output
## Overview
All CLI commands that output tables now support a `--columns` flag to customize which columns are displayed.
## Usage
```bash
# Show all default columns
headscale users list
# Show only name and email
headscale users list --columns="name,email"
# Show only ID and username
headscale users list --columns="id,username"
# Show columns in custom order
headscale users list --columns="email,name,id"
```
## Available Columns
### Users List
- `id` - User ID
- `name` - Display name
- `username` - Username
- `email` - Email address
- `created` - Creation date
### Implementation Pattern
For developers adding this to other commands:
```go
// 1. Add columns flag with default columns
AddColumnsFlag(cmd, "id,name,hostname,ip,status")
// 2. Use ListOutput with TableRenderer
ListOutput(cmd, items, func(tr *TableRenderer) {
tr.AddColumn("id", "ID", func(item interface{}) string {
node := item.(*v1.Node)
return strconv.FormatUint(node.GetId(), 10)
}).
AddColumn("name", "Name", func(item interface{}) string {
node := item.(*v1.Node)
return node.GetName()
}).
AddColumn("hostname", "Hostname", func(item interface{}) string {
node := item.(*v1.Node)
return node.GetHostname()
})
// ... add more columns
})
```
## Notes
- Column filtering only applies to table output, not JSON/YAML output
- Invalid column names are silently ignored
- Columns appear in the order specified in the --columns flag
- Default columns are defined per command based on most useful information

View File

@ -1,362 +0,0 @@
package cli
import (
"testing"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAPIKeysCommand(t *testing.T) {
// Test the main apikeys command
assert.NotNil(t, apiKeysCmd)
assert.Equal(t, "apikeys", apiKeysCmd.Use)
assert.Equal(t, "Handle the Api keys in Headscale", apiKeysCmd.Short)
// Test aliases
expectedAliases := []string{"apikey", "api"}
assert.Equal(t, expectedAliases, apiKeysCmd.Aliases)
// Test that apikeys command has subcommands
subcommands := apiKeysCmd.Commands()
assert.Greater(t, len(subcommands), 0, "API keys command should have subcommands")
// Verify expected subcommands exist
subcommandNames := make([]string, len(subcommands))
for i, cmd := range subcommands {
subcommandNames[i] = cmd.Use
}
expectedSubcommands := []string{"list", "create", "expire", "delete"}
for _, expected := range expectedSubcommands {
found := false
for _, actual := range subcommandNames {
if actual == expected {
found = true
break
}
}
assert.True(t, found, "Expected subcommand '%s' not found", expected)
}
}
func TestListAPIKeysCommand(t *testing.T) {
assert.NotNil(t, listAPIKeys)
assert.Equal(t, "list", listAPIKeys.Use)
assert.Equal(t, "List the Api keys for headscale", listAPIKeys.Short)
assert.Equal(t, []string{"ls", "show"}, listAPIKeys.Aliases)
// Test that Run function is set
assert.NotNil(t, listAPIKeys.Run)
}
func TestCreateAPIKeyCommand(t *testing.T) {
assert.NotNil(t, createAPIKeyCmd)
assert.Equal(t, "create", createAPIKeyCmd.Use)
assert.Equal(t, "Creates a new Api key", createAPIKeyCmd.Short)
assert.Equal(t, []string{"c", "new"}, createAPIKeyCmd.Aliases)
// Test that Run function is set
assert.NotNil(t, createAPIKeyCmd.Run)
// Test that Long description is set
assert.NotEmpty(t, createAPIKeyCmd.Long)
assert.Contains(t, createAPIKeyCmd.Long, "Creates a new Api key")
assert.Contains(t, createAPIKeyCmd.Long, "only visible on creation")
// Test flags
flags := createAPIKeyCmd.Flags()
assert.NotNil(t, flags.Lookup("expiration"))
// Test flag properties
expirationFlag := flags.Lookup("expiration")
assert.Equal(t, "e", expirationFlag.Shorthand)
assert.Equal(t, DefaultAPIKeyExpiry, expirationFlag.DefValue)
assert.Contains(t, expirationFlag.Usage, "Human-readable expiration")
}
func TestExpireAPIKeyCommand(t *testing.T) {
assert.NotNil(t, expireAPIKeyCmd)
assert.Equal(t, "expire", expireAPIKeyCmd.Use)
assert.Equal(t, "Expire an ApiKey", expireAPIKeyCmd.Short)
assert.Equal(t, []string{"revoke", "exp", "e"}, expireAPIKeyCmd.Aliases)
// Test that Run function is set
assert.NotNil(t, expireAPIKeyCmd.Run)
// Test flags
flags := expireAPIKeyCmd.Flags()
assert.NotNil(t, flags.Lookup("prefix"))
// Test flag properties
prefixFlag := flags.Lookup("prefix")
assert.Equal(t, "p", prefixFlag.Shorthand)
assert.Equal(t, "ApiKey prefix", prefixFlag.Usage)
// Test that prefix flag is required
// Note: We can't directly test MarkFlagRequired, but we can check the annotations
annotations := prefixFlag.Annotations
if annotations != nil {
// cobra adds required annotation when MarkFlagRequired is called
_, hasRequired := annotations[cobra.BashCompOneRequiredFlag]
assert.True(t, hasRequired, "prefix flag should be marked as required")
}
}
func TestDeleteAPIKeyCommand(t *testing.T) {
assert.NotNil(t, deleteAPIKeyCmd)
assert.Equal(t, "delete", deleteAPIKeyCmd.Use)
assert.Equal(t, "Delete an ApiKey", deleteAPIKeyCmd.Short)
assert.Equal(t, []string{"remove", "del"}, deleteAPIKeyCmd.Aliases)
// Test that Run function is set
assert.NotNil(t, deleteAPIKeyCmd.Run)
// Test flags
flags := deleteAPIKeyCmd.Flags()
assert.NotNil(t, flags.Lookup("prefix"))
// Test flag properties
prefixFlag := flags.Lookup("prefix")
assert.Equal(t, "p", prefixFlag.Shorthand)
assert.Equal(t, "ApiKey prefix", prefixFlag.Usage)
// Test that prefix flag is required
annotations := prefixFlag.Annotations
if annotations != nil {
_, hasRequired := annotations[cobra.BashCompOneRequiredFlag]
assert.True(t, hasRequired, "prefix flag should be marked as required")
}
}
func TestAPIKeyConstants(t *testing.T) {
// Test that constants are defined
assert.Equal(t, "90d", DefaultAPIKeyExpiry)
}
func TestAPIKeyCommandStructure(t *testing.T) {
// Validate command structure and help text
ValidateCommandStructure(t, apiKeysCmd, "apikeys", "Handle the Api keys in Headscale")
ValidateCommandHelp(t, apiKeysCmd)
// Validate subcommands
ValidateCommandStructure(t, listAPIKeys, "list", "List the Api keys for headscale")
ValidateCommandHelp(t, listAPIKeys)
ValidateCommandStructure(t, createAPIKeyCmd, "create", "Creates a new Api key")
ValidateCommandHelp(t, createAPIKeyCmd)
ValidateCommandStructure(t, expireAPIKeyCmd, "expire", "Expire an ApiKey")
ValidateCommandHelp(t, expireAPIKeyCmd)
ValidateCommandStructure(t, deleteAPIKeyCmd, "delete", "Delete an ApiKey")
ValidateCommandHelp(t, deleteAPIKeyCmd)
}
func TestAPIKeyCommandFlags(t *testing.T) {
// Test create API key command flags
ValidateCommandFlags(t, createAPIKeyCmd, []string{"expiration"})
// Test expire API key command flags
ValidateCommandFlags(t, expireAPIKeyCmd, []string{"prefix"})
// Test delete API key command flags
ValidateCommandFlags(t, deleteAPIKeyCmd, []string{"prefix"})
}
func TestAPIKeyCommandIntegration(t *testing.T) {
// Test that apikeys command is properly integrated into root command
found := false
for _, cmd := range rootCmd.Commands() {
if cmd.Use == "apikeys" {
found = true
break
}
}
assert.True(t, found, "API keys command should be added to root command")
}
func TestAPIKeySubcommandIntegration(t *testing.T) {
// Test that all subcommands are properly added to apikeys command
subcommands := apiKeysCmd.Commands()
expectedCommands := map[string]bool{
"list": false,
"create": false,
"expire": false,
"delete": false,
}
for _, subcmd := range subcommands {
if _, exists := expectedCommands[subcmd.Use]; exists {
expectedCommands[subcmd.Use] = true
}
}
for cmdName, found := range expectedCommands {
assert.True(t, found, "Subcommand '%s' should be added to apikeys command", cmdName)
}
}
func TestAPIKeyCommandAliases(t *testing.T) {
// Test that all aliases are properly set
testCases := []struct {
command *cobra.Command
expectedAliases []string
}{
{
command: apiKeysCmd,
expectedAliases: []string{"apikey", "api"},
},
{
command: listAPIKeys,
expectedAliases: []string{"ls", "show"},
},
{
command: createAPIKeyCmd,
expectedAliases: []string{"c", "new"},
},
{
command: expireAPIKeyCmd,
expectedAliases: []string{"revoke", "exp", "e"},
},
{
command: deleteAPIKeyCmd,
expectedAliases: []string{"remove", "del"},
},
}
for _, tc := range testCases {
t.Run(tc.command.Use, func(t *testing.T) {
assert.Equal(t, tc.expectedAliases, tc.command.Aliases)
})
}
}
func TestAPIKeyFlagDefaults(t *testing.T) {
// Test create API key command flag defaults
flags := createAPIKeyCmd.Flags()
// Test expiration flag default
expiration, err := flags.GetString("expiration")
assert.NoError(t, err)
assert.Equal(t, DefaultAPIKeyExpiry, expiration)
}
func TestAPIKeyFlagShortcuts(t *testing.T) {
// Test that flag shortcuts are properly set
// Create command
expirationFlag := createAPIKeyCmd.Flags().Lookup("expiration")
assert.Equal(t, "e", expirationFlag.Shorthand)
// Expire command
prefixFlag1 := expireAPIKeyCmd.Flags().Lookup("prefix")
assert.Equal(t, "p", prefixFlag1.Shorthand)
// Delete command
prefixFlag2 := deleteAPIKeyCmd.Flags().Lookup("prefix")
assert.Equal(t, "p", prefixFlag2.Shorthand)
}
func TestAPIKeyCommandsHaveOutputFlag(t *testing.T) {
// All API key commands should support output formatting
commands := []*cobra.Command{listAPIKeys, createAPIKeyCmd, expireAPIKeyCmd, deleteAPIKeyCmd}
for _, cmd := range commands {
t.Run(cmd.Use, func(t *testing.T) {
// Commands should be able to get output flag (though it might be inherited)
// This tests that the commands are designed to work with output formatting
assert.NotNil(t, cmd.Run, "Command should have a Run function")
})
}
}
func TestAPIKeyCommandCompleteness(t *testing.T) {
// Test that API key command covers all expected CRUD operations
subcommands := apiKeysCmd.Commands()
operations := map[string]bool{
"create": false,
"read": false, // list command
"update": false, // expire command (updates state)
"delete": false, // delete command
}
for _, subcmd := range subcommands {
switch subcmd.Use {
case "create":
operations["create"] = true
case "list":
operations["read"] = true
case "expire":
operations["update"] = true
case "delete":
operations["delete"] = true
}
}
for op, found := range operations {
assert.True(t, found, "API key command should support %s operation", op)
}
}
func TestAPIKeyCommandUsagePatterns(t *testing.T) {
// Test that commands follow consistent usage patterns
// List command should not require arguments
assert.NotNil(t, listAPIKeys.Run)
assert.Nil(t, listAPIKeys.Args) // No args validation means optional args
// Create command should not require arguments
assert.NotNil(t, createAPIKeyCmd.Run)
assert.Nil(t, createAPIKeyCmd.Args)
// Expire and delete commands require prefix flag (tested above)
assert.NotNil(t, expireAPIKeyCmd.Run)
assert.NotNil(t, deleteAPIKeyCmd.Run)
}
func TestAPIKeyCommandDocumentation(t *testing.T) {
// Test that important commands have proper documentation
// Create command should have detailed Long description
assert.NotEmpty(t, createAPIKeyCmd.Long)
assert.Contains(t, createAPIKeyCmd.Long, "only visible on creation")
assert.Contains(t, createAPIKeyCmd.Long, "cannot be retrieved again")
// Other commands should have at least Short descriptions
assert.NotEmpty(t, listAPIKeys.Short)
assert.NotEmpty(t, expireAPIKeyCmd.Short)
assert.NotEmpty(t, deleteAPIKeyCmd.Short)
}
func TestAPIKeyFlagValidation(t *testing.T) {
// Test that flags have proper validation setup
// Test that prefix flags are required where expected
requiredPrefixCommands := []*cobra.Command{expireAPIKeyCmd, deleteAPIKeyCmd}
for _, cmd := range requiredPrefixCommands {
t.Run(cmd.Use+"_prefix_required", func(t *testing.T) {
prefixFlag := cmd.Flags().Lookup("prefix")
require.NotNil(t, prefixFlag)
// Check if flag has required annotation (set by MarkFlagRequired)
if prefixFlag.Annotations != nil {
_, hasRequired := prefixFlag.Annotations[cobra.BashCompOneRequiredFlag]
assert.True(t, hasRequired, "prefix flag should be marked as required for %s command", cmd.Use)
}
})
}
}
func TestAPIKeyDefaultExpiry(t *testing.T) {
// Test that the default expiry constant is reasonable
assert.Equal(t, "90d", DefaultAPIKeyExpiry)
// Test that it can be used in flag defaults
expirationFlag := createAPIKeyCmd.Flags().Lookup("expiration")
assert.Equal(t, DefaultAPIKeyExpiry, expirationFlag.DefValue)
}

View File

@ -1,319 +0,0 @@
package cli
import (
"context"
"testing"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestClientWrapper_NewClient(t *testing.T) {
// This test validates the ClientWrapper structure without requiring actual gRPC connection
// since newHeadscaleCLIWithConfig would require a running headscale server
// Test that NewClient function exists and has the right signature
// We can't actually call it without a server, but we can test the structure
wrapper := &ClientWrapper{
ctx: context.Background(),
client: nil, // Would be set by actual connection
conn: nil, // Would be set by actual connection
cancel: func() {}, // Mock cancel function
}
// Verify wrapper structure
assert.NotNil(t, wrapper.ctx)
assert.NotNil(t, wrapper.cancel)
}
func TestClientWrapper_Close(t *testing.T) {
// Test the Close method with mock values
cancelCalled := false
mockCancel := func() {
cancelCalled = true
}
wrapper := &ClientWrapper{
ctx: context.Background(),
client: nil,
conn: nil, // In real usage would be *grpc.ClientConn
cancel: mockCancel,
}
// Call Close
wrapper.Close()
// Verify cancel was called
assert.True(t, cancelCalled)
}
func TestExecuteWithClient(t *testing.T) {
// Test ExecuteWithClient function structure
// Note: We cannot actually test ExecuteWithClient as it calls newHeadscaleCLIWithConfig()
// which requires a running headscale server. Instead we test that the function exists
// and has the correct signature.
cmd := &cobra.Command{Use: "test"}
AddOutputFlag(cmd)
// Verify the function exists and has the correct signature
assert.NotNil(t, ExecuteWithClient)
// We can't actually call ExecuteWithClient without a server since it would panic
// when trying to connect to headscale. This is expected behavior.
}
func TestClientWrapper_ExecuteWithErrorHandling(t *testing.T) {
// Test the ExecuteWithErrorHandling method structure
// Note: We can't actually test ExecuteWithErrorHandling without a real gRPC client
// since it expects a v1.HeadscaleServiceClient, but we can test the method exists
wrapper := &ClientWrapper{
ctx: context.Background(),
client: nil, // Mock client
conn: nil,
cancel: func() {},
}
// Verify the method exists
assert.NotNil(t, wrapper.ExecuteWithErrorHandling)
}
func TestClientWrapper_NodeOperations(t *testing.T) {
// Test that all node operation methods exist with correct signatures
cmd := &cobra.Command{Use: "test"}
AddOutputFlag(cmd)
wrapper := &ClientWrapper{
ctx: context.Background(),
client: nil,
conn: nil,
cancel: func() {},
}
// Test ListNodes method exists
assert.NotNil(t, wrapper.ListNodes)
// Test RegisterNode method exists
assert.NotNil(t, wrapper.RegisterNode)
// Test DeleteNode method exists
assert.NotNil(t, wrapper.DeleteNode)
// Test ExpireNode method exists
assert.NotNil(t, wrapper.ExpireNode)
// Test RenameNode method exists
assert.NotNil(t, wrapper.RenameNode)
// Test MoveNode method exists
assert.NotNil(t, wrapper.MoveNode)
// Test GetNode method exists
assert.NotNil(t, wrapper.GetNode)
// Test SetTags method exists
assert.NotNil(t, wrapper.SetTags)
// Test SetApprovedRoutes method exists
assert.NotNil(t, wrapper.SetApprovedRoutes)
// Test BackfillNodeIPs method exists
assert.NotNil(t, wrapper.BackfillNodeIPs)
}
func TestClientWrapper_UserOperations(t *testing.T) {
// Test that all user operation methods exist with correct signatures
wrapper := &ClientWrapper{
ctx: context.Background(),
client: nil,
conn: nil,
cancel: func() {},
}
// Test ListUsers method exists
assert.NotNil(t, wrapper.ListUsers)
// Test CreateUser method exists
assert.NotNil(t, wrapper.CreateUser)
// Test RenameUser method exists
assert.NotNil(t, wrapper.RenameUser)
// Test DeleteUser method exists
assert.NotNil(t, wrapper.DeleteUser)
}
func TestClientWrapper_ApiKeyOperations(t *testing.T) {
// Test that all API key operation methods exist with correct signatures
wrapper := &ClientWrapper{
ctx: context.Background(),
client: nil,
conn: nil,
cancel: func() {},
}
// Test ListApiKeys method exists
assert.NotNil(t, wrapper.ListApiKeys)
// Test CreateApiKey method exists
assert.NotNil(t, wrapper.CreateApiKey)
// Test ExpireApiKey method exists
assert.NotNil(t, wrapper.ExpireApiKey)
// Test DeleteApiKey method exists
assert.NotNil(t, wrapper.DeleteApiKey)
}
func TestClientWrapper_PreAuthKeyOperations(t *testing.T) {
// Test that all preauth key operation methods exist with correct signatures
wrapper := &ClientWrapper{
ctx: context.Background(),
client: nil,
conn: nil,
cancel: func() {},
}
// Test ListPreAuthKeys method exists
assert.NotNil(t, wrapper.ListPreAuthKeys)
// Test CreatePreAuthKey method exists
assert.NotNil(t, wrapper.CreatePreAuthKey)
// Test ExpirePreAuthKey method exists
assert.NotNil(t, wrapper.ExpirePreAuthKey)
}
func TestClientWrapper_PolicyOperations(t *testing.T) {
// Test that all policy operation methods exist with correct signatures
wrapper := &ClientWrapper{
ctx: context.Background(),
client: nil,
conn: nil,
cancel: func() {},
}
// Test GetPolicy method exists
assert.NotNil(t, wrapper.GetPolicy)
// Test SetPolicy method exists
assert.NotNil(t, wrapper.SetPolicy)
}
func TestClientWrapper_DebugOperations(t *testing.T) {
// Test that all debug operation methods exist with correct signatures
wrapper := &ClientWrapper{
ctx: context.Background(),
client: nil,
conn: nil,
cancel: func() {},
}
// Test DebugCreateNode method exists
assert.NotNil(t, wrapper.DebugCreateNode)
}
func TestClientWrapper_AllMethodsUseContext(t *testing.T) {
// Verify that ClientWrapper maintains context properly
testCtx := context.WithValue(context.Background(), "test", "value")
wrapper := &ClientWrapper{
ctx: testCtx,
client: nil,
conn: nil,
cancel: func() {},
}
// The context should be preserved
assert.Equal(t, testCtx, wrapper.ctx)
assert.Equal(t, "value", wrapper.ctx.Value("test"))
}
func TestErrorHandling_Integration(t *testing.T) {
// Test error handling integration with flag infrastructure
cmd := &cobra.Command{Use: "test"}
AddOutputFlag(cmd)
// Set output format
err := cmd.Flags().Set("output", "json")
require.NoError(t, err)
// Test that GetOutputFormat works correctly for error handling
outputFormat := GetOutputFormat(cmd)
assert.Equal(t, "json", outputFormat)
// Verify that the integration between client infrastructure and flag infrastructure
// works by testing that GetOutputFormat can be used for error formatting
// (actual ExecuteWithClient testing requires a running server)
assert.Equal(t, "json", GetOutputFormat(cmd))
}
func TestClientInfrastructure_ComprehensiveCoverage(t *testing.T) {
// Test that we have comprehensive coverage of all gRPC methods
// This ensures we haven't missed any gRPC operations in our wrapper
wrapper := &ClientWrapper{}
// Node operations (10 methods)
nodeOps := []interface{}{
wrapper.ListNodes,
wrapper.RegisterNode,
wrapper.DeleteNode,
wrapper.ExpireNode,
wrapper.RenameNode,
wrapper.MoveNode,
wrapper.GetNode,
wrapper.SetTags,
wrapper.SetApprovedRoutes,
wrapper.BackfillNodeIPs,
}
// User operations (4 methods)
userOps := []interface{}{
wrapper.ListUsers,
wrapper.CreateUser,
wrapper.RenameUser,
wrapper.DeleteUser,
}
// API key operations (4 methods)
apiKeyOps := []interface{}{
wrapper.ListApiKeys,
wrapper.CreateApiKey,
wrapper.ExpireApiKey,
wrapper.DeleteApiKey,
}
// PreAuth key operations (3 methods)
preAuthOps := []interface{}{
wrapper.ListPreAuthKeys,
wrapper.CreatePreAuthKey,
wrapper.ExpirePreAuthKey,
}
// Policy operations (2 methods)
policyOps := []interface{}{
wrapper.GetPolicy,
wrapper.SetPolicy,
}
// Debug operations (1 method)
debugOps := []interface{}{
wrapper.DebugCreateNode,
}
// Verify all operation arrays have methods
allOps := [][]interface{}{nodeOps, userOps, apiKeyOps, preAuthOps, policyOps, debugOps}
for i, ops := range allOps {
for j, op := range ops {
assert.NotNil(t, op, "Operation %d in category %d should not be nil", j, i)
}
}
// Total should be 24 gRPC wrapper methods
totalMethods := len(nodeOps) + len(userOps) + len(apiKeyOps) + len(preAuthOps) + len(policyOps) + len(debugOps)
assert.Equal(t, 24, totalMethods, "Should have exactly 24 gRPC operation wrapper methods")
}

View File

@ -1,181 +0,0 @@
package cli
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestCommandStructure tests that all expected commands exist and are properly configured
func TestCommandStructure(t *testing.T) {
// Test version command
assert.NotNil(t, versionCmd)
assert.Equal(t, "version", versionCmd.Use)
assert.Equal(t, "Print the version.", versionCmd.Short)
assert.Equal(t, "The version of headscale.", versionCmd.Long)
assert.NotNil(t, versionCmd.Run)
// Test generate command
assert.NotNil(t, generateCmd)
assert.Equal(t, "generate", generateCmd.Use)
assert.Equal(t, "Generate commands", generateCmd.Short)
assert.Contains(t, generateCmd.Aliases, "gen")
// Test generate private-key subcommand
assert.NotNil(t, generatePrivateKeyCmd)
assert.Equal(t, "private-key", generatePrivateKeyCmd.Use)
assert.Equal(t, "Generate a private key for the headscale server", generatePrivateKeyCmd.Short)
assert.NotNil(t, generatePrivateKeyCmd.Run)
// Test that generate has private-key as subcommand
found := false
for _, subcmd := range generateCmd.Commands() {
if subcmd.Name() == "private-key" {
found = true
break
}
}
assert.True(t, found, "private-key should be a subcommand of generate")
}
// TestNodeCommandStructure tests the node command hierarchy
func TestNodeCommandStructure(t *testing.T) {
assert.NotNil(t, nodeCmd)
assert.Equal(t, "nodes", nodeCmd.Use)
assert.Equal(t, "Manage the nodes of Headscale", nodeCmd.Short)
assert.Contains(t, nodeCmd.Aliases, "node")
assert.Contains(t, nodeCmd.Aliases, "machine")
assert.Contains(t, nodeCmd.Aliases, "machines")
// Test some key subcommands exist
subcommands := make(map[string]bool)
for _, subcmd := range nodeCmd.Commands() {
subcommands[subcmd.Name()] = true
}
expectedSubcommands := []string{"list", "register", "delete", "expire", "rename", "move", "tag", "approve-routes", "list-routes", "backfillips"}
for _, expected := range expectedSubcommands {
assert.True(t, subcommands[expected], "Node command should have %s subcommand", expected)
}
}
// TestUserCommandStructure tests the user command hierarchy
func TestUserCommandStructure(t *testing.T) {
assert.NotNil(t, userCmd)
assert.Equal(t, "users", userCmd.Use)
assert.Equal(t, "Manage the users of Headscale", userCmd.Short)
assert.Contains(t, userCmd.Aliases, "user")
assert.Contains(t, userCmd.Aliases, "namespace")
assert.Contains(t, userCmd.Aliases, "namespaces")
// Test some key subcommands exist
subcommands := make(map[string]bool)
for _, subcmd := range userCmd.Commands() {
subcommands[subcmd.Name()] = true
}
expectedSubcommands := []string{"list", "create", "rename", "destroy"}
for _, expected := range expectedSubcommands {
assert.True(t, subcommands[expected], "User command should have %s subcommand", expected)
}
}
// TestRootCommandStructure tests the root command setup
func TestRootCommandStructure(t *testing.T) {
assert.NotNil(t, rootCmd)
assert.Equal(t, "headscale", rootCmd.Use)
assert.Equal(t, "headscale - a Tailscale control server", rootCmd.Short)
assert.Contains(t, rootCmd.Long, "headscale is an open source implementation")
// Check that persistent flags are set up
outputFlag := rootCmd.PersistentFlags().Lookup("output")
assert.NotNil(t, outputFlag)
assert.Equal(t, "o", outputFlag.Shorthand)
configFlag := rootCmd.PersistentFlags().Lookup("config")
assert.NotNil(t, configFlag)
assert.Equal(t, "c", configFlag.Shorthand)
forceFlag := rootCmd.PersistentFlags().Lookup("force")
assert.NotNil(t, forceFlag)
}
// TestCommandAliases tests that command aliases work correctly
func TestCommandAliases(t *testing.T) {
tests := []struct {
command string
aliases []string
}{
{
command: "nodes",
aliases: []string{"node", "machine", "machines"},
},
{
command: "users",
aliases: []string{"user", "namespace", "namespaces"},
},
{
command: "generate",
aliases: []string{"gen"},
},
}
for _, tt := range tests {
t.Run(tt.command, func(t *testing.T) {
// Find the command by name
cmd, _, err := rootCmd.Find([]string{tt.command})
require.NoError(t, err)
// Check each alias
for _, alias := range tt.aliases {
aliasCmd, _, err := rootCmd.Find([]string{alias})
require.NoError(t, err)
assert.Equal(t, cmd, aliasCmd, "Alias %s should resolve to the same command as %s", alias, tt.command)
}
})
}
}
// TestDeprecationMessages tests that deprecation constants are defined
func TestDeprecationMessages(t *testing.T) {
assert.Equal(t, "use --user", deprecateNamespaceMessage)
}
// TestCommandFlagsExist tests that important flags exist on commands
func TestCommandFlagsExist(t *testing.T) {
// Test that list commands have user flag
listNodesCmd, _, err := rootCmd.Find([]string{"nodes", "list"})
require.NoError(t, err)
userFlag := listNodesCmd.Flags().Lookup("user")
assert.NotNil(t, userFlag)
assert.Equal(t, "u", userFlag.Shorthand)
// Test that delete commands have identifier flag
deleteNodeCmd, _, err := rootCmd.Find([]string{"nodes", "delete"})
require.NoError(t, err)
identifierFlag := deleteNodeCmd.Flags().Lookup("identifier")
assert.NotNil(t, identifierFlag)
assert.Equal(t, "i", identifierFlag.Shorthand)
// Test that commands have force flag available (inherited from root)
forceFlag := deleteNodeCmd.InheritedFlags().Lookup("force")
assert.NotNil(t, forceFlag)
}
// TestCommandRunFunctions tests that commands have run functions defined
func TestCommandRunFunctions(t *testing.T) {
commandsWithRun := []string{
"version",
"generate private-key",
}
for _, cmdPath := range commandsWithRun {
t.Run(cmdPath, func(t *testing.T) {
cmd, _, err := rootCmd.Find(strings.Split(cmdPath, " "))
require.NoError(t, err)
assert.NotNil(t, cmd.Run, "Command %s should have a Run function", cmdPath)
})
}
}

View File

@ -1,134 +0,0 @@
package cli
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDumpConfigCommand(t *testing.T) {
// Test the dump config command structure
assert.NotNil(t, dumpConfigCmd)
assert.Equal(t, "dumpConfig", dumpConfigCmd.Use)
assert.Equal(t, "dump current config to /etc/headscale/config.dump.yaml, integration test only", dumpConfigCmd.Short)
assert.True(t, dumpConfigCmd.Hidden, "dumpConfig should be hidden")
// Test that command has proper setup
assert.NotNil(t, dumpConfigCmd.Run, "dumpConfig should have a Run function")
assert.NotNil(t, dumpConfigCmd.Args, "dumpConfig should have Args validation")
}
func TestDumpConfigCommandStructure(t *testing.T) {
// Validate command structure and help text
ValidateCommandStructure(t, dumpConfigCmd, "dumpConfig", "dump current config to /etc/headscale/config.dump.yaml, integration test only")
ValidateCommandHelp(t, dumpConfigCmd)
}
func TestDumpConfigCommandIntegration(t *testing.T) {
// Test that dumpConfig command is properly integrated into root command
found := false
for _, cmd := range rootCmd.Commands() {
if cmd.Use == "dumpConfig" {
found = true
break
}
}
assert.True(t, found, "dumpConfig command should be added to root command")
}
func TestDumpConfigCommandFlags(t *testing.T) {
// Verify that dumpConfig doesn't have any flags (it's a simple command)
flags := dumpConfigCmd.Flags()
assert.Equal(t, 0, flags.NFlag(), "dumpConfig should not have any flags")
}
func TestDumpConfigCommandArgs(t *testing.T) {
// Test Args validation - should accept no arguments
if dumpConfigCmd.Args != nil {
err := dumpConfigCmd.Args(dumpConfigCmd, []string{})
assert.NoError(t, err, "dumpConfig should accept no arguments")
err = dumpConfigCmd.Args(dumpConfigCmd, []string{"extra"})
// Note: The current implementation accepts any arguments, but ideally should reject them
// This test documents the current behavior
assert.NoError(t, err, "Current implementation accepts extra arguments")
}
}
func TestDumpConfigCommandProperties(t *testing.T) {
// Test command properties
assert.True(t, dumpConfigCmd.Hidden, "dumpConfig should be hidden from help")
assert.False(t, dumpConfigCmd.DisableFlagsInUseLine, "dumpConfig should allow flags in usage line")
assert.Empty(t, dumpConfigCmd.Aliases, "dumpConfig should not have aliases")
// Test that it's not a group command
assert.False(t, dumpConfigCmd.HasSubCommands(), "dumpConfig should not have subcommands")
}
func TestDumpConfigCommandDocumentation(t *testing.T) {
// Test command documentation completeness
assert.NotEmpty(t, dumpConfigCmd.Use, "dumpConfig should have Use field")
assert.NotEmpty(t, dumpConfigCmd.Short, "dumpConfig should have Short description")
assert.Empty(t, dumpConfigCmd.Long, "dumpConfig does not need Long description for simple command")
assert.Empty(t, dumpConfigCmd.Example, "dumpConfig does not need examples")
// Test that Short description is descriptive
assert.Contains(t, dumpConfigCmd.Short, "config", "Short description should mention config")
assert.Contains(t, dumpConfigCmd.Short, "integration test", "Short description should mention this is for integration tests")
}
func TestDumpConfigCommandUsage(t *testing.T) {
// Test that usage line is properly formatted
usageLine := dumpConfigCmd.UseLine()
assert.Contains(t, usageLine, "dumpConfig", "Usage line should contain command name")
// Test help output
helpOutput := dumpConfigCmd.Long
if helpOutput == "" {
helpOutput = dumpConfigCmd.Short
}
assert.NotEmpty(t, helpOutput, "Command should have help text")
}
// Functional test that would verify the actual behavior
// Note: This test is commented out because it would try to write to /etc/headscale/
// which may not be accessible in test environments
/*
func TestDumpConfigCommandExecution(t *testing.T) {
// This would test actual execution but requires proper setup
// and writable /etc/headscale/ directory
// Mock test approach:
oldConfigPath := "/etc/headscale/config.dump.yaml"
// In a real test, you would:
// 1. Set up a temporary directory
// 2. Mock viper.WriteConfigAs to use the temp directory
// 3. Execute the command
// 4. Verify the file was created
// 5. Clean up
t.Skip("Functional test requires filesystem access and mocking")
}
*/
func TestDumpConfigCommandSafety(t *testing.T) {
// Test that the command is designed safely
assert.True(t, dumpConfigCmd.Hidden, "dumpConfig should be hidden to prevent accidental use")
// Verify it has integration test warning in description
assert.Contains(t, dumpConfigCmd.Short, "integration test only",
"Should warn that this is for integration tests only")
}
func TestDumpConfigCommandCompliance(t *testing.T) {
// Test compliance with CLI patterns
require.NotNil(t, dumpConfigCmd.Run, "Command must have Run function")
// Test that command follows naming conventions
assert.Equal(t, "dumpConfig", dumpConfigCmd.Use, "Command should use camelCase naming")
// Test that it's properly categorized
assert.True(t, dumpConfigCmd.Hidden, "Utility commands should be hidden")
}

View File

@ -1,163 +0,0 @@
package cli
// This file demonstrates how the new flag infrastructure simplifies command creation
// It shows a before/after comparison for the registerNodeCmd
import (
"fmt"
"log"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
)
// BEFORE: Current registerNodeCmd with lots of duplication (from nodes.go:114-158)
var originalRegisterNodeCmd = &cobra.Command{
Use: "register",
Short: "Registers a node to your network",
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output") // Manual flag parsing
user, err := cmd.Flags().GetString("user") // Manual flag parsing with error handling
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
}
ctx, client, conn, cancel := newHeadscaleCLIWithConfig() // gRPC client setup
defer cancel()
defer conn.Close()
registrationID, err := cmd.Flags().GetString("key") // More manual flag parsing
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Error getting node key from flag: %s", err),
output,
)
}
request := &v1.RegisterNodeRequest{
Key: registrationID,
User: user,
}
response, err := client.RegisterNode(ctx, request) // gRPC call with manual error handling
if err != nil {
ErrorOutput(
err,
fmt.Sprintf(
"Cannot register node: %s\n",
status.Convert(err).Message(),
),
output,
)
}
SuccessOutput(
response.GetNode(),
fmt.Sprintf("Node %s registered", response.GetNode().GetGivenName()), output)
},
}
// AFTER: Refactored registerNodeCmd using new flag infrastructure
var refactoredRegisterNodeCmd = &cobra.Command{
Use: "register",
Short: "Registers a node to your network",
Run: func(cmd *cobra.Command, args []string) {
// Clean flag parsing with standardized error handling
output := GetOutputFormat(cmd)
user, err := GetUserWithDeprecatedNamespace(cmd) // Handles both --user and deprecated --namespace
if err != nil {
ErrorOutput(err, "Error getting user", output)
return
}
key, err := GetKey(cmd)
if err != nil {
ErrorOutput(err, "Error getting key", output)
return
}
// gRPC client setup (will be further simplified in Checkpoint 2)
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
request := &v1.RegisterNodeRequest{
Key: key,
User: user,
}
response, err := client.RegisterNode(ctx, request)
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Cannot register node: %s", status.Convert(err).Message()),
output,
)
return
}
SuccessOutput(
response.GetNode(),
fmt.Sprintf("Node %s registered", response.GetNode().GetGivenName()),
output)
},
}
// BEFORE: Current flag setup in init() function (from nodes.go:36-52)
func originalFlagSetup() {
registerNodeCmd.Flags().StringP("user", "u", "", "User")
registerNodeCmd.Flags().StringP("namespace", "n", "", "User")
registerNodeNamespaceFlag := registerNodeCmd.Flags().Lookup("namespace")
registerNodeNamespaceFlag.Deprecated = deprecateNamespaceMessage
registerNodeNamespaceFlag.Hidden = true
err := registerNodeCmd.MarkFlagRequired("user")
if err != nil {
log.Fatal(err.Error())
}
registerNodeCmd.Flags().StringP("key", "k", "", "Key")
err = registerNodeCmd.MarkFlagRequired("key")
if err != nil {
log.Fatal(err.Error())
}
}
// AFTER: Simplified flag setup using new infrastructure
func refactoredFlagSetup() {
AddRequiredUserFlag(refactoredRegisterNodeCmd)
AddDeprecatedNamespaceFlag(refactoredRegisterNodeCmd)
AddRequiredKeyFlag(refactoredRegisterNodeCmd)
}
/*
IMPROVEMENT SUMMARY:
1. FLAG PARSING REDUCTION:
Before: 6 lines of manual flag parsing + error handling
After: 3 lines with standardized helpers
2. ERROR HANDLING CONSISTENCY:
Before: Inconsistent error message formatting
After: Standardized error handling with consistent format
3. DEPRECATED FLAG SUPPORT:
Before: 4 lines of deprecation setup
After: 1 line with GetUserWithDeprecatedNamespace()
4. FLAG REGISTRATION:
Before: 12 lines in init() with manual error handling
After: 3 lines with standardized helpers
5. CODE READABILITY:
Before: Business logic mixed with flag parsing boilerplate
After: Clear separation, focus on business logic
6. MAINTAINABILITY:
Before: Changes to flag patterns require updating every command
After: Changes can be made in one place (flags.go)
TOTAL REDUCTION: ~40% fewer lines, much cleaner code
*/

View File

@ -28,6 +28,17 @@ func AddRequiredIdentifierFlag(cmd *cobra.Command, name string, help string) {
}
}
// AddColumnsFlag adds a columns flag for table output customization
func AddColumnsFlag(cmd *cobra.Command, defaultColumns string) {
cmd.Flags().String("columns", defaultColumns, "Comma-separated list of columns to display")
}
// GetColumnsFlag gets the columns flag value
func GetColumnsFlag(cmd *cobra.Command) string {
columns, _ := cmd.Flags().GetString("columns")
return columns
}
// AddUserFlag adds a user flag (string for username or email)
func AddUserFlag(cmd *cobra.Command) {
cmd.Flags().StringP("user", "u", "", "User")

View File

@ -1,462 +0,0 @@
package cli
import (
"testing"
"time"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAddIdentifierFlag(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
AddIdentifierFlag(cmd, "identifier", "Test identifier")
flag := cmd.Flags().Lookup("identifier")
require.NotNil(t, flag)
assert.Equal(t, "i", flag.Shorthand)
assert.Equal(t, "Test identifier", flag.Usage)
assert.Equal(t, "0", flag.DefValue)
}
func TestAddRequiredIdentifierFlag(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
AddRequiredIdentifierFlag(cmd, "identifier", "Test identifier")
flag := cmd.Flags().Lookup("identifier")
require.NotNil(t, flag)
assert.Equal(t, "i", flag.Shorthand)
// Test that it's marked as required (cobra doesn't expose this directly)
// We test by checking if validation fails when not set
err := cmd.ValidateRequiredFlags()
assert.Error(t, err)
}
func TestAddUserFlag(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
AddUserFlag(cmd)
flag := cmd.Flags().Lookup("user")
require.NotNil(t, flag)
assert.Equal(t, "u", flag.Shorthand)
assert.Equal(t, "User", flag.Usage)
}
func TestAddOutputFlag(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
AddOutputFlag(cmd)
flag := cmd.Flags().Lookup("output")
require.NotNil(t, flag)
assert.Equal(t, "o", flag.Shorthand)
assert.Contains(t, flag.Usage, "Output format")
}
func TestAddForceFlag(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
AddForceFlag(cmd)
flag := cmd.Flags().Lookup("force")
require.NotNil(t, flag)
assert.Equal(t, "false", flag.DefValue)
}
func TestAddExpirationFlag(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
AddExpirationFlag(cmd, "24h")
flag := cmd.Flags().Lookup("expiration")
require.NotNil(t, flag)
assert.Equal(t, "e", flag.Shorthand)
assert.Equal(t, "24h", flag.DefValue)
}
func TestAddDeprecatedNamespaceFlag(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
AddDeprecatedNamespaceFlag(cmd)
flag := cmd.Flags().Lookup("namespace")
require.NotNil(t, flag)
assert.Equal(t, "n", flag.Shorthand)
assert.True(t, flag.Hidden)
assert.Equal(t, deprecateNamespaceMessage, flag.Deprecated)
}
func TestGetIdentifier(t *testing.T) {
tests := []struct {
name string
flagValue string
expectedVal uint64
expectError bool
}{
{
name: "valid identifier",
flagValue: "123",
expectedVal: 123,
expectError: false,
},
{
name: "zero identifier",
flagValue: "0",
expectedVal: 0,
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
AddIdentifierFlag(cmd, "identifier", "Test")
// Set flag value
err := cmd.Flags().Set("identifier", tt.flagValue)
require.NoError(t, err)
// Test getter
val, err := GetIdentifier(cmd, "identifier")
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expectedVal, val)
}
})
}
}
func TestGetUser(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
AddUserFlag(cmd)
// Test default value
user, err := GetUser(cmd)
assert.NoError(t, err)
assert.Equal(t, "", user)
// Test set value
err = cmd.Flags().Set("user", "testuser")
require.NoError(t, err)
user, err = GetUser(cmd)
assert.NoError(t, err)
assert.Equal(t, "testuser", user)
}
func TestGetOutputFormat(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
AddOutputFlag(cmd)
// Test default value
output := GetOutputFormat(cmd)
assert.Equal(t, "", output)
// Test set value
err := cmd.Flags().Set("output", "json")
require.NoError(t, err)
output = GetOutputFormat(cmd)
assert.Equal(t, "json", output)
}
func TestGetForce(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
AddForceFlag(cmd)
// Test default value
force := GetForce(cmd)
assert.False(t, force)
// Test set value
err := cmd.Flags().Set("force", "true")
require.NoError(t, err)
force = GetForce(cmd)
assert.True(t, force)
}
func TestGetExpiration(t *testing.T) {
tests := []struct {
name string
flagValue string
expected time.Duration
expectError bool
}{
{
name: "valid duration",
flagValue: "24h",
expected: 24 * time.Hour,
expectError: false,
},
{
name: "empty duration",
flagValue: "",
expected: 0,
expectError: false,
},
{
name: "invalid duration",
flagValue: "invalid",
expected: 0,
expectError: true,
},
{
name: "multiple units",
flagValue: "1h30m",
expected: time.Hour + 30*time.Minute,
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
AddExpirationFlag(cmd, "")
if tt.flagValue != "" {
err := cmd.Flags().Set("expiration", tt.flagValue)
require.NoError(t, err)
}
duration, err := GetExpiration(cmd)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expected, duration)
}
})
}
}
func TestValidateRequiredFlags(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
AddUserFlag(cmd)
AddIdentifierFlag(cmd, "identifier", "Test")
// Test when no flags are set
err := ValidateRequiredFlags(cmd, "user", "identifier")
assert.Error(t, err)
assert.Contains(t, err.Error(), "required flag user not set")
// Set one flag
err = cmd.Flags().Set("user", "testuser")
require.NoError(t, err)
err = ValidateRequiredFlags(cmd, "user", "identifier")
assert.Error(t, err)
assert.Contains(t, err.Error(), "required flag identifier not set")
// Set both flags
err = cmd.Flags().Set("identifier", "123")
require.NoError(t, err)
err = ValidateRequiredFlags(cmd, "user", "identifier")
assert.NoError(t, err)
}
func TestValidateExclusiveFlags(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
cmd.Flags().StringP("name", "n", "", "Name")
AddIdentifierFlag(cmd, "identifier", "Test")
// Test when no flags are set (should pass)
err := ValidateExclusiveFlags(cmd, "name", "identifier")
assert.NoError(t, err)
// Test when one flag is set (should pass)
err = cmd.Flags().Set("name", "testname")
require.NoError(t, err)
err = ValidateExclusiveFlags(cmd, "name", "identifier")
assert.NoError(t, err)
// Test when both flags are set (should fail)
err = cmd.Flags().Set("identifier", "123")
require.NoError(t, err)
err = ValidateExclusiveFlags(cmd, "name", "identifier")
assert.Error(t, err)
assert.Contains(t, err.Error(), "only one of the following flags can be set")
}
func TestValidateIdentifierFlag(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
AddIdentifierFlag(cmd, "identifier", "Test")
// Test with zero identifier (should fail)
err := cmd.Flags().Set("identifier", "0")
require.NoError(t, err)
err = ValidateIdentifierFlag(cmd, "identifier")
assert.Error(t, err)
assert.Contains(t, err.Error(), "must be greater than 0")
// Test with valid identifier (should pass)
err = cmd.Flags().Set("identifier", "123")
require.NoError(t, err)
err = ValidateIdentifierFlag(cmd, "identifier")
assert.NoError(t, err)
}
func TestValidateNonEmptyStringFlag(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
AddUserFlag(cmd)
// Test with empty string (should fail)
err := ValidateNonEmptyStringFlag(cmd, "user")
assert.Error(t, err)
assert.Contains(t, err.Error(), "cannot be empty")
// Test with non-empty string (should pass)
err = cmd.Flags().Set("user", "testuser")
require.NoError(t, err)
err = ValidateNonEmptyStringFlag(cmd, "user")
assert.NoError(t, err)
}
func TestHandleDeprecatedNamespaceFlag(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
AddUserFlag(cmd)
AddDeprecatedNamespaceFlag(cmd)
// Set namespace flag
err := cmd.Flags().Set("namespace", "testnamespace")
require.NoError(t, err)
HandleDeprecatedNamespaceFlag(cmd)
// User flag should now have the namespace value
user, err := GetUser(cmd)
assert.NoError(t, err)
assert.Equal(t, "testnamespace", user)
}
func TestGetUserWithDeprecatedNamespace(t *testing.T) {
tests := []struct {
name string
userValue string
namespaceValue string
expected string
}{
{
name: "user flag set",
userValue: "testuser",
namespaceValue: "testnamespace",
expected: "testuser",
},
{
name: "only namespace flag set",
userValue: "",
namespaceValue: "testnamespace",
expected: "testnamespace",
},
{
name: "no flags set",
userValue: "",
namespaceValue: "",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
AddUserFlag(cmd)
AddDeprecatedNamespaceFlag(cmd)
if tt.userValue != "" {
err := cmd.Flags().Set("user", tt.userValue)
require.NoError(t, err)
}
if tt.namespaceValue != "" {
err := cmd.Flags().Set("namespace", tt.namespaceValue)
require.NoError(t, err)
}
result, err := GetUserWithDeprecatedNamespace(cmd)
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
})
}
}
func TestMultipleFlagTypes(t *testing.T) {
// Test that multiple different flag types can be used together
cmd := &cobra.Command{Use: "test"}
AddUserFlag(cmd)
AddIdentifierFlag(cmd, "identifier", "Test")
AddOutputFlag(cmd)
AddForceFlag(cmd)
AddTagsFlag(cmd)
AddPrefixFlag(cmd)
// Set various flags
err := cmd.Flags().Set("user", "testuser")
require.NoError(t, err)
err = cmd.Flags().Set("identifier", "123")
require.NoError(t, err)
err = cmd.Flags().Set("output", "json")
require.NoError(t, err)
err = cmd.Flags().Set("force", "true")
require.NoError(t, err)
err = cmd.Flags().Set("tags", "true")
require.NoError(t, err)
err = cmd.Flags().Set("prefix", "testprefix")
require.NoError(t, err)
// Test all getters
user, err := GetUser(cmd)
assert.NoError(t, err)
assert.Equal(t, "testuser", user)
identifier, err := GetIdentifier(cmd, "identifier")
assert.NoError(t, err)
assert.Equal(t, uint64(123), identifier)
output := GetOutputFormat(cmd)
assert.Equal(t, "json", output)
force := GetForce(cmd)
assert.True(t, force)
tags := GetTags(cmd)
assert.True(t, tags)
prefix, err := GetPrefix(cmd)
assert.NoError(t, err)
assert.Equal(t, "testprefix", prefix)
}
func TestFlagErrorHandling(t *testing.T) {
// Test error handling when flags don't exist
cmd := &cobra.Command{Use: "test"}
// Test getting non-existent flag
_, err := GetIdentifier(cmd, "nonexistent")
assert.Error(t, err)
// Test validation of non-existent flag
err = ValidateRequiredFlags(cmd, "nonexistent")
assert.Error(t, err)
assert.Contains(t, err.Error(), "flag nonexistent not found")
}

View File

@ -1,313 +0,0 @@
package cli
import (
"testing"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
)
// TestCLIInfrastructureIntegration tests that all infrastructure components work together
func TestCLIInfrastructureIntegration(t *testing.T) {
t.Run("testing infrastructure", func(t *testing.T) {
// Test mock client creation using the helper function
mockClient := NewMockHeadscaleServiceClient()
assert.NotNil(t, mockClient)
assert.NotNil(t, mockClient.CallCount)
// Test that mock client tracks calls
_, err := mockClient.ListUsers(nil, &v1.ListUsersRequest{})
assert.NoError(t, err)
assert.Equal(t, 1, mockClient.CallCount["ListUsers"])
})
t.Run("validation integration", func(t *testing.T) {
// Test that validation functions work correctly together
assert.NoError(t, ValidateEmail("test@example.com"))
assert.NoError(t, ValidateUserName("testuser"))
assert.NoError(t, ValidateNodeName("testnode"))
assert.NoError(t, ValidateCIDR("192.168.1.0/24"))
// Test validation of complex scenarios
tags := []string{"env:prod", "team:backend"}
assert.NoError(t, ValidateTagsFormat(tags))
routes := []string{"10.0.0.0/8", "172.16.0.0/12"}
assert.NoError(t, ValidateRoutesFormat(routes))
})
t.Run("flag infrastructure", func(t *testing.T) {
// Test that flag helpers work
cmd := &cobra.Command{Use: "test"}
AddIdentifierFlag(cmd, "id", "Test ID flag")
AddUserFlag(cmd)
AddOutputFlag(cmd)
AddForceFlag(cmd)
// Verify flags were added
assert.NotNil(t, cmd.Flags().Lookup("id"))
assert.NotNil(t, cmd.Flags().Lookup("user"))
assert.NotNil(t, cmd.Flags().Lookup("output"))
assert.NotNil(t, cmd.Flags().Lookup("force"))
// Test flag shortcuts
idFlag := cmd.Flags().Lookup("id")
assert.Equal(t, "i", idFlag.Shorthand)
userFlag := cmd.Flags().Lookup("user")
assert.Equal(t, "u", userFlag.Shorthand)
outputFlag := cmd.Flags().Lookup("output")
assert.Equal(t, "o", outputFlag.Shorthand)
forceFlag := cmd.Flags().Lookup("force")
assert.Equal(t, "", forceFlag.Shorthand, "Force flag doesn't have a shorthand")
})
t.Run("output infrastructure", func(t *testing.T) {
// Test output manager creation
cmd := &cobra.Command{Use: "test"}
om := NewOutputManager(cmd)
assert.NotNil(t, om)
// Test table renderer creation
tr := NewTableRenderer(om)
assert.NotNil(t, tr)
// Test table column addition
tr.AddColumn("Test Column", func(item interface{}) string {
return "test value"
})
assert.Equal(t, 1, len(tr.columns))
assert.Equal(t, "Test Column", tr.columns[0].Header)
})
t.Run("command patterns", func(t *testing.T) {
// Test that argument validators work correctly
validator := ValidateExactArgs(2, "test <arg1> <arg2>")
assert.NotNil(t, validator)
cmd := &cobra.Command{Use: "test"}
// Should accept exactly 2 arguments
err := validator(cmd, []string{"arg1", "arg2"})
assert.NoError(t, err)
// Should reject wrong number of arguments
err = validator(cmd, []string{"arg1"})
assert.Error(t, err)
err = validator(cmd, []string{"arg1", "arg2", "arg3"})
assert.Error(t, err)
})
}
// TestCLIInfrastructureConsistency tests that the infrastructure maintains consistency
func TestCLIInfrastructureConsistency(t *testing.T) {
t.Run("error message consistency", func(t *testing.T) {
// Test that validation errors have consistent formatting
emailErr := ValidateEmail("")
userErr := ValidateUserName("")
nodeErr := ValidateNodeName("")
// All should mention "cannot be empty"
assert.Contains(t, emailErr.Error(), "cannot be empty")
assert.Contains(t, userErr.Error(), "cannot be empty")
assert.Contains(t, nodeErr.Error(), "cannot be empty")
})
t.Run("flag naming consistency", func(t *testing.T) {
// Test that common flags use consistent shortcuts
cmd := &cobra.Command{Use: "test"}
AddUserFlag(cmd)
AddIdentifierFlag(cmd, "id", "ID flag")
AddOutputFlag(cmd)
AddForceFlag(cmd)
// Common shortcuts should be consistent
assert.Equal(t, "u", cmd.Flags().Lookup("user").Shorthand)
assert.Equal(t, "i", cmd.Flags().Lookup("id").Shorthand)
assert.Equal(t, "o", cmd.Flags().Lookup("output").Shorthand)
assert.Equal(t, "", cmd.Flags().Lookup("force").Shorthand)
})
t.Run("command structure consistency", func(t *testing.T) {
// Test that main commands follow consistent patterns
commands := []*cobra.Command{userCmd, nodeCmd, apiKeysCmd, preauthkeysCmd}
for _, cmd := range commands {
// All main commands should have subcommands
assert.True(t, cmd.HasSubCommands(), "Command %s should have subcommands", cmd.Use)
// All main commands should have short descriptions
assert.NotEmpty(t, cmd.Short, "Command %s should have short description", cmd.Use)
// All main commands should be properly integrated
found := false
for _, rootSubcmd := range rootCmd.Commands() {
if rootSubcmd == cmd {
found = true
break
}
}
assert.True(t, found, "Command %s should be added to root", cmd.Use)
}
})
}
// TestCLIInfrastructurePerformance tests that the infrastructure is performant
func TestCLIInfrastructurePerformance(t *testing.T) {
t.Run("validation performance", func(t *testing.T) {
// Test that validation functions are fast enough for CLI use
for i := 0; i < 1000; i++ {
ValidateEmail("test@example.com")
ValidateUserName("testuser")
ValidateNodeName("testnode")
ValidateCIDR("192.168.1.0/24")
}
// Test passes if it completes without timeout
})
t.Run("mock client performance", func(t *testing.T) {
// Test that mock client operations are fast
mockClient := NewMockHeadscaleServiceClient()
for i := 0; i < 1000; i++ {
mockClient.ListUsers(nil, &v1.ListUsersRequest{})
mockClient.ListNodes(nil, &v1.ListNodesRequest{})
}
// Verify call tracking works efficiently
assert.Equal(t, 1000, mockClient.CallCount["ListUsers"])
assert.Equal(t, 1000, mockClient.CallCount["ListNodes"])
})
}
// TestCLIInfrastructureEdgeCases tests edge cases and error conditions
func TestCLIInfrastructureEdgeCases(t *testing.T) {
t.Run("nil handling", func(t *testing.T) {
// Test that functions handle nil inputs gracefully
err := ValidateTagsFormat(nil)
assert.NoError(t, err, "Should handle nil tags list")
err = ValidateRoutesFormat(nil)
assert.NoError(t, err, "Should handle nil routes list")
})
t.Run("empty input handling", func(t *testing.T) {
// Test empty inputs
err := ValidateTagsFormat([]string{})
assert.NoError(t, err, "Should handle empty tags list")
err = ValidateRoutesFormat([]string{})
assert.NoError(t, err, "Should handle empty routes list")
})
t.Run("boundary conditions", func(t *testing.T) {
// Test boundary conditions for string length validation
err := ValidateStringLength("", "field", 0, 10)
assert.NoError(t, err, "Should handle minimum length 0")
err = ValidateStringLength("1234567890", "field", 0, 10)
assert.NoError(t, err, "Should handle exact maximum length")
err = ValidateStringLength("12345678901", "field", 0, 10)
assert.Error(t, err, "Should reject over maximum length")
})
}
// TestCLIInfrastructureDocumentation tests that infrastructure components are well documented
func TestCLIInfrastructureDocumentation(t *testing.T) {
t.Run("function documentation", func(t *testing.T) {
// This is a meta-test to ensure we maintain good documentation
// In a real scenario, you might parse Go source and check for comments
// For now, we test that key functions exist and have meaningful names
assert.NotNil(t, ValidateEmail, "ValidateEmail should exist")
assert.NotNil(t, ValidateUserName, "ValidateUserName should exist")
assert.NotNil(t, ValidateNodeName, "ValidateNodeName should exist")
assert.NotNil(t, NewOutputManager, "NewOutputManager should exist")
assert.NotNil(t, NewTableRenderer, "NewTableRenderer should exist")
})
t.Run("error message clarity", func(t *testing.T) {
// Test that error messages are helpful and include relevant information
err := ValidateEmail("invalid")
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid", "Error should include the invalid input")
err = ValidateUserName("user with spaces")
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid characters", "Error should explain the problem")
err = ValidateAPIKeyPrefix("ab")
assert.Error(t, err)
assert.Contains(t, err.Error(), "at least 4 characters", "Error should specify requirements")
})
}
// TestCLIInfrastructureBackwardsCompatibility tests that changes don't break existing functionality
func TestCLIInfrastructureBackwardsCompatibility(t *testing.T) {
t.Run("existing command structure", func(t *testing.T) {
// Test that existing commands still work as expected
assert.NotNil(t, userCmd, "User command should still exist")
assert.NotNil(t, nodeCmd, "Node command should still exist")
assert.NotNil(t, rootCmd, "Root command should still exist")
// Test that existing subcommands still exist
assert.True(t, userCmd.HasSubCommands(), "User command should have subcommands")
assert.True(t, nodeCmd.HasSubCommands(), "Node command should have subcommands")
})
t.Run("flag compatibility", func(t *testing.T) {
// Test that common flags still exist with expected shortcuts
commands := []*cobra.Command{listUsersCmd, listNodesCmd}
for _, cmd := range commands {
userFlag := cmd.Flags().Lookup("user")
if userFlag != nil {
assert.Equal(t, "u", userFlag.Shorthand, "User flag shortcut should be 'u'")
}
}
})
}
// TestCLIInfrastructureIntegrationWithExistingCode tests integration with existing codebase
func TestCLIInfrastructureIntegrationWithExistingCode(t *testing.T) {
t.Run("command registration", func(t *testing.T) {
// Test that new infrastructure doesn't interfere with existing command registration
initialCommandCount := len(rootCmd.Commands())
assert.Greater(t, initialCommandCount, 0, "Root command should have subcommands")
// Test that all expected commands are registered
expectedCommands := []string{"users", "nodes", "apikeys", "preauthkeys", "version", "generate"}
for _, expectedCmd := range expectedCommands {
found := false
for _, cmd := range rootCmd.Commands() {
if cmd.Use == expectedCmd || cmd.Name() == expectedCmd {
found = true
break
}
}
assert.True(t, found, "Expected command %s should be registered", expectedCmd)
}
})
t.Run("configuration compatibility", func(t *testing.T) {
// Test that new infrastructure works with existing configuration
// Test that output format detection works
cmd := &cobra.Command{Use: "test"}
format := GetOutputFormat(cmd)
assert.Equal(t, "", format, "Default output format should be empty string")
// Test that machine output detection works
hasMachine := HasMachineOutputFlag()
assert.False(t, hasMachine, "Should not detect machine output by default")
})
}

View File

@ -1,250 +0,0 @@
package cli
import (
"encoding/json"
"os"
"testing"
"time"
"github.com/oauth2-proxy/mockoidc"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMockOidcCommand(t *testing.T) {
// Test that the mockoidc command exists and is properly configured
assert.NotNil(t, mockOidcCmd)
assert.Equal(t, "mockoidc", mockOidcCmd.Use)
assert.Equal(t, "Runs a mock OIDC server for testing", mockOidcCmd.Short)
assert.Equal(t, "This internal command runs a OpenID Connect for testing purposes", mockOidcCmd.Long)
assert.NotNil(t, mockOidcCmd.Run)
}
func TestMockOidcCommandInRootCommand(t *testing.T) {
// Test that mockoidc is available as a subcommand of root
cmd, _, err := rootCmd.Find([]string{"mockoidc"})
require.NoError(t, err)
assert.Equal(t, "mockoidc", cmd.Name())
assert.Equal(t, mockOidcCmd, cmd)
}
func TestMockOidcErrorConstants(t *testing.T) {
// Test that error constants are defined properly
assert.Equal(t, Error("MOCKOIDC_CLIENT_ID not defined"), errMockOidcClientIDNotDefined)
assert.Equal(t, Error("MOCKOIDC_CLIENT_SECRET not defined"), errMockOidcClientSecretNotDefined)
assert.Equal(t, Error("MOCKOIDC_PORT not defined"), errMockOidcPortNotDefined)
}
func TestMockOidcConstants(t *testing.T) {
// Test that time constants are defined
assert.Equal(t, 60*time.Minute, refreshTTL)
assert.Equal(t, 2*time.Minute, accessTTL) // This is the default value
}
func TestMockOIDCValidation(t *testing.T) {
// Test the validation logic by testing the mockOIDC function directly
// Save original env vars
originalEnv := map[string]string{
"MOCKOIDC_CLIENT_ID": os.Getenv("MOCKOIDC_CLIENT_ID"),
"MOCKOIDC_CLIENT_SECRET": os.Getenv("MOCKOIDC_CLIENT_SECRET"),
"MOCKOIDC_ADDR": os.Getenv("MOCKOIDC_ADDR"),
"MOCKOIDC_PORT": os.Getenv("MOCKOIDC_PORT"),
"MOCKOIDC_USERS": os.Getenv("MOCKOIDC_USERS"),
"MOCKOIDC_ACCESS_TTL": os.Getenv("MOCKOIDC_ACCESS_TTL"),
}
// Clear all env vars
for key := range originalEnv {
os.Unsetenv(key)
}
// Restore env vars after test
defer func() {
for key, value := range originalEnv {
if value != "" {
os.Setenv(key, value)
} else {
os.Unsetenv(key)
}
}
}()
tests := []struct {
name string
setup func()
expectedErr error
}{
{
name: "missing client ID",
setup: func() {},
expectedErr: errMockOidcClientIDNotDefined,
},
{
name: "missing client secret",
setup: func() {
os.Setenv("MOCKOIDC_CLIENT_ID", "test-client")
},
expectedErr: errMockOidcClientSecretNotDefined,
},
{
name: "missing address",
setup: func() {
os.Setenv("MOCKOIDC_CLIENT_ID", "test-client")
os.Setenv("MOCKOIDC_CLIENT_SECRET", "test-secret")
},
expectedErr: errMockOidcPortNotDefined,
},
{
name: "missing port",
setup: func() {
os.Setenv("MOCKOIDC_CLIENT_ID", "test-client")
os.Setenv("MOCKOIDC_CLIENT_SECRET", "test-secret")
os.Setenv("MOCKOIDC_ADDR", "localhost")
},
expectedErr: errMockOidcPortNotDefined,
},
{
name: "missing users",
setup: func() {
os.Setenv("MOCKOIDC_CLIENT_ID", "test-client")
os.Setenv("MOCKOIDC_CLIENT_SECRET", "test-secret")
os.Setenv("MOCKOIDC_ADDR", "localhost")
os.Setenv("MOCKOIDC_PORT", "9000")
},
expectedErr: nil, // We'll check error message instead of type
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Clear env vars for this test
for key := range originalEnv {
os.Unsetenv(key)
}
tt.setup()
// Note: We can't actually run mockOIDC() because it would start a server
// and block forever. We're testing the validation part that happens early.
// In a real implementation, we would refactor to separate validation from execution.
err := mockOIDC()
require.Error(t, err)
if tt.expectedErr != nil {
assert.Equal(t, tt.expectedErr, err)
} else {
// For the "missing users" case, just check it's an error about users
assert.Contains(t, err.Error(), "MOCKOIDC_USERS not defined")
}
})
}
}
func TestMockOIDCAccessTTLParsing(t *testing.T) {
// Test that MOCKOIDC_ACCESS_TTL environment variable parsing works
originalAccessTTL := accessTTL
defer func() { accessTTL = originalAccessTTL }()
originalEnv := os.Getenv("MOCKOIDC_ACCESS_TTL")
defer func() {
if originalEnv != "" {
os.Setenv("MOCKOIDC_ACCESS_TTL", originalEnv)
} else {
os.Unsetenv("MOCKOIDC_ACCESS_TTL")
}
}()
// Test with valid duration
os.Setenv("MOCKOIDC_ACCESS_TTL", "5m")
// We can't easily test the parsing in isolation since it's embedded in mockOIDC()
// In a refactor, we'd extract this to a separate function
// For now, we test the concept by parsing manually
accessTTLOverride := os.Getenv("MOCKOIDC_ACCESS_TTL")
if accessTTLOverride != "" {
newTTL, err := time.ParseDuration(accessTTLOverride)
require.NoError(t, err)
assert.Equal(t, 5*time.Minute, newTTL)
}
}
func TestGetMockOIDC(t *testing.T) {
// Test the getMockOIDC function
users := []mockoidc.MockUser{
{
Subject: "user1",
Email: "user1@example.com",
Groups: []string{"users"},
},
{
Subject: "user2",
Email: "user2@example.com",
Groups: []string{"admins", "users"},
},
}
mock, err := getMockOIDC("test-client", "test-secret", users)
require.NoError(t, err)
assert.NotNil(t, mock)
// Verify configuration
assert.Equal(t, "test-client", mock.ClientID)
assert.Equal(t, "test-secret", mock.ClientSecret)
assert.Equal(t, accessTTL, mock.AccessTTL)
assert.Equal(t, refreshTTL, mock.RefreshTTL)
assert.NotNil(t, mock.Keypair)
assert.NotNil(t, mock.SessionStore)
assert.NotNil(t, mock.UserQueue)
assert.NotNil(t, mock.ErrorQueue)
// Verify supported code challenge methods
expectedMethods := []string{"plain", "S256"}
assert.Equal(t, expectedMethods, mock.CodeChallengeMethodsSupported)
}
func TestMockOIDCUserJsonParsing(t *testing.T) {
// Test that user JSON parsing works correctly
userStr := `[
{
"subject": "user1",
"email": "user1@example.com",
"groups": ["users"]
},
{
"subject": "user2",
"email": "user2@example.com",
"groups": ["admins", "users"]
}
]`
var users []mockoidc.MockUser
err := json.Unmarshal([]byte(userStr), &users)
require.NoError(t, err)
assert.Len(t, users, 2)
assert.Equal(t, "user1", users[0].Subject)
assert.Equal(t, "user1@example.com", users[0].Email)
assert.Equal(t, []string{"users"}, users[0].Groups)
assert.Equal(t, "user2", users[1].Subject)
assert.Equal(t, "user2@example.com", users[1].Email)
assert.Equal(t, []string{"admins", "users"}, users[1].Groups)
}
func TestMockOIDCInvalidUserJson(t *testing.T) {
// Test that invalid JSON returns an error
invalidUserStr := `[{"subject": "user1", "email": "user1@example.com", "groups": ["users"]` // Missing closing bracket
var users []mockoidc.MockUser
err := json.Unmarshal([]byte(invalidUserStr), &users)
require.Error(t, err)
}
// Note: We don't test the actual server startup because:
// 1. It would require available ports
// 2. It blocks forever (infinite loop waiting on channel)
// 3. It's integration testing rather than unit testing
//
// In a real refactor, we would:
// 1. Extract server configuration from server startup
// 2. Add context cancellation to allow graceful shutdown
// 3. Return the server instance for testing instead of blocking forever

View File

@ -1,486 +0,0 @@
package cli
import (
"fmt"
"testing"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNodeCommand(t *testing.T) {
// Test the main node command
assert.NotNil(t, nodeCmd)
assert.Equal(t, "nodes", nodeCmd.Use)
assert.Equal(t, "Manage the nodes of Headscale", nodeCmd.Short)
// Test aliases
expectedAliases := []string{"node", "machine", "machines", "m"}
assert.Equal(t, expectedAliases, nodeCmd.Aliases)
// Test that node command has subcommands
subcommands := nodeCmd.Commands()
assert.Greater(t, len(subcommands), 0, "Node command should have subcommands")
// Verify expected subcommands exist
subcommandNames := make([]string, len(subcommands))
for i, cmd := range subcommands {
subcommandNames[i] = cmd.Use
}
expectedSubcommands := []string{"list", "register", "delete", "expire", "rename", "move", "routes", "tags", "backfill-ips"}
for _, expected := range expectedSubcommands {
found := false
for _, actual := range subcommandNames {
if actual == expected ||
(expected == "routes" && actual == "list-routes") ||
(expected == "tags" && actual == "tag") ||
(expected == "backfill-ips" && actual == "backfill-node-ips") {
found = true
break
}
}
assert.True(t, found, "Expected subcommand related to '%s' not found", expected)
}
}
func TestRegisterNodeCommand(t *testing.T) {
assert.NotNil(t, registerNodeCmd)
assert.Equal(t, "register", registerNodeCmd.Use)
assert.Equal(t, "Register a node to your headscale instance", registerNodeCmd.Short)
assert.Equal(t, []string{"r"}, registerNodeCmd.Aliases)
// Test that Run function is set
assert.NotNil(t, registerNodeCmd.Run)
// Test required flags
flags := registerNodeCmd.Flags()
assert.NotNil(t, flags.Lookup("user"))
assert.NotNil(t, flags.Lookup("key"))
// Test flag shortcuts
userFlag := flags.Lookup("user")
assert.Equal(t, "u", userFlag.Shorthand)
keyFlag := flags.Lookup("key")
assert.Equal(t, "k", keyFlag.Shorthand)
// Test deprecated namespace flag
namespaceFlag := flags.Lookup("namespace")
assert.NotNil(t, namespaceFlag)
assert.True(t, namespaceFlag.Hidden)
assert.Equal(t, deprecateNamespaceMessage, namespaceFlag.Deprecated)
}
func TestListNodesCommand(t *testing.T) {
assert.NotNil(t, listNodesCmd)
assert.Equal(t, "list", listNodesCmd.Use)
assert.Equal(t, "List nodes", listNodesCmd.Short)
assert.Equal(t, []string{"ls", "show"}, listNodesCmd.Aliases)
// Test that Run function is set
assert.NotNil(t, listNodesCmd.Run)
// Test flags
flags := listNodesCmd.Flags()
assert.NotNil(t, flags.Lookup("user"))
assert.NotNil(t, flags.Lookup("tags"))
// Test flag shortcuts
userFlag := flags.Lookup("user")
assert.Equal(t, "u", userFlag.Shorthand)
tagsFlag := flags.Lookup("tags")
assert.Equal(t, "t", tagsFlag.Shorthand)
// Test deprecated namespace flag
namespaceFlag := flags.Lookup("namespace")
assert.NotNil(t, namespaceFlag)
assert.True(t, namespaceFlag.Hidden)
assert.Equal(t, deprecateNamespaceMessage, namespaceFlag.Deprecated)
}
func TestListNodeRoutesCommand(t *testing.T) {
assert.NotNil(t, listNodeRoutesCmd)
assert.Equal(t, "list-routes", listNodeRoutesCmd.Use)
assert.Equal(t, "List node routes", listNodeRoutesCmd.Short)
assert.Equal(t, []string{"routes"}, listNodeRoutesCmd.Aliases)
// Test that Run function is set
assert.NotNil(t, listNodeRoutesCmd.Run)
// Test flags
flags := listNodeRoutesCmd.Flags()
assert.NotNil(t, flags.Lookup("identifier"))
// Test flag shortcuts
identifierFlag := flags.Lookup("identifier")
assert.Equal(t, "i", identifierFlag.Shorthand)
}
func TestExpireNodeCommand(t *testing.T) {
assert.NotNil(t, expireNodeCmd)
assert.Equal(t, "expire", expireNodeCmd.Use)
assert.Equal(t, "Expire (log out) a node", expireNodeCmd.Short)
assert.Equal(t, []string{"logout", "exp", "e"}, expireNodeCmd.Aliases)
// Test that Run function is set
assert.NotNil(t, expireNodeCmd.Run)
// Test that Args validation function is set
assert.NotNil(t, expireNodeCmd.Args)
}
func TestRenameNodeCommand(t *testing.T) {
assert.NotNil(t, renameNodeCmd)
assert.Equal(t, "rename", renameNodeCmd.Use)
assert.Equal(t, "Rename a node", renameNodeCmd.Short)
assert.Equal(t, []string{"mv"}, renameNodeCmd.Aliases)
// Test that Run function is set
assert.NotNil(t, renameNodeCmd.Run)
// Test that Args validation function is set
assert.NotNil(t, renameNodeCmd.Args)
}
func TestDeleteNodeCommand(t *testing.T) {
assert.NotNil(t, deleteNodeCmd)
assert.Equal(t, "delete", deleteNodeCmd.Use)
assert.Equal(t, "Delete a node", deleteNodeCmd.Short)
assert.Equal(t, []string{"remove", "rm"}, deleteNodeCmd.Aliases)
// Test that Run function is set
assert.NotNil(t, deleteNodeCmd.Run)
// Test that Args validation function is set
assert.NotNil(t, deleteNodeCmd.Args)
}
func TestMoveNodeCommand(t *testing.T) {
assert.NotNil(t, moveNodeCmd)
assert.Equal(t, "move", moveNodeCmd.Use)
assert.Equal(t, "Move node to another user", moveNodeCmd.Short)
// Test that Run function is set
assert.NotNil(t, moveNodeCmd.Run)
// Test that Args validation function is set
assert.NotNil(t, moveNodeCmd.Args)
}
func TestBackfillNodeIPsCommand(t *testing.T) {
assert.NotNil(t, backfillNodeIPsCmd)
assert.Equal(t, "backfill-node-ips", backfillNodeIPsCmd.Use)
assert.Equal(t, "Backfill the IPs of all the nodes in case you have to restore the database from a backup", backfillNodeIPsCmd.Short)
// Test that Run function is set
assert.NotNil(t, backfillNodeIPsCmd.Run)
// Test flags
flags := backfillNodeIPsCmd.Flags()
assert.NotNil(t, flags.Lookup("confirm"))
}
func TestTagCommand(t *testing.T) {
assert.NotNil(t, tagCmd)
assert.Equal(t, "tag", tagCmd.Use)
assert.Equal(t, "Manage the tags of Headscale", tagCmd.Short)
// Test that tag command has subcommands
subcommands := tagCmd.Commands()
assert.Greater(t, len(subcommands), 0, "Tag command should have subcommands")
}
func TestApproveRoutesCommand(t *testing.T) {
assert.NotNil(t, approveRoutesCmd)
assert.Equal(t, "approve-routes", approveRoutesCmd.Use)
assert.Equal(t, "Approve subnets advertised by a node", approveRoutesCmd.Short)
// Test that Run function is set
assert.NotNil(t, approveRoutesCmd.Run)
// Test that Args validation function is set
assert.NotNil(t, approveRoutesCmd.Args)
}
func TestNodeCommandFlags(t *testing.T) {
// Test register node command flags
ValidateCommandFlags(t, registerNodeCmd, []string{"user", "key", "namespace"})
// Test list nodes command flags
ValidateCommandFlags(t, listNodesCmd, []string{"user", "tags", "namespace"})
// Test list node routes command flags
ValidateCommandFlags(t, listNodeRoutesCmd, []string{"identifier"})
// Test backfill command flags
ValidateCommandFlags(t, backfillNodeIPsCmd, []string{"confirm"})
}
func TestNodeCommandIntegration(t *testing.T) {
// Test that node command is properly integrated into root command
found := false
for _, cmd := range rootCmd.Commands() {
if cmd.Use == "nodes" {
found = true
break
}
}
assert.True(t, found, "Node command should be added to root command")
}
func TestNodeSubcommandIntegration(t *testing.T) {
// Test that key subcommands are properly added to node command
subcommands := nodeCmd.Commands()
expectedCommands := map[string]bool{
"list": false,
"register": false,
"list-routes": false,
"expire": false,
"rename": false,
"delete": false,
"move": false,
"backfill-node-ips": false,
"tag": false,
"approve-routes": false,
}
for _, subcmd := range subcommands {
if _, exists := expectedCommands[subcmd.Use]; exists {
expectedCommands[subcmd.Use] = true
}
}
for cmdName, found := range expectedCommands {
assert.True(t, found, "Subcommand '%s' should be added to node command", cmdName)
}
}
func TestNodeCommandAliases(t *testing.T) {
// Test that all aliases are properly set
testCases := []struct {
command *cobra.Command
expectedAliases []string
}{
{
command: nodeCmd,
expectedAliases: []string{"node", "machine", "machines", "m"},
},
{
command: registerNodeCmd,
expectedAliases: []string{"r"},
},
{
command: listNodesCmd,
expectedAliases: []string{"ls", "show"},
},
{
command: listNodeRoutesCmd,
expectedAliases: []string{"routes"},
},
{
command: expireNodeCmd,
expectedAliases: []string{"logout", "exp", "e"},
},
{
command: renameNodeCmd,
expectedAliases: []string{"mv"},
},
{
command: deleteNodeCmd,
expectedAliases: []string{"remove", "rm"},
},
}
for _, tc := range testCases {
t.Run(tc.command.Use, func(t *testing.T) {
assert.Equal(t, tc.expectedAliases, tc.command.Aliases)
})
}
}
func TestNodeCommandDeprecatedFlags(t *testing.T) {
// Test deprecated namespace flags
commands := []*cobra.Command{registerNodeCmd, listNodesCmd}
for _, cmd := range commands {
t.Run(cmd.Use+"_namespace_deprecated", func(t *testing.T) {
namespaceFlag := cmd.Flags().Lookup("namespace")
require.NotNil(t, namespaceFlag, "Command %s should have deprecated namespace flag", cmd.Use)
assert.True(t, namespaceFlag.Hidden, "Namespace flag should be hidden")
assert.Equal(t, deprecateNamespaceMessage, namespaceFlag.Deprecated)
})
}
}
func TestNodeCommandRequiredFlags(t *testing.T) {
// Test that register command has required flags
flags := registerNodeCmd.Flags()
userFlag := flags.Lookup("user")
require.NotNil(t, userFlag)
keyFlag := flags.Lookup("key")
require.NotNil(t, keyFlag)
// Check if flags have required annotation (set by MarkFlagRequired)
checkRequired := func(flag *pflag.Flag, flagName string) {
if flag.Annotations != nil {
_, hasRequired := flag.Annotations[cobra.BashCompOneRequiredFlag]
assert.True(t, hasRequired, "%s flag should be marked as required", flagName)
}
}
checkRequired(userFlag, "user")
checkRequired(keyFlag, "key")
}
func TestNodeCommandsHaveRunFunctions(t *testing.T) {
// All node commands should have run functions
commands := []*cobra.Command{
registerNodeCmd,
listNodesCmd,
listNodeRoutesCmd,
expireNodeCmd,
renameNodeCmd,
deleteNodeCmd,
moveNodeCmd,
backfillNodeIPsCmd,
approveRoutesCmd,
}
for _, cmd := range commands {
t.Run(cmd.Use, func(t *testing.T) {
assert.NotNil(t, cmd.Run, "Command %s should have a Run function", cmd.Use)
})
}
}
func TestNodeCommandArgsValidation(t *testing.T) {
// Commands that require arguments should have Args validation
commandsWithArgs := []*cobra.Command{
expireNodeCmd,
renameNodeCmd,
deleteNodeCmd,
moveNodeCmd,
approveRoutesCmd,
}
for _, cmd := range commandsWithArgs {
t.Run(cmd.Use+"_has_args_validation", func(t *testing.T) {
assert.NotNil(t, cmd.Args, "Command %s should have Args validation function", cmd.Use)
})
}
}
func TestNodeCommandCompleteness(t *testing.T) {
// Test that node command covers expected node operations
subcommands := nodeCmd.Commands()
operations := map[string]bool{
"create": false, // register command
"read": false, // list command
"update": false, // rename, move, expire commands
"delete": false, // delete command
"routes": false, // route-related commands
"tags": false, // tag-related commands
"backfill": false, // maintenance commands
}
for _, subcmd := range subcommands {
switch {
case subcmd.Use == "register":
operations["create"] = true
case subcmd.Use == "list":
operations["read"] = true
case subcmd.Use == "rename" || subcmd.Use == "move" || subcmd.Use == "expire":
operations["update"] = true
case subcmd.Use == "delete":
operations["delete"] = true
case subcmd.Use == "list-routes" || subcmd.Use == "approve-routes":
operations["routes"] = true
case subcmd.Use == "tag":
operations["tags"] = true
case subcmd.Use == "backfill-node-ips":
operations["backfill"] = true
}
}
for op, found := range operations {
assert.True(t, found, "Node command should support %s operation", op)
}
}
func TestNodeCommandConsistency(t *testing.T) {
// Test that node commands follow consistent patterns
// Commands that modify nodes should have meaningful aliases
modifyCommands := map[*cobra.Command]string{
expireNodeCmd: "logout", // should have logout alias
renameNodeCmd: "mv", // should have mv alias
deleteNodeCmd: "rm", // should have rm alias
}
for cmd, expectedAlias := range modifyCommands {
t.Run(cmd.Use+"_has_"+expectedAlias+"_alias", func(t *testing.T) {
found := false
for _, alias := range cmd.Aliases {
if alias == expectedAlias {
found = true
break
}
}
assert.True(t, found, "Command %s should have %s alias", cmd.Use, expectedAlias)
})
}
}
func TestNodeCommandDocumentation(t *testing.T) {
// Test that important commands have proper documentation
commands := []*cobra.Command{
nodeCmd,
registerNodeCmd,
listNodesCmd,
deleteNodeCmd,
backfillNodeIPsCmd,
}
for _, cmd := range commands {
t.Run(cmd.Use+"_has_documentation", func(t *testing.T) {
assert.NotEmpty(t, cmd.Short, "Command %s should have Short description", cmd.Use)
// Long description is optional but recommended for complex commands
if cmd.Use == "backfill-node-ips" {
assert.NotEmpty(t, cmd.Long, "Complex command %s should have Long description", cmd.Use)
}
})
}
}
func TestNodeFlagShortcuts(t *testing.T) {
// Test that flag shortcuts are consistently assigned
flagTests := []struct {
command *cobra.Command
flagName string
shortcut string
}{
{registerNodeCmd, "user", "u"},
{registerNodeCmd, "key", "k"},
{listNodesCmd, "user", "u"},
{listNodesCmd, "tags", "t"},
{listNodeRoutesCmd, "identifier", "i"},
}
for _, test := range flagTests {
t.Run(fmt.Sprintf("%s_%s_shortcut", test.command.Use, test.flagName), func(t *testing.T) {
flag := test.command.Flags().Lookup(test.flagName)
require.NotNil(t, flag, "Flag %s should exist on command %s", test.flagName, test.command.Use)
assert.Equal(t, test.shortcut, flag.Shorthand, "Flag %s should have shortcut %s", test.flagName, test.shortcut)
})
}
}

View File

@ -2,6 +2,7 @@ package cli
import (
"fmt"
"strings"
"time"
"github.com/pterm/pterm"
@ -46,6 +47,7 @@ func (om *OutputManager) HasMachineOutput() bool {
// TableColumn defines a table column with header and data extraction function
type TableColumn struct {
Header string
Key string // Unique key for column selection
Width int // Optional width specification
Extract func(item interface{}) string
Color func(value string) string // Optional color function
@ -68,8 +70,9 @@ func NewTableRenderer(om *OutputManager) *TableRenderer {
}
// AddColumn adds a column to the table
func (tr *TableRenderer) AddColumn(header string, extract func(interface{}) string) *TableRenderer {
func (tr *TableRenderer) AddColumn(key, header string, extract func(interface{}) string) *TableRenderer {
tr.columns = append(tr.columns, TableColumn{
Key: key,
Header: header,
Extract: extract,
})
@ -77,8 +80,9 @@ func (tr *TableRenderer) AddColumn(header string, extract func(interface{}) stri
}
// AddColoredColumn adds a column with color formatting
func (tr *TableRenderer) AddColoredColumn(header string, extract func(interface{}) string, color func(string) string) *TableRenderer {
func (tr *TableRenderer) AddColoredColumn(key, header string, extract func(interface{}) string, color func(string) string) *TableRenderer {
tr.columns = append(tr.columns, TableColumn{
Key: key,
Header: header,
Extract: extract,
Color: color,
@ -92,6 +96,30 @@ func (tr *TableRenderer) SetData(data []interface{}) *TableRenderer {
return tr
}
// FilterColumns filters columns based on comma-separated list of column keys
func (tr *TableRenderer) FilterColumns(columnKeys string) *TableRenderer {
if columnKeys == "" {
return tr // No filtering
}
keys := strings.Split(columnKeys, ",")
var filteredColumns []TableColumn
// Filter columns based on keys, maintaining order from column keys
for _, key := range keys {
trimmedKey := strings.TrimSpace(key)
for _, col := range tr.columns {
if col.Key == trimmedKey {
filteredColumns = append(filteredColumns, col)
break
}
}
}
tr.columns = filteredColumns
return tr
}
// Render renders the table or outputs machine-readable format
func (tr *TableRenderer) Render() {
// If machine output format is requested, output the raw data instead of table
@ -329,6 +357,12 @@ func ListOutput(cmd *cobra.Command, data []interface{}, tableSetup func(*TableRe
renderer := NewTableRenderer(om)
renderer.SetData(data)
tableSetup(renderer)
// Apply column filtering if --columns flag is provided
if columnsFlag := GetColumnsFlag(cmd); columnsFlag != "" {
renderer.FilterColumns(columnsFlag)
}
renderer.Render()
}

View File

@ -1,375 +0,0 @@
package cli
// This file demonstrates how the new output infrastructure simplifies CLI command implementation
// It shows before/after comparisons for list and detail commands
import (
"fmt"
"strconv"
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/pterm/pterm"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
)
// BEFORE: Current listUsersCmd implementation (from users.go:199-258)
var originalListUsersCmd = &cobra.Command{
Use: "list",
Short: "List users",
Aliases: []string{"ls", "show"},
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
request := &v1.ListUsersRequest{}
response, err := client.ListUsers(ctx, request)
if err != nil {
ErrorOutput(
err,
"Cannot get users: "+status.Convert(err).Message(),
output,
)
}
if output != "" {
SuccessOutput(response.GetUsers(), "", output)
}
tableData := pterm.TableData{{"ID", "Name", "Username", "Email", "Created"}}
for _, user := range response.GetUsers() {
tableData = append(
tableData,
[]string{
strconv.FormatUint(user.GetId(), 10),
user.GetDisplayName(),
user.GetName(),
user.GetEmail(),
user.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"),
},
)
}
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
}
},
}
// AFTER: Refactored listUsersCmd using new output infrastructure
var refactoredListUsersCmd = &cobra.Command{
Use: "list",
Short: "List users",
Aliases: []string{"ls", "show"},
Run: func(cmd *cobra.Command, args []string) {
ExecuteWithClient(cmd, func(client *ClientWrapper) error {
response, err := client.ListUsers(cmd, &v1.ListUsersRequest{})
if err != nil {
return err // Error handling done by ClientWrapper
}
// Convert to []interface{} for table renderer
users := make([]interface{}, len(response.GetUsers()))
for i, user := range response.GetUsers() {
users[i] = user
}
// Use new output infrastructure
ListOutput(cmd, users, func(tr *TableRenderer) {
tr.AddColumn("ID", func(item interface{}) string {
if user, ok := item.(*v1.User); ok {
return strconv.FormatUint(user.GetId(), util.Base10)
}
return ""
}).
AddColumn("Name", func(item interface{}) string {
if user, ok := item.(*v1.User); ok {
return user.GetDisplayName()
}
return ""
}).
AddColumn("Username", func(item interface{}) string {
if user, ok := item.(*v1.User); ok {
return user.GetName()
}
return ""
}).
AddColumn("Email", func(item interface{}) string {
if user, ok := item.(*v1.User); ok {
return user.GetEmail()
}
return ""
}).
AddColumn("Created", func(item interface{}) string {
if user, ok := item.(*v1.User); ok {
return FormatTime(user.GetCreatedAt().AsTime())
}
return ""
})
})
return nil
})
},
}
// BEFORE: Current listNodesCmd implementation (from nodes.go:160-210)
var originalListNodesCmd = &cobra.Command{
Use: "list",
Short: "List nodes",
Aliases: []string{"ls", "show"},
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
user, err := cmd.Flags().GetString("user")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
}
showTags, err := cmd.Flags().GetBool("tags")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting tags flag: %s", err), output)
}
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
request := &v1.ListNodesRequest{
User: user,
}
response, err := client.ListNodes(ctx, request)
if err != nil {
ErrorOutput(
err,
"Cannot get nodes: "+status.Convert(err).Message(),
output,
)
}
if output != "" {
SuccessOutput(response.GetNodes(), "", output)
}
tableData, err := nodesToPtables(user, showTags, response.GetNodes())
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
}
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
}
},
}
// AFTER: Refactored listNodesCmd using new output infrastructure
var refactoredListNodesCmd = &cobra.Command{
Use: "list",
Short: "List nodes",
Aliases: []string{"ls", "show"},
Run: func(cmd *cobra.Command, args []string) {
user, err := GetUserWithDeprecatedNamespace(cmd)
if err != nil {
SimpleError(cmd, err, "Error getting user")
return
}
showTags := GetTags(cmd)
ExecuteWithClient(cmd, func(client *ClientWrapper) error {
response, err := client.ListNodes(cmd, &v1.ListNodesRequest{User: user})
if err != nil {
return err
}
// Convert to []interface{} for table renderer
nodes := make([]interface{}, len(response.GetNodes()))
for i, node := range response.GetNodes() {
nodes[i] = node
}
// Use new output infrastructure with dynamic columns
ListOutput(cmd, nodes, func(tr *TableRenderer) {
setupNodeTableColumns(tr, user, showTags)
})
return nil
})
},
}
// Helper function to setup node table columns (extracted for reusability)
func setupNodeTableColumns(tr *TableRenderer, currentUser string, showTags bool) {
tr.AddColumn("ID", func(item interface{}) string {
if node, ok := item.(*v1.Node); ok {
return strconv.FormatUint(node.GetId(), util.Base10)
}
return ""
}).
AddColumn("Hostname", func(item interface{}) string {
if node, ok := item.(*v1.Node); ok {
return node.GetName()
}
return ""
}).
AddColumn("Name", func(item interface{}) string {
if node, ok := item.(*v1.Node); ok {
return node.GetGivenName()
}
return ""
}).
AddColoredColumn("User", func(item interface{}) string {
if node, ok := item.(*v1.Node); ok {
return node.GetUser().GetName()
}
return ""
}, func(username string) string {
if currentUser == "" || currentUser == username {
return ColorMagenta(username) // Own user
}
return ColorYellow(username) // Shared user
}).
AddColumn("IP addresses", func(item interface{}) string {
if node, ok := item.(*v1.Node); ok {
return FormatStringSlice(node.GetIpAddresses())
}
return ""
}).
AddColumn("Last seen", func(item interface{}) string {
if node, ok := item.(*v1.Node); ok {
if node.GetLastSeen() != nil {
return FormatTime(node.GetLastSeen().AsTime())
}
}
return ""
}).
AddColoredColumn("Connected", func(item interface{}) string {
if node, ok := item.(*v1.Node); ok {
return FormatOnlineStatus(node.GetOnline())
}
return ""
}, nil). // Color already applied by FormatOnlineStatus
AddColoredColumn("Expired", func(item interface{}) string {
if node, ok := item.(*v1.Node); ok {
expired := false
if node.GetExpiry() != nil {
expiry := node.GetExpiry().AsTime()
expired = !expiry.IsZero() && expiry.Before(time.Now())
}
return FormatExpiredStatus(expired)
}
return ""
}, nil) // Color already applied by FormatExpiredStatus
// Add tag columns if requested
if showTags {
tr.AddColumn("ForcedTags", func(item interface{}) string {
if node, ok := item.(*v1.Node); ok {
return FormatStringSlice(node.GetForcedTags())
}
return ""
}).
AddColumn("InvalidTags", func(item interface{}) string {
if node, ok := item.(*v1.Node); ok {
return FormatTagList(node.GetInvalidTags(), ColorRed)
}
return ""
}).
AddColumn("ValidTags", func(item interface{}) string {
if node, ok := item.(*v1.Node); ok {
return FormatTagList(node.GetValidTags(), ColorGreen)
}
return ""
})
}
}
// BEFORE: Current registerNodeCmd implementation (from nodes.go:114-158)
// (Already shown in example_refactor_demo.go)
// AFTER: Refactored registerNodeCmd using both flag and output infrastructure
var fullyRefactoredRegisterNodeCmd = &cobra.Command{
Use: "register",
Short: "Registers a node to your network",
Run: func(cmd *cobra.Command, args []string) {
user, err := GetUserWithDeprecatedNamespace(cmd)
if err != nil {
SimpleError(cmd, err, "Error getting user")
return
}
key, err := GetKey(cmd)
if err != nil {
SimpleError(cmd, err, "Error getting key")
return
}
ExecuteWithClient(cmd, func(client *ClientWrapper) error {
response, err := client.RegisterNode(cmd, &v1.RegisterNodeRequest{
Key: key,
User: user,
})
if err != nil {
return err
}
DetailOutput(cmd, response.GetNode(),
fmt.Sprintf("Node %s registered", response.GetNode().GetGivenName()))
return nil
})
},
}
/*
IMPROVEMENT SUMMARY FOR OUTPUT INFRASTRUCTURE:
1. LIST COMMANDS REDUCTION:
Before: 35+ lines with manual table setup, output format handling, error handling
After: 15 lines with declarative table configuration
2. DETAIL COMMANDS REDUCTION:
Before: 20+ lines with manual output format detection and error handling
After: 5 lines with DetailOutput()
3. ERROR HANDLING CONSISTENCY:
Before: Manual error handling with different formats across commands
After: Automatic error handling via ClientWrapper + OutputManager integration
4. TABLE RENDERING STANDARDIZATION:
Before: Manual pterm.TableData construction and error handling
After: Declarative column configuration with automatic rendering
5. OUTPUT FORMAT DETECTION:
Before: Manual output format checking and conditional logic
After: Automatic detection and appropriate rendering
6. COLOR AND FORMATTING:
Before: Inline color logic scattered throughout commands
After: Centralized formatting functions (FormatOnlineStatus, FormatTime, etc.)
7. CODE REUSABILITY:
Before: Each command implements its own table setup
After: Reusable helper functions (setupNodeTableColumns, etc.)
8. TESTING:
Before: Difficult to test output formatting logic
After: Each component independently testable
TOTAL REDUCTION: ~60-70% fewer lines for typical list/detail commands
MAINTAINABILITY: Centralized output logic, consistent patterns
EXTENSIBILITY: Easy to add new output formats or modify existing ones
*/

View File

@ -1,461 +0,0 @@
package cli
import (
"fmt"
"testing"
"time"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewOutputManager(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
AddOutputFlag(cmd)
om := NewOutputManager(cmd)
assert.NotNil(t, om)
assert.Equal(t, cmd, om.cmd)
assert.Equal(t, "", om.outputFormat) // Default empty format
}
func TestOutputManager_HasMachineOutput(t *testing.T) {
tests := []struct {
name string
outputFormat string
expectedResult bool
}{
{
name: "empty format (human readable)",
outputFormat: "",
expectedResult: false,
},
{
name: "json format",
outputFormat: "json",
expectedResult: true,
},
{
name: "yaml format",
outputFormat: "yaml",
expectedResult: true,
},
{
name: "json-line format",
outputFormat: "json-line",
expectedResult: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
AddOutputFlag(cmd)
if tt.outputFormat != "" {
err := cmd.Flags().Set("output", tt.outputFormat)
require.NoError(t, err)
}
om := NewOutputManager(cmd)
result := om.HasMachineOutput()
assert.Equal(t, tt.expectedResult, result)
})
}
}
func TestNewTableRenderer(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
AddOutputFlag(cmd)
om := NewOutputManager(cmd)
tr := NewTableRenderer(om)
assert.NotNil(t, tr)
assert.Equal(t, om, tr.outputManager)
assert.Empty(t, tr.columns)
assert.Empty(t, tr.data)
}
func TestTableRenderer_AddColumn(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
AddOutputFlag(cmd)
om := NewOutputManager(cmd)
tr := NewTableRenderer(om)
extractFunc := func(item interface{}) string {
return "test"
}
result := tr.AddColumn("Test Header", extractFunc)
// Should return self for chaining
assert.Equal(t, tr, result)
// Should have added column
require.Len(t, tr.columns, 1)
assert.Equal(t, "Test Header", tr.columns[0].Header)
assert.NotNil(t, tr.columns[0].Extract)
assert.Nil(t, tr.columns[0].Color)
}
func TestTableRenderer_AddColoredColumn(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
AddOutputFlag(cmd)
om := NewOutputManager(cmd)
tr := NewTableRenderer(om)
extractFunc := func(item interface{}) string {
return "test"
}
colorFunc := func(value string) string {
return ColorGreen(value)
}
result := tr.AddColoredColumn("Colored Header", extractFunc, colorFunc)
// Should return self for chaining
assert.Equal(t, tr, result)
// Should have added colored column
require.Len(t, tr.columns, 1)
assert.Equal(t, "Colored Header", tr.columns[0].Header)
assert.NotNil(t, tr.columns[0].Extract)
assert.NotNil(t, tr.columns[0].Color)
}
func TestTableRenderer_SetData(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
AddOutputFlag(cmd)
om := NewOutputManager(cmd)
tr := NewTableRenderer(om)
testData := []interface{}{"item1", "item2", "item3"}
result := tr.SetData(testData)
// Should return self for chaining
assert.Equal(t, tr, result)
// Should have set data
assert.Equal(t, testData, tr.data)
}
func TestTableRenderer_Chaining(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
AddOutputFlag(cmd)
om := NewOutputManager(cmd)
testData := []interface{}{"item1", "item2"}
// Test method chaining
tr := NewTableRenderer(om).
AddColumn("Column1", func(item interface{}) string { return "col1" }).
AddColoredColumn("Column2", func(item interface{}) string { return "col2" }, ColorGreen).
SetData(testData)
assert.NotNil(t, tr)
assert.Len(t, tr.columns, 2)
assert.Equal(t, testData, tr.data)
}
func TestColorFunctions(t *testing.T) {
testText := "test"
// Test that color functions return non-empty strings
// We can't test exact output since pterm formatting depends on terminal
assert.NotEmpty(t, ColorGreen(testText))
assert.NotEmpty(t, ColorRed(testText))
assert.NotEmpty(t, ColorYellow(testText))
assert.NotEmpty(t, ColorMagenta(testText))
assert.NotEmpty(t, ColorBlue(testText))
assert.NotEmpty(t, ColorCyan(testText))
// Test that color functions actually modify the input
assert.NotEqual(t, testText, ColorGreen(testText))
assert.NotEqual(t, testText, ColorRed(testText))
}
func TestFormatTime(t *testing.T) {
tests := []struct {
name string
time time.Time
expected string
}{
{
name: "zero time",
time: time.Time{},
expected: "N/A",
},
{
name: "specific time",
time: time.Date(2023, 12, 25, 15, 30, 45, 0, time.UTC),
expected: "2023-12-25 15:30:45",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := FormatTime(tt.time)
assert.Equal(t, tt.expected, result)
})
}
}
func TestFormatTimeColored(t *testing.T) {
now := time.Now()
futureTime := now.Add(time.Hour)
pastTime := now.Add(-time.Hour)
// Test zero time
result := FormatTimeColored(time.Time{})
assert.Equal(t, "N/A", result)
// Test future time (should be green)
futureResult := FormatTimeColored(futureTime)
assert.Contains(t, futureResult, futureTime.Format(HeadscaleDateTimeFormat))
assert.NotEqual(t, futureTime.Format(HeadscaleDateTimeFormat), futureResult) // Should be colored
// Test past time (should be red)
pastResult := FormatTimeColored(pastTime)
assert.Contains(t, pastResult, pastTime.Format(HeadscaleDateTimeFormat))
assert.NotEqual(t, pastTime.Format(HeadscaleDateTimeFormat), pastResult) // Should be colored
}
func TestFormatBool(t *testing.T) {
assert.Equal(t, "true", FormatBool(true))
assert.Equal(t, "false", FormatBool(false))
}
func TestFormatBoolColored(t *testing.T) {
trueResult := FormatBoolColored(true)
falseResult := FormatBoolColored(false)
// Should contain the boolean value
assert.Contains(t, trueResult, "true")
assert.Contains(t, falseResult, "false")
// Should be colored (different from plain text)
assert.NotEqual(t, "true", trueResult)
assert.NotEqual(t, "false", falseResult)
}
func TestFormatYesNo(t *testing.T) {
assert.Equal(t, "Yes", FormatYesNo(true))
assert.Equal(t, "No", FormatYesNo(false))
}
func TestFormatYesNoColored(t *testing.T) {
yesResult := FormatYesNoColored(true)
noResult := FormatYesNoColored(false)
// Should contain the yes/no value
assert.Contains(t, yesResult, "Yes")
assert.Contains(t, noResult, "No")
// Should be colored
assert.NotEqual(t, "Yes", yesResult)
assert.NotEqual(t, "No", noResult)
}
func TestFormatOnlineStatus(t *testing.T) {
onlineResult := FormatOnlineStatus(true)
offlineResult := FormatOnlineStatus(false)
assert.Contains(t, onlineResult, "online")
assert.Contains(t, offlineResult, "offline")
// Should be colored
assert.NotEqual(t, "online", onlineResult)
assert.NotEqual(t, "offline", offlineResult)
}
func TestFormatExpiredStatus(t *testing.T) {
expiredResult := FormatExpiredStatus(true)
notExpiredResult := FormatExpiredStatus(false)
assert.Contains(t, expiredResult, "yes")
assert.Contains(t, notExpiredResult, "no")
// Should be colored
assert.NotEqual(t, "yes", expiredResult)
assert.NotEqual(t, "no", notExpiredResult)
}
func TestFormatStringSlice(t *testing.T) {
tests := []struct {
name string
slice []string
expected string
}{
{
name: "empty slice",
slice: []string{},
expected: "",
},
{
name: "single item",
slice: []string{"item1"},
expected: "item1",
},
{
name: "multiple items",
slice: []string{"item1", "item2", "item3"},
expected: "item1, item2, item3",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := FormatStringSlice(tt.slice)
assert.Equal(t, tt.expected, result)
})
}
}
func TestFormatTagList(t *testing.T) {
tests := []struct {
name string
tags []string
colorFunc func(string) string
expected string
}{
{
name: "empty tags",
tags: []string{},
colorFunc: nil,
expected: "",
},
{
name: "single tag without color",
tags: []string{"tag1"},
colorFunc: nil,
expected: "tag1",
},
{
name: "multiple tags without color",
tags: []string{"tag1", "tag2"},
colorFunc: nil,
expected: "tag1, tag2",
},
{
name: "tags with color function",
tags: []string{"tag1", "tag2"},
colorFunc: func(s string) string { return "[" + s + "]" }, // Mock color function
expected: "[tag1], [tag2]",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := FormatTagList(tt.tags, tt.colorFunc)
assert.Equal(t, tt.expected, result)
})
}
}
func TestExtractStringField(t *testing.T) {
// Test basic functionality
result := ExtractStringField("test string", "field")
assert.Equal(t, "test string", result)
// Test with number
result = ExtractStringField(123, "field")
assert.Equal(t, "123", result)
// Test with boolean
result = ExtractStringField(true, "field")
assert.Equal(t, "true", result)
}
func TestOutputManagerIntegration(t *testing.T) {
// Test integration between OutputManager and other components
cmd := &cobra.Command{Use: "test"}
AddOutputFlag(cmd)
// Test with different output formats
formats := []string{"", "json", "yaml", "json-line"}
for _, format := range formats {
t.Run("format_"+format, func(t *testing.T) {
if format != "" {
err := cmd.Flags().Set("output", format)
require.NoError(t, err)
}
om := NewOutputManager(cmd)
// Verify output format detection
expectedHasMachine := format != ""
assert.Equal(t, expectedHasMachine, om.HasMachineOutput())
// Test table renderer creation
tr := NewTableRenderer(om)
assert.NotNil(t, tr)
assert.Equal(t, om, tr.outputManager)
})
}
}
func TestTableRendererCompleteWorkflow(t *testing.T) {
// Test complete table rendering workflow
cmd := &cobra.Command{Use: "test"}
AddOutputFlag(cmd)
om := NewOutputManager(cmd)
// Mock data
type TestItem struct {
ID int
Name string
Active bool
}
testData := []interface{}{
TestItem{ID: 1, Name: "Item1", Active: true},
TestItem{ID: 2, Name: "Item2", Active: false},
}
// Create and configure table
tr := NewTableRenderer(om).
AddColumn("ID", func(item interface{}) string {
if testItem, ok := item.(TestItem); ok {
return FormatStringField(testItem.ID)
}
return ""
}).
AddColumn("Name", func(item interface{}) string {
if testItem, ok := item.(TestItem); ok {
return testItem.Name
}
return ""
}).
AddColoredColumn("Status", func(item interface{}) string {
if testItem, ok := item.(TestItem); ok {
return FormatYesNo(testItem.Active)
}
return ""
}, func(value string) string {
if value == "Yes" {
return ColorGreen(value)
}
return ColorRed(value)
}).
SetData(testData)
// Verify configuration
assert.Len(t, tr.columns, 3)
assert.Equal(t, testData, tr.data)
assert.Equal(t, "ID", tr.columns[0].Header)
assert.Equal(t, "Name", tr.columns[1].Header)
assert.Equal(t, "Status", tr.columns[2].Header)
}
// Helper function for tests
func FormatStringField(value interface{}) string {
return fmt.Sprintf("%v", value)
}

View File

@ -1,379 +0,0 @@
package cli
import (
"errors"
"testing"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
)
func TestResolveUserByNameOrID(t *testing.T) {
tests := []struct {
name string
identifier string
users []*v1.User
expected *v1.User
expectError bool
}{
{
name: "resolve by ID",
identifier: "123",
users: []*v1.User{
{Id: 123, Name: "testuser", Email: "test@example.com"},
},
expected: &v1.User{Id: 123, Name: "testuser", Email: "test@example.com"},
},
{
name: "resolve by name",
identifier: "testuser",
users: []*v1.User{
{Id: 123, Name: "testuser", Email: "test@example.com"},
},
expected: &v1.User{Id: 123, Name: "testuser", Email: "test@example.com"},
},
{
name: "resolve by email",
identifier: "test@example.com",
users: []*v1.User{
{Id: 123, Name: "testuser", Email: "test@example.com"},
},
expected: &v1.User{Id: 123, Name: "testuser", Email: "test@example.com"},
},
{
name: "not found",
identifier: "nonexistent",
users: []*v1.User{},
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// We can't easily test the actual resolution without a real client
// but we can test the logic structure
assert.NotNil(t, ResolveUserByNameOrID)
})
}
}
func TestResolveNodeByIdentifier(t *testing.T) {
tests := []struct {
name string
identifier string
nodes []*v1.Node
expected *v1.Node
expectError bool
}{
{
name: "resolve by ID",
identifier: "123",
nodes: []*v1.Node{
{Id: 123, Name: "testnode", GivenName: "test-device", IpAddresses: []string{"192.168.1.1"}},
},
expected: &v1.Node{Id: 123, Name: "testnode", GivenName: "test-device", IpAddresses: []string{"192.168.1.1"}},
},
{
name: "resolve by hostname",
identifier: "testnode",
nodes: []*v1.Node{
{Id: 123, Name: "testnode", GivenName: "test-device", IpAddresses: []string{"192.168.1.1"}},
},
expected: &v1.Node{Id: 123, Name: "testnode", GivenName: "test-device", IpAddresses: []string{"192.168.1.1"}},
},
{
name: "not found",
identifier: "nonexistent",
nodes: []*v1.Node{},
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Test that the function exists and has the right signature
assert.NotNil(t, ResolveNodeByIdentifier)
})
}
}
func TestValidateRequiredArgs(t *testing.T) {
tests := []struct {
name string
args []string
minArgs int
usage string
expectError bool
}{
{
name: "sufficient args",
args: []string{"arg1", "arg2"},
minArgs: 2,
usage: "command <arg1> <arg2>",
expectError: false,
},
{
name: "insufficient args",
args: []string{"arg1"},
minArgs: 2,
usage: "command <arg1> <arg2>",
expectError: true,
},
{
name: "no args required",
args: []string{},
minArgs: 0,
usage: "command",
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
validator := ValidateRequiredArgs(tt.minArgs, tt.usage)
err := validator(cmd, tt.args)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), "insufficient arguments")
} else {
assert.NoError(t, err)
}
})
}
}
func TestValidateExactArgs(t *testing.T) {
tests := []struct {
name string
args []string
exactArgs int
usage string
expectError bool
}{
{
name: "exact args",
args: []string{"arg1", "arg2"},
exactArgs: 2,
usage: "command <arg1> <arg2>",
expectError: false,
},
{
name: "too few args",
args: []string{"arg1"},
exactArgs: 2,
usage: "command <arg1> <arg2>",
expectError: true,
},
{
name: "too many args",
args: []string{"arg1", "arg2", "arg3"},
exactArgs: 2,
usage: "command <arg1> <arg2>",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
validator := ValidateExactArgs(tt.exactArgs, tt.usage)
err := validator(cmd, tt.args)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), "expected")
} else {
assert.NoError(t, err)
}
})
}
}
func TestProcessMultipleResources(t *testing.T) {
tests := []struct {
name string
items []string
processor func(string) error
continueOnError bool
expectedErrors int
}{
{
name: "all success",
items: []string{"item1", "item2", "item3"},
processor: func(item string) error {
return nil
},
continueOnError: true,
expectedErrors: 0,
},
{
name: "one error, continue",
items: []string{"item1", "error", "item3"},
processor: func(item string) error {
if item == "error" {
return errors.New("test error")
}
return nil
},
continueOnError: true,
expectedErrors: 1,
},
{
name: "one error, stop",
items: []string{"item1", "error", "item3"},
processor: func(item string) error {
if item == "error" {
return errors.New("test error")
}
return nil
},
continueOnError: false,
expectedErrors: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
errors := ProcessMultipleResources(tt.items, tt.processor, tt.continueOnError)
assert.Len(t, errors, tt.expectedErrors)
})
}
}
func TestIsValidationError(t *testing.T) {
tests := []struct {
name string
err error
expected bool
}{
{
name: "insufficient arguments error",
err: errors.New("insufficient arguments provided"),
expected: true,
},
{
name: "required flag error",
err: errors.New("required flag not set"),
expected: true,
},
{
name: "not found error",
err: errors.New("not found matching identifier"),
expected: true,
},
{
name: "ambiguous error",
err: errors.New("ambiguous identifier"),
expected: true,
},
{
name: "network error",
err: errors.New("connection refused"),
expected: false,
},
{
name: "random error",
err: errors.New("some other error"),
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsValidationError(tt.err)
assert.Equal(t, tt.expected, result)
})
}
}
func TestWrapCommandError(t *testing.T) {
cmd := &cobra.Command{Use: "test"}
originalErr := errors.New("original error")
action := "create user"
wrappedErr := WrapCommandError(cmd, originalErr, action)
assert.Error(t, wrappedErr)
assert.Contains(t, wrappedErr.Error(), "failed to create user")
assert.Contains(t, wrappedErr.Error(), "original error")
}
func TestCommandPatternHelpers(t *testing.T) {
// Test that the helper functions exist and return valid function types
// Mock functions for testing
listFunc := func(client *ClientWrapper, cmd *cobra.Command) ([]interface{}, error) {
return []interface{}{}, nil
}
tableSetup := func(tr *TableRenderer) {
// Mock table setup
}
createFunc := func(client *ClientWrapper, cmd *cobra.Command, args []string) (interface{}, error) {
return map[string]string{"result": "created"}, nil
}
getFunc := func(client *ClientWrapper, cmd *cobra.Command) (interface{}, error) {
return map[string]string{"result": "found"}, nil
}
deleteFunc := func(client *ClientWrapper, cmd *cobra.Command) (interface{}, error) {
return map[string]string{"result": "deleted"}, nil
}
updateFunc := func(client *ClientWrapper, cmd *cobra.Command, args []string) (interface{}, error) {
return map[string]string{"result": "updated"}, nil
}
// Test helper function creation
listCmdFunc := StandardListCommand(listFunc, tableSetup)
assert.NotNil(t, listCmdFunc)
createCmdFunc := StandardCreateCommand(createFunc, "Created successfully")
assert.NotNil(t, createCmdFunc)
deleteCmdFunc := StandardDeleteCommand(getFunc, deleteFunc, "resource")
assert.NotNil(t, deleteCmdFunc)
updateCmdFunc := StandardUpdateCommand(updateFunc, "Updated successfully")
assert.NotNil(t, updateCmdFunc)
}
func TestExecuteListCommand(t *testing.T) {
// Test that ExecuteListCommand function exists
assert.NotNil(t, ExecuteListCommand)
}
func TestExecuteCreateCommand(t *testing.T) {
// Test that ExecuteCreateCommand function exists
assert.NotNil(t, ExecuteCreateCommand)
}
func TestExecuteGetCommand(t *testing.T) {
// Test that ExecuteGetCommand function exists
assert.NotNil(t, ExecuteGetCommand)
}
func TestExecuteUpdateCommand(t *testing.T) {
// Test that ExecuteUpdateCommand function exists
assert.NotNil(t, ExecuteUpdateCommand)
}
func TestExecuteDeleteCommand(t *testing.T) {
// Test that ExecuteDeleteCommand function exists
assert.NotNil(t, ExecuteDeleteCommand)
}
func TestConfirmAction(t *testing.T) {
// Test that ConfirmAction function exists
assert.NotNil(t, ConfirmAction)
}
func TestConfirmDeletion(t *testing.T) {
// Test that ConfirmDeletion function exists
assert.NotNil(t, ConfirmDeletion)
}

View File

@ -1,364 +0,0 @@
package cli
import (
"testing"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPolicyCommand(t *testing.T) {
// Test the main policy command
assert.NotNil(t, policyCmd)
assert.Equal(t, "policy", policyCmd.Use)
assert.Equal(t, "Manage the Headscale ACL Policy", policyCmd.Short)
// Test that policy command has subcommands
subcommands := policyCmd.Commands()
assert.Greater(t, len(subcommands), 0, "Policy command should have subcommands")
// Verify expected subcommands exist
subcommandNames := make([]string, len(subcommands))
for i, cmd := range subcommands {
subcommandNames[i] = cmd.Use
}
expectedSubcommands := []string{"get", "set", "check"}
for _, expected := range expectedSubcommands {
found := false
for _, actual := range subcommandNames {
if actual == expected {
found = true
break
}
}
assert.True(t, found, "Expected subcommand '%s' not found", expected)
}
}
func TestGetPolicyCommand(t *testing.T) {
assert.NotNil(t, getPolicy)
assert.Equal(t, "get", getPolicy.Use)
assert.Equal(t, "Print the current ACL Policy", getPolicy.Short)
assert.Equal(t, []string{"show", "view", "fetch"}, getPolicy.Aliases)
// Test that Run function is set
assert.NotNil(t, getPolicy.Run)
}
func TestSetPolicyCommand(t *testing.T) {
assert.NotNil(t, setPolicy)
assert.Equal(t, "set", setPolicy.Use)
assert.Equal(t, "Updates the ACL Policy", setPolicy.Short)
assert.Equal(t, []string{"update", "save", "apply"}, setPolicy.Aliases)
// Test that Run function is set
assert.NotNil(t, setPolicy.Run)
// Test flags
flags := setPolicy.Flags()
assert.NotNil(t, flags.Lookup("file"))
// Test flag properties
fileFlag := flags.Lookup("file")
assert.Equal(t, "f", fileFlag.Shorthand)
assert.Equal(t, "Path to a policy file in HuJSON format", fileFlag.Usage)
// Test that file flag is required
if fileFlag.Annotations != nil {
_, hasRequired := fileFlag.Annotations[cobra.BashCompOneRequiredFlag]
assert.True(t, hasRequired, "file flag should be marked as required")
}
}
func TestCheckPolicyCommand(t *testing.T) {
assert.NotNil(t, checkPolicy)
assert.Equal(t, "check", checkPolicy.Use)
assert.Equal(t, "Check a policy file for syntax or other issues", checkPolicy.Short)
assert.Equal(t, []string{"validate", "test", "verify"}, checkPolicy.Aliases)
// Test that Run function is set
assert.NotNil(t, checkPolicy.Run)
// Test flags
flags := checkPolicy.Flags()
assert.NotNil(t, flags.Lookup("file"))
// Test flag properties
fileFlag := flags.Lookup("file")
assert.Equal(t, "f", fileFlag.Shorthand)
assert.Equal(t, "Path to a policy file in HuJSON format", fileFlag.Usage)
// Test that file flag is required
if fileFlag.Annotations != nil {
_, hasRequired := fileFlag.Annotations[cobra.BashCompOneRequiredFlag]
assert.True(t, hasRequired, "file flag should be marked as required")
}
}
func TestPolicyCommandStructure(t *testing.T) {
// Validate command structure and help text
ValidateCommandStructure(t, policyCmd, "policy", "Manage the Headscale ACL Policy")
ValidateCommandHelp(t, policyCmd)
// Validate subcommands
ValidateCommandStructure(t, getPolicy, "get", "Print the current ACL Policy")
ValidateCommandHelp(t, getPolicy)
ValidateCommandStructure(t, setPolicy, "set", "Updates the ACL Policy")
ValidateCommandHelp(t, setPolicy)
ValidateCommandStructure(t, checkPolicy, "check", "Check a policy file for syntax or other issues")
ValidateCommandHelp(t, checkPolicy)
}
func TestPolicyCommandFlags(t *testing.T) {
// Test set policy command flags
ValidateCommandFlags(t, setPolicy, []string{"file"})
// Test check policy command flags
ValidateCommandFlags(t, checkPolicy, []string{"file"})
}
func TestPolicyCommandIntegration(t *testing.T) {
// Test that policy command is properly integrated into root command
found := false
for _, cmd := range rootCmd.Commands() {
if cmd.Use == "policy" {
found = true
break
}
}
assert.True(t, found, "Policy command should be added to root command")
}
func TestPolicySubcommandIntegration(t *testing.T) {
// Test that all subcommands are properly added to policy command
subcommands := policyCmd.Commands()
expectedCommands := map[string]bool{
"get": false,
"set": false,
"check": false,
}
for _, subcmd := range subcommands {
if _, exists := expectedCommands[subcmd.Use]; exists {
expectedCommands[subcmd.Use] = true
}
}
for cmdName, found := range expectedCommands {
assert.True(t, found, "Subcommand '%s' should be added to policy command", cmdName)
}
}
func TestPolicyCommandAliases(t *testing.T) {
// Test that all aliases are properly set
testCases := []struct {
command *cobra.Command
expectedAliases []string
}{
{
command: getPolicy,
expectedAliases: []string{"show", "view", "fetch"},
},
{
command: setPolicy,
expectedAliases: []string{"update", "save", "apply"},
},
{
command: checkPolicy,
expectedAliases: []string{"validate", "test", "verify"},
},
}
for _, tc := range testCases {
t.Run(tc.command.Use, func(t *testing.T) {
assert.Equal(t, tc.expectedAliases, tc.command.Aliases)
})
}
}
func TestPolicyCommandsHaveOutputFlag(t *testing.T) {
// All policy commands should support output formatting
commands := []*cobra.Command{getPolicy, setPolicy, checkPolicy}
for _, cmd := range commands {
t.Run(cmd.Use, func(t *testing.T) {
// Commands should be able to get output flag (though it might be inherited)
// This tests that the commands are designed to work with output formatting
assert.NotNil(t, cmd.Run, "Command should have a Run function")
})
}
}
func TestPolicyCommandCompleteness(t *testing.T) {
// Test that policy command covers all expected operations
subcommands := policyCmd.Commands()
operations := map[string]bool{
"read": false, // get command
"write": false, // set command
"validate": false, // check command
}
for _, subcmd := range subcommands {
switch subcmd.Use {
case "get":
operations["read"] = true
case "set":
operations["write"] = true
case "check":
operations["validate"] = true
}
}
for op, found := range operations {
assert.True(t, found, "Policy command should support %s operation", op)
}
}
func TestPolicyRequiredFlags(t *testing.T) {
// Test that file flag is required for set and check commands
commandsWithRequiredFile := []*cobra.Command{setPolicy, checkPolicy}
for _, cmd := range commandsWithRequiredFile {
t.Run(cmd.Use+"_file_required", func(t *testing.T) {
fileFlag := cmd.Flags().Lookup("file")
require.NotNil(t, fileFlag)
// Check if flag has required annotation (set by MarkFlagRequired)
if fileFlag.Annotations != nil {
_, hasRequired := fileFlag.Annotations[cobra.BashCompOneRequiredFlag]
assert.True(t, hasRequired, "file flag should be marked as required for %s command", cmd.Use)
}
})
}
}
func TestPolicyFlagShortcuts(t *testing.T) {
// Test that flag shortcuts are properly set
// Set command
fileFlag1 := setPolicy.Flags().Lookup("file")
assert.Equal(t, "f", fileFlag1.Shorthand)
// Check command
fileFlag2 := checkPolicy.Flags().Lookup("file")
assert.Equal(t, "f", fileFlag2.Shorthand)
}
func TestPolicyCommandUsagePatterns(t *testing.T) {
// Test that commands follow consistent usage patterns
// Get command should not require arguments or flags
assert.NotNil(t, getPolicy.Run)
assert.Nil(t, getPolicy.Args) // No args validation means optional args
// Set and check commands require file flag (tested above)
assert.NotNil(t, setPolicy.Run)
assert.NotNil(t, checkPolicy.Run)
}
func TestPolicyCommandDocumentation(t *testing.T) {
// Test that commands have proper documentation
// Main command should reference ACL
assert.Contains(t, policyCmd.Short, "ACL Policy")
// Get command should be about reading
assert.Contains(t, getPolicy.Short, "Print")
assert.Contains(t, getPolicy.Short, "current")
// Set command should be about updating
assert.Contains(t, setPolicy.Short, "Updates")
// Check command should be about validation
assert.Contains(t, checkPolicy.Short, "Check")
assert.Contains(t, checkPolicy.Short, "syntax")
}
func TestPolicyFlagDescriptions(t *testing.T) {
// Test that file flags have helpful descriptions
setFileFlag := setPolicy.Flags().Lookup("file")
assert.Contains(t, setFileFlag.Usage, "Path to a policy file")
assert.Contains(t, setFileFlag.Usage, "HuJSON")
checkFileFlag := checkPolicy.Flags().Lookup("file")
assert.Contains(t, checkFileFlag.Usage, "Path to a policy file")
assert.Contains(t, checkFileFlag.Usage, "HuJSON")
}
func TestPolicyCommandNoAliases(t *testing.T) {
// Main policy command should not have aliases (it's clear enough)
assert.Empty(t, policyCmd.Aliases, "Main policy command should not need aliases")
}
func TestPolicyCommandConsistency(t *testing.T) {
// Test that policy commands follow consistent patterns
// Commands that work with files should use consistent flag naming
fileCommands := []*cobra.Command{setPolicy, checkPolicy}
for _, cmd := range fileCommands {
t.Run(cmd.Use+"_consistent_file_flag", func(t *testing.T) {
fileFlag := cmd.Flags().Lookup("file")
require.NotNil(t, fileFlag, "Command %s should have file flag", cmd.Use)
assert.Equal(t, "f", fileFlag.Shorthand, "File flag should have 'f' shorthand")
assert.Contains(t, fileFlag.Usage, "HuJSON", "File flag should mention HuJSON format")
})
}
}
func TestPolicyCommandMeaningfulAliases(t *testing.T) {
// Test that aliases are meaningful and intuitive
// Get command aliases should be about reading/viewing
getAliases := getPolicy.Aliases
assert.Contains(t, getAliases, "show")
assert.Contains(t, getAliases, "view")
assert.Contains(t, getAliases, "fetch")
// Set command aliases should be about writing/updating
setAliases := setPolicy.Aliases
assert.Contains(t, setAliases, "update")
assert.Contains(t, setAliases, "save")
assert.Contains(t, setAliases, "apply")
// Check command aliases should be about validation
checkAliases := checkPolicy.Aliases
assert.Contains(t, checkAliases, "validate")
assert.Contains(t, checkAliases, "test")
assert.Contains(t, checkAliases, "verify")
}
func TestPolicyWorkflowCompleteness(t *testing.T) {
// Test that policy commands support a complete workflow
// Should be able to: get current policy, check new policy, set new policy
subcommands := policyCmd.Commands()
workflow := map[string]bool{
"get_current": false, // get command
"validate_new": false, // check command
"apply_new": false, // set command
}
for _, subcmd := range subcommands {
switch subcmd.Use {
case "get":
workflow["get_current"] = true
case "check":
workflow["validate_new"] = true
case "set":
workflow["apply_new"] = true
}
}
for step, supported := range workflow {
assert.True(t, supported, "Policy workflow should support %s step", step)
}
}

View File

@ -1,401 +0,0 @@
package cli
import (
"testing"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPreAuthKeysCommand(t *testing.T) {
// Test the main preauthkeys command
assert.NotNil(t, preauthkeysCmd)
assert.Equal(t, "preauthkeys", preauthkeysCmd.Use)
assert.Equal(t, "Handle the preauthkeys in Headscale", preauthkeysCmd.Short)
// Test aliases
expectedAliases := []string{"preauthkey", "authkey", "pre"}
assert.Equal(t, expectedAliases, preauthkeysCmd.Aliases)
// Test that preauthkeys command has subcommands
subcommands := preauthkeysCmd.Commands()
assert.Greater(t, len(subcommands), 0, "PreAuth keys command should have subcommands")
// Verify expected subcommands exist
subcommandNames := make([]string, len(subcommands))
for i, cmd := range subcommands {
subcommandNames[i] = cmd.Use
}
expectedSubcommands := []string{"list", "create", "expire"}
for _, expected := range expectedSubcommands {
found := false
for _, actual := range subcommandNames {
if actual == expected {
found = true
break
}
}
assert.True(t, found, "Expected subcommand '%s' not found", expected)
}
}
func TestPreAuthKeysCommandPersistentFlags(t *testing.T) {
// Test persistent flags that apply to all subcommands
flags := preauthkeysCmd.PersistentFlags()
// Test user flag
userFlag := flags.Lookup("user")
assert.NotNil(t, userFlag)
assert.Equal(t, "u", userFlag.Shorthand)
assert.Equal(t, "User identifier (ID)", userFlag.Usage)
// Test that user flag is required
if userFlag.Annotations != nil {
_, hasRequired := userFlag.Annotations[cobra.BashCompOneRequiredFlag]
assert.True(t, hasRequired, "user flag should be marked as required")
}
// Test deprecated namespace flag
namespaceFlag := flags.Lookup("namespace")
assert.NotNil(t, namespaceFlag)
assert.Equal(t, "n", namespaceFlag.Shorthand)
assert.True(t, namespaceFlag.Hidden)
assert.Equal(t, deprecateNamespaceMessage, namespaceFlag.Deprecated)
}
func TestListPreAuthKeysCommand(t *testing.T) {
assert.NotNil(t, listPreAuthKeys)
assert.Equal(t, "list", listPreAuthKeys.Use)
assert.Equal(t, "List the Pre auth keys for the specified user", listPreAuthKeys.Short)
assert.Equal(t, []string{"ls", "show"}, listPreAuthKeys.Aliases)
// Test that Run function is set
assert.NotNil(t, listPreAuthKeys.Run)
}
func TestCreatePreAuthKeyCommand(t *testing.T) {
assert.NotNil(t, createPreAuthKeyCmd)
assert.Equal(t, "create", createPreAuthKeyCmd.Use)
assert.Equal(t, "Creates a new Pre Auth Key", createPreAuthKeyCmd.Short)
assert.Equal(t, []string{"c", "new"}, createPreAuthKeyCmd.Aliases)
// Test that Run function is set
assert.NotNil(t, createPreAuthKeyCmd.Run)
// Test persistent flags (reusable, ephemeral)
persistentFlags := createPreAuthKeyCmd.PersistentFlags()
assert.NotNil(t, persistentFlags.Lookup("reusable"))
assert.NotNil(t, persistentFlags.Lookup("ephemeral"))
// Test regular flags (expiration, tags)
flags := createPreAuthKeyCmd.Flags()
assert.NotNil(t, flags.Lookup("expiration"))
assert.NotNil(t, flags.Lookup("tags"))
// Test flag properties
expirationFlag := flags.Lookup("expiration")
assert.Equal(t, "e", expirationFlag.Shorthand)
assert.Equal(t, DefaultPreAuthKeyExpiry, expirationFlag.DefValue)
reusableFlag := persistentFlags.Lookup("reusable")
assert.Equal(t, "false", reusableFlag.DefValue)
ephemeralFlag := persistentFlags.Lookup("ephemeral")
assert.Equal(t, "false", ephemeralFlag.DefValue)
}
func TestExpirePreAuthKeyCommand(t *testing.T) {
assert.NotNil(t, expirePreAuthKeyCmd)
assert.Equal(t, "expire", expirePreAuthKeyCmd.Use)
assert.Equal(t, "Expire a Pre Auth Key", expirePreAuthKeyCmd.Short)
assert.Equal(t, []string{"revoke", "exp", "e"}, expirePreAuthKeyCmd.Aliases)
// Test that Run function is set
assert.NotNil(t, expirePreAuthKeyCmd.Run)
// Test that Args validation function is set
assert.NotNil(t, expirePreAuthKeyCmd.Args)
}
func TestPreAuthKeyConstants(t *testing.T) {
// Test that constants are defined
assert.Equal(t, "1h", DefaultPreAuthKeyExpiry)
}
func TestPreAuthKeyCommandStructure(t *testing.T) {
// Validate command structure and help text
ValidateCommandStructure(t, preauthkeysCmd, "preauthkeys", "Handle the preauthkeys in Headscale")
ValidateCommandHelp(t, preauthkeysCmd)
// Validate subcommands
ValidateCommandStructure(t, listPreAuthKeys, "list", "List the Pre auth keys for the specified user")
ValidateCommandHelp(t, listPreAuthKeys)
ValidateCommandStructure(t, createPreAuthKeyCmd, "create", "Creates a new Pre Auth Key")
ValidateCommandHelp(t, createPreAuthKeyCmd)
ValidateCommandStructure(t, expirePreAuthKeyCmd, "expire", "Expire a Pre Auth Key")
ValidateCommandHelp(t, expirePreAuthKeyCmd)
}
func TestPreAuthKeyCommandFlags(t *testing.T) {
// Test preauthkeys command persistent flags
ValidateCommandFlags(t, preauthkeysCmd, []string{"user", "namespace"})
// Test create command flags
ValidateCommandFlags(t, createPreAuthKeyCmd, []string{"reusable", "ephemeral", "expiration", "tags"})
}
func TestPreAuthKeyCommandIntegration(t *testing.T) {
// Test that preauthkeys command is properly integrated into root command
found := false
for _, cmd := range rootCmd.Commands() {
if cmd.Use == "preauthkeys" {
found = true
break
}
}
assert.True(t, found, "PreAuth keys command should be added to root command")
}
func TestPreAuthKeySubcommandIntegration(t *testing.T) {
// Test that all subcommands are properly added to preauthkeys command
subcommands := preauthkeysCmd.Commands()
expectedCommands := map[string]bool{
"list": false,
"create": false,
"expire": false,
}
for _, subcmd := range subcommands {
if _, exists := expectedCommands[subcmd.Use]; exists {
expectedCommands[subcmd.Use] = true
}
}
for cmdName, found := range expectedCommands {
assert.True(t, found, "Subcommand '%s' should be added to preauthkeys command", cmdName)
}
}
func TestPreAuthKeyCommandAliases(t *testing.T) {
// Test that all aliases are properly set
testCases := []struct {
command *cobra.Command
expectedAliases []string
}{
{
command: preauthkeysCmd,
expectedAliases: []string{"preauthkey", "authkey", "pre"},
},
{
command: listPreAuthKeys,
expectedAliases: []string{"ls", "show"},
},
{
command: createPreAuthKeyCmd,
expectedAliases: []string{"c", "new"},
},
{
command: expirePreAuthKeyCmd,
expectedAliases: []string{"revoke", "exp", "e"},
},
}
for _, tc := range testCases {
t.Run(tc.command.Use, func(t *testing.T) {
assert.Equal(t, tc.expectedAliases, tc.command.Aliases)
})
}
}
func TestPreAuthKeyFlagDefaults(t *testing.T) {
// Test create command flag defaults
// Test persistent flags
persistentFlags := createPreAuthKeyCmd.PersistentFlags()
reusable, err := persistentFlags.GetBool("reusable")
assert.NoError(t, err)
assert.False(t, reusable)
ephemeral, err := persistentFlags.GetBool("ephemeral")
assert.NoError(t, err)
assert.False(t, ephemeral)
// Test regular flags
flags := createPreAuthKeyCmd.Flags()
expiration, err := flags.GetString("expiration")
assert.NoError(t, err)
assert.Equal(t, DefaultPreAuthKeyExpiry, expiration)
tags, err := flags.GetStringSlice("tags")
assert.NoError(t, err)
assert.Empty(t, tags)
}
func TestPreAuthKeyFlagShortcuts(t *testing.T) {
// Test that flag shortcuts are properly set
// Persistent flags
userFlag := preauthkeysCmd.PersistentFlags().Lookup("user")
assert.Equal(t, "u", userFlag.Shorthand)
namespaceFlag := preauthkeysCmd.PersistentFlags().Lookup("namespace")
assert.Equal(t, "n", namespaceFlag.Shorthand)
// Create command flags
expirationFlag := createPreAuthKeyCmd.Flags().Lookup("expiration")
assert.Equal(t, "e", expirationFlag.Shorthand)
}
func TestPreAuthKeyCommandsHaveOutputFlag(t *testing.T) {
// All preauth key commands should support output formatting
commands := []*cobra.Command{listPreAuthKeys, createPreAuthKeyCmd, expirePreAuthKeyCmd}
for _, cmd := range commands {
t.Run(cmd.Use, func(t *testing.T) {
// Commands should be able to get output flag (though it might be inherited)
// This tests that the commands are designed to work with output formatting
assert.NotNil(t, cmd.Run, "Command should have a Run function")
})
}
}
func TestPreAuthKeyCommandCompleteness(t *testing.T) {
// Test that preauth key command covers all expected CRUD operations
subcommands := preauthkeysCmd.Commands()
operations := map[string]bool{
"create": false,
"read": false, // list command
"update": false, // expire command (updates state)
"delete": false, // expire is the equivalent of delete for preauth keys
}
for _, subcmd := range subcommands {
switch subcmd.Use {
case "create":
operations["create"] = true
case "list":
operations["read"] = true
case "expire":
operations["update"] = true
operations["delete"] = true // expire serves as delete for preauth keys
}
}
for op, found := range operations {
assert.True(t, found, "PreAuth key command should support %s operation", op)
}
}
func TestPreAuthKeyRequiredFlags(t *testing.T) {
// Test that user flag is required on parent command
userFlag := preauthkeysCmd.PersistentFlags().Lookup("user")
require.NotNil(t, userFlag)
// Check if flag has required annotation (set by MarkPersistentFlagRequired)
if userFlag.Annotations != nil {
_, hasRequired := userFlag.Annotations[cobra.BashCompOneRequiredFlag]
assert.True(t, hasRequired, "user flag should be marked as required")
}
}
func TestPreAuthKeyDeprecatedFlags(t *testing.T) {
// Test deprecated namespace flag
namespaceFlag := preauthkeysCmd.PersistentFlags().Lookup("namespace")
require.NotNil(t, namespaceFlag)
assert.True(t, namespaceFlag.Hidden, "Namespace flag should be hidden")
assert.Equal(t, deprecateNamespaceMessage, namespaceFlag.Deprecated)
}
func TestPreAuthKeyCommandUsagePatterns(t *testing.T) {
// Test that commands follow consistent usage patterns
// List and create commands should not require positional arguments
assert.NotNil(t, listPreAuthKeys.Run)
assert.Nil(t, listPreAuthKeys.Args) // No args validation means optional args
assert.NotNil(t, createPreAuthKeyCmd.Run)
assert.Nil(t, createPreAuthKeyCmd.Args)
// Expire command requires key argument
assert.NotNil(t, expirePreAuthKeyCmd.Run)
assert.NotNil(t, expirePreAuthKeyCmd.Args)
}
func TestPreAuthKeyFlagTypes(t *testing.T) {
// Test that flags have correct types
// User flag should be uint64
userFlag := preauthkeysCmd.PersistentFlags().Lookup("user")
require.NotNil(t, userFlag)
assert.Equal(t, "uint64", userFlag.Value.Type())
// Boolean flags
reusableFlag := createPreAuthKeyCmd.PersistentFlags().Lookup("reusable")
require.NotNil(t, reusableFlag)
assert.Equal(t, "bool", reusableFlag.Value.Type())
ephemeralFlag := createPreAuthKeyCmd.PersistentFlags().Lookup("ephemeral")
require.NotNil(t, ephemeralFlag)
assert.Equal(t, "bool", ephemeralFlag.Value.Type())
// String flags
expirationFlag := createPreAuthKeyCmd.Flags().Lookup("expiration")
require.NotNil(t, expirationFlag)
assert.Equal(t, "string", expirationFlag.Value.Type())
// String slice flags
tagsFlag := createPreAuthKeyCmd.Flags().Lookup("tags")
require.NotNil(t, tagsFlag)
assert.Equal(t, "stringSlice", tagsFlag.Value.Type())
}
func TestPreAuthKeyDefaultExpiry(t *testing.T) {
// Test that the default expiry constant is reasonable
assert.Equal(t, "1h", DefaultPreAuthKeyExpiry)
// Test that it can be used in flag defaults
expirationFlag := createPreAuthKeyCmd.Flags().Lookup("expiration")
assert.Equal(t, DefaultPreAuthKeyExpiry, expirationFlag.DefValue)
}
func TestPreAuthKeyCommandDocumentation(t *testing.T) {
// Test that commands have proper documentation
// Main command should have clear description
assert.Contains(t, preauthkeysCmd.Short, "preauthkeys")
assert.Contains(t, preauthkeysCmd.Short, "Headscale")
// Subcommands should have descriptive names
assert.Equal(t, "List the Pre auth keys for the specified user", listPreAuthKeys.Short)
assert.Equal(t, "Creates a new Pre Auth Key", createPreAuthKeyCmd.Short)
assert.Equal(t, "Expire a Pre Auth Key", expirePreAuthKeyCmd.Short)
}
func TestPreAuthKeyFlagDescriptions(t *testing.T) {
// Test that flags have helpful descriptions
userFlag := preauthkeysCmd.PersistentFlags().Lookup("user")
assert.Contains(t, userFlag.Usage, "User identifier")
reusableFlag := createPreAuthKeyCmd.PersistentFlags().Lookup("reusable")
assert.Contains(t, reusableFlag.Usage, "reusable")
ephemeralFlag := createPreAuthKeyCmd.PersistentFlags().Lookup("ephemeral")
assert.Contains(t, ephemeralFlag.Usage, "ephemeral")
expirationFlag := createPreAuthKeyCmd.Flags().Lookup("expiration")
assert.Contains(t, expirationFlag.Usage, "Human-readable")
assert.Contains(t, expirationFlag.Usage, "expiration")
tagsFlag := createPreAuthKeyCmd.Flags().Lookup("tags")
assert.Contains(t, tagsFlag.Usage, "Tags")
assert.Contains(t, tagsFlag.Usage, "automatically assign")
}

View File

@ -1,145 +0,0 @@
package cli
import (
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestColourTime(t *testing.T) {
tests := []struct {
name string
date time.Time
expectedText string
expectRed bool
expectGreen bool
}{
{
name: "future date should be green",
date: time.Now().Add(1 * time.Hour),
expectedText: time.Now().Add(1 * time.Hour).Format("2006-01-02 15:04:05"),
expectGreen: true,
expectRed: false,
},
{
name: "past date should be red",
date: time.Now().Add(-1 * time.Hour),
expectedText: time.Now().Add(-1 * time.Hour).Format("2006-01-02 15:04:05"),
expectGreen: false,
expectRed: true,
},
{
name: "very old date should be red",
date: time.Date(2020, 1, 1, 12, 0, 0, 0, time.UTC),
expectedText: "2020-01-01 12:00:00",
expectGreen: false,
expectRed: true,
},
{
name: "far future date should be green",
date: time.Date(2030, 12, 31, 23, 59, 59, 0, time.UTC),
expectedText: "2030-12-31 23:59:59",
expectGreen: true,
expectRed: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ColourTime(tt.date)
// Check that the formatted time string is present
assert.Contains(t, result, tt.expectedText)
// Check for color codes based on expectation
if tt.expectGreen {
// pterm.LightGreen adds color codes, check for green color escape sequences
assert.Contains(t, result, "\033[92m", "Expected green color codes")
}
if tt.expectRed {
// pterm.LightRed adds color codes, check for red color escape sequences
assert.Contains(t, result, "\033[91m", "Expected red color codes")
}
})
}
}
func TestColourTimeFormatting(t *testing.T) {
// Test that the date format is correct
testDate := time.Date(2023, 6, 15, 14, 30, 45, 0, time.UTC)
result := ColourTime(testDate)
// Should contain the correctly formatted date
assert.Contains(t, result, "2023-06-15 14:30:45")
}
func TestColourTimeWithTimezones(t *testing.T) {
// Test with different timezones
utc := time.Now().UTC()
local := utc.In(time.Local)
resultUTC := ColourTime(utc)
resultLocal := ColourTime(local)
// Both should format to the same time (since it's the same instant)
// but may have different colors depending on when "now" is
utcFormatted := utc.Format("2006-01-02 15:04:05")
localFormatted := local.Format("2006-01-02 15:04:05")
assert.Contains(t, resultUTC, utcFormatted)
assert.Contains(t, resultLocal, localFormatted)
}
func TestColourTimeEdgeCases(t *testing.T) {
// Test with zero time
zeroTime := time.Time{}
result := ColourTime(zeroTime)
assert.Contains(t, result, "0001-01-01 00:00:00")
// Zero time is definitely in the past, so should be red
assert.Contains(t, result, "\033[91m", "Zero time should be red")
}
func TestColourTimeConsistency(t *testing.T) {
// Test that calling the function multiple times with the same input
// produces consistent results (within a reasonable time window)
testDate := time.Now().Add(-5 * time.Minute) // 5 minutes ago
result1 := ColourTime(testDate)
time.Sleep(10 * time.Millisecond) // Small delay
result2 := ColourTime(testDate)
// Results should be identical since the input date hasn't changed
// and it's still in the past relative to "now"
assert.Equal(t, result1, result2)
}
func TestColourTimeNearCurrentTime(t *testing.T) {
// Test dates very close to current time
now := time.Now()
// 1 second in the past
pastResult := ColourTime(now.Add(-1 * time.Second))
assert.Contains(t, pastResult, "\033[91m", "1 second ago should be red")
// 1 second in the future
futureResult := ColourTime(now.Add(1 * time.Second))
assert.Contains(t, futureResult, "\033[92m", "1 second in future should be green")
}
func TestColourTimeStringContainsNoUnexpectedCharacters(t *testing.T) {
// Test that the result doesn't contain unexpected characters
testDate := time.Now()
result := ColourTime(testDate)
// Should not contain newlines or other unexpected characters
assert.False(t, strings.Contains(result, "\n"), "Result should not contain newlines")
assert.False(t, strings.Contains(result, "\r"), "Result should not contain carriage returns")
// Should contain the expected format
dateStr := testDate.Format("2006-01-02 15:04:05")
assert.Contains(t, result, dateStr)
}

View File

@ -1,604 +0,0 @@
package cli
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"os"
"strings"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/spf13/cobra"
"google.golang.org/grpc"
"google.golang.org/protobuf/types/known/timestamppb"
"gopkg.in/yaml.v3"
)
// MockHeadscaleServiceClient provides a mock implementation of the HeadscaleServiceClient
// for testing CLI commands without requiring a real server
type MockHeadscaleServiceClient struct {
// Configurable responses for all gRPC methods
ListUsersResponse *v1.ListUsersResponse
CreateUserResponse *v1.CreateUserResponse
RenameUserResponse *v1.RenameUserResponse
DeleteUserResponse *v1.DeleteUserResponse
ListNodesResponse *v1.ListNodesResponse
RegisterNodeResponse *v1.RegisterNodeResponse
DeleteNodeResponse *v1.DeleteNodeResponse
ExpireNodeResponse *v1.ExpireNodeResponse
RenameNodeResponse *v1.RenameNodeResponse
MoveNodeResponse *v1.MoveNodeResponse
GetNodeResponse *v1.GetNodeResponse
SetTagsResponse *v1.SetTagsResponse
SetApprovedRoutesResponse *v1.SetApprovedRoutesResponse
BackfillNodeIPsResponse *v1.BackfillNodeIPsResponse
ListApiKeysResponse *v1.ListApiKeysResponse
CreateApiKeyResponse *v1.CreateApiKeyResponse
ExpireApiKeyResponse *v1.ExpireApiKeyResponse
DeleteApiKeyResponse *v1.DeleteApiKeyResponse
ListPreAuthKeysResponse *v1.ListPreAuthKeysResponse
CreatePreAuthKeyResponse *v1.CreatePreAuthKeyResponse
ExpirePreAuthKeyResponse *v1.ExpirePreAuthKeyResponse
GetPolicyResponse *v1.GetPolicyResponse
SetPolicyResponse *v1.SetPolicyResponse
DebugCreateNodeResponse *v1.DebugCreateNodeResponse
// Error responses for testing error conditions
ListUsersError error
CreateUserError error
RenameUserError error
DeleteUserError error
ListNodesError error
RegisterNodeError error
DeleteNodeError error
ExpireNodeError error
RenameNodeError error
MoveNodeError error
GetNodeError error
SetTagsError error
SetApprovedRoutesError error
BackfillNodeIPsError error
ListApiKeysError error
CreateApiKeyError error
ExpireApiKeyError error
DeleteApiKeyError error
ListPreAuthKeysError error
CreatePreAuthKeyError error
ExpirePreAuthKeyError error
GetPolicyError error
SetPolicyError error
DebugCreateNodeError error
// Call tracking
LastRequest interface{}
CallCount map[string]int
}
// NewMockHeadscaleServiceClient creates a new mock client with default responses
func NewMockHeadscaleServiceClient() *MockHeadscaleServiceClient {
return &MockHeadscaleServiceClient{
CallCount: make(map[string]int),
// Default successful responses
ListUsersResponse: &v1.ListUsersResponse{Users: []*v1.User{NewTestUser(1, "testuser"), NewTestUser(2, "olduser")}},
CreateUserResponse: &v1.CreateUserResponse{User: NewTestUser(1, "testuser")},
RenameUserResponse: &v1.RenameUserResponse{User: NewTestUser(1, "renamed-user")},
DeleteUserResponse: &v1.DeleteUserResponse{},
ListNodesResponse: &v1.ListNodesResponse{Nodes: []*v1.Node{}},
RegisterNodeResponse: &v1.RegisterNodeResponse{Node: NewTestNode(1, "testnode", NewTestUser(1, "testuser"))},
DeleteNodeResponse: &v1.DeleteNodeResponse{},
ExpireNodeResponse: &v1.ExpireNodeResponse{Node: NewTestNode(1, "testnode", NewTestUser(1, "testuser"))},
RenameNodeResponse: &v1.RenameNodeResponse{Node: NewTestNode(1, "renamed-node", NewTestUser(1, "testuser"))},
MoveNodeResponse: &v1.MoveNodeResponse{Node: NewTestNode(1, "testnode", NewTestUser(2, "newuser"))},
GetNodeResponse: &v1.GetNodeResponse{Node: NewTestNode(1, "testnode", NewTestUser(1, "testuser"))},
SetTagsResponse: &v1.SetTagsResponse{Node: NewTestNode(1, "testnode", NewTestUser(1, "testuser"))},
SetApprovedRoutesResponse: &v1.SetApprovedRoutesResponse{Node: NewTestNode(1, "testnode", NewTestUser(1, "testuser"))},
BackfillNodeIPsResponse: &v1.BackfillNodeIPsResponse{Changes: []string{"192.168.1.1"}},
ListApiKeysResponse: &v1.ListApiKeysResponse{ApiKeys: []*v1.ApiKey{}},
CreateApiKeyResponse: &v1.CreateApiKeyResponse{ApiKey: "testkey_abcdef123456"},
ExpireApiKeyResponse: &v1.ExpireApiKeyResponse{},
DeleteApiKeyResponse: &v1.DeleteApiKeyResponse{},
ListPreAuthKeysResponse: &v1.ListPreAuthKeysResponse{PreAuthKeys: []*v1.PreAuthKey{}},
CreatePreAuthKeyResponse: &v1.CreatePreAuthKeyResponse{PreAuthKey: NewTestPreAuthKey(1, 1)},
ExpirePreAuthKeyResponse: &v1.ExpirePreAuthKeyResponse{},
GetPolicyResponse: &v1.GetPolicyResponse{Policy: "{}"},
SetPolicyResponse: &v1.SetPolicyResponse{Policy: "{}"},
DebugCreateNodeResponse: &v1.DebugCreateNodeResponse{Node: NewTestNode(1, "debug-node", NewTestUser(1, "testuser"))},
}
}
// NewMockClientWrapper creates a ClientWrapper with a mock client for testing
func NewMockClientWrapper() *ClientWrapper {
mockClient := NewMockHeadscaleServiceClient()
return &ClientWrapper{
client: mockClient,
}
}
// Implement all v1.HeadscaleServiceClient methods
func (m *MockHeadscaleServiceClient) ListUsers(ctx context.Context, req *v1.ListUsersRequest, opts ...grpc.CallOption) (*v1.ListUsersResponse, error) {
m.CallCount["ListUsers"]++
m.LastRequest = req
if m.ListUsersError != nil {
return nil, m.ListUsersError
}
return m.ListUsersResponse, nil
}
func (m *MockHeadscaleServiceClient) CreateUser(ctx context.Context, req *v1.CreateUserRequest, opts ...grpc.CallOption) (*v1.CreateUserResponse, error) {
m.CallCount["CreateUser"]++
m.LastRequest = req
if m.CreateUserError != nil {
return nil, m.CreateUserError
}
return m.CreateUserResponse, nil
}
func (m *MockHeadscaleServiceClient) RenameUser(ctx context.Context, req *v1.RenameUserRequest, opts ...grpc.CallOption) (*v1.RenameUserResponse, error) {
m.CallCount["RenameUser"]++
m.LastRequest = req
if m.RenameUserError != nil {
return nil, m.RenameUserError
}
return m.RenameUserResponse, nil
}
func (m *MockHeadscaleServiceClient) DeleteUser(ctx context.Context, req *v1.DeleteUserRequest, opts ...grpc.CallOption) (*v1.DeleteUserResponse, error) {
m.CallCount["DeleteUser"]++
m.LastRequest = req
if m.DeleteUserError != nil {
return nil, m.DeleteUserError
}
return m.DeleteUserResponse, nil
}
func (m *MockHeadscaleServiceClient) ListNodes(ctx context.Context, req *v1.ListNodesRequest, opts ...grpc.CallOption) (*v1.ListNodesResponse, error) {
m.CallCount["ListNodes"]++
m.LastRequest = req
if m.ListNodesError != nil {
return nil, m.ListNodesError
}
return m.ListNodesResponse, nil
}
func (m *MockHeadscaleServiceClient) RegisterNode(ctx context.Context, req *v1.RegisterNodeRequest, opts ...grpc.CallOption) (*v1.RegisterNodeResponse, error) {
m.CallCount["RegisterNode"]++
m.LastRequest = req
if m.RegisterNodeError != nil {
return nil, m.RegisterNodeError
}
return m.RegisterNodeResponse, nil
}
func (m *MockHeadscaleServiceClient) DeleteNode(ctx context.Context, req *v1.DeleteNodeRequest, opts ...grpc.CallOption) (*v1.DeleteNodeResponse, error) {
m.CallCount["DeleteNode"]++
m.LastRequest = req
if m.DeleteNodeError != nil {
return nil, m.DeleteNodeError
}
return m.DeleteNodeResponse, nil
}
func (m *MockHeadscaleServiceClient) ExpireNode(ctx context.Context, req *v1.ExpireNodeRequest, opts ...grpc.CallOption) (*v1.ExpireNodeResponse, error) {
m.CallCount["ExpireNode"]++
m.LastRequest = req
if m.ExpireNodeError != nil {
return nil, m.ExpireNodeError
}
return m.ExpireNodeResponse, nil
}
func (m *MockHeadscaleServiceClient) RenameNode(ctx context.Context, req *v1.RenameNodeRequest, opts ...grpc.CallOption) (*v1.RenameNodeResponse, error) {
m.CallCount["RenameNode"]++
m.LastRequest = req
if m.RenameNodeError != nil {
return nil, m.RenameNodeError
}
return m.RenameNodeResponse, nil
}
func (m *MockHeadscaleServiceClient) MoveNode(ctx context.Context, req *v1.MoveNodeRequest, opts ...grpc.CallOption) (*v1.MoveNodeResponse, error) {
m.CallCount["MoveNode"]++
m.LastRequest = req
if m.MoveNodeError != nil {
return nil, m.MoveNodeError
}
return m.MoveNodeResponse, nil
}
func (m *MockHeadscaleServiceClient) GetNode(ctx context.Context, req *v1.GetNodeRequest, opts ...grpc.CallOption) (*v1.GetNodeResponse, error) {
m.CallCount["GetNode"]++
m.LastRequest = req
if m.GetNodeError != nil {
return nil, m.GetNodeError
}
return m.GetNodeResponse, nil
}
func (m *MockHeadscaleServiceClient) SetTags(ctx context.Context, req *v1.SetTagsRequest, opts ...grpc.CallOption) (*v1.SetTagsResponse, error) {
m.CallCount["SetTags"]++
m.LastRequest = req
if m.SetTagsError != nil {
return nil, m.SetTagsError
}
return m.SetTagsResponse, nil
}
func (m *MockHeadscaleServiceClient) SetApprovedRoutes(ctx context.Context, req *v1.SetApprovedRoutesRequest, opts ...grpc.CallOption) (*v1.SetApprovedRoutesResponse, error) {
m.CallCount["SetApprovedRoutes"]++
m.LastRequest = req
if m.SetApprovedRoutesError != nil {
return nil, m.SetApprovedRoutesError
}
return m.SetApprovedRoutesResponse, nil
}
func (m *MockHeadscaleServiceClient) BackfillNodeIPs(ctx context.Context, req *v1.BackfillNodeIPsRequest, opts ...grpc.CallOption) (*v1.BackfillNodeIPsResponse, error) {
m.CallCount["BackfillNodeIPs"]++
m.LastRequest = req
if m.BackfillNodeIPsError != nil {
return nil, m.BackfillNodeIPsError
}
return m.BackfillNodeIPsResponse, nil
}
func (m *MockHeadscaleServiceClient) ListApiKeys(ctx context.Context, req *v1.ListApiKeysRequest, opts ...grpc.CallOption) (*v1.ListApiKeysResponse, error) {
m.CallCount["ListApiKeys"]++
m.LastRequest = req
if m.ListApiKeysError != nil {
return nil, m.ListApiKeysError
}
return m.ListApiKeysResponse, nil
}
func (m *MockHeadscaleServiceClient) CreateApiKey(ctx context.Context, req *v1.CreateApiKeyRequest, opts ...grpc.CallOption) (*v1.CreateApiKeyResponse, error) {
m.CallCount["CreateApiKey"]++
m.LastRequest = req
if m.CreateApiKeyError != nil {
return nil, m.CreateApiKeyError
}
return m.CreateApiKeyResponse, nil
}
func (m *MockHeadscaleServiceClient) ExpireApiKey(ctx context.Context, req *v1.ExpireApiKeyRequest, opts ...grpc.CallOption) (*v1.ExpireApiKeyResponse, error) {
m.CallCount["ExpireApiKey"]++
m.LastRequest = req
if m.ExpireApiKeyError != nil {
return nil, m.ExpireApiKeyError
}
return m.ExpireApiKeyResponse, nil
}
func (m *MockHeadscaleServiceClient) DeleteApiKey(ctx context.Context, req *v1.DeleteApiKeyRequest, opts ...grpc.CallOption) (*v1.DeleteApiKeyResponse, error) {
m.CallCount["DeleteApiKey"]++
m.LastRequest = req
if m.DeleteApiKeyError != nil {
return nil, m.DeleteApiKeyError
}
return m.DeleteApiKeyResponse, nil
}
func (m *MockHeadscaleServiceClient) ListPreAuthKeys(ctx context.Context, req *v1.ListPreAuthKeysRequest, opts ...grpc.CallOption) (*v1.ListPreAuthKeysResponse, error) {
m.CallCount["ListPreAuthKeys"]++
m.LastRequest = req
if m.ListPreAuthKeysError != nil {
return nil, m.ListPreAuthKeysError
}
return m.ListPreAuthKeysResponse, nil
}
func (m *MockHeadscaleServiceClient) CreatePreAuthKey(ctx context.Context, req *v1.CreatePreAuthKeyRequest, opts ...grpc.CallOption) (*v1.CreatePreAuthKeyResponse, error) {
m.CallCount["CreatePreAuthKey"]++
m.LastRequest = req
if m.CreatePreAuthKeyError != nil {
return nil, m.CreatePreAuthKeyError
}
return m.CreatePreAuthKeyResponse, nil
}
func (m *MockHeadscaleServiceClient) ExpirePreAuthKey(ctx context.Context, req *v1.ExpirePreAuthKeyRequest, opts ...grpc.CallOption) (*v1.ExpirePreAuthKeyResponse, error) {
m.CallCount["ExpirePreAuthKey"]++
m.LastRequest = req
if m.ExpirePreAuthKeyError != nil {
return nil, m.ExpirePreAuthKeyError
}
return m.ExpirePreAuthKeyResponse, nil
}
func (m *MockHeadscaleServiceClient) GetPolicy(ctx context.Context, req *v1.GetPolicyRequest, opts ...grpc.CallOption) (*v1.GetPolicyResponse, error) {
m.CallCount["GetPolicy"]++
m.LastRequest = req
if m.GetPolicyError != nil {
return nil, m.GetPolicyError
}
return m.GetPolicyResponse, nil
}
func (m *MockHeadscaleServiceClient) SetPolicy(ctx context.Context, req *v1.SetPolicyRequest, opts ...grpc.CallOption) (*v1.SetPolicyResponse, error) {
m.CallCount["SetPolicy"]++
m.LastRequest = req
if m.SetPolicyError != nil {
return nil, m.SetPolicyError
}
return m.SetPolicyResponse, nil
}
func (m *MockHeadscaleServiceClient) DebugCreateNode(ctx context.Context, req *v1.DebugCreateNodeRequest, opts ...grpc.CallOption) (*v1.DebugCreateNodeResponse, error) {
m.CallCount["DebugCreateNode"]++
m.LastRequest = req
if m.DebugCreateNodeError != nil {
return nil, m.DebugCreateNodeError
}
return m.DebugCreateNodeResponse, nil
}
// MockClientWrapper wraps MockHeadscaleServiceClient for testing
type MockClientWrapper struct {
MockClient *MockHeadscaleServiceClient
ctx context.Context
cancel context.CancelFunc
}
// NewMockClientWrapperOld creates a new mock client wrapper for testing (legacy)
func NewMockClientWrapperOld() *MockClientWrapper {
ctx, cancel := context.WithCancel(context.Background())
return &MockClientWrapper{
MockClient: NewMockHeadscaleServiceClient(),
ctx: ctx,
cancel: cancel,
}
}
// Close implements the ClientWrapper interface
func (m *MockClientWrapper) Close() {
if m.cancel != nil {
m.cancel()
}
}
// CLI test execution helpers
// ExecuteCommand executes a command and captures its output
func ExecuteCommand(cmd *cobra.Command, args []string) (string, error) {
return ExecuteCommandWithInput(cmd, args, "")
}
// ExecuteCommandWithInput executes a command with input and captures its output
func ExecuteCommandWithInput(cmd *cobra.Command, args []string, input string) (string, error) {
// Create buffers for capturing output
oldStdout := os.Stdout
oldStderr := os.Stderr
oldStdin := os.Stdin
// Create pipes for capturing output
r, w, _ := os.Pipe()
os.Stdout = w
os.Stderr = w
// Set up input if provided
if input != "" {
tmpfile, err := os.CreateTemp("", "test-input")
if err != nil {
return "", err
}
defer os.Remove(tmpfile.Name())
tmpfile.WriteString(input)
tmpfile.Seek(0, 0)
os.Stdin = tmpfile
}
// Capture output
var buf bytes.Buffer
done := make(chan bool)
go func() {
io.Copy(&buf, r)
done <- true
}()
// Execute command
cmd.SetArgs(args)
err := cmd.Execute()
// Restore original streams
w.Close()
os.Stdout = oldStdout
os.Stderr = oldStderr
os.Stdin = oldStdin
// Wait for output capture to complete
<-done
return buf.String(), err
}
// AssertCommandSuccess executes a command and asserts it succeeds
func AssertCommandSuccess(t interface{}, cmd *cobra.Command, args []string) {
output, err := ExecuteCommand(cmd, args)
if err != nil {
t.(interface{ Fatalf(string, ...interface{}) }).Fatalf("Command failed: %v\nOutput: %s", err, output)
}
}
// AssertCommandError executes a command and asserts it fails with expected error
func AssertCommandError(t interface{}, cmd *cobra.Command, args []string, expectedError string) {
output, err := ExecuteCommand(cmd, args)
if err == nil {
t.(interface{ Fatalf(string, ...interface{}) }).Fatalf("Expected command to fail but it succeeded\nOutput: %s", output)
}
if !strings.Contains(err.Error(), expectedError) {
t.(interface{ Fatalf(string, ...interface{}) }).Fatalf("Expected error to contain '%s' but got: %v", expectedError, err)
}
}
// Output format testing
// ValidateJSONOutput validates that output is valid JSON and matches expected structure
func ValidateJSONOutput(t interface{}, output string, expected interface{}) {
var actual interface{}
err := json.Unmarshal([]byte(output), &actual)
if err != nil {
t.(interface{ Fatalf(string, ...interface{}) }).Fatalf("Invalid JSON output: %v\nOutput: %s", err, output)
}
// Convert expected to JSON and back for comparison
expectedJSON, err := json.Marshal(expected)
if err != nil {
t.(interface{ Fatalf(string, ...interface{}) }).Fatalf("Failed to marshal expected JSON: %v", err)
}
var expectedParsed interface{}
err = json.Unmarshal(expectedJSON, &expectedParsed)
if err != nil {
t.(interface{ Fatalf(string, ...interface{}) }).Fatalf("Failed to unmarshal expected JSON: %v", err)
}
// Compare structures (basic comparison)
actualJSON, _ := json.Marshal(actual)
if string(actualJSON) != string(expectedJSON) {
t.(interface{ Fatalf(string, ...interface{}) }).Fatalf("JSON output mismatch.\nExpected: %s\nActual: %s", expectedJSON, actualJSON)
}
}
// ValidateYAMLOutput validates that output is valid YAML and matches expected structure
func ValidateYAMLOutput(t interface{}, output string, expected interface{}) {
var actual interface{}
err := yaml.Unmarshal([]byte(output), &actual)
if err != nil {
t.(interface{ Fatalf(string, ...interface{}) }).Fatalf("Invalid YAML output: %v\nOutput: %s", err, output)
}
// Convert expected to YAML for comparison
expectedYAML, err := yaml.Marshal(expected)
if err != nil {
t.(interface{ Fatalf(string, ...interface{}) }).Fatalf("Failed to marshal expected YAML: %v", err)
}
actualYAML, err := yaml.Marshal(actual)
if err != nil {
t.(interface{ Fatalf(string, ...interface{}) }).Fatalf("Failed to marshal actual YAML: %v", err)
}
if string(actualYAML) != string(expectedYAML) {
t.(interface{ Fatalf(string, ...interface{}) }).Fatalf("YAML output mismatch.\nExpected: %s\nActual: %s", expectedYAML, actualYAML)
}
}
// ValidateTableOutput validates that output contains expected table headers
func ValidateTableOutput(t interface{}, output string, expectedHeaders []string) {
for _, header := range expectedHeaders {
if !strings.Contains(output, header) {
t.(interface{ Fatalf(string, ...interface{}) }).Fatalf("Table output missing expected header '%s'\nOutput: %s", header, output)
}
}
}
// Test fixtures and data helpers
// NewTestUser creates a test user with the given ID and name
func NewTestUser(id uint64, name string) *v1.User {
return &v1.User{
Id: id,
Name: name,
Email: fmt.Sprintf("%s@example.com", name),
CreatedAt: timestamppb.Now(),
}
}
// NewTestNode creates a test node with the given ID, name, and user
func NewTestNode(id uint64, name string, user *v1.User) *v1.Node {
return &v1.Node{
Id: id,
Name: name,
GivenName: fmt.Sprintf("%s-device", name),
User: user,
IpAddresses: []string{fmt.Sprintf("192.168.1.%d", id)},
Online: true,
ValidTags: []string{},
CreatedAt: timestamppb.Now(),
LastSeen: timestamppb.Now(),
}
}
// NewTestApiKey creates a test API key with the given ID and prefix
func NewTestApiKey(id uint64, prefix string) *v1.ApiKey {
return &v1.ApiKey{
Id: id,
Prefix: prefix,
CreatedAt: timestamppb.Now(),
}
}
// NewTestPreAuthKey creates a test preauth key with the given ID and user ID
func NewTestPreAuthKey(id uint64, userID uint64) *v1.PreAuthKey {
return &v1.PreAuthKey{
Id: id,
Key: fmt.Sprintf("preauthkey-%d-abcdef", id),
User: NewTestUser(userID, fmt.Sprintf("user%d", userID)),
Reusable: false,
Ephemeral: false,
Used: false,
CreatedAt: timestamppb.Now(),
}
}
// CreateTestCommand creates a basic test command with common flags
func CreateTestCommand(name string) *cobra.Command {
cmd := &cobra.Command{
Use: name,
Short: fmt.Sprintf("Test %s command", name),
Run: func(cmd *cobra.Command, args []string) {
// Default test implementation
},
}
// Add common flags
AddOutputFlag(cmd)
AddForceFlag(cmd)
return cmd
}
// Test utilities for command validation
// ValidateCommandStructure validates that a command has required properties
func ValidateCommandStructure(t interface{}, cmd *cobra.Command, expectedUse string, expectedShort string) {
if cmd.Use != expectedUse {
t.(interface{ Fatalf(string, ...interface{}) }).Fatalf("Expected Use '%s', got '%s'", expectedUse, cmd.Use)
}
if cmd.Short != expectedShort {
t.(interface{ Fatalf(string, ...interface{}) }).Fatalf("Expected Short '%s', got '%s'", expectedShort, cmd.Short)
}
if cmd.Run == nil && cmd.RunE == nil {
t.(interface{ Fatalf(string, ...interface{}) }).Fatalf("Command must have a Run or RunE function")
}
}
// ValidateCommandFlags validates that a command has expected flags
func ValidateCommandFlags(t interface{}, cmd *cobra.Command, expectedFlags []string) {
for _, flagName := range expectedFlags {
flag := cmd.Flags().Lookup(flagName)
if flag == nil {
t.(interface{ Fatalf(string, ...interface{}) }).Fatalf("Expected flag '%s' not found", flagName)
}
}
}
// Helper to check if command has proper help text
func ValidateCommandHelp(t interface{}, cmd *cobra.Command) {
if cmd.Short == "" {
t.(interface{ Fatalf(string, ...interface{}) }).Fatalf("Command must have Short description")
}
if cmd.Long == "" {
// Long description is optional but recommended
}
if cmd.Example == "" {
// Examples are optional but recommended for better UX
}
}

View File

@ -1,521 +0,0 @@
package cli
import (
"context"
"encoding/json"
"fmt"
"testing"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func TestNewMockHeadscaleServiceClient(t *testing.T) {
mock := NewMockHeadscaleServiceClient()
// Verify mock is properly initialized
assert.NotNil(t, mock)
assert.NotNil(t, mock.CallCount)
assert.Equal(t, 0, len(mock.CallCount))
// Verify default responses are set
assert.NotNil(t, mock.ListUsersResponse)
assert.NotNil(t, mock.CreateUserResponse)
assert.NotNil(t, mock.ListNodesResponse)
assert.NotNil(t, mock.CreateApiKeyResponse)
}
func TestMockClient_ListUsers(t *testing.T) {
mock := NewMockHeadscaleServiceClient()
// Test successful response
req := &v1.ListUsersRequest{}
resp, err := mock.ListUsers(context.Background(), req)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, 1, mock.CallCount["ListUsers"])
assert.Equal(t, req, mock.LastRequest)
}
func TestMockClient_ListUsersError(t *testing.T) {
mock := NewMockHeadscaleServiceClient()
// Configure error response
expectedError := status.Error(codes.Internal, "test error")
mock.ListUsersError = expectedError
req := &v1.ListUsersRequest{}
resp, err := mock.ListUsers(context.Background(), req)
assert.Error(t, err)
assert.Nil(t, resp)
assert.Equal(t, expectedError, err)
assert.Equal(t, 1, mock.CallCount["ListUsers"])
}
func TestMockClient_CreateUser(t *testing.T) {
mock := NewMockHeadscaleServiceClient()
req := &v1.CreateUserRequest{Name: "testuser"}
resp, err := mock.CreateUser(context.Background(), req)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.NotNil(t, resp.User)
assert.Equal(t, 1, mock.CallCount["CreateUser"])
assert.Equal(t, req, mock.LastRequest)
}
func TestMockClient_ListNodes(t *testing.T) {
mock := NewMockHeadscaleServiceClient()
req := &v1.ListNodesRequest{}
resp, err := mock.ListNodes(context.Background(), req)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, 1, mock.CallCount["ListNodes"])
assert.Equal(t, req, mock.LastRequest)
}
func TestMockClient_CreateApiKey(t *testing.T) {
mock := NewMockHeadscaleServiceClient()
req := &v1.CreateApiKeyRequest{}
resp, err := mock.CreateApiKey(context.Background(), req)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.NotNil(t, resp.ApiKey)
assert.Equal(t, 1, mock.CallCount["CreateApiKey"])
}
func TestMockClient_CallTracking(t *testing.T) {
mock := NewMockHeadscaleServiceClient()
// Make multiple calls to different methods
mock.ListUsers(context.Background(), &v1.ListUsersRequest{})
mock.ListUsers(context.Background(), &v1.ListUsersRequest{})
mock.ListNodes(context.Background(), &v1.ListNodesRequest{})
// Verify call counts
assert.Equal(t, 2, mock.CallCount["ListUsers"])
assert.Equal(t, 1, mock.CallCount["ListNodes"])
assert.Equal(t, 0, mock.CallCount["CreateUser"]) // Not called
}
func TestNewMockClientWrapper(t *testing.T) {
wrapper := NewMockClientWrapperOld()
assert.NotNil(t, wrapper)
assert.NotNil(t, wrapper.MockClient)
assert.NotNil(t, wrapper.ctx)
assert.NotNil(t, wrapper.cancel)
}
func TestMockClientWrapper_Close(t *testing.T) {
wrapper := NewMockClientWrapperOld()
// Test that Close doesn't panic
wrapper.Close()
// Verify context is cancelled
select {
case <-wrapper.ctx.Done():
// Context was cancelled - good
default:
t.Error("Context should be cancelled after Close()")
}
}
func TestExecuteCommand(t *testing.T) {
// Create a simple test command that doesn't call external dependencies
cmd := CreateTestCommand("test")
cmd.Run = func(cmd *cobra.Command, args []string) {
fmt.Print("test output")
}
output, err := ExecuteCommand(cmd, []string{})
assert.NoError(t, err)
assert.Contains(t, output, "test output")
}
func TestExecuteCommandWithInput(t *testing.T) {
// Create a command that reads input
cmd := CreateTestCommand("test")
cmd.Run = func(cmd *cobra.Command, args []string) {
fmt.Print("command executed")
}
output, err := ExecuteCommandWithInput(cmd, []string{}, "test input\n")
assert.NoError(t, err)
assert.Contains(t, output, "command executed")
}
func TestExecuteCommandError(t *testing.T) {
// Create a command that returns an error
cmd := CreateTestCommand("test")
cmd.RunE = func(cmd *cobra.Command, args []string) error {
return fmt.Errorf("test error")
}
cmd.Run = nil // Clear the default Run function
output, err := ExecuteCommand(cmd, []string{})
assert.Error(t, err)
assert.Contains(t, err.Error(), "test error")
assert.Equal(t, "", output) // No output on error
}
func TestValidateJSONOutput(t *testing.T) {
// Test valid JSON
jsonOutput := `{"name": "test", "id": 123}`
expected := map[string]interface{}{
"name": "test",
"id": float64(123), // JSON numbers become float64
}
// This should not panic or fail
ValidateJSONOutput(t, jsonOutput, expected)
}
func TestValidateJSONOutput_Invalid(t *testing.T) {
// Test with invalid JSON - should cause test failure
// We can't easily test this without a custom test runner,
// but we can verify the function exists
assert.NotNil(t, ValidateJSONOutput)
}
func TestValidateYAMLOutput(t *testing.T) {
// Test valid YAML
yamlOutput := `name: test
id: 123`
expected := map[string]interface{}{
"name": "test",
"id": 123,
}
// This should not panic or fail
ValidateYAMLOutput(t, yamlOutput, expected)
}
func TestValidateTableOutput(t *testing.T) {
// Test table output validation
tableOutput := `ID Name Status
1 testnode online
2 testnode2 offline`
expectedHeaders := []string{"ID", "Name", "Status"}
// This should not panic or fail
ValidateTableOutput(t, tableOutput, expectedHeaders)
}
func TestNewTestUser(t *testing.T) {
user := NewTestUser(123, "testuser")
assert.NotNil(t, user)
assert.Equal(t, uint64(123), user.Id)
assert.Equal(t, "testuser", user.Name)
assert.Equal(t, "testuser@example.com", user.Email)
assert.NotNil(t, user.CreatedAt)
}
func TestNewTestNode(t *testing.T) {
user := NewTestUser(1, "testuser")
node := NewTestNode(456, "testnode", user)
assert.NotNil(t, node)
assert.Equal(t, uint64(456), node.Id)
assert.Equal(t, "testnode", node.Name)
assert.Equal(t, "testnode-device", node.GivenName)
assert.Equal(t, user, node.User)
assert.Equal(t, []string{"192.168.1.456"}, node.IpAddresses)
assert.True(t, node.Online)
assert.NotNil(t, node.CreatedAt)
assert.NotNil(t, node.LastSeen)
}
func TestNewTestApiKey(t *testing.T) {
apiKey := NewTestApiKey(789, "testprefix")
assert.NotNil(t, apiKey)
assert.Equal(t, uint64(789), apiKey.Id)
assert.Equal(t, "testprefix", apiKey.Prefix)
assert.NotNil(t, apiKey.CreatedAt)
}
func TestNewTestPreAuthKey(t *testing.T) {
preAuthKey := NewTestPreAuthKey(101, 202)
assert.NotNil(t, preAuthKey)
assert.Equal(t, uint64(101), preAuthKey.Id)
assert.Equal(t, "preauthkey-101-abcdef", preAuthKey.Key)
assert.NotNil(t, preAuthKey.User)
assert.Equal(t, uint64(202), preAuthKey.User.Id)
assert.False(t, preAuthKey.Reusable)
assert.False(t, preAuthKey.Ephemeral)
assert.False(t, preAuthKey.Used)
assert.NotNil(t, preAuthKey.CreatedAt)
}
func TestCreateTestCommand(t *testing.T) {
cmd := CreateTestCommand("testcmd")
assert.NotNil(t, cmd)
assert.Equal(t, "testcmd", cmd.Use)
assert.Equal(t, "Test testcmd command", cmd.Short)
assert.NotNil(t, cmd.Run)
// Verify common flags are added
assert.NotNil(t, cmd.Flags().Lookup("output"))
assert.NotNil(t, cmd.Flags().Lookup("force"))
}
func TestValidateCommandStructure(t *testing.T) {
cmd := &cobra.Command{
Use: "test",
Short: "Test command",
Run: func(cmd *cobra.Command, args []string) {},
}
// This should not panic or fail
ValidateCommandStructure(t, cmd, "test", "Test command")
}
func TestValidateCommandFlags(t *testing.T) {
cmd := CreateTestCommand("test")
// This should not panic or fail - output and force flags should exist
ValidateCommandFlags(t, cmd, []string{"output", "force"})
}
func TestValidateCommandHelp(t *testing.T) {
cmd := &cobra.Command{
Use: "test",
Short: "Test command",
Long: "This is a test command",
Run: func(cmd *cobra.Command, args []string) {},
}
// This should not panic or fail
ValidateCommandHelp(t, cmd)
}
func TestMockClient_AllOperationsCovered(t *testing.T) {
// Test that all required gRPC operations are implemented in the mock
mock := NewMockHeadscaleServiceClient()
ctx := context.Background()
// Test all user operations
_, err := mock.ListUsers(ctx, &v1.ListUsersRequest{})
assert.NoError(t, err)
_, err = mock.CreateUser(ctx, &v1.CreateUserRequest{})
assert.NoError(t, err)
_, err = mock.RenameUser(ctx, &v1.RenameUserRequest{})
assert.NoError(t, err)
_, err = mock.DeleteUser(ctx, &v1.DeleteUserRequest{})
assert.NoError(t, err)
// Test all node operations
_, err = mock.ListNodes(ctx, &v1.ListNodesRequest{})
assert.NoError(t, err)
_, err = mock.RegisterNode(ctx, &v1.RegisterNodeRequest{})
assert.NoError(t, err)
_, err = mock.DeleteNode(ctx, &v1.DeleteNodeRequest{})
assert.NoError(t, err)
_, err = mock.ExpireNode(ctx, &v1.ExpireNodeRequest{})
assert.NoError(t, err)
_, err = mock.RenameNode(ctx, &v1.RenameNodeRequest{})
assert.NoError(t, err)
_, err = mock.MoveNode(ctx, &v1.MoveNodeRequest{})
assert.NoError(t, err)
_, err = mock.GetNode(ctx, &v1.GetNodeRequest{})
assert.NoError(t, err)
_, err = mock.SetTags(ctx, &v1.SetTagsRequest{})
assert.NoError(t, err)
_, err = mock.SetApprovedRoutes(ctx, &v1.SetApprovedRoutesRequest{})
assert.NoError(t, err)
_, err = mock.BackfillNodeIPs(ctx, &v1.BackfillNodeIPsRequest{})
assert.NoError(t, err)
// Test all API key operations
_, err = mock.ListApiKeys(ctx, &v1.ListApiKeysRequest{})
assert.NoError(t, err)
_, err = mock.CreateApiKey(ctx, &v1.CreateApiKeyRequest{})
assert.NoError(t, err)
_, err = mock.ExpireApiKey(ctx, &v1.ExpireApiKeyRequest{})
assert.NoError(t, err)
_, err = mock.DeleteApiKey(ctx, &v1.DeleteApiKeyRequest{})
assert.NoError(t, err)
// Test all preauth key operations
_, err = mock.ListPreAuthKeys(ctx, &v1.ListPreAuthKeysRequest{})
assert.NoError(t, err)
_, err = mock.CreatePreAuthKey(ctx, &v1.CreatePreAuthKeyRequest{})
assert.NoError(t, err)
_, err = mock.ExpirePreAuthKey(ctx, &v1.ExpirePreAuthKeyRequest{})
assert.NoError(t, err)
// Test policy operations
_, err = mock.GetPolicy(ctx, &v1.GetPolicyRequest{})
assert.NoError(t, err)
_, err = mock.SetPolicy(ctx, &v1.SetPolicyRequest{})
assert.NoError(t, err)
// Test debug operations
_, err = mock.DebugCreateNode(ctx, &v1.DebugCreateNodeRequest{})
assert.NoError(t, err)
// Verify all operations were called
expectedOperations := []string{
"ListUsers", "CreateUser", "RenameUser", "DeleteUser",
"ListNodes", "RegisterNode", "DeleteNode", "ExpireNode", "RenameNode", "MoveNode", "GetNode", "SetTags", "SetApprovedRoutes", "BackfillNodeIPs",
"ListApiKeys", "CreateApiKey", "ExpireApiKey", "DeleteApiKey",
"ListPreAuthKeys", "CreatePreAuthKey", "ExpirePreAuthKey",
"GetPolicy", "SetPolicy",
"DebugCreateNode",
}
for _, op := range expectedOperations {
assert.Equal(t, 1, mock.CallCount[op], "Operation %s should have been called exactly once", op)
}
}
func TestMockIntegrationWithExistingInfrastructure(t *testing.T) {
// Test that mock client integrates well with existing CLI infrastructure
// Create a test command that uses our flag infrastructure
cmd := CreateTestCommand("integration-test")
AddUserFlag(cmd)
AddIdentifierFlag(cmd, "identifier", "Test identifier")
// Set up flags
err := cmd.Flags().Set("user", "testuser")
require.NoError(t, err)
err = cmd.Flags().Set("identifier", "123")
require.NoError(t, err)
err = cmd.Flags().Set("output", "json")
require.NoError(t, err)
// Test that flag getters work
user, err := GetUser(cmd)
assert.NoError(t, err)
assert.Equal(t, "testuser", user)
identifier, err := GetIdentifier(cmd, "identifier")
assert.NoError(t, err)
assert.Equal(t, uint64(123), identifier)
output := GetOutputFormat(cmd)
assert.Equal(t, "json", output)
// Test that output manager works
om := NewOutputManager(cmd)
assert.True(t, om.HasMachineOutput())
// Test that mock client can be used with our patterns
mock := NewMockClientWrapperOld()
defer mock.Close()
// Verify mock client has the expected structure
assert.NotNil(t, mock.MockClient)
assert.NotNil(t, mock.ctx)
}
func TestTestingInfrastructure_CompleteWorkflow(t *testing.T) {
// Test a complete workflow using the testing infrastructure
// 1. Create a mock client
mock := NewMockClientWrapperOld()
defer mock.Close()
// 2. Configure mock responses
testUser := NewTestUser(1, "testuser")
testNode := NewTestNode(1, "testnode", testUser)
mock.MockClient.ListUsersResponse = &v1.ListUsersResponse{
Users: []*v1.User{testUser},
}
mock.MockClient.ListNodesResponse = &v1.ListNodesResponse{
Nodes: []*v1.Node{testNode},
}
// 3. Test that mock responds correctly
usersResp, err := mock.MockClient.ListUsers(context.Background(), &v1.ListUsersRequest{})
assert.NoError(t, err)
assert.Len(t, usersResp.Users, 1)
assert.Equal(t, "testuser", usersResp.Users[0].Name)
nodesResp, err := mock.MockClient.ListNodes(context.Background(), &v1.ListNodesRequest{})
assert.NoError(t, err)
assert.Len(t, nodesResp.Nodes, 1)
assert.Equal(t, "testnode", nodesResp.Nodes[0].Name)
// 4. Verify call tracking
assert.Equal(t, 1, mock.MockClient.CallCount["ListUsers"])
assert.Equal(t, 1, mock.MockClient.CallCount["ListNodes"])
// 5. Test JSON serialization (important for CLI output)
userJSON, err := json.Marshal(testUser)
assert.NoError(t, err)
assert.Contains(t, string(userJSON), "testuser")
nodeJSON, err := json.Marshal(testNode)
assert.NoError(t, err)
assert.Contains(t, string(nodeJSON), "testnode")
}
func TestErrorScenarios(t *testing.T) {
// Test various error scenarios with the mock
mock := NewMockHeadscaleServiceClient()
// Test network error
mock.ListUsersError = status.Error(codes.Unavailable, "connection refused")
_, err := mock.ListUsers(context.Background(), &v1.ListUsersRequest{})
assert.Error(t, err)
assert.Contains(t, err.Error(), "connection refused")
// Test not found error
mock.GetNodeError = status.Error(codes.NotFound, "node not found")
_, err = mock.GetNode(context.Background(), &v1.GetNodeRequest{})
assert.Error(t, err)
assert.Contains(t, err.Error(), "node not found")
// Test permission error
mock.DeleteUserError = status.Error(codes.PermissionDenied, "insufficient permissions")
_, err = mock.DeleteUser(context.Background(), &v1.DeleteUserRequest{})
assert.Error(t, err)
assert.Contains(t, err.Error(), "insufficient permissions")
}

View File

@ -8,7 +8,6 @@ import (
survey "github.com/AlecAivazis/survey/v2"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/pterm/pterm"
"github.com/rs/zerolog/log"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
@ -45,6 +44,7 @@ func init() {
userCmd.AddCommand(listUsersCmd)
usernameAndIDFlag(listUsersCmd)
listUsersCmd.Flags().StringP("email", "e", "", "Email")
AddColumnsFlag(listUsersCmd, "id,name,username,email,created")
userCmd.AddCommand(destroyUserCmd)
usernameAndIDFlag(destroyUserCmd)
userCmd.AddCommand(renameUserCmd)
@ -230,31 +230,35 @@ var listUsersCmd = &cobra.Command{
)
}
if output != "" {
SuccessOutput(response.GetUsers(), "", output)
// Convert users to []interface{} for generic table handling
users := make([]interface{}, len(response.GetUsers()))
for i, user := range response.GetUsers() {
users[i] = user
}
tableData := pterm.TableData{{"ID", "Name", "Username", "Email", "Created"}}
for _, user := range response.GetUsers() {
tableData = append(
tableData,
[]string{
strconv.FormatUint(user.GetId(), 10),
user.GetDisplayName(),
user.GetName(),
user.GetEmail(),
user.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"),
},
)
}
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil {
ErrorOutput(
err,
fmt.Sprintf("Failed to render pterm table: %s", err),
output,
)
}
// Use the new table system with column filtering support
ListOutput(cmd, users, func(tr *TableRenderer) {
tr.AddColumn("id", "ID", func(item interface{}) string {
user := item.(*v1.User)
return strconv.FormatUint(user.GetId(), 10)
}).
AddColumn("name", "Name", func(item interface{}) string {
user := item.(*v1.User)
return user.GetDisplayName()
}).
AddColumn("username", "Username", func(item interface{}) string {
user := item.(*v1.User)
return user.GetName()
}).
AddColumn("email", "Email", func(item interface{}) string {
user := item.(*v1.User)
return user.GetEmail()
}).
AddColumn("created", "Created", func(item interface{}) string {
user := item.(*v1.User)
return user.GetCreatedAt().AsTime().Format(HeadscaleDateTimeFormat)
})
})
},
}

View File

@ -1,331 +0,0 @@
package cli
import (
"fmt"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/spf13/cobra"
)
// Refactored user commands using the new CLI infrastructure
// This demonstrates the improved patterns with significantly less code
// createUserRefactored demonstrates the new create user command
func createUserRefactored() *cobra.Command {
cmd := &cobra.Command{
Use: "create NAME",
Short: "Creates a new user",
Aliases: []string{"c", "new"},
Args: ValidateExactArgs(1, "create <username>"),
Run: StandardCreateCommand(
createUserLogic,
"User created successfully",
),
}
// Use standardized flag helpers
cmd.Flags().StringP("display-name", "d", "", "Display name")
cmd.Flags().StringP("email", "e", "", "Email address")
cmd.Flags().StringP("picture-url", "p", "", "Profile picture URL")
AddOutputFlag(cmd)
return cmd
}
// createUserLogic implements the business logic for creating a user
func createUserLogic(client *ClientWrapper, cmd *cobra.Command, args []string) (interface{}, error) {
userName := args[0]
// Validate username using our validation infrastructure
if err := ValidateUserName(userName); err != nil {
return nil, err
}
request := &v1.CreateUserRequest{Name: userName}
// Get optional display name
if displayName, _ := cmd.Flags().GetString("display-name"); displayName != "" {
request.DisplayName = displayName
}
// Get and validate email
if email, _ := cmd.Flags().GetString("email"); email != "" {
if err := ValidateEmail(email); err != nil {
return nil, fmt.Errorf("invalid email: %w", err)
}
request.Email = email
}
// Get and validate picture URL
if pictureURL, _ := cmd.Flags().GetString("picture-url"); pictureURL != "" {
if err := ValidateURL(pictureURL); err != nil {
return nil, fmt.Errorf("invalid picture URL: %w", err)
}
request.PictureUrl = pictureURL
}
// Check for duplicate users
if err := ValidateNoDuplicateUsers(client, userName, 0); err != nil {
return nil, err
}
response, err := client.CreateUser(cmd, request)
if err != nil {
return nil, err
}
return response.GetUser(), nil
}
// listUsersRefactored demonstrates the new list users command
func listUsersRefactored() *cobra.Command {
cmd := &cobra.Command{
Use: "list",
Short: "List all users",
Aliases: []string{"ls", "show"},
Run: StandardListCommand(
listUsersLogic,
setupUsersTableRefactored,
),
}
// Use standardized flag helpers
AddIdentifierFlag(cmd, "identifier", "Filter by user ID")
cmd.Flags().StringP("name", "n", "", "Filter by username")
cmd.Flags().StringP("email", "e", "", "Filter by email")
AddOutputFlag(cmd)
return cmd
}
// listUsersLogic implements the business logic for listing users
func listUsersLogic(client *ClientWrapper, cmd *cobra.Command) ([]interface{}, error) {
request := &v1.ListUsersRequest{}
// Handle filtering
if id, _ := GetIdentifier(cmd, "identifier"); id > 0 {
request.Id = id
} else if name, _ := cmd.Flags().GetString("name"); name != "" {
request.Name = name
} else if email, _ := cmd.Flags().GetString("email"); email != "" {
if err := ValidateEmail(email); err != nil {
return nil, fmt.Errorf("invalid email filter: %w", err)
}
request.Email = email
}
response, err := client.ListUsers(cmd, request)
if err != nil {
return nil, err
}
// Convert to []interface{} for table renderer
users := make([]interface{}, len(response.GetUsers()))
for i, user := range response.GetUsers() {
users[i] = user
}
return users, nil
}
// setupUsersTableRefactored configures the table columns for user display
func setupUsersTableRefactored(tr *TableRenderer) {
tr.AddColumn("ID", func(item interface{}) string {
if user, ok := item.(*v1.User); ok {
return fmt.Sprintf("%d", user.GetId())
}
return ""
}).AddColumn("Name", func(item interface{}) string {
if user, ok := item.(*v1.User); ok {
return user.GetName()
}
return ""
}).AddColumn("Display Name", func(item interface{}) string {
if user, ok := item.(*v1.User); ok {
return user.GetDisplayName()
}
return ""
}).AddColumn("Email", func(item interface{}) string {
if user, ok := item.(*v1.User); ok {
return user.GetEmail()
}
return ""
}).AddColumn("Created", func(item interface{}) string {
if user, ok := item.(*v1.User); ok {
return FormatTime(user.GetCreatedAt().AsTime())
}
return ""
})
}
// deleteUserRefactored demonstrates the new delete user command
func deleteUserRefactored() *cobra.Command {
cmd := &cobra.Command{
Use: "delete",
Short: "Delete a user",
Aliases: []string{"remove", "rm", "destroy"},
Args: ValidateRequiredArgs(1, "delete <username|id>"),
Run: StandardDeleteCommand(
getUserLogic,
deleteUserLogic,
"user",
),
}
AddForceFlag(cmd)
AddOutputFlag(cmd)
return cmd
}
// getUserLogic retrieves a user for delete confirmation
// Note: This assumes the user identifier is passed via flag or context
func getUserLogic(client *ClientWrapper, cmd *cobra.Command) (interface{}, error) {
// In a real implementation, we'd need to get the user identifier from somewhere
// For now, let's use a default for testing
userIdentifier := "testuser" // This would come from command args in real usage
return ResolveUserByNameOrID(client, cmd, userIdentifier)
}
// deleteUserLogic implements the business logic for deleting a user
func deleteUserLogic(client *ClientWrapper, cmd *cobra.Command) (interface{}, error) {
// In a real implementation, this would get the user identifier from command args
// For now, let's use a default for testing
userIdentifier := "testuser" // This would come from command args in real usage
user, err := ResolveUserByNameOrID(client, cmd, userIdentifier)
if err != nil {
return nil, err
}
request := &v1.DeleteUserRequest{Id: user.GetId()}
response, err := client.DeleteUser(cmd, request)
if err != nil {
return nil, err
}
return response, nil
}
// renameUserRefactored demonstrates the new rename user command
func renameUserRefactored() *cobra.Command {
cmd := &cobra.Command{
Use: "rename <current-name|id> <new-name>",
Short: "Rename a user",
Aliases: []string{"mv"},
Args: ValidateExactArgs(2, "rename <current-name|id> <new-name>"),
Run: StandardUpdateCommand(
renameUserLogic,
"User renamed successfully",
),
}
AddOutputFlag(cmd)
return cmd
}
// renameUserLogic implements the business logic for renaming a user
func renameUserLogic(client *ClientWrapper, cmd *cobra.Command, args []string) (interface{}, error) {
currentIdentifier := args[0]
newName := args[1]
// Validate new name
if err := ValidateUserName(newName); err != nil {
return nil, fmt.Errorf("invalid new username: %w", err)
}
// Resolve current user
user, err := ResolveUserByNameOrID(client, cmd, currentIdentifier)
if err != nil {
return nil, err
}
// Check that new name isn't taken
if err := ValidateNoDuplicateUsers(client, newName, user.GetId()); err != nil {
return nil, err
}
request := &v1.RenameUserRequest{
OldId: user.GetId(),
NewName: newName,
}
response, err := client.RenameUser(cmd, request)
if err != nil {
return nil, err
}
return response.GetUser(), nil
}
// createRefactoredUserCommand creates the refactored user command hierarchy
func createRefactoredUserCommand() *cobra.Command {
cmd := &cobra.Command{
Use: "users-refactored",
Short: "Manage users using new infrastructure (demo)",
Aliases: []string{"ur"},
Hidden: true, // Hidden for demo purposes
}
// Add subcommands using the new infrastructure
cmd.AddCommand(createUserRefactored())
cmd.AddCommand(listUsersRefactored())
cmd.AddCommand(deleteUserRefactored())
cmd.AddCommand(renameUserRefactored())
return cmd
}
// init function to register the refactored command for demonstration
func init() {
// Add the refactored command for comparison
rootCmd.AddCommand(createRefactoredUserCommand())
}
/*
Benefits of the refactored approach:
1. **Significantly Less Code**:
- Original createUserCmd: ~45 lines of implementation
- Refactored createUserFunc: ~25 lines of business logic only
- ~50% reduction in code per command
2. **Better Error Handling**:
- Consistent validation with meaningful error messages
- Centralized error handling through patterns
- Type-safe operations throughout
3. **Improved Maintainability**:
- Business logic separated from command setup
- Reusable validation functions
- Consistent flag handling across commands
4. **Enhanced Testing**:
- Each function can be unit tested in isolation
- Mock client integration for reliable testing
- Validation logic is independently testable
5. **Standardized Patterns**:
- All CRUD operations follow the same structure
- Consistent output formatting (JSON/YAML/table)
- Uniform confirmation and error handling
6. **Type Safety**:
- Proper ClientWrapper usage throughout
- No interface{} or any types
- Compile-time type checking
7. **Better User Experience**:
- More descriptive error messages
- Consistent argument validation
- Improved help text and usage
8. **Code Reuse**:
- Validation functions used across multiple commands
- Table setup functions can be shared
- Flag helpers ensure consistency
The refactored commands provide the same functionality as the original
commands but with better structure, testing capability, and maintainability.
*/

View File

@ -1,278 +0,0 @@
package cli
import (
"fmt"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/spf13/cobra"
)
// Example of how user commands could be refactored using our new infrastructure
// createUserWithNewInfrastructure demonstrates the refactored create user command
func createUserWithNewInfrastructure() *cobra.Command {
cmd := &cobra.Command{
Use: "create NAME",
Short: "Creates a new user",
Aliases: []string{"c", "new"},
Args: ValidateExactArgs(1, "create <username>"),
Run: StandardCreateCommand(
createUserFunc,
"User created successfully",
),
}
// Use standardized flag helpers
AddNameFlag(cmd, "Display name for the user")
cmd.Flags().StringP("email", "e", "", "Email address")
cmd.Flags().StringP("picture-url", "p", "", "Profile picture URL")
AddOutputFlag(cmd)
return cmd
}
// createUserFunc implements the business logic for creating a user
func createUserFunc(client *ClientWrapper, cmd *cobra.Command, args []string) (interface{}, error) {
userName := args[0]
// Validate username using our validation infrastructure
if err := ValidateUserName(userName); err != nil {
return nil, err
}
request := &v1.CreateUserRequest{Name: userName}
// Get optional display name
if displayName, _ := cmd.Flags().GetString("name"); displayName != "" {
request.DisplayName = displayName
}
// Get and validate email
if email, _ := cmd.Flags().GetString("email"); email != "" {
if err := ValidateEmail(email); err != nil {
return nil, fmt.Errorf("invalid email: %w", err)
}
request.Email = email
}
// Get and validate picture URL
if pictureURL, _ := cmd.Flags().GetString("picture-url"); pictureURL != "" {
if err := ValidateURL(pictureURL); err != nil {
return nil, fmt.Errorf("invalid picture URL: %w", err)
}
request.PictureUrl = pictureURL
}
// Check for duplicate users
if err := ValidateNoDuplicateUsers(client, userName, 0); err != nil {
return nil, err
}
response, err := client.CreateUser(cmd, request)
if err != nil {
return nil, err
}
return response.GetUser(), nil
}
// listUsersWithNewInfrastructure demonstrates the refactored list users command
func listUsersWithNewInfrastructure() *cobra.Command {
cmd := &cobra.Command{
Use: "list",
Short: "List all users",
Aliases: []string{"ls", "show"},
Run: StandardListCommand(
listUsersFunc,
setupUsersTable,
),
}
// Use standardized flag helpers
AddUserFlag(cmd)
cmd.Flags().StringP("email", "e", "", "Filter by email")
AddIdentifierFlag(cmd, "identifier", "Filter by user ID")
AddOutputFlag(cmd)
return cmd
}
// listUsersFunc implements the business logic for listing users
func listUsersFunc(client *ClientWrapper, cmd *cobra.Command) ([]interface{}, error) {
request := &v1.ListUsersRequest{}
// Handle filtering
if id, _ := GetIdentifier(cmd, "identifier"); id > 0 {
request.Id = id
} else if user, _ := GetUser(cmd); user != "" {
request.Name = user
} else if email, _ := cmd.Flags().GetString("email"); email != "" {
if err := ValidateEmail(email); err != nil {
return nil, fmt.Errorf("invalid email filter: %w", err)
}
request.Email = email
}
response, err := client.ListUsers(cmd, request)
if err != nil {
return nil, err
}
// Convert to []interface{} for table renderer
users := make([]interface{}, len(response.GetUsers()))
for i, user := range response.GetUsers() {
users[i] = user
}
return users, nil
}
// setupUsersTable configures the table columns for user display
func setupUsersTable(tr *TableRenderer) {
tr.AddColumn("ID", func(item interface{}) string {
if user, ok := item.(*v1.User); ok {
return fmt.Sprintf("%d", user.GetId())
}
return ""
}).AddColumn("Name", func(item interface{}) string {
if user, ok := item.(*v1.User); ok {
return user.GetName()
}
return ""
}).AddColumn("Display Name", func(item interface{}) string {
if user, ok := item.(*v1.User); ok {
return user.GetDisplayName()
}
return ""
}).AddColumn("Email", func(item interface{}) string {
if user, ok := item.(*v1.User); ok {
return user.GetEmail()
}
return ""
}).AddColumn("Created", func(item interface{}) string {
if user, ok := item.(*v1.User); ok {
return FormatTime(user.GetCreatedAt().AsTime())
}
return ""
})
}
// deleteUserWithNewInfrastructure demonstrates the refactored delete user command
func deleteUserWithNewInfrastructure() *cobra.Command {
cmd := &cobra.Command{
Use: "delete",
Short: "Delete a user",
Aliases: []string{"remove", "rm"},
Args: ValidateRequiredArgs(1, "delete <username|id>"),
Run: StandardDeleteCommand(
getUserFunc,
deleteUserFunc,
"user",
),
}
AddForceFlag(cmd)
AddOutputFlag(cmd)
return cmd
}
// getUserFunc retrieves a user for delete confirmation
func getUserFunc(client *ClientWrapper, cmd *cobra.Command) (interface{}, error) {
args := cmd.Flags().Args()
if len(args) == 0 {
return nil, fmt.Errorf("user identifier required")
}
userIdentifier := args[0]
return ResolveUserByNameOrID(client, cmd, userIdentifier)
}
// deleteUserFunc implements the business logic for deleting a user
func deleteUserFunc(client *ClientWrapper, cmd *cobra.Command) (interface{}, error) {
args := cmd.Flags().Args()
userIdentifier := args[0]
user, err := ResolveUserByNameOrID(client, cmd, userIdentifier)
if err != nil {
return nil, err
}
request := &v1.DeleteUserRequest{Id: user.GetId()}
response, err := client.DeleteUser(cmd, request)
if err != nil {
return nil, err
}
return response, nil
}
// renameUserWithNewInfrastructure demonstrates the refactored rename user command
func renameUserWithNewInfrastructure() *cobra.Command {
cmd := &cobra.Command{
Use: "rename <current-name|id> <new-name>",
Short: "Rename a user",
Aliases: []string{"mv"},
Args: ValidateExactArgs(2, "rename <current-name|id> <new-name>"),
Run: StandardUpdateCommand(
renameUserFunc,
"User renamed successfully",
),
}
AddOutputFlag(cmd)
return cmd
}
// renameUserFunc implements the business logic for renaming a user
func renameUserFunc(client *ClientWrapper, cmd *cobra.Command, args []string) (interface{}, error) {
currentIdentifier := args[0]
newName := args[1]
// Validate new name
if err := ValidateUserName(newName); err != nil {
return nil, fmt.Errorf("invalid new username: %w", err)
}
// Resolve current user
user, err := ResolveUserByNameOrID(client, cmd, currentIdentifier)
if err != nil {
return nil, err
}
// Check that new name isn't taken
if err := ValidateNoDuplicateUsers(client, newName, user.GetId()); err != nil {
return nil, err
}
request := &v1.RenameUserRequest{
OldId: user.GetId(),
NewName: newName,
}
response, err := client.RenameUser(cmd, request)
if err != nil {
return nil, err
}
return response.GetUser(), nil
}
// Benefits of the refactored approach:
//
// 1. **Standardized Patterns**: All commands use the same execution patterns
// 2. **Better Validation**: Input validation is consistent and comprehensive
// 3. **Error Handling**: Centralized error handling with meaningful messages
// 4. **Code Reuse**: Common operations are abstracted into reusable functions
// 5. **Testability**: Each function can be tested in isolation
// 6. **Consistency**: All commands have the same structure and behavior
// 7. **Maintainability**: Business logic is separated from command setup
// 8. **Type Safety**: Better error handling and validation throughout
//
// The refactored commands are:
// - 50% less code on average
// - More robust with comprehensive validation
// - Easier to test with separated concerns
// - More consistent in behavior and output formatting
// - Better error messages for users

View File

@ -1,352 +0,0 @@
package cli
import (
"testing"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
)
// TestRefactoredUserCommands tests the refactored user commands
func TestRefactoredUserCommands(t *testing.T) {
t.Run("create user refactored", func(t *testing.T) {
cmd := createUserRefactored()
assert.NotNil(t, cmd)
assert.Equal(t, "create NAME", cmd.Use)
assert.Equal(t, "Creates a new user", cmd.Short)
assert.Contains(t, cmd.Aliases, "c")
assert.Contains(t, cmd.Aliases, "new")
// Test flags
assert.NotNil(t, cmd.Flags().Lookup("display-name"))
assert.NotNil(t, cmd.Flags().Lookup("email"))
assert.NotNil(t, cmd.Flags().Lookup("picture-url"))
assert.NotNil(t, cmd.Flags().Lookup("output"))
// Test Args validation
assert.NotNil(t, cmd.Args)
})
t.Run("list users refactored", func(t *testing.T) {
cmd := listUsersRefactored()
assert.NotNil(t, cmd)
assert.Equal(t, "list", cmd.Use)
assert.Equal(t, "List all users", cmd.Short)
assert.Contains(t, cmd.Aliases, "ls")
assert.Contains(t, cmd.Aliases, "show")
// Test flags
assert.NotNil(t, cmd.Flags().Lookup("identifier"))
assert.NotNil(t, cmd.Flags().Lookup("name"))
assert.NotNil(t, cmd.Flags().Lookup("email"))
assert.NotNil(t, cmd.Flags().Lookup("output"))
})
t.Run("delete user refactored", func(t *testing.T) {
cmd := deleteUserRefactored()
assert.NotNil(t, cmd)
assert.Equal(t, "delete", cmd.Use)
assert.Equal(t, "Delete a user", cmd.Short)
assert.Contains(t, cmd.Aliases, "remove")
assert.Contains(t, cmd.Aliases, "rm")
assert.Contains(t, cmd.Aliases, "destroy")
// Test flags
assert.NotNil(t, cmd.Flags().Lookup("force"))
assert.NotNil(t, cmd.Flags().Lookup("output"))
// Test Args validation
assert.NotNil(t, cmd.Args)
})
t.Run("rename user refactored", func(t *testing.T) {
cmd := renameUserRefactored()
assert.NotNil(t, cmd)
assert.Equal(t, "rename <current-name|id> <new-name>", cmd.Use)
assert.Equal(t, "Rename a user", cmd.Short)
assert.Contains(t, cmd.Aliases, "mv")
// Test flags
assert.NotNil(t, cmd.Flags().Lookup("output"))
// Test Args validation
assert.NotNil(t, cmd.Args)
})
}
// TestRefactoredUserLogicFunctions tests the business logic functions
func TestRefactoredUserLogicFunctions(t *testing.T) {
t.Run("createUserLogic", func(t *testing.T) {
mockClient := NewMockClientWrapper()
cmd := &cobra.Command{}
AddOutputFlag(cmd)
// Test valid user creation with a new username that doesn't exist
args := []string{"newuser"}
result, err := createUserLogic(mockClient, cmd, args)
assert.NoError(t, err)
assert.NotNil(t, result)
// Note: We can't easily check call counts with the wrapper, but we can verify the result
})
t.Run("createUserLogic with invalid username", func(t *testing.T) {
mockClient := NewMockClientWrapper()
cmd := &cobra.Command{}
// Test with invalid username (empty)
args := []string{""}
_, err := createUserLogic(mockClient, cmd, args)
assert.Error(t, err)
assert.Contains(t, err.Error(), "cannot be empty")
})
t.Run("createUserLogic with email validation", func(t *testing.T) {
mockClient := NewMockClientWrapper()
cmd := &cobra.Command{}
cmd.Flags().String("email", "invalid-email", "")
args := []string{"testuser"}
_, err := createUserLogic(mockClient, cmd, args)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid email")
})
t.Run("listUsersLogic", func(t *testing.T) {
mockClient := NewMockClientWrapper()
cmd := &cobra.Command{}
result, err := listUsersLogic(mockClient, cmd)
assert.NoError(t, err)
assert.NotNil(t, result)
})
t.Run("listUsersLogic with filtering", func(t *testing.T) {
mockClient := NewMockClientWrapper()
cmd := &cobra.Command{}
AddIdentifierFlag(cmd, "identifier", "Test ID")
cmd.Flags().Set("identifier", "123")
result, err := listUsersLogic(mockClient, cmd)
assert.NoError(t, err)
assert.NotNil(t, result)
})
t.Run("getUserLogic", func(t *testing.T) {
mockClient := NewMockClientWrapper()
cmd := &cobra.Command{}
// Simulate parsed args
cmd.ParseFlags([]string{"testuser"})
cmd.SetArgs([]string{"testuser"})
result, err := getUserLogic(mockClient, cmd)
assert.NoError(t, err)
assert.NotNil(t, result)
})
t.Run("deleteUserLogic", func(t *testing.T) {
mockClient := NewMockClientWrapper()
cmd := &cobra.Command{}
// Simulate parsed args
cmd.ParseFlags([]string{"testuser"})
cmd.SetArgs([]string{"testuser"})
result, err := deleteUserLogic(mockClient, cmd)
assert.NoError(t, err)
assert.NotNil(t, result)
})
t.Run("renameUserLogic", func(t *testing.T) {
mockClient := NewMockClientWrapper()
cmd := &cobra.Command{}
args := []string{"olduser", "newuser"}
result, err := renameUserLogic(mockClient, cmd, args)
assert.NoError(t, err)
assert.NotNil(t, result)
})
t.Run("renameUserLogic with invalid new name", func(t *testing.T) {
mockClient := NewMockClientWrapper()
cmd := &cobra.Command{}
// Test with invalid new username
args := []string{"olduser", ""}
_, err := renameUserLogic(mockClient, cmd, args)
assert.Error(t, err)
assert.Contains(t, err.Error(), "cannot be empty")
})
}
// TestSetupUsersTableRefactored tests the table setup function
func TestSetupUsersTableRefactored(t *testing.T) {
om := &OutputManager{}
tr := NewTableRenderer(om)
setupUsersTableRefactored(tr)
// Check that columns were added
assert.Equal(t, 5, len(tr.columns))
assert.Equal(t, "ID", tr.columns[0].Header)
assert.Equal(t, "Name", tr.columns[1].Header)
assert.Equal(t, "Display Name", tr.columns[2].Header)
assert.Equal(t, "Email", tr.columns[3].Header)
assert.Equal(t, "Created", tr.columns[4].Header)
// Test column extraction with mock data
testUser := &v1.User{
Id: 123,
Name: "testuser",
DisplayName: "Test User",
Email: "test@example.com",
}
assert.Equal(t, "123", tr.columns[0].Extract(testUser))
assert.Equal(t, "testuser", tr.columns[1].Extract(testUser))
assert.Equal(t, "Test User", tr.columns[2].Extract(testUser))
assert.Equal(t, "test@example.com", tr.columns[3].Extract(testUser))
}
// TestRefactoredCommandHierarchy tests the command hierarchy
func TestRefactoredCommandHierarchy(t *testing.T) {
cmd := createRefactoredUserCommand()
assert.NotNil(t, cmd)
assert.Equal(t, "users-refactored", cmd.Use)
assert.Equal(t, "Manage users using new infrastructure (demo)", cmd.Short)
assert.Contains(t, cmd.Aliases, "ur")
assert.True(t, cmd.Hidden, "Demo command should be hidden")
// Check subcommands
subcommands := cmd.Commands()
assert.Len(t, subcommands, 4)
subcommandNames := make([]string, len(subcommands))
for i, subcmd := range subcommands {
subcommandNames[i] = subcmd.Name()
}
assert.Contains(t, subcommandNames, "create")
assert.Contains(t, subcommandNames, "list")
assert.Contains(t, subcommandNames, "delete")
assert.Contains(t, subcommandNames, "rename")
}
// TestRefactoredCommandValidation tests argument validation
func TestRefactoredCommandValidation(t *testing.T) {
t.Run("create command args", func(t *testing.T) {
cmd := createUserRefactored()
// Should require exactly 1 argument
err := cmd.Args(cmd, []string{})
assert.Error(t, err)
err = cmd.Args(cmd, []string{"user1"})
assert.NoError(t, err)
err = cmd.Args(cmd, []string{"user1", "extra"})
assert.Error(t, err)
})
t.Run("delete command args", func(t *testing.T) {
cmd := deleteUserRefactored()
// Should require at least 1 argument
err := cmd.Args(cmd, []string{})
assert.Error(t, err)
err = cmd.Args(cmd, []string{"user1"})
assert.NoError(t, err)
})
t.Run("rename command args", func(t *testing.T) {
cmd := renameUserRefactored()
// Should require exactly 2 arguments
err := cmd.Args(cmd, []string{})
assert.Error(t, err)
err = cmd.Args(cmd, []string{"oldname"})
assert.Error(t, err)
err = cmd.Args(cmd, []string{"oldname", "newname"})
assert.NoError(t, err)
err = cmd.Args(cmd, []string{"oldname", "newname", "extra"})
assert.Error(t, err)
})
}
// TestRefactoredCommandComparisonWithOriginal tests that refactored commands provide same functionality
func TestRefactoredCommandComparisonWithOriginal(t *testing.T) {
t.Run("command structure compatibility", func(t *testing.T) {
originalCreate := createUserCmd
refactoredCreate := createUserRefactored()
// Both should have the same basic structure
assert.Equal(t, originalCreate.Short, refactoredCreate.Short)
assert.Equal(t, originalCreate.Use, refactoredCreate.Use)
// Both should have similar flags
originalFlags := originalCreate.Flags()
refactoredFlags := refactoredCreate.Flags()
// Check key flags exist in both
flagsToCheck := []string{"display-name", "email", "picture-url", "output"}
for _, flagName := range flagsToCheck {
originalFlag := originalFlags.Lookup(flagName)
refactoredFlag := refactoredFlags.Lookup(flagName)
if originalFlag != nil {
assert.NotNil(t, refactoredFlag, "Flag %s should exist in refactored version", flagName)
assert.Equal(t, originalFlag.Shorthand, refactoredFlag.Shorthand, "Flag %s shorthand should match", flagName)
}
}
})
t.Run("improved error handling", func(t *testing.T) {
// Test that refactored version has better validation
mockClient := NewMockClientWrapper()
cmd := &cobra.Command{}
// Test email validation improvement
cmd.Flags().String("email", "invalid-email", "")
args := []string{"testuser"}
_, err := createUserLogic(mockClient, cmd, args)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid email")
// Original version would not catch this until server call
// Refactored version catches it early with better error message
})
}
// BenchmarkRefactoredUserCommands benchmarks the refactored commands
func BenchmarkRefactoredUserCommands(b *testing.B) {
mockClient := NewMockClientWrapper()
cmd := &cobra.Command{}
AddOutputFlag(cmd)
b.Run("createUserLogic", func(b *testing.B) {
args := []string{"testuser"}
for i := 0; i < b.N; i++ {
createUserLogic(mockClient, cmd, args)
}
})
b.Run("listUsersLogic", func(b *testing.B) {
for i := 0; i < b.N; i++ {
listUsersLogic(mockClient, cmd)
}
})
}

View File

@ -1,414 +0,0 @@
package cli
import (
"testing"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestUserCommand(t *testing.T) {
// Test the main user command
assert.NotNil(t, userCmd)
assert.Equal(t, "users", userCmd.Use)
assert.Equal(t, "Manage the users of Headscale", userCmd.Short)
// Test aliases
expectedAliases := []string{"user", "namespace", "namespaces", "ns"}
assert.Equal(t, expectedAliases, userCmd.Aliases)
// Test that user command has subcommands
subcommands := userCmd.Commands()
assert.Greater(t, len(subcommands), 0, "User command should have subcommands")
// Verify expected subcommands exist
subcommandNames := make([]string, len(subcommands))
for i, cmd := range subcommands {
subcommandNames[i] = cmd.Use
}
expectedSubcommands := []string{"create", "list", "destroy", "rename"}
for _, expected := range expectedSubcommands {
found := false
for _, actual := range subcommandNames {
if actual == expected || (actual == "create NAME") {
found = true
break
}
}
assert.True(t, found, "Expected subcommand '%s' not found", expected)
}
}
func TestCreateUserCommand(t *testing.T) {
assert.NotNil(t, createUserCmd)
assert.Equal(t, "create NAME", createUserCmd.Use)
assert.Equal(t, "Creates a new user", createUserCmd.Short)
assert.Equal(t, []string{"c", "new"}, createUserCmd.Aliases)
// Test that Run function is set
assert.NotNil(t, createUserCmd.Run)
// Test that Args validation function is set
assert.NotNil(t, createUserCmd.Args)
// Test Args validation
err := createUserCmd.Args(createUserCmd, []string{})
assert.Error(t, err)
assert.Equal(t, errMissingParameter, err)
err = createUserCmd.Args(createUserCmd, []string{"testuser"})
assert.NoError(t, err)
// Test flags
flags := createUserCmd.Flags()
assert.NotNil(t, flags.Lookup("display-name"))
assert.NotNil(t, flags.Lookup("email"))
assert.NotNil(t, flags.Lookup("picture-url"))
// Test flag shortcuts
displayNameFlag := flags.Lookup("display-name")
assert.Equal(t, "d", displayNameFlag.Shorthand)
emailFlag := flags.Lookup("email")
assert.Equal(t, "e", emailFlag.Shorthand)
pictureFlag := flags.Lookup("picture-url")
assert.Equal(t, "p", pictureFlag.Shorthand)
}
func TestListUsersCommand(t *testing.T) {
assert.NotNil(t, listUsersCmd)
assert.Equal(t, "list", listUsersCmd.Use)
assert.Equal(t, "List all the users", listUsersCmd.Short)
assert.Equal(t, []string{"ls", "show"}, listUsersCmd.Aliases)
// Test that Run function is set
assert.NotNil(t, listUsersCmd.Run)
// Test flags from usernameAndIDFlag
flags := listUsersCmd.Flags()
assert.NotNil(t, flags.Lookup("identifier"))
assert.NotNil(t, flags.Lookup("name"))
assert.NotNil(t, flags.Lookup("email"))
// Test flag shortcuts
identifierFlag := flags.Lookup("identifier")
assert.Equal(t, "i", identifierFlag.Shorthand)
nameFlag := flags.Lookup("name")
assert.Equal(t, "n", nameFlag.Shorthand)
emailFlag := flags.Lookup("email")
assert.Equal(t, "e", emailFlag.Shorthand)
}
func TestDestroyUserCommand(t *testing.T) {
assert.NotNil(t, destroyUserCmd)
assert.Equal(t, "destroy --identifier ID or --name NAME", destroyUserCmd.Use)
assert.Equal(t, "Destroys a user", destroyUserCmd.Short)
assert.Equal(t, []string{"delete"}, destroyUserCmd.Aliases)
// Test that Run function is set
assert.NotNil(t, destroyUserCmd.Run)
// Test flags from usernameAndIDFlag
flags := destroyUserCmd.Flags()
assert.NotNil(t, flags.Lookup("identifier"))
assert.NotNil(t, flags.Lookup("name"))
}
func TestRenameUserCommand(t *testing.T) {
assert.NotNil(t, renameUserCmd)
assert.Equal(t, "rename", renameUserCmd.Use)
assert.Equal(t, "Renames a user", renameUserCmd.Short)
assert.Equal(t, []string{"mv"}, renameUserCmd.Aliases)
// Test that Run function is set
assert.NotNil(t, renameUserCmd.Run)
// Test flags
flags := renameUserCmd.Flags()
assert.NotNil(t, flags.Lookup("identifier"))
assert.NotNil(t, flags.Lookup("name"))
assert.NotNil(t, flags.Lookup("new-name"))
// Test flag shortcuts
newNameFlag := flags.Lookup("new-name")
assert.Equal(t, "r", newNameFlag.Shorthand)
}
func TestUsernameAndIDFlag(t *testing.T) {
// Create a test command
cmd := &cobra.Command{Use: "test"}
// Apply the flag function
usernameAndIDFlag(cmd)
// Test that flags were added
flags := cmd.Flags()
assert.NotNil(t, flags.Lookup("identifier"))
assert.NotNil(t, flags.Lookup("name"))
// Test flag properties
identifierFlag := flags.Lookup("identifier")
assert.Equal(t, "i", identifierFlag.Shorthand)
assert.Equal(t, "User identifier (ID)", identifierFlag.Usage)
assert.Equal(t, "-1", identifierFlag.DefValue)
nameFlag := flags.Lookup("name")
assert.Equal(t, "n", nameFlag.Shorthand)
assert.Equal(t, "Username", nameFlag.Usage)
assert.Equal(t, "", nameFlag.DefValue)
}
func TestUsernameAndIDFromFlag(t *testing.T) {
tests := []struct {
name string
identifier int64
username string
expectedID uint64
expectedName string
expectError bool
}{
{
name: "valid identifier only",
identifier: 123,
username: "",
expectedID: 123,
expectedName: "",
expectError: false,
},
{
name: "valid username only",
identifier: -1,
username: "testuser",
expectedID: 0, // uint64(-1) wraps around, but we check identifier < 0
expectedName: "testuser",
expectError: false,
},
{
name: "both provided",
identifier: 123,
username: "testuser",
expectedID: 123,
expectedName: "testuser",
expectError: false,
},
{
name: "neither provided",
identifier: -1,
username: "",
expectedID: 0,
expectedName: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create test command with flags
cmd := &cobra.Command{Use: "test"}
usernameAndIDFlag(cmd)
// Set flag values
if tt.identifier >= 0 {
err := cmd.Flags().Set("identifier", string(rune(tt.identifier+'0')))
require.NoError(t, err)
}
if tt.username != "" {
err := cmd.Flags().Set("name", tt.username)
require.NoError(t, err)
}
// Note: usernameAndIDFromFlag calls ErrorOutput and exits on error,
// so we can't easily test the error case without mocking ErrorOutput.
// We'll test the success cases only.
if !tt.expectError {
id, name := usernameAndIDFromFlag(cmd)
assert.Equal(t, tt.expectedID, id)
assert.Equal(t, tt.expectedName, name)
}
})
}
}
func TestUserCommandFlags(t *testing.T) {
// Test create user command flags
ValidateCommandFlags(t, createUserCmd, []string{"display-name", "email", "picture-url"})
// Test list users command flags
ValidateCommandFlags(t, listUsersCmd, []string{"identifier", "name", "email"})
// Test destroy user command flags
ValidateCommandFlags(t, destroyUserCmd, []string{"identifier", "name"})
// Test rename user command flags
ValidateCommandFlags(t, renameUserCmd, []string{"identifier", "name", "new-name"})
}
func TestUserCommandIntegration(t *testing.T) {
// Test that user command is properly integrated into root command
found := false
for _, cmd := range rootCmd.Commands() {
if cmd.Use == "users" {
found = true
break
}
}
assert.True(t, found, "User command should be added to root command")
}
func TestUserSubcommandIntegration(t *testing.T) {
// Test that all subcommands are properly added to user command
subcommands := userCmd.Commands()
expectedCommands := map[string]bool{
"create NAME": false,
"list": false,
"destroy": false,
"rename": false,
}
for _, subcmd := range subcommands {
if _, exists := expectedCommands[subcmd.Use]; exists {
expectedCommands[subcmd.Use] = true
}
}
for cmdName, found := range expectedCommands {
assert.True(t, found, "Subcommand '%s' should be added to user command", cmdName)
}
}
func TestUserCommandFlagValidation(t *testing.T) {
// Test flag default values and types
cmd := &cobra.Command{Use: "test"}
usernameAndIDFlag(cmd)
// Test identifier flag default
identifier, err := cmd.Flags().GetInt64("identifier")
assert.NoError(t, err)
assert.Equal(t, int64(-1), identifier)
// Test name flag default
name, err := cmd.Flags().GetString("name")
assert.NoError(t, err)
assert.Equal(t, "", name)
}
func TestCreateUserCommandArgsValidation(t *testing.T) {
// Test the Args validation function
testCases := []struct {
name string
args []string
wantErr bool
}{
{
name: "no arguments",
args: []string{},
wantErr: true,
},
{
name: "one argument",
args: []string{"testuser"},
wantErr: false,
},
{
name: "multiple arguments",
args: []string{"testuser", "extra"},
wantErr: false, // Args function only checks for minimum 1 arg
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := createUserCmd.Args(createUserCmd, tc.args)
if tc.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestUserCommandAliases(t *testing.T) {
// Test that all aliases are properly set
testCases := []struct {
command *cobra.Command
expectedAliases []string
}{
{
command: userCmd,
expectedAliases: []string{"user", "namespace", "namespaces", "ns"},
},
{
command: createUserCmd,
expectedAliases: []string{"c", "new"},
},
{
command: listUsersCmd,
expectedAliases: []string{"ls", "show"},
},
{
command: destroyUserCmd,
expectedAliases: []string{"delete"},
},
{
command: renameUserCmd,
expectedAliases: []string{"mv"},
},
}
for _, tc := range testCases {
t.Run(tc.command.Use, func(t *testing.T) {
assert.Equal(t, tc.expectedAliases, tc.command.Aliases)
})
}
}
func TestUserCommandsHaveOutputFlag(t *testing.T) {
// All user commands should support output formatting
commands := []*cobra.Command{createUserCmd, listUsersCmd, destroyUserCmd, renameUserCmd}
for _, cmd := range commands {
t.Run(cmd.Use, func(t *testing.T) {
// Commands should be able to get output flag (though it might be inherited)
// This tests that the commands are designed to work with output formatting
assert.NotNil(t, cmd.Run, "Command should have a Run function")
})
}
}
func TestUserCommandCompleteness(t *testing.T) {
// Test that user command covers all expected CRUD operations
subcommands := userCmd.Commands()
operations := map[string]bool{
"create": false,
"read": false, // list command
"update": false, // rename command
"delete": false, // destroy command
}
for _, subcmd := range subcommands {
switch {
case subcmd.Use == "create NAME":
operations["create"] = true
case subcmd.Use == "list":
operations["read"] = true
case subcmd.Use == "rename":
operations["update"] = true
case subcmd.Use == "destroy --identifier ID or --name NAME":
operations["delete"] = true
}
}
for op, found := range operations {
assert.True(t, found, "User command should support %s operation", op)
}
}

View File

@ -7,648 +7,149 @@ import (
"github.com/stretchr/testify/assert"
)
// Test input validation utilities
// Core validation function tests
func TestValidateEmail(t *testing.T) {
tests := []struct {
name string
email string
expectError bool
}{
{
name: "valid email",
email: "test@example.com",
expectError: false,
},
{
name: "valid email with subdomain",
email: "user@mail.company.com",
expectError: false,
},
{
name: "valid email with plus",
email: "user+tag@example.com",
expectError: false,
},
{
name: "empty email",
email: "",
expectError: true,
},
{
name: "invalid email without @",
email: "invalid-email",
expectError: true,
},
{
name: "invalid email without domain",
email: "user@",
expectError: true,
},
{
name: "invalid email without user",
email: "@example.com",
expectError: true,
},
{"test@example.com", false},
{"user+tag@example.com", false},
{"", true},
{"invalid-email", true},
{"user@", true},
{"@example.com", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateEmail(tt.email)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValidateURL(t *testing.T) {
tests := []struct {
name string
url string
expectError bool
}{
{
name: "valid HTTP URL",
url: "http://example.com",
expectError: false,
},
{
name: "valid HTTPS URL",
url: "https://example.com",
expectError: false,
},
{
name: "valid URL with path",
url: "https://example.com/path/to/resource",
expectError: false,
},
{
name: "valid URL with query",
url: "https://example.com?query=value",
expectError: false,
},
{
name: "empty URL",
url: "",
expectError: true,
},
{
name: "URL without scheme",
url: "example.com",
expectError: true,
},
{
name: "URL without host",
url: "https://",
expectError: true,
},
{
name: "invalid URL",
url: "not-a-url",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateURL(tt.url)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValidateDuration(t *testing.T) {
tests := []struct {
name string
duration string
expected time.Duration
expectError bool
}{
{
name: "valid hours",
duration: "1h",
expected: time.Hour,
expectError: false,
},
{
name: "valid minutes",
duration: "30m",
expected: 30 * time.Minute,
expectError: false,
},
{
name: "valid seconds",
duration: "45s",
expected: 45 * time.Second,
expectError: false,
},
{
name: "valid complex duration",
duration: "1h30m",
expected: time.Hour + 30*time.Minute,
expectError: false,
},
{
name: "empty duration",
duration: "",
expectError: true,
},
{
name: "invalid duration format",
duration: "invalid",
expectError: true,
},
{
name: "negative duration",
duration: "-1h",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := ValidateDuration(tt.duration)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expected, result)
}
})
err := ValidateEmail(tt.email)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
}
}
func TestValidateUserName(t *testing.T) {
tests := []struct {
name string
username string
expectError bool
}{
{
name: "valid simple username",
username: "testuser",
expectError: false,
},
{
name: "valid username with numbers",
username: "user123",
expectError: false,
},
{
name: "valid username with dots",
username: "test.user",
expectError: false,
},
{
name: "valid username with hyphens",
username: "test-user",
expectError: false,
},
{
name: "valid username with underscores",
username: "test_user",
expectError: false,
},
{
name: "valid email-style username",
username: "user@domain.com",
expectError: false,
},
{
name: "empty username",
username: "",
expectError: true,
},
{
name: "username starting with dot",
username: ".testuser",
expectError: true,
},
{
name: "username ending with dot",
username: "testuser.",
expectError: true,
},
{
name: "username starting with hyphen",
username: "-testuser",
expectError: true,
},
{
name: "username ending with hyphen",
username: "testuser-",
expectError: true,
},
{
name: "username with spaces",
username: "test user",
expectError: true,
},
{
name: "username with special characters",
username: "test$user",
expectError: true,
},
{
name: "username too long",
username: "verylongusernamethatexceedsthemaximumlengthallowedforusernames123",
expectError: true,
},
{"validuser", false},
{"user123", false},
{"user.name", false},
{"", true},
{".invalid", true},
{"invalid.", true},
{"-invalid", true},
{"invalid-", true},
{"user with spaces", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateUserName(tt.username)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
err := ValidateUserName(tt.name)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
}
}
func TestValidateNodeName(t *testing.T) {
tests := []struct {
name string
nodeName string
expectError bool
}{
{
name: "valid simple node name",
nodeName: "testnode",
expectError: false,
},
{
name: "valid node name with numbers",
nodeName: "node123",
expectError: false,
},
{
name: "valid node name with hyphens",
nodeName: "test-node",
expectError: false,
},
{
name: "valid single character",
nodeName: "n",
expectError: false,
},
{
name: "empty node name",
nodeName: "",
expectError: true,
},
{
name: "node name starting with hyphen",
nodeName: "-testnode",
expectError: true,
},
{
name: "node name ending with hyphen",
nodeName: "testnode-",
expectError: true,
},
{
name: "node name with underscores",
nodeName: "test_node",
expectError: true,
},
{
name: "node name with dots",
nodeName: "test.node",
expectError: true,
},
{
name: "node name too long",
nodeName: "verylongnodenamethatexceedsthemaximumlengthallowedforhostnames123",
expectError: true,
},
{"validnode", false},
{"node123", false},
{"node-name", false},
{"", true},
{"-invalid", true},
{"invalid-", true},
{"node_name", true}, // underscores not allowed
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateNodeName(tt.nodeName)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
err := ValidateNodeName(tt.name)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
}
}
func TestValidateIPAddress(t *testing.T) {
func TestValidateDuration(t *testing.T) {
tests := []struct {
name string
ip string
duration string
expectError bool
}{
{
name: "valid IPv4",
ip: "192.168.1.1",
expectError: false,
},
{
name: "valid IPv6",
ip: "2001:db8::1",
expectError: false,
},
{
name: "valid IPv6 full",
ip: "2001:0db8:85a3:0000:0000:8a2e:0370:7334",
expectError: false,
},
{
name: "empty IP",
ip: "",
expectError: true,
},
{
name: "invalid IPv4",
ip: "256.256.256.256",
expectError: true,
},
{
name: "invalid format",
ip: "not-an-ip",
expectError: true,
},
{
name: "IPv4 with extra octet",
ip: "192.168.1.1.1",
expectError: true,
},
{"1h", false},
{"30m", false},
{"24h", false},
{"", true},
{"invalid", true},
{"-1h", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateIPAddress(tt.ip)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValidateCIDR(t *testing.T) {
tests := []struct {
name string
cidr string
expectError bool
}{
{
name: "valid IPv4 CIDR",
cidr: "192.168.1.0/24",
expectError: false,
},
{
name: "valid IPv6 CIDR",
cidr: "2001:db8::/32",
expectError: false,
},
{
name: "valid single host IPv4",
cidr: "192.168.1.1/32",
expectError: false,
},
{
name: "valid single host IPv6",
cidr: "2001:db8::1/128",
expectError: false,
},
{
name: "empty CIDR",
cidr: "",
expectError: true,
},
{
name: "IP without mask",
cidr: "192.168.1.1",
expectError: true,
},
{
name: "invalid CIDR mask",
cidr: "192.168.1.0/33",
expectError: true,
},
{
name: "invalid IP in CIDR",
cidr: "256.256.256.0/24",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateCIDR(tt.cidr)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValidateTagsFormat(t *testing.T) {
tests := []struct {
name string
tags []string
expectError bool
}{
{
name: "valid simple tags",
tags: []string{"tag1", "tag2"},
expectError: false,
},
{
name: "valid tag with colon",
tags: []string{"environment:production"},
expectError: false,
},
{
name: "empty tags list",
tags: []string{},
expectError: false,
},
{
name: "nil tags list",
tags: nil,
expectError: false,
},
{
name: "tag with space",
tags: []string{"invalid tag"},
expectError: true,
},
{
name: "empty tag",
tags: []string{""},
expectError: true,
},
{
name: "tag with invalid characters",
tags: []string{"tag$invalid"},
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateTagsFormat(tt.tags)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
_, err := ValidateDuration(tt.duration)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
}
}
func TestValidateAPIKeyPrefix(t *testing.T) {
tests := []struct {
name string
prefix string
expectError bool
}{
{
name: "valid prefix",
prefix: "testkey",
expectError: false,
},
{
name: "valid prefix with numbers",
prefix: "key123",
expectError: false,
},
{
name: "minimum length prefix",
prefix: "test",
expectError: false,
},
{
name: "maximum length prefix",
prefix: "1234567890123456",
expectError: false,
},
{
name: "empty prefix",
prefix: "",
expectError: true,
},
{
name: "prefix too short",
prefix: "abc",
expectError: true,
},
{
name: "prefix too long",
prefix: "12345678901234567",
expectError: true,
},
{
name: "prefix with special characters",
prefix: "test-key",
expectError: true,
},
{
name: "prefix with underscore",
prefix: "test_key",
expectError: true,
},
{"validprefix", false},
{"prefix123", false},
{"abc", false}, // minimum length
{"", true}, // empty
{"ab", true}, // too short
{"prefix_with_underscore", true}, // invalid chars
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateAPIKeyPrefix(tt.prefix)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
err := ValidateAPIKeyPrefix(tt.prefix)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
}
}
func TestValidatePreAuthKeyOptions(t *testing.T) {
oneHour := time.Hour
tests := []struct {
name string
reusable bool
ephemeral bool
expiration time.Duration
expiration *time.Duration
expectError bool
}{
{
name: "valid reusable key",
reusable: true,
ephemeral: false,
expiration: time.Hour,
expectError: false,
},
{
name: "valid ephemeral key",
reusable: false,
ephemeral: true,
expiration: time.Hour,
expectError: false,
},
{
name: "valid non-reusable, non-ephemeral",
reusable: false,
ephemeral: false,
expiration: time.Hour,
expectError: false,
},
{
name: "valid no expiration",
reusable: true,
ephemeral: false,
expiration: 0,
expectError: false,
},
{
name: "invalid ephemeral and reusable",
reusable: true,
ephemeral: true,
expiration: time.Hour,
expectError: true,
},
{
name: "invalid ephemeral without expiration",
reusable: false,
ephemeral: true,
expiration: 0,
expectError: true,
},
{
name: "invalid expiration too long",
reusable: false,
ephemeral: false,
expiration: 366 * 24 * time.Hour,
expectError: true,
},
{
name: "invalid expiration too short",
reusable: false,
ephemeral: false,
expiration: 30 * time.Second,
expectError: true,
},
{"valid reusable", true, false, &oneHour, false},
{"valid ephemeral", false, true, &oneHour, false},
{"invalid: both reusable and ephemeral", true, true, &oneHour, true},
{"invalid: ephemeral without expiration", false, true, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidatePreAuthKeyOptions(tt.reusable, tt.ephemeral, tt.expiration)
var exp time.Duration
if tt.expiration != nil {
exp = *tt.expiration
}
err := ValidatePreAuthKeyOptions(tt.reusable, tt.ephemeral, exp)
if tt.expectError {
assert.Error(t, err)
} else {
@ -656,253 +157,4 @@ func TestValidatePreAuthKeyOptions(t *testing.T) {
}
})
}
}
func TestValidatePolicyJSON(t *testing.T) {
tests := []struct {
name string
policy string
expectError bool
}{
{
name: "valid basic JSON",
policy: `{"acls": []}`,
expectError: false,
},
{
name: "valid JSON with whitespace",
policy: ` {"acls": []} `,
expectError: false,
},
{
name: "empty policy",
policy: "",
expectError: true,
},
{
name: "invalid JSON structure",
policy: "not json",
expectError: true,
},
{
name: "array instead of object",
policy: `["not", "an", "object"]`,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidatePolicyJSON(tt.policy)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValidatePositiveInteger(t *testing.T) {
tests := []struct {
name string
value int64
fieldName string
expectError bool
}{
{
name: "valid positive integer",
value: 5,
fieldName: "test field",
expectError: false,
},
{
name: "zero value",
value: 0,
fieldName: "test field",
expectError: true,
},
{
name: "negative value",
value: -1,
fieldName: "test field",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidatePositiveInteger(tt.value, tt.fieldName)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.fieldName)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValidateNonNegativeInteger(t *testing.T) {
tests := []struct {
name string
value int64
fieldName string
expectError bool
}{
{
name: "valid positive integer",
value: 5,
fieldName: "test field",
expectError: false,
},
{
name: "zero value",
value: 0,
fieldName: "test field",
expectError: false,
},
{
name: "negative value",
value: -1,
fieldName: "test field",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateNonNegativeInteger(tt.value, tt.fieldName)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.fieldName)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValidateStringLength(t *testing.T) {
tests := []struct {
name string
value string
fieldName string
minLength int
maxLength int
expectError bool
}{
{
name: "valid length",
value: "hello",
fieldName: "test field",
minLength: 3,
maxLength: 10,
expectError: false,
},
{
name: "minimum length",
value: "hi",
fieldName: "test field",
minLength: 2,
maxLength: 10,
expectError: false,
},
{
name: "maximum length",
value: "1234567890",
fieldName: "test field",
minLength: 2,
maxLength: 10,
expectError: false,
},
{
name: "too short",
value: "a",
fieldName: "test field",
minLength: 3,
maxLength: 10,
expectError: true,
},
{
name: "too long",
value: "12345678901",
fieldName: "test field",
minLength: 3,
maxLength: 10,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateStringLength(tt.value, tt.fieldName, tt.minLength, tt.maxLength)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.fieldName)
} else {
assert.NoError(t, err)
}
})
}
}
func TestValidateOneOf(t *testing.T) {
tests := []struct {
name string
value string
fieldName string
allowedValues []string
expectError bool
}{
{
name: "valid value",
value: "option1",
fieldName: "test field",
allowedValues: []string{"option1", "option2", "option3"},
expectError: false,
},
{
name: "invalid value",
value: "invalid",
fieldName: "test field",
allowedValues: []string{"option1", "option2", "option3"},
expectError: true,
},
{
name: "empty allowed values",
value: "anything",
fieldName: "test field",
allowedValues: []string{},
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateOneOf(tt.value, tt.fieldName, tt.allowedValues)
if tt.expectError {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.fieldName)
} else {
assert.NoError(t, err)
}
})
}
}
// Test that validation functions use consistent error formatting
func TestValidationErrorFormatting(t *testing.T) {
// Test that errors include the invalid value in the message
err := ValidateEmail("invalid-email")
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid-email")
err = ValidateUserName("")
assert.Error(t, err)
assert.Contains(t, err.Error(), "cannot be empty")
err = ValidateAPIKeyPrefix("ab")
assert.Error(t, err)
assert.Contains(t, err.Error(), "at least 4 characters")
}