diff --git a/grpcv1.go b/grpcv1.go index 620b8fe0..cadd6304 100644 --- a/grpcv1.go +++ b/grpcv1.go @@ -106,6 +106,18 @@ func (api headscaleV1APIServer) CreatePreAuthKey( expiration = request.GetExpiration().AsTime() } + if len(request.AclTags) > 0 { + for _, tag := range request.AclTags { + err := validateTag(tag) + + if err != nil { + return &v1.CreatePreAuthKeyResponse{ + PreAuthKey: nil, + }, status.Error(codes.InvalidArgument, err.Error()) + } + } + } + preAuthKey, err := api.h.CreatePreAuthKey( request.GetNamespace(), request.GetReusable(), diff --git a/integration_cli_test.go b/integration_cli_test.go index d2e28bee..b016ac29 100644 --- a/integration_cli_test.go +++ b/integration_cli_test.go @@ -260,6 +260,8 @@ func (s *IntegrationCLITestSuite) TestPreAuthKeyCommand() { "24h", "--output", "json", + "--tags", + "tag:test1,tag:test2", }, []string{}, ) @@ -333,6 +335,11 @@ func (s *IntegrationCLITestSuite) TestPreAuthKeyCommand() { listedPreAuthKeys[4].Expiration.AsTime().Before(time.Now().Add(time.Hour*26)), ) + // Test that tags are present + for i := 0; i < count; i++ { + assert.DeepEquals(listedPreAuthKeys[i].AclTags, []string{"tag:test1,", "tag:test2"}) + } + // Expire three keys for i := 0; i < 3; i++ { _, err := ExecuteCommand( diff --git a/preauth_keys.go b/preauth_keys.go index 7995acb6..6ebd6560 100644 --- a/preauth_keys.go +++ b/preauth_keys.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "strconv" + "strings" "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" @@ -55,6 +56,12 @@ func (h *Headscale) CreatePreAuthKey( return nil, err } + for _, tag := range aclTags { + if !strings.HasPrefix(tag, "tag:") { + return nil, fmt.Errorf("aclTag '%s' did not begin with 'tag:'", tag) + } + } + now := time.Now().UTC() kstr, err := h.generateKey() if err != nil { @@ -77,12 +84,17 @@ func (h *Headscale) CreatePreAuthKey( } if len(aclTags) > 0 { + seenTags := map[string]bool{} + for _, tag := range aclTags { - if err := db.Save(&PreAuthKeyAclTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil { - return fmt.Errorf( - "failed to create key tag in the database: %w", - err, - ) + if seenTags[tag] == false { + if err := db.Save(&PreAuthKeyAclTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil { + return fmt.Errorf( + "failed to ceate key tag in the database: %w", + err, + ) + } + seenTags[tag] = true } } } @@ -222,7 +234,7 @@ func (key *PreAuthKey) toProto() *v1.PreAuthKey { if len(key.AclTags) > 0 { for idx := range key.AclTags { - protoKey.AclTags[idx] = key.AclTags[0].Tag + protoKey.AclTags[idx] = key.AclTags[idx].Tag } } diff --git a/preauth_keys_test.go b/preauth_keys_test.go index fe108337..ffcaf9a4 100644 --- a/preauth_keys_test.go +++ b/preauth_keys_test.go @@ -190,3 +190,20 @@ func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) { _, err = app.checkKeyValidity(pak.Key) c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) } + +func (*Suite) TestPreAuthKeyAclTags(c *check.C) { + namespace, err := app.CreateNamespace("test8") + c.Assert(err, check.IsNil) + + _, err = app.CreatePreAuthKey(namespace.Name, false, false, nil, []string{"badtag"}) + c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected + + tags := []string{"tag:test1", "tag:test2"} + tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"} + _, err = app.CreatePreAuthKey(namespace.Name, false, false, nil, tagsWithDuplicate) + c.Assert(err, check.IsNil) + + listedPaks, err := app.ListPreAuthKeys("test8") + c.Assert(err, check.IsNil) + c.Assert(listedPaks[0].toProto().AclTags, check.DeepEquals, tags) +} diff --git a/protocol_common.go b/protocol_common.go index 154c14c9..72b38f4b 100644 --- a/protocol_common.go +++ b/protocol_common.go @@ -345,6 +345,7 @@ func (h *Headscale) handleAuthKeyCommon( machine.NodeKey = nodeKey machine.AuthKeyID = uint(pak.ID) err := h.RefreshMachine(machine, registerRequest.Expiry) + if err != nil { log.Error(). Caller(). @@ -355,6 +356,25 @@ func (h *Headscale) handleAuthKeyCommon( return } + + aclTags := pak.toProto().AclTags + if len(aclTags) > 0 { + // This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login + err = h.SetTags(machine, aclTags) + } + + if err != nil { + log.Error(). + Caller(). + Bool("noise", machineKey.IsZero()). + Str("machine", machine.Hostname). + Strs("aclTags", aclTags). + Err(err). + Msg("Failed to set tags after refreshing machine") + + return + } + } else { now := time.Now().UTC() @@ -380,6 +400,7 @@ func (h *Headscale) handleAuthKeyCommon( NodeKey: nodeKey, LastSeen: &now, AuthKeyID: uint(pak.ID), + ForcedTags: pak.toProto().AclTags, } machine, err = h.RegisterMachine(