Implement namespace matching

This commit is contained in:
Kristoffer Dalby 2021-10-18 19:27:52 +00:00
parent a347d276bd
commit 677bd9b657
5 changed files with 267 additions and 55 deletions

4
api.go
View File

@ -170,7 +170,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
Str("machine", m.Name). Str("machine", m.Name).
Msg("Machine registration has expired. Sending a authurl to register") Msg("Machine registration has expired. Sending a authurl to register")
if h.cfg.OIDCIssuer != "" { if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
} else { } else {
@ -225,7 +225,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
Str("handler", "Registration"). Str("handler", "Registration").
Str("machine", m.Name). Str("machine", m.Name).
Msg("The node is sending us a new NodeKey, sending auth url") Msg("The node is sending us a new NodeKey, sending auth url")
if h.cfg.OIDCIssuer != "" { if h.cfg.OIDC.Issuer != "" {
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString()) resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
} else { } else {
resp.AuthURL = fmt.Sprintf("%s/register?key=%s", resp.AuthURL = fmt.Sprintf("%s/register?key=%s",

26
app.go
View File

@ -3,9 +3,6 @@ package headscale
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/patrickmn/go-cache"
"golang.org/x/oauth2"
"net/http" "net/http"
"os" "os"
"sort" "sort"
@ -13,6 +10,10 @@ import (
"sync" "sync"
"time" "time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/patrickmn/go-cache"
"golang.org/x/oauth2"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -57,14 +58,19 @@ type Config struct {
DNSConfig *tailcfg.DNSConfig DNSConfig *tailcfg.DNSConfig
OIDCIssuer string OIDC OIDCConfig
OIDCClientID string
OIDCClientSecret string
MaxMachineRegistrationDuration time.Duration MaxMachineRegistrationDuration time.Duration
DefaultMachineRegistrationDuration time.Duration DefaultMachineRegistrationDuration time.Duration
} }
type OIDCConfig struct {
Issuer string
ClientID string
ClientSecret string
MatchMap map[string]string
}
// Headscale represents the base app of the service // Headscale represents the base app of the service
type Headscale struct { type Headscale struct {
cfg Config cfg Config
@ -122,14 +128,14 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
return nil, err return nil, err
} }
if cfg.OIDCIssuer != "" { if cfg.OIDC.Issuer != "" {
err = h.initOIDC() err = h.initOIDC()
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
if h.cfg.DNSConfig != nil && h.cfg.DNSConfig.Proxied { // if MagicDNS if h.cfg.DNSConfig != nil && h.cfg.DNSConfig.Proxied { // if MagicDNS
magicDNSDomains, err := generateMagicDNSRootDomains(h.cfg.IPPrefix, h.cfg.BaseDomain) magicDNSDomains, err := generateMagicDNSRootDomains(h.cfg.IPPrefix, h.cfg.BaseDomain)
if err != nil { if err != nil {
return nil, err return nil, err
@ -294,7 +300,6 @@ func (h *Headscale) getLastStateChange(namespaces ...string) time.Time {
times = append(times, lastChange) times = append(times, lastChange)
} }
} }
sort.Slice(times, func(i, j int) bool { sort.Slice(times, func(i, j int) bool {
@ -305,7 +310,6 @@ func (h *Headscale) getLastStateChange(namespaces ...string) time.Time {
if len(times) == 0 { if len(times) == 0 {
return time.Now().UTC() return time.Now().UTC()
} else { } else {
return times[0] return times[0]
} }

View File

@ -7,6 +7,7 @@ import (
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
"regexp"
"strings" "strings"
"time" "time"
@ -73,7 +74,6 @@ func LoadConfig(path string) error {
} else { } else {
return nil return nil
} }
} }
func GetDNSConfig() (*tailcfg.DNSConfig, string) { func GetDNSConfig() (*tailcfg.DNSConfig, string) {
@ -206,15 +206,19 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
ACMEEmail: viper.GetString("acme_email"), ACMEEmail: viper.GetString("acme_email"),
ACMEURL: viper.GetString("acme_url"), ACMEURL: viper.GetString("acme_url"),
OIDCIssuer: viper.GetString("oidc_issuer"), OIDC: headscale.OIDCConfig{
OIDCClientID: viper.GetString("oidc_client_id"), Issuer: viper.GetString("oidc.issuer"),
OIDCClientSecret: viper.GetString("oidc_client_secret"), ClientID: viper.GetString("oidc.client_id"),
ClientSecret: viper.GetString("oidc.client_secret"),
},
MaxMachineRegistrationDuration: maxMachineRegistrationDuration, // the maximum duration a client may request for expiry time MaxMachineRegistrationDuration: maxMachineRegistrationDuration, // the maximum duration a client may request for expiry time
DefaultMachineRegistrationDuration: defaultMachineRegistrationDuration, // if a client does not request a specific expiry time, use this duration DefaultMachineRegistrationDuration: defaultMachineRegistrationDuration, // if a client does not request a specific expiry time, use this duration
} }
cfg.OIDC.MatchMap = loadOIDCMatchMap()
h, err := headscale.NewHeadscale(cfg) h, err := headscale.NewHeadscale(cfg)
if err != nil { if err != nil {
return nil, err return nil, err
@ -291,3 +295,15 @@ func HasJsonOutputFlag() bool {
} }
return false return false
} }
// loadOIDCMatchMap is a wrapper around viper to verifies that the keys in
// the match map is valid regex strings.
func loadOIDCMatchMap() map[string]string {
strMap := viper.GetStringMapString("oidc.domain_map")
for oidcMatcher := range strMap {
_ = regexp.MustCompile(oidcMatcher)
}
return strMap
}

93
oidc.go
View File

@ -5,14 +5,16 @@ import (
"crypto/rand" "crypto/rand"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"net/http"
"regexp"
"strings"
"time"
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/patrickmn/go-cache" "github.com/patrickmn/go-cache"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"net/http"
"strings"
"time"
) )
type IDTokenClaims struct { type IDTokenClaims struct {
@ -26,7 +28,7 @@ func (h *Headscale) initOIDC() error {
var err error var err error
// grab oidc config if it hasn't been already // grab oidc config if it hasn't been already
if h.oauth2Config == nil { if h.oauth2Config == nil {
h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDCIssuer) h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDC.Issuer)
if err != nil { if err != nil {
log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error()) log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error())
@ -34,8 +36,8 @@ func (h *Headscale) initOIDC() error {
} }
h.oauth2Config = &oauth2.Config{ h.oauth2Config = &oauth2.Config{
ClientID: h.cfg.OIDCClientID, ClientID: h.cfg.OIDC.ClientID,
ClientSecret: h.cfg.OIDCClientSecret, ClientSecret: h.cfg.OIDC.ClientSecret,
Endpoint: h.oidcProvider.Endpoint(), Endpoint: h.oidcProvider.Endpoint(),
RedirectURL: fmt.Sprintf("%s/oidc/callback", strings.TrimSuffix(h.cfg.ServerURL, "/")), RedirectURL: fmt.Sprintf("%s/oidc/callback", strings.TrimSuffix(h.cfg.ServerURL, "/")),
Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
@ -62,7 +64,6 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
b := make([]byte, 16) b := make([]byte, 16)
_, err := rand.Read(b) _, err := rand.Read(b)
if err != nil { if err != nil {
log.Error().Msg("could not read 16 bytes from rand") log.Error().Msg("could not read 16 bytes from rand")
c.String(http.StatusInternalServerError, "could not read 16 bytes from rand") c.String(http.StatusInternalServerError, "could not read 16 bytes from rand")
@ -86,7 +87,6 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
// TODO: Add groups information from OIDC tokens into machine HostInfo // TODO: Add groups information from OIDC tokens into machine HostInfo
// Listens in /oidc/callback // Listens in /oidc/callback
func (h *Headscale) OIDCCallback(c *gin.Context) { func (h *Headscale) OIDCCallback(c *gin.Context) {
code := c.Query("code") code := c.Query("code")
state := c.Query("state") state := c.Query("state")
@ -109,7 +109,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
return return
} }
verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDCClientID}) verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDC.ClientID})
idToken, err := verifier.Verify(context.Background(), rawIDToken) idToken, err := verifier.Verify(context.Background(), rawIDToken)
if err != nil { if err != nil {
@ -131,7 +131,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
return return
} }
//retrieve machinekey from state cache // retrieve machinekey from state cache
mKeyIf, mKeyFound := h.oidcStateCache.Get(state) mKeyIf, mKeyFound := h.oidcStateCache.Get(state)
if !mKeyFound { if !mKeyFound {
@ -149,7 +149,6 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
// retrieve machine information // retrieve machine information
m, err := h.GetMachineByMachineKey(mKeyStr) m, err := h.GetMachineByMachineKey(mKeyStr)
if err != nil { if err != nil {
log.Error().Msg("machine key not found in database") log.Error().Msg("machine key not found in database")
c.String(http.StatusInternalServerError, "could not get machine info from database") c.String(http.StatusInternalServerError, "could not get machine info from database")
@ -158,40 +157,40 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
now := time.Now().UTC() now := time.Now().UTC()
// register the machine if it's new if nsName, ok := h.getNamespaceFromEmail(claims.Email); ok {
if !m.Registered { // register the machine if it's new
nsName := strings.ReplaceAll(claims.Email, "@", "-") // TODO: Implement a better email sanitisation if !m.Registered {
log.Debug().Msg("Registering new machine after successful callback") log.Debug().Msg("Registering new machine after successful callback")
ns, err := h.GetNamespace(nsName)
if err != nil {
ns, err = h.CreateNamespace(nsName)
ns, err := h.GetNamespace(nsName)
if err != nil { if err != nil {
log.Error().Msgf("could not create new namespace '%s'", claims.Email) ns, err = h.CreateNamespace(nsName)
c.String(http.StatusInternalServerError, "could not create new namespace")
if err != nil {
log.Error().Msgf("could not create new namespace '%s'", claims.Email)
c.String(http.StatusInternalServerError, "could not create new namespace")
return
}
}
ip, err := h.getAvailableIP()
if err != nil {
c.String(http.StatusInternalServerError, "could not get an IP from the pool")
return return
} }
m.IPAddress = ip.String()
m.NamespaceID = ns.ID
m.Registered = true
m.RegisterMethod = "oidc"
m.LastSuccessfulUpdate = &now
h.db.Save(&m)
} }
ip, err := h.getAvailableIP() h.updateMachineExpiry(m)
if err != nil {
c.String(http.StatusInternalServerError, "could not get an IP from the pool")
return
}
m.IPAddress = ip.String() c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
m.NamespaceID = ns.ID
m.Registered = true
m.RegisterMethod = "oidc"
m.LastSuccessfulUpdate = &now
h.db.Save(&m)
}
h.updateMachineExpiry(m)
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
<html> <html>
<body> <body>
<h1>headscale</h1> <h1>headscale</h1>
@ -202,4 +201,24 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
</html> </html>
`, claims.Email))) `, claims.Email)))
}
log.Error().
Str("email", claims.Email).
Str("username", claims.Username).
Str("machine", m.Name).
Msg("Email could not be mapped to a namespace")
c.String(http.StatusBadRequest, "email from claim could not be mapped to a namespace")
}
func (h *Headscale) getNamespaceFromEmail(email string) (string, bool) {
for match, namespace := range h.cfg.OIDC.MatchMap {
regex := regexp.MustCompile(match)
if regex.MatchString(email) {
return namespace, true
}
}
return "", false
} }

173
oidc_test.go Normal file
View File

@ -0,0 +1,173 @@
package headscale
import (
"sync"
"testing"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/patrickmn/go-cache"
"golang.org/x/oauth2"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/wgkey"
)
func TestHeadscale_getNamespaceFromEmail(t *testing.T) {
type fields struct {
cfg Config
db *gorm.DB
dbString string
dbType string
dbDebug bool
publicKey *wgkey.Key
privateKey *wgkey.Private
aclPolicy *ACLPolicy
aclRules *[]tailcfg.FilterRule
lastStateChange sync.Map
oidcProvider *oidc.Provider
oauth2Config *oauth2.Config
oidcStateCache *cache.Cache
}
type args struct {
email string
}
tests := []struct {
name string
fields fields
args args
want string
want1 bool
}{
{
name: "match all",
fields: fields{
cfg: Config{
OIDC: OIDCConfig{
MatchMap: map[string]string{
".*": "space",
},
},
},
},
args: args{
email: "test@example.no",
},
want: "space",
want1: true,
},
{
name: "match user",
fields: fields{
cfg: Config{
OIDC: OIDCConfig{
MatchMap: map[string]string{
"specific@user\\.no": "user-namespace",
},
},
},
},
args: args{
email: "specific@user.no",
},
want: "user-namespace",
want1: true,
},
{
name: "match domain",
fields: fields{
cfg: Config{
OIDC: OIDCConfig{
MatchMap: map[string]string{
".*@example\\.no": "example",
},
},
},
},
args: args{
email: "test@example.no",
},
want: "example",
want1: true,
},
{
name: "multi match domain",
fields: fields{
cfg: Config{
OIDC: OIDCConfig{
MatchMap: map[string]string{
".*@example\\.no": "exammple",
".*@gmail\\.com": "gmail",
},
},
},
},
args: args{
email: "someuser@gmail.com",
},
want: "gmail",
want1: true,
},
{
name: "no match domain",
fields: fields{
cfg: Config{
OIDC: OIDCConfig{
MatchMap: map[string]string{
".*@dontknow.no": "never",
},
},
},
},
args: args{
email: "test@wedontknow.no",
},
want: "",
want1: false,
},
{
name: "multi no match domain",
fields: fields{
cfg: Config{
OIDC: OIDCConfig{
MatchMap: map[string]string{
".*@dontknow.no": "never",
".*@wedontknow.no": "other",
".*\\.no": "stuffy",
},
},
},
},
args: args{
email: "tasy@nonofthem.com",
},
want: "",
want1: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := &Headscale{
cfg: tt.fields.cfg,
db: tt.fields.db,
dbString: tt.fields.dbString,
dbType: tt.fields.dbType,
dbDebug: tt.fields.dbDebug,
publicKey: tt.fields.publicKey,
privateKey: tt.fields.privateKey,
aclPolicy: tt.fields.aclPolicy,
aclRules: tt.fields.aclRules,
lastStateChange: tt.fields.lastStateChange,
oidcProvider: tt.fields.oidcProvider,
oauth2Config: tt.fields.oauth2Config,
oidcStateCache: tt.fields.oidcStateCache,
}
got, got1 := h.getNamespaceFromEmail(tt.args.email)
if got != tt.want {
t.Errorf("Headscale.getNamespaceFromEmail() got = %v, want %v", got, tt.want)
}
if got1 != tt.want1 {
t.Errorf("Headscale.getNamespaceFromEmail() got1 = %v, want %v", got1, tt.want1)
}
})
}
}