Merge pull request #126 from unreality/main

Initial work on OIDC (SSO) integration
This commit is contained in:
Kristoffer Dalby 2021-10-31 09:33:35 +00:00 committed by GitHub
commit fbdfa55629
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 653 additions and 58 deletions

View File

@ -31,6 +31,7 @@ headscale implements this coordination server.
- [x] Taildrop (File Sharing) - [x] Taildrop (File Sharing)
- [x] Support for alternative IP ranges in the tailnets (default Tailscale's 100.64.0.0/10) - [x] Support for alternative IP ranges in the tailnets (default Tailscale's 100.64.0.0/10)
- [x] DNS (passing DNS servers to nodes) - [x] DNS (passing DNS servers to nodes)
- [x] Single-Sign-On (via Open ID Connect)
- [x] Share nodes between namespaces - [x] Share nodes between namespaces
- [x] MagicDNS (see `docs/`) - [x] MagicDNS (see `docs/`)
@ -49,7 +50,6 @@ headscale implements this coordination server.
Suggestions/PRs welcomed! Suggestions/PRs welcomed!
## Running headscale ## Running headscale
Please have a look at the documentation under [`docs/`](docs/). Please have a look at the documentation under [`docs/`](docs/).

120
api.go
View File

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"strings"
"time" "time"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -64,7 +65,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
Str("handler", "Registration"). Str("handler", "Registration").
Err(err). Err(err).
Msg("Cannot parse machine key") Msg("Cannot parse machine key")
machineRegistrations.WithLabelValues("unkown", "web", "error", "unknown").Inc() machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
c.String(http.StatusInternalServerError, "Sad!") c.String(http.StatusInternalServerError, "Sad!")
return return
} }
@ -75,37 +76,33 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
Str("handler", "Registration"). Str("handler", "Registration").
Err(err). Err(err).
Msg("Cannot decode message") Msg("Cannot decode message")
machineRegistrations.WithLabelValues("unkown", "web", "error", "unknown").Inc() machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
c.String(http.StatusInternalServerError, "Very sad!") c.String(http.StatusInternalServerError, "Very sad!")
return return
} }
now := time.Now().UTC() now := time.Now().UTC()
var m Machine m, err := h.GetMachineByMachineKey(mKey.HexString())
if result := h.db.Preload("Namespace").First(&m, "machine_key = ?", mKey.HexString()); errors.Is( if errors.Is(err, gorm.ErrRecordNotFound) {
result.Error,
gorm.ErrRecordNotFound,
) {
log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine") log.Info().Str("machine", req.Hostinfo.Hostname).Msg("New machine")
m = Machine{ newMachine := Machine{
Expiry: &req.Expiry, Expiry: &time.Time{},
MachineKey: mKey.HexString(), MachineKey: mKey.HexString(),
Name: req.Hostinfo.Hostname, Name: req.Hostinfo.Hostname,
NodeKey: wgkey.Key(req.NodeKey).HexString(),
LastSuccessfulUpdate: &now,
} }
if err := h.db.Create(&m).Error; err != nil { if err := h.db.Create(&newMachine).Error; err != nil {
log.Error(). log.Error().
Str("handler", "Registration"). Str("handler", "Registration").
Err(err). Err(err).
Msg("Could not create row") Msg("Could not create row")
machineRegistrations.WithLabelValues("unkown", "web", "error", m.Namespace.Name).Inc() machineRegistrations.WithLabelValues("unknown", "web", "error", m.Namespace.Name).Inc()
return return
} }
m = &newMachine
} }
if !m.Registered && req.Auth.AuthKey != "" { if !m.Registered && req.Auth.AuthKey != "" {
h.handleAuthKey(c, h.db, mKey, req, m) h.handleAuthKey(c, h.db, mKey, req, *m)
return return
} }
@ -113,7 +110,36 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
// We have the updated key! // We have the updated key!
if m.NodeKey == wgkey.Key(req.NodeKey).HexString() { if m.NodeKey == wgkey.Key(req.NodeKey).HexString() {
if m.Registered {
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
if !req.Expiry.IsZero() && req.Expiry.UTC().Before(now) {
log.Info().
Str("handler", "Registration").
Str("machine", m.Name).
Msg("Client requested logout")
m.Expiry = &req.Expiry // save the expiry so that the machine is marked as expired
h.db.Save(&m)
resp.AuthURL = ""
resp.MachineAuthorized = false
resp.User = *m.Namespace.toUser()
respBody, err := encode(resp, &mKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot encode message")
c.String(http.StatusInternalServerError, "")
return
}
c.Data(200, "application/json; charset=utf-8", respBody)
return
}
if m.Registered && m.Expiry.UTC().After(now) {
// The machine registration is valid, respond with redirect to /map
log.Debug(). log.Debug().
Str("handler", "Registration"). Str("handler", "Registration").
Str("machine", m.Name). Str("machine", m.Name).
@ -122,6 +148,8 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
resp.AuthURL = "" resp.AuthURL = ""
resp.MachineAuthorized = true resp.MachineAuthorized = true
resp.User = *m.Namespace.toUser() resp.User = *m.Namespace.toUser()
resp.Login = *m.Namespace.toLogin()
respBody, err := encode(resp, &mKey, h.privateKey) respBody, err := encode(resp, &mKey, h.privateKey)
if err != nil { if err != nil {
log.Error(). log.Error().
@ -137,12 +165,30 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
return return
} }
// The client has registered before, but has expired
log.Debug(). log.Debug().
Str("handler", "Registration"). Str("handler", "Registration").
Str("machine", m.Name). Str("machine", m.Name).
Msg("Not registered and not NodeKey rotation. Sending a authurl to register") Msg("Machine registration has expired. Sending a authurl to register")
if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
} else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s", resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
h.cfg.ServerURL, mKey.HexString()) strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
}
// When a client connects, it may request a specific expiry time in its
// RegisterRequest (https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L634)
// RequestedExpiry is used to store the clients requested expiry time since the authentication flow is broken
// into two steps (which cant pass arbitrary data between them easily) and needs to be
// retrieved again after the user has authenticated. After the authentication flow
// completes, RequestedExpiry is copied into Expiry.
m.RequestedExpiry = &req.Expiry
h.db.Save(&m)
respBody, err := encode(resp, &mKey, h.privateKey) respBody, err := encode(resp, &mKey, h.privateKey)
if err != nil { if err != nil {
log.Error(). log.Error().
@ -158,8 +204,8 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
return return
} }
// The NodeKey we have matches OldNodeKey, which means this is a refresh after an key expiration // The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
if m.NodeKey == wgkey.Key(req.OldNodeKey).HexString() { if m.NodeKey == wgkey.Key(req.OldNodeKey).HexString() && m.Expiry.UTC().After(now) {
log.Debug(). log.Debug().
Str("handler", "Registration"). Str("handler", "Registration").
Str("machine", m.Name). Str("machine", m.Name).
@ -182,35 +228,23 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
return return
} }
// We arrive here after a client is restarted without finalizing the authentication flow or // The machine registration is new, redirect the client to the registration URL
// when headscale is stopped in the middle of the auth process.
if m.Registered {
log.Debug().
Str("handler", "Registration").
Str("machine", m.Name).
Msg("The node is sending us a new NodeKey, but machine is registered. All clear for /map")
resp.AuthURL = ""
resp.MachineAuthorized = true
resp.User = *m.Namespace.toUser()
respBody, err := encode(resp, &mKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "Registration").
Err(err).
Msg("Cannot encode message")
c.String(http.StatusInternalServerError, "")
return
}
c.Data(200, "application/json; charset=utf-8", respBody)
return
}
log.Debug(). log.Debug().
Str("handler", "Registration"). Str("handler", "Registration").
Str("machine", m.Name). Str("machine", m.Name).
Msg("The node is sending us a new NodeKey, sending auth url") Msg("The node is sending us a new NodeKey, sending auth url")
if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
} else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s", resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
h.cfg.ServerURL, mKey.HexString()) strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
}
// save the requested expiry time for retrieval later in the authentication flow
m.RequestedExpiry = &req.Expiry
m.NodeKey = wgkey.Key(req.NodeKey).HexString() // save the NodeKey
h.db.Save(&m)
respBody, err := encode(resp, &mKey, h.privateKey) respBody, err := encode(resp, &mKey, h.privateKey)
if err != nil { if err != nil {
log.Error(). log.Error().

29
app.go
View File

@ -14,6 +14,10 @@ import (
"sync" "sync"
"time" "time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/patrickmn/go-cache"
"golang.org/x/oauth2"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
apiV1 "github.com/juanfont/headscale/gen/go/v1" apiV1 "github.com/juanfont/headscale/gen/go/v1"
@ -62,6 +66,18 @@ type Config struct {
ACMEEmail string ACMEEmail string
DNSConfig *tailcfg.DNSConfig DNSConfig *tailcfg.DNSConfig
OIDC OIDCConfig
MaxMachineRegistrationDuration time.Duration
DefaultMachineRegistrationDuration time.Duration
}
type OIDCConfig struct {
Issuer string
ClientID string
ClientSecret string
MatchMap map[string]string
} }
type DERPConfig struct { type DERPConfig struct {
@ -87,6 +103,10 @@ type Headscale struct {
aclRules *[]tailcfg.FilterRule aclRules *[]tailcfg.FilterRule
lastStateChange sync.Map lastStateChange sync.Map
oidcProvider *oidc.Provider
oauth2Config *oauth2.Config
oidcStateCache *cache.Cache
} }
// NewHeadscale returns the Headscale app. // NewHeadscale returns the Headscale app.
@ -127,6 +147,13 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
return nil, err return nil, err
} }
if cfg.OIDC.Issuer != "" {
err = h.initOIDC()
if err != nil {
return nil, err
}
}
if h.cfg.DNSConfig != nil && h.cfg.DNSConfig.Proxied { // if MagicDNS if h.cfg.DNSConfig != nil && h.cfg.DNSConfig.Proxied { // if MagicDNS
magicDNSDomains, err := generateMagicDNSRootDomains(h.cfg.IPPrefix, h.cfg.BaseDomain) magicDNSDomains, err := generateMagicDNSRootDomains(h.cfg.IPPrefix, h.cfg.BaseDomain)
if err != nil { if err != nil {
@ -255,6 +282,8 @@ func (h *Headscale) Serve() error {
r.GET("/register", h.RegisterWebAPI) r.GET("/register", h.RegisterWebAPI)
r.POST("/machine/:id/map", h.PollNetMapHandler) r.POST("/machine/:id/map", h.PollNetMapHandler)
r.POST("/machine/:id", h.RegistrationHandler) r.POST("/machine/:id", h.RegistrationHandler)
r.GET("/oidc/register/:mkey", h.RegisterOIDC)
r.GET("/oidc/callback", h.OIDCCallback)
r.GET("/apple", h.AppleMobileConfig) r.GET("/apple", h.AppleMobileConfig)
r.GET("/apple/:platform", h.ApplePlatformConfig) r.GET("/apple/:platform", h.ApplePlatformConfig)

3
cli.go
View File

@ -23,6 +23,8 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err
return nil, errors.New("Machine not found") return nil, errors.New("Machine not found")
} }
h.updateMachineExpiry(&m) // update the machine's expiry before bailing if its already registered
if m.isAlreadyRegistered() { if m.isAlreadyRegistered() {
return nil, errors.New("Machine already registered") return nil, errors.New("Machine already registered")
} }
@ -36,5 +38,6 @@ func (h *Headscale) RegisterMachine(key string, namespace string) (*Machine, err
m.Registered = true m.Registered = true
m.RegisterMethod = "cli" m.RegisterMethod = "cli"
h.db.Save(&m) h.db.Save(&m)
return &m, nil return &m, nil
} }

View File

@ -1,6 +1,8 @@
package headscale package headscale
import ( import (
"time"
"gopkg.in/check.v1" "gopkg.in/check.v1"
) )
@ -8,6 +10,8 @@ func (s *Suite) TestRegisterMachine(c *check.C) {
n, err := h.CreateNamespace("test") n, err := h.CreateNamespace("test")
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
now := time.Now().UTC()
m := Machine{ m := Machine{
ID: 0, ID: 0,
MachineKey: "8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e", MachineKey: "8ce002a935f8c394e55e78fbbb410576575ff8ec5cfa2e627e4b807f1be15b0e",
@ -16,6 +20,8 @@ func (s *Suite) TestRegisterMachine(c *check.C) {
Name: "testmachine", Name: "testmachine",
NamespaceID: n.ID, NamespaceID: n.ID,
IPAddress: "10.0.0.1", IPAddress: "10.0.0.1",
Expiry: &now,
RequestedExpiry: &now,
} }
h.db.Save(&m) h.db.Save(&m)

View File

@ -7,6 +7,7 @@ import (
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
"regexp"
"strings" "strings"
"time" "time"
@ -215,6 +216,26 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
return nil, err return nil, err
} }
// maxMachineRegistrationDuration is the maximum time headscale will allow a client to (optionally) request for
// the machine key expiry time. RegisterRequests with Expiry times that are more than
// maxMachineRegistrationDuration in the future will be clamped to (now + maxMachineRegistrationDuration)
maxMachineRegistrationDuration, _ := time.ParseDuration(
"10h",
) // use 10h here because it is the length of a standard business day plus a small amount of leeway
if viper.GetDuration("max_machine_registration_duration") >= time.Second {
maxMachineRegistrationDuration = viper.GetDuration("max_machine_registration_duration")
}
// defaultMachineRegistrationDuration is the default time assigned to a machine registration if one is not
// specified by the tailscale client. It is the default amount of time a machine registration is valid for
// (ie the amount of time before the user has to re-authenticate when requesting a connection)
defaultMachineRegistrationDuration, _ := time.ParseDuration(
"8h",
) // use 8h here because it's the length of a standard business day
if viper.GetDuration("default_machine_registration_duration") >= time.Second {
defaultMachineRegistrationDuration = viper.GetDuration("default_machine_registration_duration")
}
dnsConfig, baseDomain := GetDNSConfig() dnsConfig, baseDomain := GetDNSConfig()
derpConfig := GetDERPConfig() derpConfig := GetDERPConfig()
@ -249,8 +270,19 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
ACMEEmail: viper.GetString("acme_email"), ACMEEmail: viper.GetString("acme_email"),
ACMEURL: viper.GetString("acme_url"), ACMEURL: viper.GetString("acme_url"),
OIDC: headscale.OIDCConfig{
Issuer: viper.GetString("oidc.issuer"),
ClientID: viper.GetString("oidc.client_id"),
ClientSecret: viper.GetString("oidc.client_secret"),
},
MaxMachineRegistrationDuration: maxMachineRegistrationDuration,
DefaultMachineRegistrationDuration: defaultMachineRegistrationDuration,
} }
cfg.OIDC.MatchMap = loadOIDCMatchMap()
h, err := headscale.NewHeadscale(cfg) h, err := headscale.NewHeadscale(cfg)
if err != nil { if err != nil {
return nil, err return nil, err
@ -312,3 +344,15 @@ func HasJsonOutputFlag() bool {
} }
return false return false
} }
// loadOIDCMatchMap is a wrapper around viper to verifies that the keys in
// the match map is valid regex strings.
func loadOIDCMatchMap() map[string]string {
strMap := viper.GetStringMapString("oidc.domain_map")
for oidcMatcher := range strMap {
_ = regexp.MustCompile(oidcMatcher)
}
return strMap
}

