Files
headscale/hscontrol/db/preauth_keys.go
Kristoffer Dalby eb788cd007 make tags first class node owner (#2885)
This PR changes tags to be something that exists on nodes in addition to users, to being its own thing. It is part of moving our tags support towards the correct tailscale compatible implementation.

There are probably rough edges in this PR, but the intention is to get it in, and then start fixing bugs from 0.28.0 milestone (long standing tags issue) to discover what works and what doesnt.

Updates #2417
Closes #2619
2025-12-02 12:01:25 +01:00

348 lines
9.5 KiB
Go

package db
import (
"errors"
"fmt"
"slices"
"strings"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
"tailscale.com/util/set"
)
var (
ErrPreAuthKeyNotFound = errors.New("auth-key not found")
ErrPreAuthKeyExpired = errors.New("auth-key expired")
ErrSingleUseAuthKeyHasBeenUsed = errors.New("auth-key has already been used")
ErrUserMismatch = errors.New("user mismatch")
ErrPreAuthKeyACLTagInvalid = errors.New("auth-key tag is invalid")
)
func (hsdb *HSDatabase) CreatePreAuthKey(
uid *types.UserID,
reusable bool,
ephemeral bool,
expiration *time.Time,
aclTags []string,
) (*types.PreAuthKeyNew, error) {
return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKeyNew, error) {
return CreatePreAuthKey(tx, uid, reusable, ephemeral, expiration, aclTags)
})
}
const (
authKeyPrefix = "hskey-auth-"
authKeyPrefixLength = 12
authKeyLength = 64
)
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
// The uid parameter can be nil for system-created tagged keys.
// For tagged keys, uid tracks "created by" (who created the key).
// For user-owned keys, uid tracks the node owner.
func CreatePreAuthKey(
tx *gorm.DB,
uid *types.UserID,
reusable bool,
ephemeral bool,
expiration *time.Time,
aclTags []string,
) (*types.PreAuthKeyNew, error) {
// Validate: must be tagged OR user-owned, not neither
if uid == nil && len(aclTags) == 0 {
return nil, ErrPreAuthKeyNotTaggedOrOwned
}
// If uid != nil && len(aclTags) > 0:
// Both are allowed: UserID tracks "created by", tags define node ownership
// This is valid per the new model
var (
user *types.User
userID *uint
)
if uid != nil {
var err error
user, err = GetUserByID(tx, *uid)
if err != nil {
return nil, err
}
userID = &user.ID
}
// Remove duplicates and sort for consistency
aclTags = set.SetOf(aclTags).Slice()
slices.Sort(aclTags)
// TODO(kradalby): factor out and create a reusable tag validation,
// check if there is one in Tailscale's lib.
for _, tag := range aclTags {
if !strings.HasPrefix(tag, "tag:") {
return nil, fmt.Errorf(
"%w: '%s' did not begin with 'tag:'",
ErrPreAuthKeyACLTagInvalid,
tag,
)
}
}
now := time.Now().UTC()
prefix, err := util.GenerateRandomStringURLSafe(authKeyPrefixLength)
if err != nil {
return nil, err
}
// Validate generated prefix (should always be valid, but be defensive)
if len(prefix) != authKeyPrefixLength {
return nil, fmt.Errorf("%w: generated prefix has invalid length: expected %d, got %d", ErrPreAuthKeyFailedToParse, authKeyPrefixLength, len(prefix))
}
if !isValidBase64URLSafe(prefix) {
return nil, fmt.Errorf("%w: generated prefix contains invalid characters", ErrPreAuthKeyFailedToParse)
}
toBeHashed, err := util.GenerateRandomStringURLSafe(authKeyLength)
if err != nil {
return nil, err
}
// Validate generated hash (should always be valid, but be defensive)
if len(toBeHashed) != authKeyLength {
return nil, fmt.Errorf("%w: generated hash has invalid length: expected %d, got %d", ErrPreAuthKeyFailedToParse, authKeyLength, len(toBeHashed))
}
if !isValidBase64URLSafe(toBeHashed) {
return nil, fmt.Errorf("%w: generated hash contains invalid characters", ErrPreAuthKeyFailedToParse)
}
keyStr := authKeyPrefix + prefix + "-" + toBeHashed
hash, err := bcrypt.GenerateFromPassword([]byte(toBeHashed), bcrypt.DefaultCost)
if err != nil {
return nil, err
}
key := types.PreAuthKey{
UserID: userID, // nil for system-created keys, or "created by" for tagged keys
User: user, // nil for system-created keys
Reusable: reusable,
Ephemeral: ephemeral,
CreatedAt: &now,
Expiration: expiration,
Tags: aclTags, // empty for user-owned keys
Prefix: prefix, // Store prefix
Hash: hash, // Store hash
}
if err := tx.Save(&key).Error; err != nil {
return nil, fmt.Errorf("failed to create key in the database: %w", err)
}
return &types.PreAuthKeyNew{
ID: key.ID,
Key: keyStr,
Reusable: key.Reusable,
Ephemeral: key.Ephemeral,
Tags: key.Tags,
Expiration: key.Expiration,
CreatedAt: key.CreatedAt,
User: key.User,
}, nil
}
func (hsdb *HSDatabase) ListPreAuthKeys(uid types.UserID) ([]types.PreAuthKey, error) {
return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) {
return ListPreAuthKeysByUser(rx, uid)
})
}
// ListPreAuthKeysByUser returns the list of PreAuthKeys for a user.
func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, error) {
user, err := GetUserByID(tx, uid)
if err != nil {
return nil, err
}
keys := []types.PreAuthKey{}
err = tx.Preload("User").Where(&types.PreAuthKey{UserID: &user.ID}).Find(&keys).Error
if err != nil {
return nil, err
}
return keys, nil
}
var (
ErrPreAuthKeyFailedToParse = errors.New("failed to parse auth-key")
ErrPreAuthKeyNotTaggedOrOwned = errors.New("auth-key must be either tagged or owned by user")
)
func findAuthKey(tx *gorm.DB, keyStr string) (*types.PreAuthKey, error) {
var pak types.PreAuthKey
// Validate input is not empty
if keyStr == "" {
return nil, ErrPreAuthKeyFailedToParse
}
_, prefixAndHash, found := strings.Cut(keyStr, authKeyPrefix)
if !found {
// Legacy format (plaintext) - backwards compatibility
err := tx.Preload("User").First(&pak, "key = ?", keyStr).Error
if err != nil {
return nil, ErrPreAuthKeyNotFound
}
return &pak, nil
}
// New format: hskey-auth-{12-char-prefix}-{64-char-hash}
// Expected minimum length: 12 (prefix) + 1 (separator) + 64 (hash) = 77
const expectedMinLength = authKeyPrefixLength + 1 + authKeyLength
if len(prefixAndHash) < expectedMinLength {
return nil, fmt.Errorf(
"%w: key too short, expected at least %d chars after prefix, got %d",
ErrPreAuthKeyFailedToParse,
expectedMinLength,
len(prefixAndHash),
)
}
// Use fixed-length parsing instead of separator-based to handle dashes in base64 URL-safe
prefix := prefixAndHash[:authKeyPrefixLength]
// Validate separator at expected position
if prefixAndHash[authKeyPrefixLength] != '-' {
return nil, fmt.Errorf(
"%w: expected separator '-' at position %d, got '%c'",
ErrPreAuthKeyFailedToParse,
authKeyPrefixLength,
prefixAndHash[authKeyPrefixLength],
)
}
hash := prefixAndHash[authKeyPrefixLength+1:]
// Validate hash length
if len(hash) != authKeyLength {
return nil, fmt.Errorf(
"%w: hash length mismatch, expected %d chars, got %d",
ErrPreAuthKeyFailedToParse,
authKeyLength,
len(hash),
)
}
// Validate prefix contains only base64 URL-safe characters
if !isValidBase64URLSafe(prefix) {
return nil, fmt.Errorf(
"%w: prefix contains invalid characters (expected base64 URL-safe: A-Za-z0-9_-)",
ErrPreAuthKeyFailedToParse,
)
}
// Validate hash contains only base64 URL-safe characters
if !isValidBase64URLSafe(hash) {
return nil, fmt.Errorf(
"%w: hash contains invalid characters (expected base64 URL-safe: A-Za-z0-9_-)",
ErrPreAuthKeyFailedToParse,
)
}
// Look up key by prefix
err := tx.Preload("User").First(&pak, "prefix = ?", prefix).Error
if err != nil {
return nil, ErrPreAuthKeyNotFound
}
// Verify hash matches
err = bcrypt.CompareHashAndPassword(pak.Hash, []byte(hash))
if err != nil {
return nil, fmt.Errorf("invalid auth key: %w", err)
}
return &pak, nil
}
// isValidBase64URLSafe checks if a string contains only base64 URL-safe characters.
func isValidBase64URLSafe(s string) bool {
for _, c := range s {
if (c < 'A' || c > 'Z') && (c < 'a' || c > 'z') && (c < '0' || c > '9') && c != '-' && c != '_' {
return false
}
}
return true
}
func (hsdb *HSDatabase) GetPreAuthKey(key string) (*types.PreAuthKey, error) {
return GetPreAuthKey(hsdb.DB, key)
}
// GetPreAuthKey returns a PreAuthKey for a given key. The caller is responsible
// for checking if the key is usable (expired or used).
func GetPreAuthKey(tx *gorm.DB, key string) (*types.PreAuthKey, error) {
return findAuthKey(tx, key)
}
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
// does not exist. This also clears the auth_key_id on any nodes that reference
// this key.
func DestroyPreAuthKey(tx *gorm.DB, pak types.PreAuthKey) error {
return tx.Transaction(func(db *gorm.DB) error {
// First, clear the foreign key reference on any nodes using this key
err := db.Model(&types.Node{}).
Where("auth_key_id = ?", pak.ID).
Update("auth_key_id", nil).Error
if err != nil {
return fmt.Errorf("failed to clear auth_key_id on nodes: %w", err)
}
// Then delete the pre-auth key
if result := db.Unscoped().Delete(pak); result.Error != nil {
return result.Error
}
return nil
})
}
func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error {
return hsdb.Write(func(tx *gorm.DB) error {
return ExpirePreAuthKey(tx, k)
})
}
func (hsdb *HSDatabase) DeletePreAuthKey(k *types.PreAuthKey) error {
return hsdb.Write(func(tx *gorm.DB) error {
return DestroyPreAuthKey(tx, *k)
})
}
// UsePreAuthKey marks a PreAuthKey as used.
func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
err := tx.Model(k).Update("used", true).Error
if err != nil {
return fmt.Errorf("failed to update key used status in the database: %w", err)
}
k.Used = true
return nil
}
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
func ExpirePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
now := time.Now()
return tx.Model(&types.PreAuthKey{}).Where("id = ?", k.ID).Update("expiration", now).Error
}