Remove redundant caches

This commit removes the two extra caches (oidc, requested time) and uses
the new central registration cache instead. The requested time is
unified into the main machine object and the oidc key is just added to
the same cache, as a string with the state as a key instead of machine
key.
This commit is contained in:
Kristoffer Dalby 2022-02-28 22:42:30 +00:00
parent e64bee778f
commit 5e92ddad43
6 changed files with 27 additions and 84 deletions

25
api.go
View File

@ -140,17 +140,25 @@ func (h *Headscale) RegistrationHandler(ctx *gin.Context) {
// We create the machine and then keep it around until a callback // We create the machine and then keep it around until a callback
// happens // happens
newMachine := Machine{ newMachine := Machine{
Expiry: &time.Time{},
MachineKey: machineKeyStr, MachineKey: machineKeyStr,
Name: req.Hostinfo.Hostname, Name: req.Hostinfo.Hostname,
NodeKey: NodePublicKeyStripPrefix(req.NodeKey), NodeKey: NodePublicKeyStripPrefix(req.NodeKey),
LastSeen: &now, LastSeen: &now,
} }
if !req.Expiry.IsZero() {
log.Trace().
Caller().
Str("machine", req.Hostinfo.Hostname).
Time("expiry", req.Expiry).
Msg("Non-zero expiry time requested")
newMachine.Expiry = &req.Expiry
}
h.registrationCache.Set( h.registrationCache.Set(
machineKeyStr, machineKeyStr,
newMachine, newMachine,
requestedExpiryCacheExpiration, registerCacheExpiration,
) )
h.handleMachineRegistrationNew(ctx, machineKey, req) h.handleMachineRegistrationNew(ctx, machineKey, req)
@ -490,19 +498,6 @@ func (h *Headscale) handleMachineRegistrationNew(
strings.TrimSuffix(h.cfg.ServerURL, "/"), MachinePublicKeyStripPrefix(machineKey)) strings.TrimSuffix(h.cfg.ServerURL, "/"), MachinePublicKeyStripPrefix(machineKey))
} }
if !registerRequest.Expiry.IsZero() {
log.Trace().
Caller().
Str("machine", registerRequest.Hostinfo.Hostname).
Time("expiry", registerRequest.Expiry).
Msg("Non-zero expiry time requested, adding to cache")
h.requestedExpiryCache.Set(
machineKey.String(),
registerRequest.Expiry,
requestedExpiryCacheExpiration,
)
}
respBody, err := encode(resp, &machineKey, h.privateKey) respBody, err := encode(resp, &machineKey, h.privateKey)
if err != nil { if err != nil {
log.Error(). log.Error().

17
app.go
View File

@ -55,8 +55,8 @@ const (
HTTPReadTimeout = 30 * time.Second HTTPReadTimeout = 30 * time.Second
privateKeyFileMode = 0o600 privateKeyFileMode = 0o600
requestedExpiryCacheExpiration = time.Minute * 5 registerCacheExpiration = time.Minute * 15
requestedExpiryCacheCleanupInterval = time.Minute * 10 registerCacheCleanup = time.Minute * 20
errUnsupportedDatabase = Error("unsupported DB") errUnsupportedDatabase = Error("unsupported DB")
errUnsupportedLetsEncryptChallengeType = Error( errUnsupportedLetsEncryptChallengeType = Error(
@ -150,9 +150,6 @@ type Headscale struct {
oidcProvider *oidc.Provider oidcProvider *oidc.Provider
oauth2Config *oauth2.Config oauth2Config *oauth2.Config
oidcStateCache *cache.Cache
requestedExpiryCache *cache.Cache
registrationCache *cache.Cache registrationCache *cache.Cache
@ -204,15 +201,10 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
return nil, errUnsupportedDatabase return nil, errUnsupportedDatabase
} }
requestedExpiryCache := cache.New(
requestedExpiryCacheExpiration,
requestedExpiryCacheCleanupInterval,
)
registrationCache := cache.New( registrationCache := cache.New(
// TODO(kradalby): Add unified cache expiry config options // TODO(kradalby): Add unified cache expiry config options
requestedExpiryCacheExpiration, registerCacheExpiration,
requestedExpiryCacheCleanupInterval, registerCacheCleanup,
) )
app := Headscale{ app := Headscale{
@ -221,7 +213,6 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
dbString: dbString, dbString: dbString,
privateKey: privKey, privateKey: privKey,
aclRules: tailcfg.FilterAllowAll, // default allowall aclRules: tailcfg.FilterAllowAll, // default allowall
requestedExpiryCache: requestedExpiryCache,
registrationCache: registrationCache, registrationCache: registrationCache,
} }

View File

@ -5,7 +5,6 @@ import (
"os" "os"
"testing" "testing"
"github.com/patrickmn/go-cache"
"gopkg.in/check.v1" "gopkg.in/check.v1"
"inet.af/netaddr" "inet.af/netaddr"
) )
@ -50,10 +49,6 @@ func (s *Suite) ResetDB(c *check.C) {
cfg: cfg, cfg: cfg,
dbType: "sqlite3", dbType: "sqlite3",
dbString: tmpDir + "/headscale_test.db", dbString: tmpDir + "/headscale_test.db",
requestedExpiryCache: cache.New(
requestedExpiryCacheExpiration,
requestedExpiryCacheCleanupInterval,
),
} }
err = app.initDB() err = app.initDB()
if err != nil { if err != nil {

View File

@ -160,25 +160,10 @@ func (api headscaleV1APIServer) RegisterMachine(
Str("machine_key", request.GetKey()). Str("machine_key", request.GetKey()).
Msg("Registering machine") Msg("Registering machine")
// TODO(kradalby): Currently, if it fails to find a requested expiry, non will be set
// This means that if a user is to slow with register a machine, it will possibly not
// have the correct expiry.
requestedTime := time.Time{}
if requestedTimeIf, found := api.h.requestedExpiryCache.Get(request.GetKey()); found {
log.Trace().
Caller().
Str("machine", request.Key).
Msg("Expiry time found in cache, assigning to node")
if reqTime, ok := requestedTimeIf.(time.Time); ok {
requestedTime = reqTime
}
}
machine, err := api.h.RegisterMachineFromAuthCallback( machine, err := api.h.RegisterMachineFromAuthCallback(
request.GetKey(), request.GetKey(),
request.GetNamespace(), request.GetNamespace(),
RegisterMethodCLI, RegisterMethodCLI,
&requestedTime,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -418,7 +403,7 @@ func (api headscaleV1APIServer) DebugCreateMachine(
api.h.registrationCache.Set( api.h.registrationCache.Set(
request.GetKey(), request.GetKey(),
newMachine, newMachine,
requestedExpiryCacheExpiration, registerCacheExpiration,
) )
return &v1.DebugCreateMachineResponse{Machine: newMachine.toProto()}, nil return &v1.DebugCreateMachineResponse{Machine: newMachine.toProto()}, nil

View File

@ -683,7 +683,6 @@ func (h *Headscale) RegisterMachineFromAuthCallback(
machineKeyStr string, machineKeyStr string,
namespaceName string, namespaceName string,
registrationMethod string, registrationMethod string,
expiry *time.Time,
) (*Machine, error) { ) (*Machine, error) {
if machineInterface, ok := h.registrationCache.Get(machineKeyStr); ok { if machineInterface, ok := h.registrationCache.Get(machineKeyStr); ok {
if registrationMachine, ok := machineInterface.(Machine); ok { if registrationMachine, ok := machineInterface.(Machine); ok {
@ -697,7 +696,6 @@ func (h *Headscale) RegisterMachineFromAuthCallback(
registrationMachine.NamespaceID = namespace.ID registrationMachine.NamespaceID = namespace.ID
registrationMachine.RegisterMethod = registrationMethod registrationMachine.RegisterMethod = registrationMethod
registrationMachine.Expiry = expiry
machine, err := h.RegisterMachine( machine, err := h.RegisterMachine(
registrationMachine, registrationMachine,

27
oidc.go
View File

@ -10,19 +10,15 @@ import (
"html/template" "html/template"
"net/http" "net/http"
"strings" "strings"
"time"
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/patrickmn/go-cache"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"tailscale.com/types/key" "tailscale.com/types/key"
) )
const ( const (
oidcStateCacheExpiration = time.Minute * 5
oidcStateCacheCleanupInterval = time.Minute * 10
randomByteSize = 16 randomByteSize = 16
) )
@ -60,14 +56,6 @@ func (h *Headscale) initOIDC() error {
} }
} }
// init the state cache if it hasn't been already
if h.oidcStateCache == nil {
h.oidcStateCache = cache.New(
oidcStateCacheExpiration,
oidcStateCacheCleanupInterval,
)
}
return nil return nil
} }
@ -100,7 +88,7 @@ func (h *Headscale) RegisterOIDC(ctx *gin.Context) {
stateStr := hex.EncodeToString(randomBlob)[:32] stateStr := hex.EncodeToString(randomBlob)[:32]
// place the machine key into the state cache, so it can be retrieved later // place the machine key into the state cache, so it can be retrieved later
h.oidcStateCache.Set(stateStr, machineKeyStr, oidcStateCacheExpiration) h.registrationCache.Set(stateStr, machineKeyStr, registerCacheExpiration)
authURL := h.oauth2Config.AuthCodeURL(stateStr) authURL := h.oauth2Config.AuthCodeURL(stateStr)
log.Debug().Msgf("Redirecting to %s for authentication", authURL) log.Debug().Msgf("Redirecting to %s for authentication", authURL)
@ -196,7 +184,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
} }
// retrieve machinekey from state cache // retrieve machinekey from state cache
machineKeyIf, machineKeyFound := h.oidcStateCache.Get(state) machineKeyIf, machineKeyFound := h.registrationCache.Get(state)
if !machineKeyFound { if !machineKeyFound {
log.Error(). log.Error().
@ -228,14 +216,6 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
return return
} }
// TODO(kradalby): Currently, if it fails to find a requested expiry, non will be set
requestedTime := time.Time{}
if requestedTimeIf, found := h.requestedExpiryCache.Get(machineKey.String()); found {
if reqTime, ok := requestedTimeIf.(time.Time); ok {
requestedTime = reqTime
}
}
// retrieve machine information // retrieve machine information
machine, err := h.GetMachineByMachineKey(machineKey) machine, err := h.GetMachineByMachineKey(machineKey)
if err != nil { if err != nil {
@ -254,7 +234,7 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
Str("machine", machine.Name). Str("machine", machine.Name).
Msg("machine already registered, reauthenticating") Msg("machine already registered, reauthenticating")
h.RefreshMachine(machine, requestedTime) h.RefreshMachine(machine, *machine.Expiry)
var content bytes.Buffer var content bytes.Buffer
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{ if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
@ -329,7 +309,6 @@ func (h *Headscale) OIDCCallback(ctx *gin.Context) {
machineKeyStr, machineKeyStr,
namespace.Name, namespace.Name,
RegisterMethodOIDC, RegisterMethodOIDC,
&requestedTime,
) )
if err != nil { if err != nil {
log.Error(). log.Error().