From 7ba0c3d5154bd1ac60f993b12f878e678c98c546 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sun, 17 Nov 2024 19:40:06 -0700 Subject: [PATCH] use userID instead of username everywhere Signed-off-by: Kristoffer Dalby --- hscontrol/db/db_test.go | 4 +- hscontrol/db/node.go | 8 +-- hscontrol/db/node_test.go | 40 +++++------ hscontrol/db/preauth_keys.go | 20 +++--- hscontrol/db/preauth_keys_test.go | 35 +++++----- hscontrol/db/routes_test.go | 16 ++--- hscontrol/db/users.go | 110 +++++++++++++----------------- hscontrol/db/users_test.go | 42 +++++++----- hscontrol/grpcv1.go | 52 +++++++++++--- hscontrol/types/users.go | 2 +- 10 files changed, 178 insertions(+), 151 deletions(-) diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go index ebc37694..87f94eb9 100644 --- a/hscontrol/db/db_test.go +++ b/hscontrol/db/db_test.go @@ -121,12 +121,12 @@ func TestMigrations(t *testing.T) { dbPath: "testdata/0-23-0-to-0-24-0-preauthkey-tags-table.sqlite", wantFunc: func(t *testing.T, h *HSDatabase) { keys, err := Read(h.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) { - kratest, err := ListPreAuthKeys(rx, "kratest") + kratest, err := ListPreAuthKeysByUser(rx, 1) // kratest if err != nil { return nil, err } - testkra, err := ListPreAuthKeys(rx, "testkra") + testkra, err := ListPreAuthKeysByUser(rx, 2) // testkra if err != nil { return nil, err } diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 1b6e7538..1c2a165c 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -91,15 +91,15 @@ func (hsdb *HSDatabase) ListEphemeralNodes() (types.Nodes, error) { }) } -func (hsdb *HSDatabase) getNode(user string, name string) (*types.Node, error) { +func (hsdb *HSDatabase) getNode(uid types.UserID, name string) (*types.Node, error) { return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { - return getNode(rx, user, name) + return getNode(rx, uid, name) }) } // getNode finds a Node by name and user and returns the Node struct. -func getNode(tx *gorm.DB, user string, name string) (*types.Node, error) { - nodes, err := ListNodesByUser(tx, user) +func getNode(tx *gorm.DB, uid types.UserID, name string) (*types.Node, error) { + nodes, err := ListNodesByUser(tx, uid) if err != nil { return nil, err } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index a81d8f0f..6c1d1099 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -30,10 +30,10 @@ func (s *Suite) TestGetNode(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.getNode("test", "testnode") + _, err = db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.NotNil) nodeKey := key.NewNode() @@ -51,7 +51,7 @@ func (s *Suite) TestGetNode(c *check.C) { trx := db.DB.Save(node) c.Assert(trx.Error, check.IsNil) - _, err = db.getNode("test", "testnode") + _, err = db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.IsNil) } @@ -59,7 +59,7 @@ func (s *Suite) TestGetNodeByID(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) c.Assert(err, check.IsNil) _, err = db.GetNodeByID(0) @@ -88,7 +88,7 @@ func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) c.Assert(err, check.IsNil) _, err = db.GetNodeByID(0) @@ -136,7 +136,7 @@ func (s *Suite) TestHardDeleteNode(c *check.C) { _, err = db.DeleteNode(&node, xsync.NewMapOf[types.NodeID, bool]()) c.Assert(err, check.IsNil) - _, err = db.getNode(user.Name, "testnode3") + _, err = db.getNode(types.UserID(user.ID), "testnode3") c.Assert(err, check.NotNil) } @@ -144,7 +144,7 @@ func (s *Suite) TestListPeers(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) c.Assert(err, check.IsNil) _, err = db.GetNodeByID(0) @@ -190,7 +190,7 @@ func (s *Suite) TestGetACLFilteredPeers(c *check.C) { for _, name := range []string{"test", "admin"} { user, err := db.CreateUser(name) c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) c.Assert(err, check.IsNil) stor = append(stor, base{user, pak}) } @@ -282,10 +282,10 @@ func (s *Suite) TestExpireNode(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.getNode("test", "testnode") + _, err = db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.NotNil) nodeKey := key.NewNode() @@ -303,7 +303,7 @@ func (s *Suite) TestExpireNode(c *check.C) { } db.DB.Save(node) - nodeFromDB, err := db.getNode("test", "testnode") + nodeFromDB, err := db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.IsNil) c.Assert(nodeFromDB, check.NotNil) @@ -313,7 +313,7 @@ func (s *Suite) TestExpireNode(c *check.C) { err = db.NodeSetExpiry(nodeFromDB.ID, now) c.Assert(err, check.IsNil) - nodeFromDB, err = db.getNode("test", "testnode") + nodeFromDB, err = db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.IsNil) c.Assert(nodeFromDB.IsExpired(), check.Equals, true) @@ -323,10 +323,10 @@ func (s *Suite) TestSetTags(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.getNode("test", "testnode") + _, err = db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.NotNil) nodeKey := key.NewNode() @@ -349,7 +349,7 @@ func (s *Suite) TestSetTags(c *check.C) { sTags := []string{"tag:test", "tag:foo"} err = db.SetTags(node.ID, sTags) c.Assert(err, check.IsNil) - node, err = db.getNode("test", "testnode") + node, err = db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.IsNil) c.Assert(node.ForcedTags, check.DeepEquals, sTags) @@ -357,7 +357,7 @@ func (s *Suite) TestSetTags(c *check.C) { eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"} err = db.SetTags(node.ID, eTags) c.Assert(err, check.IsNil) - node, err = db.getNode("test", "testnode") + node, err = db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.IsNil) c.Assert( node.ForcedTags, @@ -368,7 +368,7 @@ func (s *Suite) TestSetTags(c *check.C) { // test removing tags err = db.SetTags(node.ID, []string{}) c.Assert(err, check.IsNil) - node, err = db.getNode("test", "testnode") + node, err = db.getNode(types.UserID(user.ID), "testnode") c.Assert(err, check.IsNil) c.Assert(node.ForcedTags, check.DeepEquals, []string{}) } @@ -568,7 +568,7 @@ func TestAutoApproveRoutes(t *testing.T) { user, err := adb.CreateUser("test") require.NoError(t, err) - pak, err := adb.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := adb.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) require.NoError(t, err) nodeKey := key.NewNode() @@ -700,10 +700,10 @@ func TestListEphemeralNodes(t *testing.T) { user, err := db.CreateUser("test") require.NoError(t, err) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) require.NoError(t, err) - pakEph, err := db.CreatePreAuthKey(user.Name, false, true, nil, nil) + pakEph, err := db.CreatePreAuthKey(types.UserID(user.ID), false, true, nil, nil) require.NoError(t, err) node := types.Node{ diff --git a/hscontrol/db/preauth_keys.go b/hscontrol/db/preauth_keys.go index 59bbdf98..aeee5b52 100644 --- a/hscontrol/db/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -23,29 +23,27 @@ var ( ) func (hsdb *HSDatabase) CreatePreAuthKey( - // TODO(kradalby): Should be ID, not name - userName string, + uid types.UserID, reusable bool, ephemeral bool, expiration *time.Time, aclTags []string, ) (*types.PreAuthKey, error) { return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKey, error) { - return CreatePreAuthKey(tx, userName, reusable, ephemeral, expiration, aclTags) + return CreatePreAuthKey(tx, uid, reusable, ephemeral, expiration, aclTags) }) } // CreatePreAuthKey creates a new PreAuthKey in a user, and returns it. func CreatePreAuthKey( tx *gorm.DB, - // TODO(kradalby): Should be ID, not name - userName string, + uid types.UserID, reusable bool, ephemeral bool, expiration *time.Time, aclTags []string, ) (*types.PreAuthKey, error) { - user, err := GetUserByUsername(tx, userName) + user, err := GetUserByID(tx, uid) if err != nil { return nil, err } @@ -89,15 +87,15 @@ func CreatePreAuthKey( return &key, nil } -func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) { +func (hsdb *HSDatabase) ListPreAuthKeys(uid types.UserID) ([]types.PreAuthKey, error) { return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) { - return ListPreAuthKeys(rx, userName) + return ListPreAuthKeysByUser(rx, uid) }) } -// ListPreAuthKeys returns the list of PreAuthKeys for a user. -func ListPreAuthKeys(tx *gorm.DB, userName string) ([]types.PreAuthKey, error) { - user, err := GetUserByUsername(tx, userName) +// ListPreAuthKeysByUser returns the list of PreAuthKeys for a user. +func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, error) { + user, err := GetUserByID(tx, uid) if err != nil { return nil, err } diff --git a/hscontrol/db/preauth_keys_test.go b/hscontrol/db/preauth_keys_test.go index ec3f6441..3c56a35e 100644 --- a/hscontrol/db/preauth_keys_test.go +++ b/hscontrol/db/preauth_keys_test.go @@ -11,14 +11,14 @@ import ( ) func (*Suite) TestCreatePreAuthKey(c *check.C) { - _, err := db.CreatePreAuthKey("bogus", true, false, nil, nil) - + // ID does not exist + _, err := db.CreatePreAuthKey(12345, true, false, nil, nil) c.Assert(err, check.NotNil) user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - key, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) + key, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) c.Assert(err, check.IsNil) // Did we get a valid key? @@ -26,17 +26,18 @@ func (*Suite) TestCreatePreAuthKey(c *check.C) { c.Assert(len(key.Key), check.Equals, 48) // Make sure the User association is populated - c.Assert(key.User.Name, check.Equals, user.Name) + c.Assert(key.User.ID, check.Equals, user.ID) - _, err = db.ListPreAuthKeys("bogus") + // ID does not exist + _, err = db.ListPreAuthKeys(1000000) c.Assert(err, check.NotNil) - keys, err := db.ListPreAuthKeys(user.Name) + keys, err := db.ListPreAuthKeys(types.UserID(user.ID)) c.Assert(err, check.IsNil) c.Assert(len(keys), check.Equals, 1) // Make sure the User association is populated - c.Assert((keys)[0].User.Name, check.Equals, user.Name) + c.Assert((keys)[0].User.ID, check.Equals, user.ID) } func (*Suite) TestExpiredPreAuthKey(c *check.C) { @@ -44,7 +45,7 @@ func (*Suite) TestExpiredPreAuthKey(c *check.C) { c.Assert(err, check.IsNil) now := time.Now().Add(-5 * time.Second) - pak, err := db.CreatePreAuthKey(user.Name, true, false, &now, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, &now, nil) c.Assert(err, check.IsNil) key, err := db.ValidatePreAuthKey(pak.Key) @@ -62,7 +63,7 @@ func (*Suite) TestValidateKeyOk(c *check.C) { user, err := db.CreateUser("test3") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) c.Assert(err, check.IsNil) key, err := db.ValidatePreAuthKey(pak.Key) @@ -74,7 +75,7 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) { user, err := db.CreateUser("test4") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) c.Assert(err, check.IsNil) node := types.Node{ @@ -96,7 +97,7 @@ func (*Suite) TestReusableBeingUsedKey(c *check.C) { user, err := db.CreateUser("test5") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) c.Assert(err, check.IsNil) node := types.Node{ @@ -118,7 +119,7 @@ func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) { user, err := db.CreateUser("test6") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) c.Assert(err, check.IsNil) key, err := db.ValidatePreAuthKey(pak.Key) @@ -130,7 +131,7 @@ func (*Suite) TestExpirePreauthKey(c *check.C) { user, err := db.CreateUser("test3") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) c.Assert(err, check.IsNil) c.Assert(pak.Expiration, check.IsNil) @@ -147,7 +148,7 @@ func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) { user, err := db.CreateUser("test6") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) c.Assert(err, check.IsNil) pak.Used = true db.DB.Save(&pak) @@ -160,15 +161,15 @@ func (*Suite) TestPreAuthKeyACLTags(c *check.C) { user, err := db.CreateUser("test8") c.Assert(err, check.IsNil) - _, err = db.CreatePreAuthKey(user.Name, false, false, nil, []string{"badtag"}) + _, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, []string{"badtag"}) c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected tags := []string{"tag:test1", "tag:test2"} tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"} - _, err = db.CreatePreAuthKey(user.Name, false, false, nil, tagsWithDuplicate) + _, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, tagsWithDuplicate) c.Assert(err, check.IsNil) - listedPaks, err := db.ListPreAuthKeys("test8") + listedPaks, err := db.ListPreAuthKeys(types.UserID(user.ID)) c.Assert(err, check.IsNil) gotTags := listedPaks[0].Proto().GetAclTags() sort.Sort(sort.StringSlice(gotTags)) diff --git a/hscontrol/db/routes_test.go b/hscontrol/db/routes_test.go index 5071077c..7b11e136 100644 --- a/hscontrol/db/routes_test.go +++ b/hscontrol/db/routes_test.go @@ -35,10 +35,10 @@ func (s *Suite) TestGetRoutes(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.getNode("test", "test_get_route_node") + _, err = db.getNode(types.UserID(user.ID), "test_get_route_node") c.Assert(err, check.NotNil) route, err := netip.ParsePrefix("10.0.0.0/24") @@ -79,10 +79,10 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.getNode("test", "test_enable_route_node") + _, err = db.getNode(types.UserID(user.ID), "test_enable_route_node") c.Assert(err, check.NotNil) route, err := netip.ParsePrefix( @@ -153,10 +153,10 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.getNode("test", "test_enable_route_node") + _, err = db.getNode(types.UserID(user.ID), "test_enable_route_node") c.Assert(err, check.NotNil) route, err := netip.ParsePrefix( @@ -234,10 +234,10 @@ func (s *Suite) TestDeleteRoutes(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) c.Assert(err, check.IsNil) - _, err = db.getNode("test", "test_enable_route_node") + _, err = db.getNode(types.UserID(user.ID), "test_enable_route_node") c.Assert(err, check.NotNil) prefix, err := netip.ParsePrefix( diff --git a/hscontrol/db/users.go b/hscontrol/db/users.go index 135276c7..840d316d 100644 --- a/hscontrol/db/users.go +++ b/hscontrol/db/users.go @@ -40,21 +40,21 @@ func CreateUser(tx *gorm.DB, name string) (*types.User, error) { return &user, nil } -func (hsdb *HSDatabase) DestroyUser(name string) error { +func (hsdb *HSDatabase) DestroyUser(uid types.UserID) error { return hsdb.Write(func(tx *gorm.DB) error { - return DestroyUser(tx, name) + return DestroyUser(tx, uid) }) } // DestroyUser destroys a User. Returns error if the User does // not exist or if there are nodes associated with it. -func DestroyUser(tx *gorm.DB, name string) error { - user, err := GetUserByUsername(tx, name) +func DestroyUser(tx *gorm.DB, uid types.UserID) error { + user, err := GetUserByID(tx, uid) if err != nil { - return ErrUserNotFound + return err } - nodes, err := ListNodesByUser(tx, name) + nodes, err := ListNodesByUser(tx, uid) if err != nil { return err } @@ -62,7 +62,7 @@ func DestroyUser(tx *gorm.DB, name string) error { return ErrUserStillHasNodes } - keys, err := ListPreAuthKeys(tx, name) + keys, err := ListPreAuthKeysByUser(tx, uid) if err != nil { return err } @@ -80,17 +80,17 @@ func DestroyUser(tx *gorm.DB, name string) error { return nil } -func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { +func (hsdb *HSDatabase) RenameUser(uid types.UserID, newName string) error { return hsdb.Write(func(tx *gorm.DB) error { - return RenameUser(tx, oldName, newName) + return RenameUser(tx, uid, newName) }) } // RenameUser renames a User. Returns error if the User does // not exist or if another User exists with the new name. -func RenameUser(tx *gorm.DB, oldName, newName string) error { +func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error { var err error - oldUser, err := GetUserByUsername(tx, oldName) + oldUser, err := GetUserByID(tx, uid) if err != nil { return err } @@ -98,50 +98,25 @@ func RenameUser(tx *gorm.DB, oldName, newName string) error { if err != nil { return err } - _, err = GetUserByUsername(tx, newName) - if err == nil { - return ErrUserExists - } - if !errors.Is(err, ErrUserNotFound) { - return err - } oldUser.Name = newName - if result := tx.Save(&oldUser); result.Error != nil { - return result.Error + if err := tx.Save(&oldUser).Error; err != nil { + return err } return nil } -func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) { +func (hsdb *HSDatabase) GetUserByID(uid types.UserID) (*types.User, error) { return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) { - return GetUserByUsername(rx, name) + return GetUserByID(rx, uid) }) } -func GetUserByUsername(tx *gorm.DB, name string) (*types.User, error) { +func GetUserByID(tx *gorm.DB, uid types.UserID) (*types.User, error) { user := types.User{} - if result := tx.First(&user, "name = ?", name); errors.Is( - result.Error, - gorm.ErrRecordNotFound, - ) { - return nil, ErrUserNotFound - } - - return &user, nil -} - -func (hsdb *HSDatabase) GetUserByID(id types.UserID) (*types.User, error) { - return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) { - return GetUserByID(rx, id) - }) -} - -func GetUserByID(tx *gorm.DB, id types.UserID) (*types.User, error) { - user := types.User{} - if result := tx.First(&user, "id = ?", id); errors.Is( + if result := tx.First(&user, "id = ?", uid); errors.Is( result.Error, gorm.ErrRecordNotFound, ) { @@ -169,54 +144,65 @@ func GetUserByOIDCIdentifier(tx *gorm.DB, id string) (*types.User, error) { return &user, nil } -func (hsdb *HSDatabase) ListUsers() ([]types.User, error) { +func (hsdb *HSDatabase) ListUsers(where ...*types.User) ([]types.User, error) { return Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) { - return ListUsers(rx) + return ListUsers(rx, where...) }) } // ListUsers gets all the existing users. -func ListUsers(tx *gorm.DB) ([]types.User, error) { +func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) { + if len(where) > 1 { + return nil, fmt.Errorf("expect 0 or 1 where User structs, got %d", len(where)) + } + + var user *types.User + if len(where) == 1 { + user = where[0] + } + users := []types.User{} - if err := tx.Find(&users).Error; err != nil { + if err := tx.Where(user).Find(&users).Error; err != nil { return nil, err } return users, nil } -// ListNodesByUser gets all the nodes in a given user. -func ListNodesByUser(tx *gorm.DB, name string) (types.Nodes, error) { - err := util.CheckForFQDNRules(name) - if err != nil { - return nil, err - } - user, err := GetUserByUsername(tx, name) +// GetUserByName returns a user if the provided username is +// unique, and otherwise an error. +func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) { + users, err := hsdb.ListUsers(&types.User{Name: name}) if err != nil { return nil, err } + if len(users) != 1 { + return nil, fmt.Errorf("expected exactly one user, found %d", len(users)) + } + + return &users[0], nil +} + +// ListNodesByUser gets all the nodes in a given user. +func ListNodesByUser(tx *gorm.DB, uid types.UserID) (types.Nodes, error) { nodes := types.Nodes{} - if err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: user.ID}).Find(&nodes).Error; err != nil { + if err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: uint(uid)}).Find(&nodes).Error; err != nil { return nil, err } return nodes, nil } -func (hsdb *HSDatabase) AssignNodeToUser(node *types.Node, username string) error { +func (hsdb *HSDatabase) AssignNodeToUser(node *types.Node, uid types.UserID) error { return hsdb.Write(func(tx *gorm.DB) error { - return AssignNodeToUser(tx, node, username) + return AssignNodeToUser(tx, node, uid) }) } // AssignNodeToUser assigns a Node to a user. -func AssignNodeToUser(tx *gorm.DB, node *types.Node, username string) error { - err := util.CheckForFQDNRules(username) - if err != nil { - return err - } - user, err := GetUserByUsername(tx, username) +func AssignNodeToUser(tx *gorm.DB, node *types.Node, uid types.UserID) error { + user, err := GetUserByID(tx, uid) if err != nil { return err } diff --git a/hscontrol/db/users_test.go b/hscontrol/db/users_test.go index 54399664..6684989e 100644 --- a/hscontrol/db/users_test.go +++ b/hscontrol/db/users_test.go @@ -1,6 +1,8 @@ package db import ( + "strings" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "gopkg.in/check.v1" @@ -17,24 +19,24 @@ func (s *Suite) TestCreateAndDestroyUser(c *check.C) { c.Assert(err, check.IsNil) c.Assert(len(users), check.Equals, 1) - err = db.DestroyUser("test") + err = db.DestroyUser(types.UserID(user.ID)) c.Assert(err, check.IsNil) - _, err = db.GetUserByName("test") + _, err = db.GetUserByID(types.UserID(user.ID)) c.Assert(err, check.NotNil) } func (s *Suite) TestDestroyUserErrors(c *check.C) { - err := db.DestroyUser("test") + err := db.DestroyUser(9998) c.Assert(err, check.Equals, ErrUserNotFound) user, err := db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) c.Assert(err, check.IsNil) - err = db.DestroyUser("test") + err = db.DestroyUser(types.UserID(user.ID)) c.Assert(err, check.IsNil) result := db.DB.Preload("User").First(&pak, "key = ?", pak.Key) @@ -44,7 +46,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) { user, err = db.CreateUser("test") c.Assert(err, check.IsNil) - pak, err = db.CreatePreAuthKey(user.Name, false, false, nil, nil) + pak, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) c.Assert(err, check.IsNil) node := types.Node{ @@ -57,7 +59,7 @@ func (s *Suite) TestDestroyUserErrors(c *check.C) { trx := db.DB.Save(&node) c.Assert(trx.Error, check.IsNil) - err = db.DestroyUser("test") + err = db.DestroyUser(types.UserID(user.ID)) c.Assert(err, check.Equals, ErrUserStillHasNodes) } @@ -70,24 +72,28 @@ func (s *Suite) TestRenameUser(c *check.C) { c.Assert(err, check.IsNil) c.Assert(len(users), check.Equals, 1) - err = db.RenameUser("test", "test-renamed") + err = db.RenameUser(types.UserID(userTest.ID), "test-renamed") c.Assert(err, check.IsNil) - _, err = db.GetUserByName("test") - c.Assert(err, check.Equals, ErrUserNotFound) + users, err = db.ListUsers(&types.User{Name: "test"}) + c.Assert(err, check.Equals, nil) + c.Assert(len(users), check.Equals, 0) - _, err = db.GetUserByName("test-renamed") + users, err = db.ListUsers(&types.User{Name: "test-renamed"}) c.Assert(err, check.IsNil) + c.Assert(len(users), check.Equals, 1) - err = db.RenameUser("test-does-not-exit", "test") + err = db.RenameUser(99988, "test") c.Assert(err, check.Equals, ErrUserNotFound) userTest2, err := db.CreateUser("test2") c.Assert(err, check.IsNil) c.Assert(userTest2.Name, check.Equals, "test2") - err = db.RenameUser("test2", "test-renamed") - c.Assert(err, check.Equals, ErrUserExists) + err = db.RenameUser(types.UserID(userTest2.ID), "test-renamed") + if !strings.Contains(err.Error(), "UNIQUE constraint failed") { + c.Fatalf("expected failure with unique constraint, got: %s", err.Error()) + } } func (s *Suite) TestSetMachineUser(c *check.C) { @@ -97,7 +103,7 @@ func (s *Suite) TestSetMachineUser(c *check.C) { newUser, err := db.CreateUser("new") c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil) + pak, err := db.CreatePreAuthKey(types.UserID(oldUser.ID), false, false, nil, nil) c.Assert(err, check.IsNil) node := types.Node{ @@ -111,15 +117,15 @@ func (s *Suite) TestSetMachineUser(c *check.C) { c.Assert(trx.Error, check.IsNil) c.Assert(node.UserID, check.Equals, oldUser.ID) - err = db.AssignNodeToUser(&node, newUser.Name) + err = db.AssignNodeToUser(&node, types.UserID(newUser.ID)) c.Assert(err, check.IsNil) c.Assert(node.UserID, check.Equals, newUser.ID) c.Assert(node.User.Name, check.Equals, newUser.Name) - err = db.AssignNodeToUser(&node, "non-existing-user") + err = db.AssignNodeToUser(&node, 9584849) c.Assert(err, check.Equals, ErrUserNotFound) - err = db.AssignNodeToUser(&node, newUser.Name) + err = db.AssignNodeToUser(&node, types.UserID(newUser.ID)) c.Assert(err, check.IsNil) c.Assert(node.UserID, check.Equals, newUser.ID) c.Assert(node.User.Name, check.Equals, newUser.Name) diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 68793716..dd7ab03d 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -65,24 +65,34 @@ func (api headscaleV1APIServer) RenameUser( ctx context.Context, request *v1.RenameUserRequest, ) (*v1.RenameUserResponse, error) { - err := api.h.db.RenameUser(request.GetOldName(), request.GetNewName()) + oldUser, err := api.h.db.GetUserByName(request.GetOldName()) if err != nil { return nil, err } - user, err := api.h.db.GetUserByName(request.GetNewName()) + err = api.h.db.RenameUser(types.UserID(oldUser.ID), request.GetNewName()) if err != nil { return nil, err } - return &v1.RenameUserResponse{User: user.Proto()}, nil + newUser, err := api.h.db.GetUserByName(request.GetNewName()) + if err != nil { + return nil, err + } + + return &v1.RenameUserResponse{User: newUser.Proto()}, nil } func (api headscaleV1APIServer) DeleteUser( ctx context.Context, request *v1.DeleteUserRequest, ) (*v1.DeleteUserResponse, error) { - err := api.h.db.DestroyUser(request.GetName()) + user, err := api.h.db.GetUserByName(request.GetName()) + if err != nil { + return nil, err + } + + err = api.h.db.DestroyUser(types.UserID(user.ID)) if err != nil { return nil, err } @@ -131,8 +141,13 @@ func (api headscaleV1APIServer) CreatePreAuthKey( } } + user, err := api.h.db.GetUserByName(request.GetUser()) + if err != nil { + return nil, err + } + preAuthKey, err := api.h.db.CreatePreAuthKey( - request.GetUser(), + types.UserID(user.ID), request.GetReusable(), request.GetEphemeral(), &expiration, @@ -168,7 +183,12 @@ func (api headscaleV1APIServer) ListPreAuthKeys( ctx context.Context, request *v1.ListPreAuthKeysRequest, ) (*v1.ListPreAuthKeysResponse, error) { - preAuthKeys, err := api.h.db.ListPreAuthKeys(request.GetUser()) + user, err := api.h.db.GetUserByName(request.GetUser()) + if err != nil { + return nil, err + } + + preAuthKeys, err := api.h.db.ListPreAuthKeys(types.UserID(user.ID)) if err != nil { return nil, err } @@ -406,10 +426,20 @@ func (api headscaleV1APIServer) ListNodes( ctx context.Context, request *v1.ListNodesRequest, ) (*v1.ListNodesResponse, error) { + // TODO(kradalby): it looks like this can be simplified a lot, + // the filtering of nodes by user, vs nodes as a whole can + // probably be done once. + // TODO(kradalby): This should be done in one tx. + isLikelyConnected := api.h.nodeNotifier.LikelyConnectedMap() if request.GetUser() != "" { + user, err := api.h.db.GetUserByName(request.GetUser()) + if err != nil { + return nil, err + } + nodes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Nodes, error) { - return db.ListNodesByUser(rx, request.GetUser()) + return db.ListNodesByUser(rx, types.UserID(user.ID)) }) if err != nil { return nil, err @@ -465,12 +495,18 @@ func (api headscaleV1APIServer) MoveNode( ctx context.Context, request *v1.MoveNodeRequest, ) (*v1.MoveNodeResponse, error) { + // TODO(kradalby): This should be done in one tx. node, err := api.h.db.GetNodeByID(types.NodeID(request.GetNodeId())) if err != nil { return nil, err } - err = api.h.db.AssignNodeToUser(node, request.GetUser()) + user, err := api.h.db.GetUserByName(request.GetUser()) + if err != nil { + return nil, err + } + + err = api.h.db.AssignNodeToUser(node, types.UserID(user.ID)) if err != nil { return nil, err } diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index 5b27e671..9e0bfeb0 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -26,7 +26,7 @@ type User struct { // Username for the user, is used if email is empty // Should not be used, please use Username(). - Name string `gorm:"index,uniqueIndex:idx_name_provider_identifier"` + Name string `gorm:"uniqueIndex:idx_name_provider_identifier,index"` // Typically the full name of the user DisplayName string