This commit is contained in:
Kristoffer Dalby
2025-07-15 14:51:23 +00:00
parent 024ed59ea9
commit 8253d588c6
31 changed files with 300 additions and 364 deletions

View File

@@ -15,7 +15,6 @@ import (
"google.golang.org/protobuf/types/known/timestamppb"
)
func init() {
rootCmd.AddCommand(apiKeysCmd)
apiKeysCmd.AddCommand(listAPIKeys)
@@ -98,7 +97,6 @@ var listAPIKeys = &cobra.Command{
}
return nil
})
if err != nil {
return
}
@@ -148,7 +146,6 @@ If you loose a key, create a new one and revoke (expire) the old one.`,
SuccessOutput(response.GetApiKey(), response.GetApiKey(), output)
return nil
})
if err != nil {
return
}
@@ -185,7 +182,6 @@ var expireAPIKeyCmd = &cobra.Command{
SuccessOutput(response, "Key expired", output)
return nil
})
if err != nil {
return
}
@@ -222,7 +218,6 @@ var deleteAPIKeyCmd = &cobra.Command{
SuccessOutput(response, "Key deleted", output)
return nil
})
if err != nil {
return
}

View File

@@ -2,7 +2,7 @@ package cli
import (
"context"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
)
@@ -11,6 +11,6 @@ func WithClient(fn func(context.Context, v1.HeadscaleServiceClient) error) error
ctx, client, conn, cancel := newHeadscaleCLIWithConfig()
defer cancel()
defer conn.Close()
return fn(ctx, client)
}
}

View File

@@ -37,10 +37,10 @@ func TestConfigTestCommandHelp(t *testing.T) {
// 1. It depends on configuration files being present
// 2. It calls log.Fatal() which would exit the test process
// 3. It tries to initialize a full Headscale server
//
//
// In a real refactor, we would:
// 1. Extract the configuration validation logic to a testable function
// 2. Return errors instead of calling log.Fatal()
// 3. Accept configuration as a parameter instead of loading from global state
//
// For now, we test the command structure and that it's properly wired up.
// For now, we test the command structure and that it's properly wired up.

View File

@@ -15,11 +15,6 @@ const (
errPreAuthKeyMalformed = Error("key is malformed. expected 64 hex characters with `nodekey` prefix")
)
// Error is used to compare errors as per https://dave.cheney.net/2016/04/07/constant-errors
type Error string
func (e Error) Error() string { return string(e) }
func init() {
rootCmd.AddCommand(debugCmd)
@@ -30,11 +25,6 @@ func init() {
}
createNodeCmd.Flags().StringP("user", "u", "", "User")
createNodeCmd.Flags().StringP("namespace", "n", "", "User")
createNodeNamespaceFlag := createNodeCmd.Flags().Lookup("namespace")
createNodeNamespaceFlag.Deprecated = deprecateNamespaceMessage
createNodeNamespaceFlag.Hidden = true
err = createNodeCmd.MarkFlagRequired("user")
if err != nil {
log.Fatal().Err(err).Msg("")
@@ -60,7 +50,7 @@ var createNodeCmd = &cobra.Command{
Use: "create-node",
Short: "Create a node that can be registered with `nodes register <>` command",
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
output := GetOutputFlag(cmd)
user, err := cmd.Flags().GetString("user")
if err != nil {
@@ -129,7 +119,6 @@ var createNodeCmd = &cobra.Command{
SuccessOutput(response.GetNode(), "Node created", output)
return nil
})
if err != nil {
return
}

View File

@@ -41,7 +41,7 @@ func TestCreateNodeCommandInDebugCommand(t *testing.T) {
func TestCreateNodeCommandFlags(t *testing.T) {
// Test that create-node has the required flags
// Test name flag
nameFlag := createNodeCmd.Flags().Lookup("name")
assert.NotNil(t, nameFlag)
@@ -63,22 +63,16 @@ func TestCreateNodeCommandFlags(t *testing.T) {
assert.NotNil(t, routeFlag)
assert.Equal(t, "r", routeFlag.Shorthand)
// Test deprecated namespace flag
namespaceFlag := createNodeCmd.Flags().Lookup("namespace")
assert.NotNil(t, namespaceFlag)
assert.Equal(t, "n", namespaceFlag.Shorthand)
assert.True(t, namespaceFlag.Hidden, "Namespace flag should be hidden")
assert.Equal(t, deprecateNamespaceMessage, namespaceFlag.Deprecated)
}
func TestCreateNodeCommandRequiredFlags(t *testing.T) {
// Test that required flags are marked as required
// We can't easily test the actual requirement enforcement without executing the command
// But we can test that the flags exist and have the expected properties
// These flags should be required based on the init() function
requiredFlags := []string{"name", "user", "key"}
for _, flagName := range requiredFlags {
flag := createNodeCmd.Flags().Lookup(flagName)
assert.NotNil(t, flag, "Required flag %s should exist", flagName)
@@ -134,8 +128,6 @@ func TestCreateNodeCommandFlagDescriptions(t *testing.T) {
routeFlag := createNodeCmd.Flags().Lookup("route")
assert.Contains(t, routeFlag.Usage, "routes to advertise")
namespaceFlag := createNodeCmd.Flags().Lookup("namespace")
assert.Equal(t, "User", namespaceFlag.Usage) // Same as user flag
}
// Note: We can't easily test the actual execution of create-node because:
@@ -149,4 +141,4 @@ func TestCreateNodeCommandFlagDescriptions(t *testing.T) {
// 3. Return errors instead of calling ErrorOutput/SuccessOutput
// 4. Add validation functions that can be tested independently
//
// For now, we test the command structure and flag configuration.
// For now, we test the command structure and flag configuration.

View File

@@ -22,7 +22,7 @@ var generatePrivateKeyCmd = &cobra.Command{
Use: "private-key",
Short: "Generate a private key for the headscale server",
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
output := GetOutputFlag(cmd)
machineKey := key.NewMachine()
machineKeyStr, err := machineKey.MarshalText()

View File

@@ -18,17 +18,17 @@ func TestGenerateCommand(t *testing.T) {
Use: "headscale",
Short: "headscale - a Tailscale control server",
}
cmd.AddCommand(generateCmd)
out := new(bytes.Buffer)
cmd.SetOut(out)
cmd.SetErr(out)
cmd.SetArgs([]string{"generate", "--help"})
err := cmd.Execute()
require.NoError(t, err)
outStr := out.String()
assert.Contains(t, outStr, "Generate commands")
assert.Contains(t, outStr, "private-key")
@@ -42,17 +42,17 @@ func TestGenerateCommandAlias(t *testing.T) {
Use: "headscale",
Short: "headscale - a Tailscale control server",
}
cmd.AddCommand(generateCmd)
out := new(bytes.Buffer)
cmd.SetOut(out)
cmd.SetErr(out)
cmd.SetArgs([]string{"gen", "--help"})
err := cmd.Execute()
require.NoError(t, err)
outStr := out.String()
assert.Contains(t, outStr, "Generate commands")
}
@@ -77,7 +77,7 @@ func TestGeneratePrivateKeyCommand(t *testing.T) {
expectYAML: false,
},
{
name: "yaml output",
name: "yaml output",
args: []string{"generate", "private-key", "--output", "yaml"},
expectJSON: false,
expectYAML: true,
@@ -89,15 +89,15 @@ func TestGeneratePrivateKeyCommand(t *testing.T) {
// Note: This command calls SuccessOutput which exits the process
// We can't test the actual execution easily without mocking
// Instead, we test the command structure and that it exists
cmd := &cobra.Command{
Use: "headscale",
Short: "headscale - a Tailscale control server",
}
cmd.AddCommand(generateCmd)
cmd.PersistentFlags().StringP("output", "o", "", "Output format")
// Test that the command exists and can be found
privateKeyCmd, _, err := cmd.Find([]string{"generate", "private-key"})
require.NoError(t, err)
@@ -112,17 +112,17 @@ func TestGeneratePrivateKeyHelp(t *testing.T) {
Use: "headscale",
Short: "headscale - a Tailscale control server",
}
cmd.AddCommand(generateCmd)
out := new(bytes.Buffer)
cmd.SetOut(out)
cmd.SetErr(out)
cmd.SetArgs([]string{"generate", "private-key", "--help"})
err := cmd.Execute()
require.NoError(t, err)
outStr := out.String()
assert.Contains(t, outStr, "Generate a private key for the headscale server")
assert.Contains(t, outStr, "Usage:")
@@ -132,10 +132,10 @@ func TestGeneratePrivateKeyHelp(t *testing.T) {
func TestPrivateKeyGeneration(t *testing.T) {
// We can't easily test the full command because it calls SuccessOutput which exits
// But we can test that the key generation produces valid output format
// This is testing the core logic that would be in the command
// In a real refactor, we'd extract this to a testable function
// For now, we can test that the command structure is correct
assert.NotNil(t, generatePrivateKeyCmd)
assert.Equal(t, "private-key", generatePrivateKeyCmd.Use)
@@ -148,7 +148,7 @@ func TestGenerateCommandStructure(t *testing.T) {
assert.Equal(t, "generate", generateCmd.Use)
assert.Equal(t, "Generate commands", generateCmd.Short)
assert.Contains(t, generateCmd.Aliases, "gen")
// Test that private-key is a subcommand
found := false
for _, subcmd := range generateCmd.Commands() {
@@ -167,31 +167,31 @@ func validatePrivateKeyOutput(t *testing.T, output string, format string) {
var result map[string]interface{}
err := json.Unmarshal([]byte(output), &result)
require.NoError(t, err, "Output should be valid JSON")
privateKey, exists := result["private_key"]
require.True(t, exists, "JSON should contain private_key field")
keyStr, ok := privateKey.(string)
require.True(t, ok, "private_key should be a string")
require.NotEmpty(t, keyStr, "private_key should not be empty")
// Basic validation that it looks like a machine key
assert.True(t, strings.HasPrefix(keyStr, "mkey:"), "Machine key should start with mkey:")
case "yaml":
var result map[string]interface{}
err := yaml.Unmarshal([]byte(output), &result)
require.NoError(t, err, "Output should be valid YAML")
privateKey, exists := result["private_key"]
require.True(t, exists, "YAML should contain private_key field")
keyStr, ok := privateKey.(string)
require.True(t, ok, "private_key should be a string")
require.NotEmpty(t, keyStr, "private_key should not be empty")
assert.True(t, strings.HasPrefix(keyStr, "mkey:"), "Machine key should start with mkey:")
default:
// Default format should just be the key itself
assert.True(t, strings.HasPrefix(output, "mkey:"), "Default output should be the machine key")
@@ -203,7 +203,7 @@ func validatePrivateKeyOutput(t *testing.T, output string, format string) {
func TestPrivateKeyOutputFormats(t *testing.T) {
// Test cases for different output formats
// These test the validation logic we would use after refactoring
tests := []struct {
format string
sample string
@@ -213,7 +213,7 @@ func TestPrivateKeyOutputFormats(t *testing.T) {
sample: `{"private_key": "mkey:abcd1234567890abcd1234567890abcd1234567890abcd1234567890abcd1234"}`,
},
{
format: "yaml",
format: "yaml",
sample: "private_key: mkey:abcd1234567890abcd1234567890abcd1234567890abcd1234567890abcd1234\n",
},
{
@@ -221,10 +221,10 @@ func TestPrivateKeyOutputFormats(t *testing.T) {
sample: "mkey:abcd1234567890abcd1234567890abcd1234567890abcd1234567890abcd1234",
},
}
for _, tt := range tests {
t.Run("format_"+tt.format, func(t *testing.T) {
validatePrivateKeyOutput(t, tt.sample, tt.format)
})
}
}
}

View File

@@ -15,6 +15,11 @@ import (
"github.com/spf13/cobra"
)
// Error is used to compare errors as per https://dave.cheney.net/2016/04/07/constant-errors
type Error string
func (e Error) Error() string { return string(e) }
const (
errMockOidcClientIDNotDefined = Error("MOCKOIDC_CLIENT_ID not defined")
errMockOidcClientSecretNotDefined = Error("MOCKOIDC_CLIENT_SECRET not defined")

View File

@@ -32,27 +32,13 @@ func init() {
// Display options
listNodesCmd.Flags().BoolP("tags", "t", false, "Show tags")
listNodesCmd.Flags().String("columns", "", "Comma-separated list of columns to display")
// Backward compatibility
listNodesCmd.Flags().StringP("namespace", "n", "", "User")
listNodesNamespaceFlag := listNodesCmd.Flags().Lookup("namespace")
listNodesNamespaceFlag.Deprecated = deprecateNamespaceMessage
listNodesNamespaceFlag.Hidden = true
nodeCmd.AddCommand(listNodesCmd)
listNodeRoutesCmd.Flags().StringP("node", "n", "", "Node identifier (ID, name, hostname, or IP)")
listNodeRoutesCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID) - deprecated, use --node")
identifierFlag := listNodeRoutesCmd.Flags().Lookup("identifier")
identifierFlag.Deprecated = "use --node"
identifierFlag.Hidden = true
nodeCmd.AddCommand(listNodeRoutesCmd)
registerNodeCmd.Flags().StringP("user", "u", "", "User")
registerNodeCmd.Flags().StringP("namespace", "n", "", "User")
registerNodeNamespaceFlag := registerNodeCmd.Flags().Lookup("namespace")
registerNodeNamespaceFlag.Deprecated = deprecateNamespaceMessage
registerNodeNamespaceFlag.Hidden = true
err := registerNodeCmd.MarkFlagRequired("user")
if err != nil {
log.Fatal(err.Error())
@@ -65,40 +51,24 @@ func init() {
nodeCmd.AddCommand(registerNodeCmd)
expireNodeCmd.Flags().StringP("node", "n", "", "Node identifier (ID, name, hostname, or IP)")
expireNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID) - deprecated, use --node")
identifierFlag = expireNodeCmd.Flags().Lookup("identifier")
identifierFlag.Deprecated = "use --node"
identifierFlag.Hidden = true
if err != nil {
log.Fatal(err.Error())
}
nodeCmd.AddCommand(expireNodeCmd)
renameNodeCmd.Flags().StringP("node", "n", "", "Node identifier (ID, name, hostname, or IP)")
renameNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID) - deprecated, use --node")
identifierFlag = renameNodeCmd.Flags().Lookup("identifier")
identifierFlag.Deprecated = "use --node"
identifierFlag.Hidden = true
if err != nil {
log.Fatal(err.Error())
}
nodeCmd.AddCommand(renameNodeCmd)
deleteNodeCmd.Flags().StringP("node", "n", "", "Node identifier (ID, name, hostname, or IP)")
deleteNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID) - deprecated, use --node")
identifierFlag = deleteNodeCmd.Flags().Lookup("identifier")
identifierFlag.Deprecated = "use --node"
identifierFlag.Hidden = true
if err != nil {
log.Fatal(err.Error())
}
nodeCmd.AddCommand(deleteNodeCmd)
moveNodeCmd.Flags().StringP("node", "n", "", "Node identifier (ID, name, hostname, or IP)")
moveNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID) - deprecated, use --node")
identifierFlag = moveNodeCmd.Flags().Lookup("identifier")
identifierFlag.Deprecated = "use --node"
identifierFlag.Hidden = true
if err != nil {
log.Fatal(err.Error())
@@ -106,24 +76,19 @@ func init() {
moveNodeCmd.Flags().Uint64P("user", "u", 0, "New user")
moveNodeCmd.Flags().StringP("namespace", "n", "", "User")
moveNodeNamespaceFlag := moveNodeCmd.Flags().Lookup("namespace")
moveNodeNamespaceFlag.Deprecated = deprecateNamespaceMessage
moveNodeNamespaceFlag.Hidden = true
err = moveNodeCmd.MarkFlagRequired("user")
if err != nil {
log.Fatal(err.Error())
}
nodeCmd.AddCommand(moveNodeCmd)
tagCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
tagCmd.MarkFlagRequired("identifier")
tagCmd.Flags().StringP("node", "n", "", "Node identifier (ID, name, hostname, or IP)")
tagCmd.MarkFlagRequired("node")
tagCmd.Flags().StringSliceP("tags", "t", []string{}, "List of tags to add to the node")
nodeCmd.AddCommand(tagCmd)
approveRoutesCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
approveRoutesCmd.MarkFlagRequired("identifier")
approveRoutesCmd.Flags().StringP("node", "n", "", "Node identifier (ID, name, hostname, or IP)")
approveRoutesCmd.MarkFlagRequired("node")
approveRoutesCmd.Flags().StringSliceP("routes", "r", []string{}, `List of routes that will be approved (comma-separated, e.g. "10.0.0.0/8,192.168.0.0/24" or empty string to remove all approved routes)`)
nodeCmd.AddCommand(approveRoutesCmd)
@@ -140,7 +105,7 @@ var registerNodeCmd = &cobra.Command{
Use: "register",
Short: "Registers a node to your network",
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
output := GetOutputFlag(cmd)
user, err := cmd.Flags().GetString("user")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
@@ -181,7 +146,6 @@ var registerNodeCmd = &cobra.Command{
fmt.Sprintf("Node %s registered", response.GetNode().GetGivenName()), output)
return nil
})
if err != nil {
return
}
@@ -202,15 +166,12 @@ var listNodesCmd = &cobra.Command{
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
request := &v1.ListNodesRequest{}
// Handle user filtering (existing functionality)
if user, _ := cmd.Flags().GetString("user"); user != "" {
request.User = user
}
if namespace, _ := cmd.Flags().GetString("namespace"); namespace != "" {
request.User = namespace // backward compatibility
}
// Handle node filtering (new functionality)
if nodeFlag, _ := cmd.Flags().GetString("node"); nodeFlag != "" {
// Use smart lookup to determine filter type
@@ -267,7 +228,6 @@ var listNodesCmd = &cobra.Command{
}
return nil
})
if err != nil {
return
}
@@ -279,7 +239,7 @@ var listNodeRoutesCmd = &cobra.Command{
Short: "List routes available on nodes",
Aliases: []string{"lsr", "routes"},
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
output := GetOutputFlag(cmd)
identifier, err := GetNodeIdentifier(cmd)
if err != nil {
ErrorOutput(
@@ -339,7 +299,6 @@ var listNodeRoutesCmd = &cobra.Command{
}
return nil
})
if err != nil {
return
}
@@ -352,7 +311,7 @@ var expireNodeCmd = &cobra.Command{
Long: "Expiring a node will keep the node in the database and force it to reauthenticate.",
Aliases: []string{"logout", "exp", "e"},
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
output := GetOutputFlag(cmd)
identifier, err := GetNodeIdentifier(cmd)
if err != nil {
@@ -385,7 +344,6 @@ var expireNodeCmd = &cobra.Command{
SuccessOutput(response.GetNode(), "Node expired", output)
return nil
})
if err != nil {
return
}
@@ -396,7 +354,7 @@ var renameNodeCmd = &cobra.Command{
Use: "rename NEW_NAME",
Short: "Renames a node in your network",
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
output := GetOutputFlag(cmd)
identifier, err := GetNodeIdentifier(cmd)
if err != nil {
@@ -412,7 +370,7 @@ var renameNodeCmd = &cobra.Command{
if len(args) > 0 {
newName = args[0]
}
err = WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
request := &v1.RenameNodeRequest{
NodeId: identifier,
@@ -435,7 +393,6 @@ var renameNodeCmd = &cobra.Command{
SuccessOutput(response.GetNode(), "Node renamed", output)
return nil
})
if err != nil {
return
}
@@ -447,7 +404,7 @@ var deleteNodeCmd = &cobra.Command{
Short: "Delete a node",
Aliases: []string{"del"},
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
output := GetOutputFlag(cmd)
identifier, err := GetNodeIdentifier(cmd)
if err != nil {
@@ -477,7 +434,6 @@ var deleteNodeCmd = &cobra.Command{
nodeName = getResponse.GetNode().GetName()
return nil
})
if err != nil {
return
}
@@ -502,7 +458,7 @@ var deleteNodeCmd = &cobra.Command{
deleteRequest := &v1.DeleteNodeRequest{
NodeId: identifier,
}
response, err := client.DeleteNode(ctx, deleteRequest)
if output != "" {
SuccessOutput(response, "", output)
@@ -523,7 +479,6 @@ var deleteNodeCmd = &cobra.Command{
)
return nil
})
if err != nil {
return
}
@@ -538,7 +493,7 @@ var moveNodeCmd = &cobra.Command{
Short: "Move node to another user",
Aliases: []string{"mv"},
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
output := GetOutputFlag(cmd)
identifier, err := GetNodeIdentifier(cmd)
if err != nil {
@@ -593,7 +548,6 @@ var moveNodeCmd = &cobra.Command{
SuccessOutput(moveResponse.GetNode(), "Node moved to another user", output)
return nil
})
if err != nil {
return
}
@@ -618,7 +572,7 @@ it can be run to remove the IPs that should no longer
be assigned to nodes.`,
Run: func(cmd *cobra.Command, args []string) {
var err error
output, _ := cmd.Flags().GetString("output")
output := GetOutputFlag(cmd)
confirm := false
prompt := &survey.Confirm{
@@ -643,7 +597,6 @@ be assigned to nodes.`,
SuccessOutput(changes, "Node IPs backfilled successfully", output)
return nil
})
if err != nil {
return
}
@@ -829,8 +782,8 @@ var tagCmd = &cobra.Command{
Short: "Manage the tags of a node",
Aliases: []string{"tags", "t"},
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
output := GetOutputFlag(cmd)
// retrieve flags from CLI
identifier, err := GetNodeIdentifier(cmd)
if err != nil {
@@ -876,7 +829,6 @@ var tagCmd = &cobra.Command{
}
return nil
})
if err != nil {
return
}
@@ -887,8 +839,8 @@ var approveRoutesCmd = &cobra.Command{
Use: "approve-routes",
Short: "Manage the approved routes of a node",
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
output := GetOutputFlag(cmd)
// retrieve flags from CLI
identifier, err := GetNodeIdentifier(cmd)
if err != nil {
@@ -934,7 +886,6 @@ var approveRoutesCmd = &cobra.Command{
}
return nil
})
if err != nil {
return
}

View File

@@ -41,8 +41,8 @@ var getPolicy = &cobra.Command{
Short: "Print the current ACL Policy",
Aliases: []string{"show", "view", "fetch"},
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
output := GetOutputFlag(cmd)
err := WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
request := &v1.GetPolicyRequest{}
@@ -58,7 +58,6 @@ var getPolicy = &cobra.Command{
SuccessOutput("", response.GetPolicy(), "")
return nil
})
if err != nil {
return
}
@@ -73,7 +72,7 @@ var setPolicy = &cobra.Command{
This command only works when the acl.policy_mode is set to "db", and the policy will be stored in the database.`,
Aliases: []string{"put", "update"},
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
output := GetOutputFlag(cmd)
policyPath, _ := cmd.Flags().GetString("file")
f, err := os.Open(policyPath)
@@ -100,7 +99,6 @@ var setPolicy = &cobra.Command{
SuccessOutput(nil, "Policy updated.", "")
return nil
})
if err != nil {
return
}
@@ -111,23 +109,26 @@ var checkPolicy = &cobra.Command{
Use: "check",
Short: "Check the Policy file for errors",
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
output := GetOutputFlag(cmd)
policyPath, _ := cmd.Flags().GetString("file")
f, err := os.Open(policyPath)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error opening the policy file: %s", err), output)
return
}
defer f.Close()
policyBytes, err := io.ReadAll(f)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error reading the policy file: %s", err), output)
return
}
_, err = policy.NewPolicyManager(policyBytes, nil, views.Slice[types.NodeView]{})
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error parsing the policy file: %s", err), output)
return
}
SuccessOutput(nil, "Policy is valid", "")