View File

@ -64,3 +64,18 @@ dns_config:
magic_dns: true magic_dns: true
base_domain: example.com base_domain: example.com
# headscale supports experimental OpenID connect support,
# it is still being tested and might have some bugs, please
# help us test it.
# OpenID Connect
# oidc:
# issuer: "https://your-oidc.issuer.com/path"
# client_id: "your-oidc-client-id"
# client_secret: "your-oidc-client-secret"
#
# # Domain map is used to map incomming users (by their email) to
# # a namespace. The key can be a string, or regex.
# domain_map:
# ".*": default-namespace

3
go.mod
View File

@ -7,6 +7,7 @@ require (
github.com/Microsoft/go-winio v0.5.0 // indirect github.com/Microsoft/go-winio v0.5.0 // indirect
github.com/cenkalti/backoff/v4 v4.1.1 // indirect github.com/cenkalti/backoff/v4 v4.1.1 // indirect
github.com/containerd/continuity v0.1.0 // indirect github.com/containerd/continuity v0.1.0 // indirect
github.com/coreos/go-oidc/v3 v3.1.0
github.com/docker/cli v20.10.8+incompatible // indirect github.com/docker/cli v20.10.8+incompatible // indirect
github.com/docker/docker v20.10.8+incompatible // indirect github.com/docker/docker v20.10.8+incompatible // indirect
github.com/efekarakus/termcolor v1.0.1 github.com/efekarakus/termcolor v1.0.1
@ -23,6 +24,7 @@ require (
github.com/moby/term v0.0.0-20210619224110-3f7ff695adc6 // indirect github.com/moby/term v0.0.0-20210619224110-3f7ff695adc6 // indirect
github.com/opencontainers/runc v1.0.2 // indirect github.com/opencontainers/runc v1.0.2 // indirect
github.com/ory/dockertest/v3 v3.7.0 github.com/ory/dockertest/v3 v3.7.0
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/prometheus/client_golang v1.11.0 github.com/prometheus/client_golang v1.11.0
github.com/pterm/pterm v0.12.30 github.com/pterm/pterm v0.12.30
github.com/rs/zerolog v1.25.0 github.com/rs/zerolog v1.25.0
@ -36,6 +38,7 @@ require (
github.com/zsais/go-gin-prometheus v0.1.0 github.com/zsais/go-gin-prometheus v0.1.0
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5 golang.org/x/crypto v0.0.0-20210817164053-32db794688a5
golang.org/x/net v0.0.0-20210913180222-943fd674d43e // indirect golang.org/x/net v0.0.0-20210913180222-943fd674d43e // indirect
golang.org/x/oauth2 v0.0.0-20210819190943-2bc19b11175f
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0 // indirect golang.org/x/sys v0.0.0-20210910150752-751e447fb3d0 // indirect
google.golang.org/genproto v0.0.0-20210903162649-d08c68adba83 google.golang.org/genproto v0.0.0-20210903162649-d08c68adba83

9
go.sum
View File

@ -153,6 +153,8 @@ github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkE
github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE=
github.com/coreos/etcd v3.3.13+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= github.com/coreos/etcd v3.3.13+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE=
github.com/coreos/go-iptables v0.6.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q= github.com/coreos/go-iptables v0.6.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q=
github.com/coreos/go-oidc/v3 v3.1.0 h1:6avEvcdvTa1qYsOZ6I5PRkSYHzpTNWgKYmaJfaYbrRw=
github.com/coreos/go-oidc/v3 v3.1.0/go.mod h1:rEJ/idjfUyfkBit1eI1fvyr+64/g9dcKpAm8MJMesvo=
github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
github.com/coreos/go-systemd v0.0.0-20180511133405-39ca1b05acc7/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd v0.0.0-20180511133405-39ca1b05acc7/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
@ -767,6 +769,8 @@ github.com/ory/dockertest/v3 v3.7.0 h1:Bijzonc69Ont3OU0a3TWKJ1Rzlh3TsDXP1JrTAkSm
github.com/ory/dockertest/v3 v3.7.0/go.mod h1:PvCCgnP7AfBZeVrzwiUTjZx/IUXlGLC1zQlUQrLIlUE= github.com/ory/dockertest/v3 v3.7.0/go.mod h1:PvCCgnP7AfBZeVrzwiUTjZx/IUXlGLC1zQlUQrLIlUE=
github.com/pact-foundation/pact-go v1.0.4/go.mod h1:uExwJY4kCzNPcHRj+hCR/HBbOOIwwtUjcrb0b5/5kLM= github.com/pact-foundation/pact-go v1.0.4/go.mod h1:uExwJY4kCzNPcHRj+hCR/HBbOOIwwtUjcrb0b5/5kLM=
github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc=
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
github.com/pborman/getopt v1.1.0/go.mod h1:FxXoW1Re00sQG/+KIkuSqRL/LwQgSkv7uyac+STFsbk= github.com/pborman/getopt v1.1.0/go.mod h1:FxXoW1Re00sQG/+KIkuSqRL/LwQgSkv7uyac+STFsbk=
github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k=
github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic=
@ -1137,6 +1141,7 @@ golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/net v0.0.0-20200421231249-e086a090c8fd/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200421231249-e086a090c8fd/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/net v0.0.0-20200501053045-e0ff5e5a1de5/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200501053045-e0ff5e5a1de5/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/net v0.0.0-20200505041828-1ed23360d12c/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/net v0.0.0-20200506145744-7e3656a0809f/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200506145744-7e3656a0809f/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/net v0.0.0-20200513185701-a91f0712d120/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200513185701-a91f0712d120/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
@ -1174,6 +1179,7 @@ golang.org/x/oauth2 v0.0.0-20210220000619-9bb904979d93/go.mod h1:KelEdhl1UZF7XfJ
golang.org/x/oauth2 v0.0.0-20210313182246-cd4f82c27b84/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210313182246-cd4f82c27b84/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
golang.org/x/oauth2 v0.0.0-20210427180440-81ed05c6b58c/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210427180440-81ed05c6b58c/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
golang.org/x/oauth2 v0.0.0-20210819190943-2bc19b11175f h1:Qmd2pbz05z7z6lm0DrgQVVPuBm92jqujBKMHMOlOQEw=
golang.org/x/oauth2 v0.0.0-20210819190943-2bc19b11175f/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210819190943-2bc19b11175f/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@ -1436,6 +1442,7 @@ google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7
google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0=
google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c=
google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
@ -1553,6 +1560,8 @@ gopkg.in/ini.v1 v1.51.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/ini.v1 v1.62.0 h1:duBzk771uxoUuOlyRLkHsygud9+5lrlGjdFBb4mSKDU= gopkg.in/ini.v1 v1.62.0 h1:duBzk771uxoUuOlyRLkHsygud9+5lrlGjdFBb4mSKDU=
gopkg.in/ini.v1 v1.62.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/ini.v1 v1.62.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo=
gopkg.in/square/go-jose.v2 v2.5.1 h1:7odma5RETjNHWJnR32wx8t+Io4djHE1PqxCFx3iiZ2w=
gopkg.in/square/go-jose.v2 v2.5.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI= gopkg.in/warnings.v0 v0.1.2/go.mod h1:jksf8JmL6Qr/oQM2OXTHunEvvTAsrWBLb6OOjuVWRNI=
gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74=

View File

@ -36,6 +36,7 @@ type Machine struct {
LastSeen *time.Time LastSeen *time.Time
LastSuccessfulUpdate *time.Time LastSuccessfulUpdate *time.Time
Expiry *time.Time Expiry *time.Time
RequestedExpiry *time.Time
HostInfo datatypes.JSON HostInfo datatypes.JSON
Endpoints datatypes.JSON Endpoints datatypes.JSON
@ -56,6 +57,38 @@ func (m Machine) isAlreadyRegistered() bool {
return m.Registered return m.Registered
} }
// isExpired returns whether the machine registration has expired
func (m Machine) isExpired() bool {
return time.Now().UTC().After(*m.Expiry)
}
// If the Machine is expired, updateMachineExpiry updates the Machine Expiry time to the maximum allowed duration,
// or the default duration if no Expiry time was requested by the client. The expiry time here does not (yet) cause
// a client to be disconnected, however they will have to re-auth the machine if they attempt to reconnect after the
// expiry time.
func (h *Headscale) updateMachineExpiry(m *Machine) {
if m.isExpired() {
now := time.Now().UTC()
maxExpiry := now.Add(h.cfg.MaxMachineRegistrationDuration) // calculate the maximum expiry
defaultExpiry := now.Add(h.cfg.DefaultMachineRegistrationDuration) // calculate the default expiry
// clamp the expiry time of the machine registration to the maximum allowed, or use the default if none supplied
if maxExpiry.Before(*m.RequestedExpiry) {
log.Debug().
Msgf("Clamping registration expiry time to maximum: %v (%v)", maxExpiry, h.cfg.MaxMachineRegistrationDuration)
m.Expiry = &maxExpiry
} else if m.RequestedExpiry.IsZero() {
log.Debug().Msgf("Using default machine registration expiry time: %v (%v)", defaultExpiry, h.cfg.DefaultMachineRegistrationDuration)
m.Expiry = &defaultExpiry
} else {
log.Debug().Msgf("Using requested machine registration expiry time: %v", m.RequestedExpiry)
m.Expiry = m.RequestedExpiry
}
h.db.Save(&m)
}
}
func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) { func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) {
log.Trace(). log.Trace().
Str("func", "getDirectPeers"). Str("func", "getDirectPeers").
@ -326,7 +359,11 @@ func (ms MachinesP) String() string {
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
} }
func (ms Machines) toNodes(baseDomain string, dnsConfig *tailcfg.DNSConfig, includeRoutes bool) ([]*tailcfg.Node, error) { func (ms Machines) toNodes(
baseDomain string,
dnsConfig *tailcfg.DNSConfig,
includeRoutes bool,
) ([]*tailcfg.Node, error) {
nodes := make([]*tailcfg.Node, len(ms)) nodes := make([]*tailcfg.Node, len(ms))
for index, machine := range ms { for index, machine := range ms {
@ -447,7 +484,9 @@ func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, include
n := tailcfg.Node{ n := tailcfg.Node{
ID: tailcfg.NodeID(m.ID), // this is the actual ID ID: tailcfg.NodeID(m.ID), // this is the actual ID
StableID: tailcfg.StableNodeID(strconv.FormatUint(m.ID, 10)), // in headscale, unlike tailcontrol server, IDs are permanent StableID: tailcfg.StableNodeID(
strconv.FormatUint(m.ID, 10),
), // in headscale, unlike tailcontrol server, IDs are permanent
Name: hostname, Name: hostname,
User: tailcfg.UserID(m.NamespaceID), User: tailcfg.UserID(m.NamespaceID),
Key: tailcfg.NodeKey(nKey), Key: tailcfg.NodeKey(nKey),

View File

@ -246,6 +246,17 @@ func (n *Namespace) toUser() *tailcfg.User {
return &u return &u
} }
func (n *Namespace) toLogin() *tailcfg.Login {
l := tailcfg.Login{
ID: tailcfg.LoginID(n.ID),
LoginName: n.Name,
DisplayName: n.Name,
ProfilePicURL: "",
Domain: "headscale.net",
}
return &l
}
func getMapResponseUserProfiles(m Machine, peers Machines) []tailcfg.UserProfile { func getMapResponseUserProfiles(m Machine, peers Machines) []tailcfg.UserProfile {
namespaceMap := make(map[string]Namespace) namespaceMap := make(map[string]Namespace)
namespaceMap[m.Namespace.Name] = m.Namespace namespaceMap[m.Namespace.Name] = m.Namespace

228
oidc.go Normal file
View File

@ -0,0 +1,228 @@
package headscale
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"net/http"
"regexp"
"strings"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gin-gonic/gin"
"github.com/patrickmn/go-cache"
"github.com/rs/zerolog/log"
"golang.org/x/oauth2"
)
type IDTokenClaims struct {
Name string `json:"name,omitempty"`
Groups []string `json:"groups,omitempty"`
Email string `json:"email"`
Username string `json:"preferred_username,omitempty"`
}
func (h *Headscale) initOIDC() error {
var err error
// grab oidc config if it hasn't been already
if h.oauth2Config == nil {
h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDC.Issuer)
if err != nil {
log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error())
return err
}
h.oauth2Config = &oauth2.Config{
ClientID: h.cfg.OIDC.ClientID,
ClientSecret: h.cfg.OIDC.ClientSecret,
Endpoint: h.oidcProvider.Endpoint(),
RedirectURL: fmt.Sprintf("%s/oidc/callback", strings.TrimSuffix(h.cfg.ServerURL, "/")),
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
}
}
// init the state cache if it hasn't been already
if h.oidcStateCache == nil {
h.oidcStateCache = cache.New(time.Minute*5, time.Minute*10)
}
return nil
}
// RegisterOIDC redirects to the OIDC provider for authentication
// Puts machine key in cache so the callback can retrieve it using the oidc state param
// Listens in /oidc/register/:mKey
func (h *Headscale) RegisterOIDC(c *gin.Context) {
mKeyStr := c.Param("mkey")
if mKeyStr == "" {
c.String(http.StatusBadRequest, "Wrong params")
return
}
b := make([]byte, 16)
_, err := rand.Read(b)
if err != nil {
log.Error().Msg("could not read 16 bytes from rand")
c.String(http.StatusInternalServerError, "could not read 16 bytes from rand")
return
}
stateStr := hex.EncodeToString(b)[:32]
// place the machine key into the state cache, so it can be retrieved later
h.oidcStateCache.Set(stateStr, mKeyStr, time.Minute*5)
authUrl := h.oauth2Config.AuthCodeURL(stateStr)
log.Debug().Msgf("Redirecting to %s for authentication", authUrl)
c.Redirect(http.StatusFound, authUrl)
}
// OIDCCallback handles the callback from the OIDC endpoint
// Retrieves the mkey from the state cache and adds the machine to the users email namespace
// TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities
// TODO: Add groups information from OIDC tokens into machine HostInfo
// Listens in /oidc/callback
func (h *Headscale) OIDCCallback(c *gin.Context) {
code := c.Query("code")
state := c.Query("state")
if code == "" || state == "" {
c.String(http.StatusBadRequest, "Wrong params")
return
}
oauth2Token, err := h.oauth2Config.Exchange(context.Background(), code)
if err != nil {
c.String(http.StatusBadRequest, "Could not exchange code for token")
return
}
log.Debug().Msgf("AccessToken: %v", oauth2Token.AccessToken)
rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string)
if !rawIDTokenOK {
c.String(http.StatusBadRequest, "Could not extract ID Token")
return
}
verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDC.ClientID})
idToken, err := verifier.Verify(context.Background(), rawIDToken)
if err != nil {
c.String(http.StatusBadRequest, "Failed to verify id token: %s", err.Error())
return
}
// TODO: we can use userinfo at some point to grab additional information about the user (groups membership, etc)
//userInfo, err := oidcProvider.UserInfo(context.Background(), oauth2.StaticTokenSource(oauth2Token))
//if err != nil {
// c.String(http.StatusBadRequest, fmt.Sprintf("Failed to retrieve userinfo: %s", err))
// return
//}
// Extract custom claims
var claims IDTokenClaims
if err = idToken.Claims(&claims); err != nil {
c.String(http.StatusBadRequest, fmt.Sprintf("Failed to decode id token claims: %s", err))
return
}
// retrieve machinekey from state cache
mKeyIf, mKeyFound := h.oidcStateCache.Get(state)
if !mKeyFound {
log.Error().Msg("requested machine state key expired before authorisation completed")
c.String(http.StatusBadRequest, "state has expired")
return
}
mKeyStr, mKeyOK := mKeyIf.(string)
if !mKeyOK {
log.Error().Msg("could not get machine key from cache")
c.String(http.StatusInternalServerError, "could not get machine key from cache")
return
}
// retrieve machine information
m, err := h.GetMachineByMachineKey(mKeyStr)
if err != nil {
log.Error().Msg("machine key not found in database")
c.String(http.StatusInternalServerError, "could not get machine info from database")
return
}
now := time.Now().UTC()
if nsName, ok := h.getNamespaceFromEmail(claims.Email); ok {
// register the machine if it's new
if !m.Registered {
log.Debug().Msg("Registering new machine after successful callback")
ns, err := h.GetNamespace(nsName)
if err != nil {
ns, err = h.CreateNamespace(nsName)
if err != nil {
log.Error().Msgf("could not create new namespace '%s'", claims.Email)
c.String(http.StatusInternalServerError, "could not create new namespace")
return
}
}
ip, err := h.getAvailableIP()
if err != nil {
c.String(http.StatusInternalServerError, "could not get an IP from the pool")
return
}
m.IPAddress = ip.String()
m.NamespaceID = ns.ID
m.Registered = true
m.RegisterMethod = "oidc"
m.LastSuccessfulUpdate = &now
h.db.Save(&m)
}
h.updateMachineExpiry(m)
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
<html>
<body>
<h1>headscale</h1>
<p>
Authenticated as %s, you can now close this window.
</p>
</body>
</html>
`, claims.Email)))
}
log.Error().
Str("email", claims.Email).
Str("username", claims.Username).
Str("machine", m.Name).
Msg("Email could not be mapped to a namespace")
c.String(http.StatusBadRequest, "email from claim could not be mapped to a namespace")
}
// getNamespaceFromEmail passes the users email through a list of "matchers"
// and iterates through them until it matches and returns a namespace.
// If no match is found, an empty string will be returned.
// TODO(kradalby): golang Maps key order is not stable, so this list is _not_ deterministic. Find a way to make the list of keys stable, preferably in the order presented in a users configuration.
func (h *Headscale) getNamespaceFromEmail(email string) (string, bool) {
for match, namespace := range h.cfg.OIDC.MatchMap {
regex := regexp.MustCompile(match)
if regex.MatchString(email) {
return namespace, true
}
}
return "", false
}

174
oidc_test.go Normal file
View File

@ -0,0 +1,174 @@
package headscale
import (
"sync"
"testing"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/patrickmn/go-cache"
"golang.org/x/oauth2"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/wgkey"
)
func TestHeadscale_getNamespaceFromEmail(t *testing.T) {
type fields struct {
cfg Config
db *gorm.DB
dbString string
dbType string
dbDebug bool
publicKey *wgkey.Key
privateKey *wgkey.Private
aclPolicy *ACLPolicy
aclRules *[]tailcfg.FilterRule
lastStateChange sync.Map
oidcProvider *oidc.Provider
oauth2Config *oauth2.Config
oidcStateCache *cache.Cache
}
type args struct {
email string
}
tests := []struct {
name string
fields fields
args args
want string
want1 bool
}{
{
name: "match all",
fields: fields{
cfg: Config{
OIDC: OIDCConfig{
MatchMap: map[string]string{
".*": "space",
},
},
},
},
args: args{
email: "test@example.no",
},
want: "space",
want1: true,
},
{
name: "match user",
fields: fields{
cfg: Config{
OIDC: OIDCConfig{
MatchMap: map[string]string{
"specific@user\\.no": "user-namespace",
},
},
},
},
args: args{
email: "specific@user.no",
},
want: "user-namespace",
want1: true,
},
{
name: "match domain",
fields: fields{
cfg: Config{
OIDC: OIDCConfig{
MatchMap: map[string]string{
".*@example\\.no": "example",
},
},
},
},
args: args{
email: "test@example.no",
},
want: "example",
want1: true,
},
{
name: "multi match domain",
fields: fields{
cfg: Config{
OIDC: OIDCConfig{
MatchMap: map[string]string{
".*@example\\.no": "exammple",
".*@gmail\\.com": "gmail",
},
},
},
},
args: args{
email: "someuser@gmail.com",
},
want: "gmail",
want1: true,
},
{
name: "no match domain",
fields: fields{
cfg: Config{
OIDC: OIDCConfig{
MatchMap: map[string]string{
".*@dontknow.no": "never",
},
},
},
},
args: args{
email: "test@wedontknow.no",
},
want: "",
want1: false,
},
{
name: "multi no match domain",
fields: fields{
cfg: Config{
OIDC: OIDCConfig{
MatchMap: map[string]string{
".*@dontknow.no": "never",
".*@wedontknow.no": "other",
".*\\.no": "stuffy",
},
},
},
},
args: args{
email: "tasy@nonofthem.com",
},
want: "",
want1: false,
},
}
//nolint
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := &Headscale{
cfg: tt.fields.cfg,
db: tt.fields.db,
dbString: tt.fields.dbString,
dbType: tt.fields.dbType,
dbDebug: tt.fields.dbDebug,
publicKey: tt.fields.publicKey,
privateKey: tt.fields.privateKey,
aclPolicy: tt.fields.aclPolicy,
aclRules: tt.fields.aclRules,
lastStateChange: tt.fields.lastStateChange,
oidcProvider: tt.fields.oidcProvider,
oauth2Config: tt.fields.oauth2Config,
oidcStateCache: tt.fields.oidcStateCache,
}
got, got1 := h.getNamespaceFromEmail(tt.args.email)
if got != tt.want {
t.Errorf("Headscale.getNamespaceFromEmail() got = %v, want %v", got, tt.want)
}
if got1 != tt.want1 {
t.Errorf("Headscale.getNamespaceFromEmail() got1 = %v, want %v", got1, tt.want1)
}
})
}
}