diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go index 87f94eb9..a291ad7d 100644 --- a/hscontrol/db/db_test.go +++ b/hscontrol/db/db_test.go @@ -1,6 +1,7 @@ package db import ( + "database/sql" "fmt" "io" "net/netip" @@ -257,3 +258,110 @@ func testCopyOfDatabase(src string) (string, error) { func emptyCache() *zcache.Cache[string, types.Node] { return zcache.New[string, types.Node](time.Minute, time.Hour) } + +func TestConstraints(t *testing.T) { + tests := []struct { + name string + run func(*testing.T, *gorm.DB) + }{ + { + name: "no-duplicate-username-if-no-oidc", + run: func(t *testing.T, db *gorm.DB) { + _, err := CreateUser(db, "user1") + require.NoError(t, err) + _, err = CreateUser(db, "user1") + require.Error(t, err) + // assert.Contains(t, err.Error(), "UNIQUE constraint failed: users.username") + require.Contains(t, err.Error(), "user already exists") + }, + }, + { + name: "no-oidc-duplicate-username-and-id", + run: func(t *testing.T, db *gorm.DB) { + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "user1", + } + user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true} + + err := db.Save(&user).Error + require.NoError(t, err) + + user = types.User{ + Model: gorm.Model{ID: 2}, + Name: "user1", + } + user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true} + + err = db.Save(&user).Error + require.Error(t, err) + require.Contains(t, err.Error(), "UNIQUE constraint failed: users.provider_identifier") + }, + }, + { + name: "no-oidc-duplicate-id", + run: func(t *testing.T, db *gorm.DB) { + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "user1", + } + user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true} + + err := db.Save(&user).Error + require.NoError(t, err) + + user = types.User{ + Model: gorm.Model{ID: 2}, + Name: "user1.1", + } + user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true} + + err = db.Save(&user).Error + require.Error(t, err) + require.Contains(t, err.Error(), "UNIQUE constraint failed: users.provider_identifier") + }, + }, + { + name: "allow-duplicate-username-cli-then-oidc", + run: func(t *testing.T, db *gorm.DB) { + _, err := CreateUser(db, "user1") // Create CLI username + require.NoError(t, err) + + user := types.User{ + Name: "user1", + } + user.ProviderIdentifier.String = "http://test.com/user1" + + err = db.Save(&user).Error + require.NoError(t, err) + }, + }, + { + name: "allow-duplicate-username-oidc-then-cli", + run: func(t *testing.T, db *gorm.DB) { + user := types.User{ + Name: "user1", + } + user.ProviderIdentifier.String = "http://test.com/user1" + + err := db.Save(&user).Error + require.NoError(t, err) + + _, err = CreateUser(db, "user1") // Create CLI username + require.NoError(t, err) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db, err := newTestDB() + if err != nil { + t.Fatalf("creating database: %s", err) + } + + tt.run(t, db.DB) + }) + + } +} diff --git a/hscontrol/db/users.go b/hscontrol/db/users.go index 840d316d..0eaa9ea3 100644 --- a/hscontrol/db/users.go +++ b/hscontrol/db/users.go @@ -28,11 +28,9 @@ func CreateUser(tx *gorm.DB, name string) (*types.User, error) { if err != nil { return nil, err } - user := types.User{} - if err := tx.Where("name = ?", name).First(&user).Error; err == nil { - return nil, ErrUserExists + user := types.User{ + Name: name, } - user.Name = name if err := tx.Create(&user).Error; err != nil { return nil, fmt.Errorf("creating user: %w", err) } @@ -177,6 +175,10 @@ func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) { return nil, err } + if len(users) == 0 { + return nil, ErrUserNotFound + } + if len(users) != 1 { return nil, fmt.Errorf("expected exactly one user, found %d", len(users)) } diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index fce7e455..e8461967 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -460,7 +460,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( // This is to prevent users that have already been migrated to the new OIDC format // to be updated with the new OIDC identifier inexplicitly which might be the cause of an // account takeover. - if user != nil && user.ProviderIdentifier != "" { + if user != nil && user.ProviderIdentifier.Valid { log.Info().Str("username", claims.Username).Str("sub", claims.Sub).Msg("user found by username, but has provider identifier, creating new user.") user = &types.User{} } diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index 8b3d2e83..f36be708 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -2,6 +2,7 @@ package types import ( "cmp" + "database/sql" "strconv" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" @@ -26,7 +27,7 @@ type User struct { // Username for the user, is used if email is empty // Should not be used, please use Username(). - Name string `gorm:"uniqueIndex:idx_name_provider_identifier,index"` + Name string `gorm:"uniqueIndex:idx_name_provider_identifier;index"` // Typically the full name of the user DisplayName string @@ -38,7 +39,7 @@ type User struct { // Unique identifier of the user from OIDC, // comes from `sub` claim in the OIDC token // and is used to lookup the user. - ProviderIdentifier string `gorm:"unique,index,uniqueIndex:idx_name_provider_identifier"` + ProviderIdentifier sql.NullString `gorm:"uniqueIndex:idx_name_provider_identifier;uniqueIndex:idx_provider_identifier"` // Provider is the origin of the user account, // same as RegistrationMethod, without authkey. @@ -55,7 +56,7 @@ type User struct { // should be used throughout headscale, in information returned to the // user and the Policy engine. func (u *User) Username() string { - username := cmp.Or(u.Email, u.Name, u.ProviderIdentifier, strconv.FormatUint(uint64(u.ID), 10)) + username := cmp.Or(u.Email, u.Name, u.ProviderIdentifier.String, strconv.FormatUint(uint64(u.ID), 10)) // TODO(kradalby): Wire up all of this for the future // if !strings.Contains(username, "@") { @@ -118,7 +119,7 @@ func (u *User) Proto() *v1.User { CreatedAt: timestamppb.New(u.CreatedAt), DisplayName: u.DisplayName, Email: u.Email, - ProviderId: u.ProviderIdentifier, + ProviderId: u.ProviderIdentifier.String, Provider: u.Provider, ProfilePicUrl: u.ProfilePicURL, } @@ -145,7 +146,7 @@ func (c *OIDCClaims) Identifier() string { // FromClaim overrides a User from OIDC claims. // All fields will be updated, except for the ID. func (u *User) FromClaim(claims *OIDCClaims) { - u.ProviderIdentifier = claims.Identifier() + u.ProviderIdentifier = sql.NullString{String: claims.Identifier(), Valid: true} u.DisplayName = claims.Name if claims.EmailVerified { u.Email = claims.Email diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index 25fb358c..2fbfb555 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -54,7 +54,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { scenario := AuthOIDCScenario{ Scenario: baseScenario, } - defer scenario.ShutdownAssertNoPanics(t) + // defer scenario.ShutdownAssertNoPanics(t) // Logins to MockOIDC is served by a queue with a strict order, // if we use more than one node per user, the order of the logins