diff --git a/cmd/yggdrasil/chuser_unix.go b/cmd/yggdrasil/chuser_unix.go index 69803773..fc3e5c2c 100644 --- a/cmd/yggdrasil/chuser_unix.go +++ b/cmd/yggdrasil/chuser_unix.go @@ -4,89 +4,53 @@ package main import ( - "errors" "fmt" - "math" - osuser "os/user" + "os/user" "strconv" "strings" - "syscall" + + "golang.org/x/sys/unix" ) -func chuser(user string) error { - group := "" - if i := strings.IndexByte(user, ':'); i >= 0 { - user, group = user[:i], user[i+1:] - } +func chuser(input string) error { + givenUser, givenGroup, _ := strings.Cut(input, ":") - u := (*osuser.User)(nil) - g := (*osuser.Group)(nil) + var ( + err error + usr *user.User + grp *user.Group + uid, gid int + ) - if user != "" { - if _, err := strconv.ParseUint(user, 10, 32); err == nil { - u, err = osuser.LookupId(user) - if err != nil { - return fmt.Errorf("failed to lookup user by id %q: %v", user, err) - } - } else { - u, err = osuser.Lookup(user) - if err != nil { - return fmt.Errorf("failed to lookup user by name %q: %v", user, err) - } + if usr, err = user.LookupId(givenUser); err != nil { + if usr, err = user.Lookup(givenUser); err != nil { + return err } } - if group != "" { - if _, err := strconv.ParseUint(group, 10, 32); err == nil { - g, err = osuser.LookupGroupId(group) - if err != nil { - return fmt.Errorf("failed to lookup group by id %q: %v", user, err) - } - } else { - g, err = osuser.LookupGroup(group) - if err != nil { - return fmt.Errorf("failed to lookup group by name %q: %v", user, err) - } - } + if uid, err = strconv.Atoi(usr.Uid); err != nil { + return err } - if g != nil { - gid, _ := strconv.ParseUint(g.Gid, 10, 32) - var err error - if gid < math.MaxInt { - if err := syscall.Setgroups([]int{int(gid)}); err != nil { - return fmt.Errorf("failed to setgroups %d: %v", gid, err) + if givenGroup != "" { + if grp, err = user.LookupGroupId(givenGroup); err != nil { + if grp, err = user.LookupGroup(givenGroup); err != nil { + return err } - err = syscall.Setgid(int(gid)) - } else { - err = errors.New("gid too big") } - if err != nil { - return fmt.Errorf("failed to setgid %d: %v", gid, err) - } - } else if u != nil { - gid, _ := strconv.ParseUint(u.Gid, 10, 32) - if err := syscall.Setgroups([]int{int(uint32(gid))}); err != nil { - return fmt.Errorf("failed to setgroups %d: %v", gid, err) - } - err := syscall.Setgid(int(uint32(gid))) - if err != nil { - return fmt.Errorf("failed to setgid %d: %v", gid, err) - } + gid, _ = strconv.Atoi(grp.Gid) + } else { + gid, _ = strconv.Atoi(usr.Gid) } - if u != nil { - uid, _ := strconv.ParseUint(u.Uid, 10, 32) - var err error - if uid < math.MaxInt { - err = syscall.Setuid(int(uid)) - } else { - err = errors.New("uid too big") - } - - if err != nil { - return fmt.Errorf("failed to setuid %d: %v", uid, err) - } + if err := unix.Setgroups([]int{gid}); err != nil { + return fmt.Errorf("setgroups: %d: %v", gid, err) + } + if err := unix.Setgid(gid); err != nil { + return fmt.Errorf("setgid: %d: %v", gid, err) + } + if err := unix.Setuid(uid); err != nil { + return fmt.Errorf("setuid: %d: %v", uid, err) } return nil diff --git a/cmd/yggdrasil/chuser_unix_test.go b/cmd/yggdrasil/chuser_unix_test.go new file mode 100644 index 00000000..ad2e3517 --- /dev/null +++ b/cmd/yggdrasil/chuser_unix_test.go @@ -0,0 +1,80 @@ +//go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris +// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris + +package main + +import ( + "testing" + "os/user" +) + +// Usernames must not contain a number sign. +func TestEmptyString (t *testing.T) { + if chuser("") == nil { + t.Fatal("the empty string is not a valid user") + } +} + +// Either omit delimiter and group, or omit both. +func TestEmptyGroup (t *testing.T) { + if chuser("0:") == nil { + t.Fatal("the empty group is not allowed") + } +} + +// Either user only or user and group. +func TestGroupOnly (t *testing.T) { + if chuser(":0") == nil { + t.Fatal("group only is not allowed") + } +} + +// Usenames must not contain the number sign. +func TestInvalidUsername (t *testing.T) { + const username = "#user" + if chuser(username) == nil { + t.Fatalf("'%s' is not a valid username", username) + } +} + +// User IDs must be non-negative. +func TestInvalidUserid (t *testing.T) { + if chuser("-1") == nil { + t.Fatal("User ID cannot be negative") + } +} + +// Change to the current user by ID. +func TestCurrentUserid (t *testing.T) { + usr, err := user.Current() + if err != nil { + t.Fatal(err) + } + + if usr.Uid != "0" { + t.Skip("setgroups(2): Only the superuser may set new groups.") + } + + if err = chuser(usr.Uid); err != nil { + t.Fatal(err) + } +} + +// Change to a common user by name. +func TestCommonUsername (t *testing.T) { + usr, err := user.Current() + if err != nil { + t.Fatal(err) + } + + if usr.Uid != "0" { + t.Skip("setgroups(2): Only the superuser may set new groups.") + } + + if err := chuser("nobody"); err != nil { + if _, ok := err.(user.UnknownUserError); ok { + t.Skip(err) + } + t.Fatal(err) + } +}