mirror of
https://github.com/juanfont/headscale.git
synced 2024-11-27 12:05:26 +00:00
Add cache for requested expiry times
This commit adds a sentral cache to keep track of clients whom has requested an expiry time, but were we need to keep hold of it until the second request comes in.
This commit is contained in:
parent
e600ead3e9
commit
021c464148
14
api.go
14
api.go
@ -23,6 +23,9 @@ const (
|
||||
RegisterMethodAuthKey = "authKey"
|
||||
RegisterMethodOIDC = "oidc"
|
||||
RegisterMethodCLI = "cli"
|
||||
ErrRegisterMethodCLIDoesNotSupportExpire = Error(
|
||||
"machines registered with CLI does not support expire",
|
||||
)
|
||||
)
|
||||
|
||||
// KeyHandler provides the Headscale pub key
|
||||
@ -441,7 +444,16 @@ func (h *Headscale) handleMachineRegistrationNew(
|
||||
}
|
||||
|
||||
if !reqisterRequest.Expiry.IsZero() {
|
||||
machine.Expiry = &reqisterRequest.Expiry
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", machine.Name).
|
||||
Time("expiry", reqisterRequest.Expiry).
|
||||
Msg("Non-zero expiry time requested, adding to cache")
|
||||
h.requestedExpiryCache.Set(
|
||||
idKey.HexString(),
|
||||
reqisterRequest.Expiry,
|
||||
requestedExpiryCacheExpiration,
|
||||
)
|
||||
}
|
||||
|
||||
machine.NodeKey = wgkey.Key(reqisterRequest.NodeKey).HexString() // save the NodeKey
|
||||
|
11
app.go
11
app.go
@ -53,6 +53,9 @@ const (
|
||||
updateInterval = 5000
|
||||
HTTPReadTimeout = 30 * time.Second
|
||||
|
||||
requestedExpiryCacheExpiration = time.Minute * 5
|
||||
requestedExpiryCacheCleanupInterval = time.Minute * 10
|
||||
|
||||
errUnsupportedDatabase = Error("unsupported DB")
|
||||
errUnsupportedLetsEncryptChallengeType = Error(
|
||||
"unknown value for Lets Encrypt challenge type",
|
||||
@ -139,6 +142,8 @@ type Headscale struct {
|
||||
oidcProvider *oidc.Provider
|
||||
oauth2Config *oauth2.Config
|
||||
oidcStateCache *cache.Cache
|
||||
|
||||
requestedExpiryCache *cache.Cache
|
||||
}
|
||||
|
||||
// NewHeadscale returns the Headscale app.
|
||||
@ -171,6 +176,11 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
|
||||
return nil, errUnsupportedDatabase
|
||||
}
|
||||
|
||||
requestedExpiryCache := cache.New(
|
||||
requestedExpiryCacheExpiration,
|
||||
requestedExpiryCacheCleanupInterval,
|
||||
)
|
||||
|
||||
app := Headscale{
|
||||
cfg: cfg,
|
||||
dbType: cfg.DBtype,
|
||||
@ -178,6 +188,7 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
|
||||
privateKey: privKey,
|
||||
publicKey: &pubKey,
|
||||
aclRules: tailcfg.FilterAllowAll, // default allowall
|
||||
requestedExpiryCache: requestedExpiryCache,
|
||||
}
|
||||
|
||||
err = app.initDB()
|
||||
|
25
machine.go
25
machine.go
@ -616,6 +616,31 @@ func (h *Headscale) RegisterMachine(
|
||||
return nil, errMachineNotFound
|
||||
}
|
||||
|
||||
// TODO(kradalby): Currently, if it fails to find a requested expiry, non will be set
|
||||
// This means that if a user is to slow with register a machine, it will possibly not
|
||||
// have the correct expiry.
|
||||
requestedTime := time.Time{}
|
||||
if requestedTimeIf, found := h.requestedExpiryCache.Get(machineKey.HexString()); found {
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", machine.Name).
|
||||
Msg("Expiry time found in cache, assigning to node")
|
||||
if reqTime, ok := requestedTimeIf.(time.Time); ok {
|
||||
requestedTime = reqTime
|
||||
}
|
||||
}
|
||||
|
||||
if machine.isRegistered() {
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", machine.Name).
|
||||
Msg("machine already registered, reauthenticating")
|
||||
|
||||
h.RefreshMachine(&machine, requestedTime)
|
||||
|
||||
return &machine, nil
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("machine", machine.Name).
|
||||
|
8
oidc.go
8
oidc.go
@ -199,6 +199,14 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// TODO(kradalby): Currently, if it fails to find a requested expiry, non will be set
|
||||
requestedTime := time.Time{}
|
||||
if requestedTimeIf, found := h.requestedExpiryCache.Get(machineKey); found {
|
||||
if reqTime, ok := requestedTimeIf.(time.Time); ok {
|
||||
requestedTime = reqTime
|
||||
}
|
||||
}
|
||||
|
||||
// retrieve machine information
|
||||
machine, err := h.GetMachineByMachineKey(machineKey)
|
||||
if err != nil {
|
||||
|
Loading…
Reference in New Issue
Block a user