View File

@@ -15,16 +15,10 @@ import (
"google.golang.org/protobuf/types/known/timestamppb"
)
func init() {
rootCmd.AddCommand(preauthkeysCmd)
preauthkeysCmd.PersistentFlags().Uint64P("user", "u", 0, "User identifier (ID)")
preauthkeysCmd.PersistentFlags().StringP("namespace", "n", "", "User")
pakNamespaceFlag := preauthkeysCmd.PersistentFlags().Lookup("namespace")
pakNamespaceFlag.Deprecated = deprecateNamespaceMessage
pakNamespaceFlag.Hidden = true
err := preauthkeysCmd.MarkPersistentFlagRequired("user")
if err != nil {
log.Fatal().Err(err).Msg("")
@@ -53,7 +47,7 @@ var listPreAuthKeys = &cobra.Command{
Short: "List the preauthkeys for this user",
Aliases: []string{"ls", "show"},
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
output := GetOutputFlag(cmd)
user, err := cmd.Flags().GetUint64("user")
if err != nil {
@@ -130,7 +124,6 @@ var listPreAuthKeys = &cobra.Command{
}
return nil
})
if err != nil {
return
}
@@ -142,7 +135,7 @@ var createPreAuthKeyCmd = &cobra.Command{
Short: "Creates a new preauthkey in the specified user",
Aliases: []string{"c", "new"},
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
output := GetOutputFlag(cmd)
user, err := cmd.Flags().GetUint64("user")
if err != nil {
@@ -195,7 +188,6 @@ var createPreAuthKeyCmd = &cobra.Command{
SuccessOutput(response.GetPreAuthKey(), response.GetPreAuthKey().GetKey(), output)
return nil
})
if err != nil {
return
}
@@ -214,7 +206,7 @@ var expirePreAuthKeyCmd = &cobra.Command{
return nil
},
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
output := GetOutputFlag(cmd)
user, err := cmd.Flags().GetUint64("user")
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output)
@@ -240,7 +232,6 @@ var expirePreAuthKeyCmd = &cobra.Command{
SuccessOutput(response, "Key expired", output)
return nil
})
if err != nil {
return
}

View File

@@ -14,7 +14,6 @@ import (
"github.com/tcnksm/go-latest"
)
var cfgFile string = ""
func init() {

View File

@@ -28,11 +28,11 @@ func TestServeCommandArgs(t *testing.T) {
// Test that the Args function is defined and accepts any arguments
// The current implementation always returns nil (accepts any args)
assert.NotNil(t, serveCmd.Args)
// Test the args function directly
err := serveCmd.Args(serveCmd, []string{})
assert.NoError(t, err, "Args function should accept empty arguments")
err = serveCmd.Args(serveCmd, []string{"extra", "args"})
assert.NoError(t, err, "Args function should accept extra arguments")
}
@@ -48,7 +48,7 @@ func TestServeCommandStructure(t *testing.T) {
// Test basic command structure
assert.Equal(t, "serve", serveCmd.Name())
assert.Equal(t, "Launches the headscale server", serveCmd.Short)
// Test that it has no subcommands (it's a leaf command)
subcommands := serveCmd.Commands()
assert.Empty(t, subcommands, "Serve command should not have subcommands")
@@ -67,4 +67,4 @@ func TestServeCommandStructure(t *testing.T) {
// 4. Add graceful shutdown capabilities for testing
// 5. Allow server startup to be cancelled via context
//
// For now, we test the command structure and basic properties.
// For now, we test the command structure and basic properties.

View File

@@ -8,10 +8,9 @@ import (
)
const (
deprecateNamespaceMessage = "use --user"
HeadscaleDateTimeFormat = "2006-01-02 15:04:05"
DefaultAPIKeyExpiry = "90d"
DefaultPreAuthKeyExpiry = "1h"
HeadscaleDateTimeFormat = "2006-01-02 15:04:05"
DefaultAPIKeyExpiry = "90d"
DefaultPreAuthKeyExpiry = "1h"
)
// FilterTableColumns filters table columns based on --columns flag
@@ -23,7 +22,7 @@ func FilterTableColumns(cmd *cobra.Command, tableData pterm.TableData) pterm.Tab
headers := tableData[0]
wantedColumns := strings.Split(columns, ",")
// Find column indices
var indices []int
for _, wanted := range wantedColumns {
@@ -53,4 +52,4 @@ func FilterTableColumns(cmd *cobra.Command, tableData pterm.TableData) pterm.Tab
}
return filtered
}
}

View File

@@ -18,16 +18,12 @@ import (
func usernameAndIDFlag(cmd *cobra.Command) {
cmd.Flags().StringP("user", "u", "", "User identifier (ID, name, or email)")
cmd.Flags().Uint64P("identifier", "i", 0, "User identifier (ID) - deprecated, use --user")
identifierFlag := cmd.Flags().Lookup("identifier")
identifierFlag.Deprecated = "use --user"
identifierFlag.Hidden = true
cmd.Flags().StringP("name", "n", "", "Username")
}
// usernameAndIDFromFlag returns the user ID using smart lookup.
// userIDFromFlag returns the user ID using smart lookup.
// If no user is specified, it will exit the program with an error.
func usernameAndIDFromFlag(cmd *cobra.Command) (uint64, string) {
func userIDFromFlag(cmd *cobra.Command) uint64 {
userID, err := GetUserIdentifier(cmd)
if err != nil {
ErrorOutput(
@@ -37,7 +33,7 @@ func usernameAndIDFromFlag(cmd *cobra.Command) (uint64, string) {
)
}
return userID, ""
return userID
}
func init() {
@@ -52,11 +48,6 @@ func init() {
listUsersCmd.Flags().Uint64P("id", "", 0, "Filter by user ID")
listUsersCmd.Flags().StringP("name", "n", "", "Filter by username")
listUsersCmd.Flags().StringP("email", "e", "", "Filter by email address")
// Backward compatibility (deprecated)
listUsersCmd.Flags().Uint64P("identifier", "i", 0, "Filter by user ID - deprecated, use --id")
identifierFlag := listUsersCmd.Flags().Lookup("identifier")
identifierFlag.Deprecated = "use --id"
identifierFlag.Hidden = true
listUsersCmd.Flags().String("columns", "", "Comma-separated list of columns to display (ID,Name,Username,Email,Created)")
userCmd.AddCommand(destroyUserCmd)
usernameAndIDFlag(destroyUserCmd)
@@ -117,7 +108,7 @@ var createUserCmd = &cobra.Command{
err := WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
log.Trace().Interface("client", client).Msg("Obtained gRPC client")
log.Trace().Interface("request", request).Msg("Sending CreateUser request")
response, err := client.CreateUser(ctx, request)
if err != nil {
ErrorOutput(
@@ -131,7 +122,6 @@ var createUserCmd = &cobra.Command{
SuccessOutput(response.GetUser(), "User created", output)
return nil
})
if err != nil {
return
}
@@ -145,10 +135,9 @@ var destroyUserCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) {
output := GetOutputFlag(cmd)
id, username := usernameAndIDFromFlag(cmd)
id := userIDFromFlag(cmd)
request := &v1.ListUsersRequest{
Name: username,
Id: id,
Id: id,
}
var user *v1.User
@@ -176,7 +165,6 @@ var destroyUserCmd = &cobra.Command{
user = users.GetUsers()[0]
return nil
})
if err != nil {
return
}
@@ -212,7 +200,6 @@ var destroyUserCmd = &cobra.Command{
SuccessOutput(response, "User destroyed", output)
return nil
})
if err != nil {
return
}
@@ -247,8 +234,6 @@ var listUsersCmd = &cobra.Command{
// Check specific filter flags
if id, _ := cmd.Flags().GetUint64("id"); id > 0 {
request.Id = id
} else if identifier, _ := cmd.Flags().GetUint64("identifier"); identifier > 0 {
request.Id = identifier // backward compatibility
} else if name, _ := cmd.Flags().GetString("name"); name != "" {
request.Name = name
} else if email, _ := cmd.Flags().GetString("email"); email != "" {
@@ -296,7 +281,6 @@ var listUsersCmd = &cobra.Command{
}
return nil
})
if err != nil {
// Error already handled in closure
return
@@ -311,13 +295,12 @@ var renameUserCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) {
output := GetOutputFlag(cmd)
id, username := usernameAndIDFromFlag(cmd)
id := userIDFromFlag(cmd)
newName, _ := cmd.Flags().GetString("new-name")
err := WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
listReq := &v1.ListUsersRequest{
Name: username,
Id: id,
Id: id,
}
users, err := client.ListUsers(ctx, listReq)
@@ -358,7 +341,6 @@ var renameUserCmd = &cobra.Command{
SuccessOutput(response.GetUser(), "User renamed", output)
return nil
})
if err != nil {
return
}

View File

@@ -22,10 +22,6 @@ import (
"gopkg.in/yaml.v3"
)
const (
SocketWritePermissions = 0o666
)
func newHeadscaleServerWithConfig() (*hscontrol.Headscale, error) {
cfg, err := types.LoadServerConfig()
if err != nil {
@@ -75,7 +71,7 @@ func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *g
// Try to give the user better feedback if we cannot write to the headscale
// socket.
socket, err := os.OpenFile(cfg.UnixSocket, os.O_WRONLY, SocketWritePermissions) // nolint
socket, err := os.OpenFile(cfg.UnixSocket, os.O_WRONLY, 0o666) // nolint
if err != nil {
if os.IsPermission(err) {
log.Fatal().
@@ -210,21 +206,16 @@ func GetOutputFlag(cmd *cobra.Command) string {
return output
}
// GetNodeIdentifier returns the node ID using smart lookup via gRPC ListNodes call
func GetNodeIdentifier(cmd *cobra.Command) (uint64, error) {
nodeFlag, _ := cmd.Flags().GetString("node")
identifierFlag, _ := cmd.Flags().GetUint64("identifier")
// Check if --identifier (deprecated) was used
if identifierFlag > 0 {
return identifierFlag, nil
}
// Use --node flag
if nodeFlag == "" {
return 0, fmt.Errorf("--node flag is required")
}
// Use smart lookup via gRPC
return lookupNodeBySpecifier(nodeFlag)
}
@@ -232,10 +223,10 @@ func GetNodeIdentifier(cmd *cobra.Command) (uint64, error) {
// lookupNodeBySpecifier performs smart lookup of a node by ID, name, hostname, or IP
func lookupNodeBySpecifier(specifier string) (uint64, error) {
var nodeID uint64
err := WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
request := &v1.ListNodesRequest{}
// Detect what type of specifier this is and set appropriate filter
if id, err := strconv.ParseUint(specifier, 10, 64); err == nil && id > 0 {
// Looks like a numeric ID
@@ -247,17 +238,17 @@ func lookupNodeBySpecifier(specifier string) (uint64, error) {
// Treat as hostname/name
request.Name = specifier
}
response, err := client.ListNodes(ctx, request)
if err != nil {
return fmt.Errorf("failed to lookup node: %w", err)
}
nodes := response.GetNodes()
if len(nodes) == 0 {
return fmt.Errorf("no node found matching '%s'", specifier)
}
if len(nodes) > 1 {
var nodeInfo []string
for _, node := range nodes {
@@ -265,16 +256,15 @@ func lookupNodeBySpecifier(specifier string) (uint64, error) {
}
return fmt.Errorf("multiple nodes found matching '%s': %s", specifier, strings.Join(nodeInfo, ", "))
}
// Exactly one match - this is what we want
nodeID = nodes[0].GetId()
return nil
})
if err != nil {
return 0, err
}
return nodeID, nil
}
@@ -295,21 +285,18 @@ func isIPAddress(s string) bool {
func GetUserIdentifier(cmd *cobra.Command) (uint64, error) {
userFlag, _ := cmd.Flags().GetString("user")
nameFlag, _ := cmd.Flags().GetString("name")
identifierFlag, _ := cmd.Flags().GetUint64("identifier")
var specifier string
// Determine which flag was used (prefer --user, fall back to legacy flags)
if userFlag != "" {
specifier = userFlag
} else if nameFlag != "" {
specifier = nameFlag
} else if identifierFlag > 0 {
return identifierFlag, nil // Direct ID, no lookup needed
} else {
return 0, fmt.Errorf("--user flag is required")
}
// Use smart lookup via gRPC
return lookupUserBySpecifier(specifier)
}
@@ -317,10 +304,10 @@ func GetUserIdentifier(cmd *cobra.Command) (uint64, error) {
// lookupUserBySpecifier performs smart lookup of a user by ID, name, or email
func lookupUserBySpecifier(specifier string) (uint64, error) {
var userID uint64
err := WithClient(func(ctx context.Context, client v1.HeadscaleServiceClient) error {
request := &v1.ListUsersRequest{}
// Detect what type of specifier this is and set appropriate filter
if id, err := strconv.ParseUint(specifier, 10, 64); err == nil && id > 0 {
// Looks like a numeric ID
@@ -332,17 +319,17 @@ func lookupUserBySpecifier(specifier string) (uint64, error) {
// Treat as username
request.Name = specifier
}
response, err := client.ListUsers(ctx, request)
if err != nil {
return fmt.Errorf("failed to lookup user: %w", err)
}
users := response.GetUsers()
if len(users) == 0 {
return fmt.Errorf("no user found matching '%s'", specifier)
}
if len(users) > 1 {
var userInfo []string
for _, user := range users {
@@ -350,15 +337,14 @@ func lookupUserBySpecifier(specifier string) (uint64, error) {
}
return fmt.Errorf("multiple users found matching '%s': %s", specifier, strings.Join(userInfo, ", "))
}
// Exactly one match - this is what we want
userID = users[0].GetId()
return nil
})
if err != nil {
return 0, err
}
return userID, nil
}

View File

@@ -172,4 +172,4 @@ func TestOutputWithEmptyData(t *testing.T) {
emptyMap := map[string]string{}
result = output(emptyMap, "fallback", "json")
assert.Equal(t, "{}", result)
}
}

View File

@@ -14,7 +14,7 @@ var versionCmd = &cobra.Command{
Short: "Print the version",
Long: "The version of headscale",
Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
output := GetOutputFlag(cmd)
SuccessOutput(map[string]string{
"version": types.Version,
"commit": types.GitCommitHash,

View File

@@ -39,7 +39,7 @@ func TestVersionCommandFlags(t *testing.T) {
func TestVersionCommandRun(t *testing.T) {
// Test that Run function is set
assert.NotNil(t, versionCmd.Run)
// We can't easily test the actual execution without mocking SuccessOutput
// but we can verify the function exists and has the right signature
}
}