diff --git a/api.go b/api.go index 7d535322..60ca63f2 100644 --- a/api.go +++ b/api.go @@ -125,25 +125,40 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) { machine, err := h.GetMachineByMachineKey(machineKey) if errors.Is(err, gorm.ErrRecordNotFound) { log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine") - newMachine := Machine{ - Expiry: &time.Time{}, - MachineKey: MachinePublicKeyStripPrefix(machineKey), - Name: req.Hostinfo.Hostname, - } - if err := h.db.Create(&newMachine).Error; err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Could not create row") - machineRegistrations.WithLabelValues("unknown", "web", "error", machine.Namespace.Name). - Inc() + + machineKeyStr := MachinePublicKeyStripPrefix(machineKey) + + // If the machine has AuthKey set, handle registration via PreAuthKeys + if req.Auth.AuthKey != "" { + h.handleAuthKey(ctx, machineKey, req) return } - machine = &newMachine + + // The machine did not have a key to authenticate, which means + // that we rely on a method that calls back some how (OpenID or CLI) + // We create the machine and then keep it around until a callback + // happens + newMachine := Machine{ + Expiry: &time.Time{}, + MachineKey: machineKeyStr, + Name: req.Hostinfo.Hostname, + NodeKey: NodePublicKeyStripPrefix(req.NodeKey), + LastSeen: &now, + } + + h.registrationCache.Set( + machineKeyStr, + newMachine, + requestedExpiryCacheExpiration, + ) + + h.handleMachineRegistrationNew(ctx, machineKey, req) + return } - if machine.Registered { + // The machine is already registered, so we need to pass through reauth or key update. + if machine != nil { // If the NodeKey stored in headscale is the same as the key presented in a registration // request, then we have a node that is either: // - Trying to log out (sending a expiry in the past) @@ -180,15 +195,6 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) { return } - - // If the machine has AuthKey set, handle registration via PreAuthKeys - if req.Auth.AuthKey != "" { - h.handleAuthKey(ctx, machineKey, req, *machine) - - return - } - - h.handleMachineRegistrationNew(ctx, machineKey, req, *machine) } func (h *Headscale) getMapResponse( @@ -402,7 +408,7 @@ func (h *Headscale) handleMachineExpired( Msg("Machine registration has expired. Sending a authurl to register") if registerRequest.Auth.AuthKey != "" { - h.handleAuthKey(ctx, machineKey, registerRequest, machine) + h.handleAuthKey(ctx, machineKey, registerRequest) return } @@ -465,13 +471,12 @@ func (h *Headscale) handleMachineRegistrationNew( ctx *gin.Context, machineKey key.MachinePublic, registerRequest tailcfg.RegisterRequest, - machine Machine, ) { resp := tailcfg.RegisterResponse{} // The machine registration is new, redirect the client to the registration URL log.Debug(). - Str("machine", machine.Name). + Str("machine", registerRequest.Hostinfo.Hostname). Msg("The node is sending us a new NodeKey, sending auth url") if h.cfg.OIDC.Issuer != "" { resp.AuthURL = fmt.Sprintf( @@ -487,7 +492,7 @@ func (h *Headscale) handleMachineRegistrationNew( if !registerRequest.Expiry.IsZero() { log.Trace(). Caller(). - Str("machine", machine.Name). + Str("machine", registerRequest.Hostinfo.Hostname). Time("expiry", registerRequest.Expiry). Msg("Non-zero expiry time requested, adding to cache") h.requestedExpiryCache.Set( @@ -497,11 +502,6 @@ func (h *Headscale) handleMachineRegistrationNew( ) } - machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey) - - // save the NodeKey - h.db.Save(&machine) - respBody, err := encode(resp, &machineKey, h.privateKey) if err != nil { log.Error(). @@ -520,19 +520,21 @@ func (h *Headscale) handleAuthKey( ctx *gin.Context, machineKey key.MachinePublic, registerRequest tailcfg.RegisterRequest, - machine Machine, ) { + machineKeyStr := MachinePublicKeyStripPrefix(machineKey) + log.Debug(). Str("func", "handleAuthKey"). Str("machine", registerRequest.Hostinfo.Hostname). Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname) resp := tailcfg.RegisterResponse{} + pak, err := h.checkKeyValidity(registerRequest.Auth.AuthKey) if err != nil { log.Error(). Caller(). Str("func", "handleAuthKey"). - Str("machine", machine.Name). + Str("machine", registerRequest.Hostinfo.Hostname). Err(err). Msg("Failed authentication via AuthKey") resp.MachineAuthorized = false @@ -541,69 +543,62 @@ func (h *Headscale) handleAuthKey( log.Error(). Caller(). Str("func", "handleAuthKey"). - Str("machine", machine.Name). + Str("machine", registerRequest.Hostinfo.Hostname). Err(err). Msg("Cannot encode message") ctx.String(http.StatusInternalServerError, "") - machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", machine.Namespace.Name). + machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name). Inc() return } + ctx.Data(http.StatusUnauthorized, "application/json; charset=utf-8", respBody) log.Error(). Caller(). Str("func", "handleAuthKey"). - Str("machine", machine.Name). + Str("machine", registerRequest.Hostinfo.Hostname). Msg("Failed authentication via AuthKey") - machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", machine.Namespace.Name). + machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name). Inc() return } - if machine.isRegistered() { - log.Trace(). + log.Debug(). + Str("func", "handleAuthKey"). + Str("machine", registerRequest.Hostinfo.Hostname). + Msg("Authentication key was valid, proceeding to acquire IP addresses") + + nodeKey := NodePublicKeyStripPrefix(registerRequest.NodeKey) + now := time.Now().UTC() + + machine, err := h.RegisterMachine( + registerRequest.Hostinfo.Hostname, + machineKeyStr, + pak.Namespace.Name, + RegisterMethodAuthKey, + ®isterRequest.Expiry, + pak, + &nodeKey, + &now, + ) + if err != nil { + log.Error(). Caller(). - Str("machine", machine.Name). - Msg("machine already registered, reauthenticating") - - h.RefreshMachine(&machine, registerRequest.Expiry) - } else { - log.Debug(). - Str("func", "handleAuthKey"). - Str("machine", machine.Name). - Msg("Authentication key was valid, proceeding to acquire IP addresses") - - nodeKey := NodePublicKeyStripPrefix(registerRequest.NodeKey) - now := time.Now().UTC() - - _, err = h.RegisterMachine( - machine.Name, - machine.Namespace.Name, - RegisterMethodAuthKey, - ®isterRequest.Expiry, - pak, - &nodeKey, - &now, + Err(err). + Msg("could not register machine") + machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name). + Inc() + ctx.String( + http.StatusInternalServerError, + "could not register machine", ) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("could not register machine") - machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", machine.Namespace.Name).Inc() - ctx.String( - http.StatusInternalServerError, - "could not register machine", - ) - return - } + return } - pak.Used = true - h.db.Save(&pak) + h.UsePreAuthKey(pak) resp.MachineAuthorized = true resp.User = *pak.Namespace.toUser() @@ -612,21 +607,21 @@ func (h *Headscale) handleAuthKey( log.Error(). Caller(). Str("func", "handleAuthKey"). - Str("machine", machine.Name). + Str("machine", registerRequest.Hostinfo.Hostname). Err(err). Msg("Cannot encode message") - machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", machine.Namespace.Name). + machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.Namespace.Name). Inc() ctx.String(http.StatusInternalServerError, "Extremely sad!") return } - machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "success", machine.Namespace.Name). + machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "success", pak.Namespace.Name). Inc() ctx.Data(http.StatusOK, "application/json; charset=utf-8", respBody) log.Info(). Str("func", "handleAuthKey"). - Str("machine", machine.Name). + Str("machine", registerRequest.Hostinfo.Hostname). Str("ips", strings.Join(machine.IPAddresses.ToStringSlice(), ", ")). Msg("Successfully authenticated via AuthKey") } diff --git a/app.go b/app.go index eda670ed..63ca10e4 100644 --- a/app.go +++ b/app.go @@ -154,6 +154,8 @@ type Headscale struct { requestedExpiryCache *cache.Cache + registrationCache *cache.Cache + ipAllocationMutex sync.Mutex } @@ -207,6 +209,12 @@ func NewHeadscale(cfg Config) (*Headscale, error) { requestedExpiryCacheCleanupInterval, ) + registrationCache := cache.New( + // TODO(kradalby): Add unified cache expiry config options + requestedExpiryCacheExpiration, + requestedExpiryCacheCleanupInterval, + ) + app := Headscale{ cfg: cfg, dbType: cfg.DBtype, @@ -214,6 +222,7 @@ func NewHeadscale(cfg Config) (*Headscale, error) { privateKey: privKey, aclRules: tailcfg.FilterAllowAll, // default allowall requestedExpiryCache: requestedExpiryCache, + registrationCache: registrationCache, } err = app.initDB() diff --git a/cli_test.go b/cli_test.go index 2eedb5b4..fab8201f 100644 --- a/cli_test.go +++ b/cli_test.go @@ -30,6 +30,7 @@ func (s *Suite) TestRegisterMachine(c *check.C) { c.Assert(err, check.IsNil) machineAfterRegistering, err := app.RegisterMachine( + "testmachine", machine.MachineKey, namespace.Name, RegisterMethodCLI, diff --git a/grpcv1.go b/grpcv1.go index b1396bc8..3a2e19ce 100644 --- a/grpcv1.go +++ b/grpcv1.go @@ -174,12 +174,11 @@ func (api headscaleV1APIServer) RegisterMachine( } } - machine, err := api.h.RegisterMachine( + machine, err := api.h.RegisterMachineFromAuthCallback( request.GetKey(), request.GetNamespace(), RegisterMethodCLI, &requestedTime, - nil, nil, nil, ) if err != nil { return nil, err diff --git a/machine.go b/machine.go index b805de67..74251553 100644 --- a/machine.go +++ b/machine.go @@ -20,11 +20,15 @@ import ( ) const ( - errMachineNotFound = Error("machine not found") - errMachineAlreadyRegistered = Error("machine already registered") - errMachineRouteIsNotAvailable = Error("route is not available on machine") - errMachineAddressesInvalid = Error("failed to parse machine addresses") - errHostnameTooLong = Error("Hostname too long") + errMachineNotFound = Error("machine not found") + errMachineAlreadyRegistered = Error("machine already registered") + errMachineRouteIsNotAvailable = Error("route is not available on machine") + errMachineAddressesInvalid = Error("failed to parse machine addresses") + errMachineNotFoundRegistrationCache = Error( + "machine not found in registration cache", + ) + errCouldNotConvertMachineInterface = Error("failed to convert machine interface") + errHostnameTooLong = Error("Hostname too long") ) const ( @@ -686,14 +690,44 @@ func (machine *Machine) toProto() *v1.Machine { return machineProto } -// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey. -func (h *Headscale) RegisterMachine( +func (h *Headscale) RegisterMachineFromAuthCallback( machineKeyStr string, namespaceName string, registrationMethod string, + expiry *time.Time, +) (*Machine, error) { + if machineInterface, ok := h.registrationCache.Get(machineKeyStr); ok { + if registrationMachine, ok := machineInterface.(Machine); ok { + machine, err := h.RegisterMachine( + registrationMachine.Name, + machineKeyStr, + namespaceName, + registrationMethod, + expiry, + nil, + ®istrationMachine.NodeKey, + registrationMachine.LastSeen, + ) + + return machine, err + + } else { + return nil, errCouldNotConvertMachineInterface + } + } + + return nil, errMachineNotFoundRegistrationCache +} + +// RegisterMachine is executed from the CLI to register a new Machine using its MachineKey. +func (h *Headscale) RegisterMachine( + machineName string, + machineKeyStr string, + namespaceName string, + registrationMethod string, + expiry *time.Time, // Optionals - expiry *time.Time, authKey *PreAuthKey, nodePublicKey *string, lastSeen *time.Time, @@ -768,6 +802,7 @@ func (h *Headscale) RegisterMachine( machine.LastSeen = lastSeen } + machine.Name = machineName machine.NamespaceID = namespace.ID // TODO(kradalby): This field is uneccessary metadata, @@ -780,6 +815,7 @@ func (h *Headscale) RegisterMachine( // Let us simplify the model, a machine is _only_ saved if // it is registered. machine.Registered = true + h.db.Save(&machine) log.Trace(). diff --git a/oidc.go b/oidc.go index 5207c64e..48664aed 100644 --- a/oidc.go +++ b/oidc.go @@ -279,8 +279,6 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { return } - now := time.Now().UTC() - namespaceName, err := NormalizeNamespaceName( claims.Email, h.cfg.OIDC.StripEmaildomain, @@ -328,14 +326,11 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) { return } - _, err = h.RegisterMachine( + _, err = h.RegisterMachineFromAuthCallback( machineKeyStr, namespace.Name, RegisterMethodOIDC, &requestedTime, - nil, - nil, - &now, ) if err != nil { log.Error(). diff --git a/preauth_keys.go b/preauth_keys.go index 50bc4746..55f62226 100644 --- a/preauth_keys.go +++ b/preauth_keys.go @@ -113,6 +113,12 @@ func (h *Headscale) ExpirePreAuthKey(k *PreAuthKey) error { return nil } +// UsePreAuthKey marks a PreAuthKey as used. +func (h *Headscale) UsePreAuthKey(k *PreAuthKey) { + k.Used = true + h.db.Save(k) +} + // checkKeyValidity does the heavy lifting for validation of the PreAuthKey coming from a node // If returns no error and a PreAuthKey, it can be used. func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) {