Generalise the registration method to DRY stuff up

This commit is contained in:
Kristoffer Dalby 2022-02-27 18:40:10 +01:00
parent b1bd17f316
commit c58ce6f60c
2 changed files with 41 additions and 36 deletions

View File

@ -32,6 +32,8 @@ func (s *Suite) TestRegisterMachine(c *check.C) {
machineAfterRegistering, err := app.RegisterMachine( machineAfterRegistering, err := app.RegisterMachine(
machine.MachineKey, machine.MachineKey,
namespace.Name, namespace.Name,
RegisterMethodCLI,
nil, nil, nil, nil,
) )
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(machineAfterRegistering.Registered, check.Equals, true) c.Assert(machineAfterRegistering.Registered, check.Equals, true)

View File

@ -44,6 +44,8 @@ type Machine struct {
Registered bool // temp Registered bool // temp
RegisterMethod string RegisterMethod string
// TODO(kradalby): This seems like irrelevant information?
AuthKeyID uint AuthKeyID uint
AuthKey *PreAuthKey AuthKey *PreAuthKey
@ -686,6 +688,13 @@ func (machine *Machine) toProto() *v1.Machine {
func (h *Headscale) RegisterMachine( func (h *Headscale) RegisterMachine(
machineKeyStr string, machineKeyStr string,
namespaceName string, namespaceName string,
registrationMethod string,
// Optionals
expiry *time.Time,
authKey *PreAuthKey,
nodePublicKey *string,
lastSeen *time.Time,
) (*Machine, error) { ) (*Machine, error) {
namespace, err := h.GetNamespace(namespaceName) namespace, err := h.GetNamespace(namespaceName)
if err != nil { if err != nil {
@ -709,27 +718,13 @@ func (h *Headscale) RegisterMachine(
return nil, err return nil, err
} }
// TODO(kradalby): Currently, if it fails to find a requested expiry, non will be set
// This means that if a user is to slow with register a machine, it will possibly not
// have the correct expiry.
requestedTime := time.Time{}
if requestedTimeIf, found := h.requestedExpiryCache.Get(machineKey.String()); found {
log.Trace().
Caller().
Str("machine", machine.Name).
Msg("Expiry time found in cache, assigning to node")
if reqTime, ok := requestedTimeIf.(time.Time); ok {
requestedTime = reqTime
}
}
if machine.isRegistered() { if machine.isRegistered() {
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", machine.Name). Str("machine", machine.Name).
Msg("machine already registered, reauthenticating") Msg("machine already registered, reauthenticating")
h.RefreshMachine(machine, requestedTime) h.RefreshMachine(machine, *expiry)
return machine, nil return machine, nil
} }
@ -739,17 +734,6 @@ func (h *Headscale) RegisterMachine(
Str("machine", machine.Name). Str("machine", machine.Name).
Msg("Attempting to register machine") Msg("Attempting to register machine")
if machine.isRegistered() {
err := errMachineAlreadyRegistered
log.Error().
Caller().
Err(err).
Str("machine", machine.Name).
Msg("Attempting to register machine")
return nil, err
}
h.ipAllocationMutex.Lock() h.ipAllocationMutex.Lock()
defer h.ipAllocationMutex.Unlock() defer h.ipAllocationMutex.Unlock()
@ -764,17 +748,36 @@ func (h *Headscale) RegisterMachine(
return nil, err return nil, err
} }
log.Trace().
Caller().
Str("machine", machine.Name).
Str("ip", strings.Join(ips.ToStringSlice(), ",")).
Msg("Found IP for host")
machine.IPAddresses = ips machine.IPAddresses = ips
if expiry != nil {
machine.Expiry = expiry
}
if authKey != nil {
machine.AuthKeyID = uint(authKey.ID)
}
if nodePublicKey != nil {
machine.NodeKey = *nodePublicKey
}
if lastSeen != nil {
machine.LastSeen = lastSeen
}
machine.NamespaceID = namespace.ID machine.NamespaceID = namespace.ID
// TODO(kradalby): This field is uneccessary metadata,
// move it to tags instead of having a column.
machine.RegisterMethod = registrationMethod
// TODO(kradalby): Registered is a very frustrating value
// to keep up to date, and it makes is have to care if a
// machine is registered, authenticated and expired.
// Let us simplify the model, a machine is _only_ saved if
// it is registered.
machine.Registered = true machine.Registered = true
machine.RegisterMethod = RegisterMethodCLI
machine.Expiry = &requestedTime
h.db.Save(&machine) h.db.Save(&machine)
log.Trace(). log.Trace().