mirror of
https://github.com/juanfont/headscale.git
synced 2024-11-23 10:05:19 +00:00
Implement namespace matching
This commit is contained in:
parent
a347d276bd
commit
677bd9b657
6
api.go
6
api.go
@ -170,7 +170,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||
Str("machine", m.Name).
|
||||
Msg("Machine registration has expired. Sending a authurl to register")
|
||||
|
||||
if h.cfg.OIDCIssuer != "" {
|
||||
if h.cfg.OIDC.Issuer != "" {
|
||||
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
|
||||
strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
|
||||
} else {
|
||||
@ -225,7 +225,7 @@ func (h *Headscale) RegistrationHandler(c *gin.Context) {
|
||||
Str("handler", "Registration").
|
||||
Str("machine", m.Name).
|
||||
Msg("The node is sending us a new NodeKey, sending auth url")
|
||||
if h.cfg.OIDCIssuer != "" {
|
||||
if h.cfg.OIDC.Issuer != "" {
|
||||
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", strings.TrimSuffix(h.cfg.ServerURL, "/"), mKey.HexString())
|
||||
} else {
|
||||
resp.AuthURL = fmt.Sprintf("%s/register?key=%s",
|
||||
@ -424,7 +424,7 @@ func (h *Headscale) handleAuthKey(c *gin.Context, db *gorm.DB, idKey wgkey.Key,
|
||||
db.Save(&m)
|
||||
|
||||
h.updateMachineExpiry(&m) // TODO: do we want to do different expiry times for AuthKeys?
|
||||
|
||||
|
||||
pak.Used = true
|
||||
db.Save(&pak)
|
||||
|
||||
|
26
app.go
26
app.go
@ -3,9 +3,6 @@ package headscale
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"golang.org/x/oauth2"
|
||||
"net/http"
|
||||
"os"
|
||||
"sort"
|
||||
@ -13,6 +10,10 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@ -57,14 +58,19 @@ type Config struct {
|
||||
|
||||
DNSConfig *tailcfg.DNSConfig
|
||||
|
||||
OIDCIssuer string
|
||||
OIDCClientID string
|
||||
OIDCClientSecret string
|
||||
OIDC OIDCConfig
|
||||
|
||||
MaxMachineRegistrationDuration time.Duration
|
||||
DefaultMachineRegistrationDuration time.Duration
|
||||
}
|
||||
|
||||
type OIDCConfig struct {
|
||||
Issuer string
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
MatchMap map[string]string
|
||||
}
|
||||
|
||||
// Headscale represents the base app of the service
|
||||
type Headscale struct {
|
||||
cfg Config
|
||||
@ -122,14 +128,14 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cfg.OIDCIssuer != "" {
|
||||
if cfg.OIDC.Issuer != "" {
|
||||
err = h.initOIDC()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if h.cfg.DNSConfig != nil && h.cfg.DNSConfig.Proxied { // if MagicDNS
|
||||
if h.cfg.DNSConfig != nil && h.cfg.DNSConfig.Proxied { // if MagicDNS
|
||||
magicDNSDomains, err := generateMagicDNSRootDomains(h.cfg.IPPrefix, h.cfg.BaseDomain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -294,7 +300,6 @@ func (h *Headscale) getLastStateChange(namespaces ...string) time.Time {
|
||||
|
||||
times = append(times, lastChange)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
sort.Slice(times, func(i, j int) bool {
|
||||
@ -305,7 +310,6 @@ func (h *Headscale) getLastStateChange(namespaces ...string) time.Time {
|
||||
|
||||
if len(times) == 0 {
|
||||
return time.Now().UTC()
|
||||
|
||||
} else {
|
||||
return times[0]
|
||||
}
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@ -73,7 +74,6 @@ func LoadConfig(path string) error {
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func GetDNSConfig() (*tailcfg.DNSConfig, string) {
|
||||
@ -206,15 +206,19 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
|
||||
ACMEEmail: viper.GetString("acme_email"),
|
||||
ACMEURL: viper.GetString("acme_url"),
|
||||
|
||||
OIDCIssuer: viper.GetString("oidc_issuer"),
|
||||
OIDCClientID: viper.GetString("oidc_client_id"),
|
||||
OIDCClientSecret: viper.GetString("oidc_client_secret"),
|
||||
OIDC: headscale.OIDCConfig{
|
||||
Issuer: viper.GetString("oidc.issuer"),
|
||||
ClientID: viper.GetString("oidc.client_id"),
|
||||
ClientSecret: viper.GetString("oidc.client_secret"),
|
||||
},
|
||||
|
||||
MaxMachineRegistrationDuration: maxMachineRegistrationDuration, // the maximum duration a client may request for expiry time
|
||||
DefaultMachineRegistrationDuration: defaultMachineRegistrationDuration, // if a client does not request a specific expiry time, use this duration
|
||||
|
||||
}
|
||||
|
||||
cfg.OIDC.MatchMap = loadOIDCMatchMap()
|
||||
|
||||
h, err := headscale.NewHeadscale(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -291,3 +295,15 @@ func HasJsonOutputFlag() bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// loadOIDCMatchMap is a wrapper around viper to verifies that the keys in
|
||||
// the match map is valid regex strings.
|
||||
func loadOIDCMatchMap() map[string]string {
|
||||
strMap := viper.GetStringMapString("oidc.domain_map")
|
||||
|
||||
for oidcMatcher := range strMap {
|
||||
_ = regexp.MustCompile(oidcMatcher)
|
||||
}
|
||||
|
||||
return strMap
|
||||
}
|
||||
|
93
oidc.go
93
oidc.go
@ -5,14 +5,16 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/oauth2"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type IDTokenClaims struct {
|
||||
@ -26,7 +28,7 @@ func (h *Headscale) initOIDC() error {
|
||||
var err error
|
||||
// grab oidc config if it hasn't been already
|
||||
if h.oauth2Config == nil {
|
||||
h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDCIssuer)
|
||||
h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDC.Issuer)
|
||||
|
||||
if err != nil {
|
||||
log.Error().Msgf("Could not retrieve OIDC Config: %s", err.Error())
|
||||
@ -34,8 +36,8 @@ func (h *Headscale) initOIDC() error {
|
||||
}
|
||||
|
||||
h.oauth2Config = &oauth2.Config{
|
||||
ClientID: h.cfg.OIDCClientID,
|
||||
ClientSecret: h.cfg.OIDCClientSecret,
|
||||
ClientID: h.cfg.OIDC.ClientID,
|
||||
ClientSecret: h.cfg.OIDC.ClientSecret,
|
||||
Endpoint: h.oidcProvider.Endpoint(),
|
||||
RedirectURL: fmt.Sprintf("%s/oidc/callback", strings.TrimSuffix(h.cfg.ServerURL, "/")),
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
@ -62,7 +64,6 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
|
||||
|
||||
b := make([]byte, 16)
|
||||
_, err := rand.Read(b)
|
||||
|
||||
if err != nil {
|
||||
log.Error().Msg("could not read 16 bytes from rand")
|
||||
c.String(http.StatusInternalServerError, "could not read 16 bytes from rand")
|
||||
@ -86,7 +87,6 @@ func (h *Headscale) RegisterOIDC(c *gin.Context) {
|
||||
// TODO: Add groups information from OIDC tokens into machine HostInfo
|
||||
// Listens in /oidc/callback
|
||||
func (h *Headscale) OIDCCallback(c *gin.Context) {
|
||||
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
|
||||
@ -109,7 +109,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDCClientID})
|
||||
verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDC.ClientID})
|
||||
|
||||
idToken, err := verifier.Verify(context.Background(), rawIDToken)
|
||||
if err != nil {
|
||||
@ -131,7 +131,7 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
//retrieve machinekey from state cache
|
||||
// retrieve machinekey from state cache
|
||||
mKeyIf, mKeyFound := h.oidcStateCache.Get(state)
|
||||
|
||||
if !mKeyFound {
|
||||
@ -149,7 +149,6 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
||||
|
||||
// retrieve machine information
|
||||
m, err := h.GetMachineByMachineKey(mKeyStr)
|
||||
|
||||
if err != nil {
|
||||
log.Error().Msg("machine key not found in database")
|
||||
c.String(http.StatusInternalServerError, "could not get machine info from database")
|
||||
@ -158,40 +157,40 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
// register the machine if it's new
|
||||
if !m.Registered {
|
||||
nsName := strings.ReplaceAll(claims.Email, "@", "-") // TODO: Implement a better email sanitisation
|
||||
if nsName, ok := h.getNamespaceFromEmail(claims.Email); ok {
|
||||
// register the machine if it's new
|
||||
if !m.Registered {
|
||||
|
||||
log.Debug().Msg("Registering new machine after successful callback")
|
||||
|
||||
ns, err := h.GetNamespace(nsName)
|
||||
if err != nil {
|
||||
ns, err = h.CreateNamespace(nsName)
|
||||
log.Debug().Msg("Registering new machine after successful callback")
|
||||
|
||||
ns, err := h.GetNamespace(nsName)
|
||||
if err != nil {
|
||||
log.Error().Msgf("could not create new namespace '%s'", claims.Email)
|
||||
c.String(http.StatusInternalServerError, "could not create new namespace")
|
||||
ns, err = h.CreateNamespace(nsName)
|
||||
|
||||
if err != nil {
|
||||
log.Error().Msgf("could not create new namespace '%s'", claims.Email)
|
||||
c.String(http.StatusInternalServerError, "could not create new namespace")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ip, err := h.getAvailableIP()
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, "could not get an IP from the pool")
|
||||
return
|
||||
}
|
||||
|
||||
m.IPAddress = ip.String()
|
||||
m.NamespaceID = ns.ID
|
||||
m.Registered = true
|
||||
m.RegisterMethod = "oidc"
|
||||
m.LastSuccessfulUpdate = &now
|
||||
h.db.Save(&m)
|
||||
}
|
||||
|
||||
ip, err := h.getAvailableIP()
|
||||
if err != nil {
|
||||
c.String(http.StatusInternalServerError, "could not get an IP from the pool")
|
||||
return
|
||||
}
|
||||
h.updateMachineExpiry(m)
|
||||
|
||||
m.IPAddress = ip.String()
|
||||
m.NamespaceID = ns.ID
|
||||
m.Registered = true
|
||||
m.RegisterMethod = "oidc"
|
||||
m.LastSuccessfulUpdate = &now
|
||||
h.db.Save(&m)
|
||||
}
|
||||
|
||||
h.updateMachineExpiry(m)
|
||||
|
||||
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
|
||||
c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(fmt.Sprintf(`
|
||||
<html>
|
||||
<body>
|
||||
<h1>headscale</h1>
|
||||
@ -202,4 +201,24 @@ func (h *Headscale) OIDCCallback(c *gin.Context) {
|
||||
</html>
|
||||
|
||||
`, claims.Email)))
|
||||
|
||||
}
|
||||
|
||||
log.Error().
|
||||
Str("email", claims.Email).
|
||||
Str("username", claims.Username).
|
||||
Str("machine", m.Name).
|
||||
Msg("Email could not be mapped to a namespace")
|
||||
c.String(http.StatusBadRequest, "email from claim could not be mapped to a namespace")
|
||||
}
|
||||
|
||||
func (h *Headscale) getNamespaceFromEmail(email string) (string, bool) {
|
||||
for match, namespace := range h.cfg.OIDC.MatchMap {
|
||||
regex := regexp.MustCompile(match)
|
||||
if regex.MatchString(email) {
|
||||
return namespace, true
|
||||
}
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
173
oidc_test.go
Normal file
173
oidc_test.go
Normal file
@ -0,0 +1,173 @@
|
||||
package headscale
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"golang.org/x/oauth2"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/wgkey"
|
||||
)
|
||||
|
||||
func TestHeadscale_getNamespaceFromEmail(t *testing.T) {
|
||||
type fields struct {
|
||||
cfg Config
|
||||
db *gorm.DB
|
||||
dbString string
|
||||
dbType string
|
||||
dbDebug bool
|
||||
publicKey *wgkey.Key
|
||||
privateKey *wgkey.Private
|
||||
aclPolicy *ACLPolicy
|
||||
aclRules *[]tailcfg.FilterRule
|
||||
lastStateChange sync.Map
|
||||
oidcProvider *oidc.Provider
|
||||
oauth2Config *oauth2.Config
|
||||
oidcStateCache *cache.Cache
|
||||
}
|
||||
type args struct {
|
||||
email string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want string
|
||||
want1 bool
|
||||
}{
|
||||
{
|
||||
name: "match all",
|
||||
fields: fields{
|
||||
cfg: Config{
|
||||
OIDC: OIDCConfig{
|
||||
MatchMap: map[string]string{
|
||||
".*": "space",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
email: "test@example.no",
|
||||
},
|
||||
want: "space",
|
||||
want1: true,
|
||||
},
|
||||
{
|
||||
name: "match user",
|
||||
fields: fields{
|
||||
cfg: Config{
|
||||
OIDC: OIDCConfig{
|
||||
MatchMap: map[string]string{
|
||||
"specific@user\\.no": "user-namespace",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
email: "specific@user.no",
|
||||
},
|
||||
want: "user-namespace",
|
||||
want1: true,
|
||||
},
|
||||
{
|
||||
name: "match domain",
|
||||
fields: fields{
|
||||
cfg: Config{
|
||||
OIDC: OIDCConfig{
|
||||
MatchMap: map[string]string{
|
||||
".*@example\\.no": "example",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
email: "test@example.no",
|
||||
},
|
||||
want: "example",
|
||||
want1: true,
|
||||
},
|
||||
{
|
||||
name: "multi match domain",
|
||||
fields: fields{
|
||||
cfg: Config{
|
||||
OIDC: OIDCConfig{
|
||||
MatchMap: map[string]string{
|
||||
".*@example\\.no": "exammple",
|
||||
".*@gmail\\.com": "gmail",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
email: "someuser@gmail.com",
|
||||
},
|
||||
want: "gmail",
|
||||
want1: true,
|
||||
},
|
||||
{
|
||||
name: "no match domain",
|
||||
fields: fields{
|
||||
cfg: Config{
|
||||
OIDC: OIDCConfig{
|
||||
MatchMap: map[string]string{
|
||||
".*@dontknow.no": "never",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
email: "test@wedontknow.no",
|
||||
},
|
||||
want: "",
|
||||
want1: false,
|
||||
},
|
||||
{
|
||||
name: "multi no match domain",
|
||||
fields: fields{
|
||||
cfg: Config{
|
||||
OIDC: OIDCConfig{
|
||||
MatchMap: map[string]string{
|
||||
".*@dontknow.no": "never",
|
||||
".*@wedontknow.no": "other",
|
||||
".*\\.no": "stuffy",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
email: "tasy@nonofthem.com",
|
||||
},
|
||||
want: "",
|
||||
want1: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := &Headscale{
|
||||
cfg: tt.fields.cfg,
|
||||
db: tt.fields.db,
|
||||
dbString: tt.fields.dbString,
|
||||
dbType: tt.fields.dbType,
|
||||
dbDebug: tt.fields.dbDebug,
|
||||
publicKey: tt.fields.publicKey,
|
||||
privateKey: tt.fields.privateKey,
|
||||
aclPolicy: tt.fields.aclPolicy,
|
||||
aclRules: tt.fields.aclRules,
|
||||
lastStateChange: tt.fields.lastStateChange,
|
||||
oidcProvider: tt.fields.oidcProvider,
|
||||
oauth2Config: tt.fields.oauth2Config,
|
||||
oidcStateCache: tt.fields.oidcStateCache,
|
||||
}
|
||||
got, got1 := h.getNamespaceFromEmail(tt.args.email)
|
||||
if got != tt.want {
|
||||
t.Errorf("Headscale.getNamespaceFromEmail() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
if got1 != tt.want1 {
|
||||
t.Errorf("Headscale.getNamespaceFromEmail() got1 = %v, want %v", got1, tt.want1)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user