Remove all references to Machine.Registered

This commit is contained in:
Kristoffer Dalby 2022-02-28 16:55:57 +00:00
parent 35616eb861
commit 16b21e8158
2 changed files with 44 additions and 61 deletions

View File

@ -72,11 +72,6 @@ type (
MachinesP []*Machine MachinesP []*Machine
) )
// For the time being this method is rather naive.
func (machine Machine) isRegistered() bool {
return machine.Registered
}
type MachineAddresses []netaddr.IP type MachineAddresses []netaddr.IP
func (ma MachineAddresses) ToStringSlice() []string { func (ma MachineAddresses) ToStringSlice() []string {
@ -221,7 +216,7 @@ func (h *Headscale) ListPeers(machine *Machine) (Machines, error) {
Msg("Finding direct peers") Msg("Finding direct peers")
machines := Machines{} machines := Machines{}
if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where("machine_key <> ? AND registered", if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where("machine_key <> ?",
machine.MachineKey).Find(&machines).Error; err != nil { machine.MachineKey).Find(&machines).Error; err != nil {
log.Error().Err(err).Msg("Error accessing db") log.Error().Err(err).Msg("Error accessing db")
@ -284,7 +279,7 @@ func (h *Headscale) getValidPeers(machine *Machine) (Machines, error) {
} }
for _, peer := range peers { for _, peer := range peers {
if peer.isRegistered() && !peer.isExpired() { if !peer.isExpired() {
validPeers = append(validPeers, peer) validPeers = append(validPeers, peer)
} }
} }
@ -373,8 +368,6 @@ func (h *Headscale) RefreshMachine(machine *Machine, expiry time.Time) {
// DeleteMachine softs deletes a Machine from the database. // DeleteMachine softs deletes a Machine from the database.
func (h *Headscale) DeleteMachine(machine *Machine) error { func (h *Headscale) DeleteMachine(machine *Machine) error {
machine.Registered = false
h.db.Save(&machine) // we mark it as unregistered, just in case
if err := h.db.Delete(&machine).Error; err != nil { if err := h.db.Delete(&machine).Error; err != nil {
return err return err
} }
@ -642,7 +635,7 @@ func (machine Machine) toNode(
LastSeen: machine.LastSeen, LastSeen: machine.LastSeen,
KeepAlive: true, KeepAlive: true,
MachineAuthorized: machine.Registered, MachineAuthorized: !machine.isExpired(),
Capabilities: []string{tailcfg.CapabilityFileSharing}, Capabilities: []string{tailcfg.CapabilityFileSharing},
} }
@ -660,8 +653,6 @@ func (machine *Machine) toProto() *v1.Machine {
Name: machine.Name, Name: machine.Name,
Namespace: machine.Namespace.toProto(), Namespace: machine.Namespace.toProto(),
Registered: machine.Registered,
// TODO(kradalby): Implement register method enum converter // TODO(kradalby): Implement register method enum converter
// RegisterMethod: , // RegisterMethod: ,
@ -738,7 +729,7 @@ func (h *Headscale) RegisterMachine(machine Machine,
[]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)), []byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)),
) )
machineFromDatabase, _ := h.GetMachineByMachineKey(machineKey) machineFromDatabase, _ := h.GetMachineByMachineKey(machineKey)
if machine.isRegistered() { if machineFromDatabase != nil {
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", machine.Name). Str("machine", machine.Name).
@ -774,13 +765,6 @@ func (h *Headscale) RegisterMachine(machine Machine,
// move it to tags instead of having a column. // move it to tags instead of having a column.
// machine.RegisterMethod = registrationMethod // machine.RegisterMethod = registrationMethod
// TODO(kradalby): Registered is a very frustrating value
// to keep up to date, and it makes is have to care if a
// machine is registered, authenticated and expired.
// Let us simplify the model, a machine is _only_ saved if
// it is registered.
machine.Registered = true
h.db.Save(&machine) h.db.Save(&machine)
log.Trace(). log.Trace().

81
oidc.go
View File

@ -248,7 +248,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
return return
} }
if machine.isRegistered() { if machine != nil {
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", machine.Name). Str("machine", machine.Name).
@ -291,58 +291,57 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
return return
} }
// register the machine if it's new // register the machine if it's new
if !machine.Registered { log.Debug().Msg("Registering new machine after successful callback")
log.Debug().Msg("Registering new machine after successful callback")
namespace, err := h.GetNamespace(namespaceName) namespace, err := h.GetNamespace(namespaceName)
if errors.Is(err, errNamespaceNotFound) { if errors.Is(err, errNamespaceNotFound) {
namespace, err = h.CreateNamespace(namespaceName) namespace, err = h.CreateNamespace(namespaceName)
if err != nil {
log.Error().
Err(err).
Caller().
Msgf("could not create new namespace '%s'", namespaceName)
ctx.String(
http.StatusInternalServerError,
"could not create new namespace",
)
return
}
} else if err != nil {
log.Error().
Caller().
Err(err).
Str("namespace", namespaceName).
Msg("could not find or create namespace")
ctx.String(
http.StatusInternalServerError,
"could not find or create namespace",
)
return
}
_, err = h.RegisterMachineFromAuthCallback(
machineKeyStr,
namespace.Name,
RegisterMethodOIDC,
&requestedTime,
)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller().
Err(err). Err(err).
Msg("could not register machine") Caller().
Msgf("could not create new namespace '%s'", namespaceName)
ctx.String( ctx.String(
http.StatusInternalServerError, http.StatusInternalServerError,
"could not register machine", "could not create new namespace",
) )
return return
} }
} else if err != nil {
log.Error().
Caller().
Err(err).
Str("namespace", namespaceName).
Msg("could not find or create namespace")
ctx.String(
http.StatusInternalServerError,
"could not find or create namespace",
)
return
}
_, err = h.RegisterMachineFromAuthCallback(
machineKeyStr,
namespace.Name,
RegisterMethodOIDC,
&requestedTime,
)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("could not register machine")
ctx.String(
http.StatusInternalServerError,
"could not register machine",
)
return
} }
var content bytes.Buffer var content bytes.Buffer