diff --git a/machine.go b/machine.go index 0b08d3f9..5b8a3ae5 100644 --- a/machine.go +++ b/machine.go @@ -26,6 +26,7 @@ const ( ) errCouldNotConvertMachineInterface = Error("failed to convert machine interface") errHostnameTooLong = Error("Hostname too long") + MachineGivenNameHashLength = 8 ) const ( @@ -813,3 +814,32 @@ func (machine *Machine) RoutesToProto() *v1.Routes { EnabledRoutes: ipPrefixToString(enabledRoutes), } } + +func (h *Headscale) GenerateGivenName(suppliedName string) (string, error) { + // If a hostname is or will be longer than 63 chars after adding the hash, + // it needs to be trimmed. + trimmedHostnameLength := labelHostnameLength - MachineGivenNameHashLength - 2 + + normalizedHostname, err := NormalizeToFQDNRules( + suppliedName, + h.cfg.OIDC.StripEmaildomain, + ) + if err != nil { + return "", err + } + + postfix, err := GenerateRandomStringDNSSafe(MachineGivenNameHashLength) + if err != nil { + return "", err + } + + // Verify that that the new unique name is shorter than the maximum allowed + // DNS segment. + if len(normalizedHostname) <= trimmedHostnameLength { + normalizedHostname = fmt.Sprintf("%s-%s", normalizedHostname, postfix) + } else { + normalizedHostname = fmt.Sprintf("%s-%s", normalizedHostname[:trimmedHostnameLength], postfix) + } + + return normalizedHostname, nil +} diff --git a/machine_test.go b/machine_test.go index 02f2d0d7..e2d1f395 100644 --- a/machine_test.go +++ b/machine_test.go @@ -4,6 +4,7 @@ import ( "fmt" "reflect" "strconv" + "strings" "testing" "time" @@ -654,3 +655,136 @@ func Test_getFilteredByACLPeers(t *testing.T) { }) } } + +func TestHeadscale_GenerateGivenName(t *testing.T) { + type args struct { + suppliedName string + } + tests := []struct { + name string + h *Headscale + args args + want string + wantErr bool + }{ + { + name: "simple machine name generation", + h: &Headscale{ + cfg: Config{ + OIDC: OIDCConfig{ + StripEmaildomain: true, + }, + }, + }, + args: args{ + suppliedName: "testmachine", + }, + want: "testmachine", + wantErr: false, + }, + { + name: "machine name with 53 chars", + h: &Headscale{ + cfg: Config{ + OIDC: OIDCConfig{ + StripEmaildomain: true, + }, + }, + }, + args: args{ + suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine", + }, + want: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine", + wantErr: false, + }, + { + name: "machine name with 60 chars", + h: &Headscale{ + cfg: Config{ + OIDC: OIDCConfig{ + StripEmaildomain: true, + }, + }, + }, + args: args{ + suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine1234567", + }, + want: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine", + wantErr: false, + }, + { + name: "machine name with 63 chars", + h: &Headscale{ + cfg: Config{ + OIDC: OIDCConfig{ + StripEmaildomain: true, + }, + }, + }, + args: args{ + suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine1234567890", + }, + want: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + wantErr: false, + }, + { + name: "machine name with 64 chars", + h: &Headscale{ + cfg: Config{ + OIDC: OIDCConfig{ + StripEmaildomain: true, + }, + }, + }, + args: args{ + suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine1234567891", + }, + want: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + wantErr: false, + }, + { + name: "machine name with 73 chars", + h: &Headscale{ + cfg: Config{ + OIDC: OIDCConfig{ + StripEmaildomain: true, + }, + }, + }, + args: args{ + suppliedName: "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine12345678901234567890", + }, + want: "", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.h.GenerateGivenName(tt.args.suppliedName) + if (err != nil) != tt.wantErr { + t.Errorf( + "Headscale.GenerateGivenName() error = %v, wantErr %v", + err, + tt.wantErr, + ) + return + } + + if tt.want != "" && strings.Contains(tt.want, got) { + t.Errorf( + "Headscale.GenerateGivenName() = %v, is not a substring of %v", + tt.want, + got, + ) + } + + if len(got) > labelHostnameLength { + t.Errorf( + "Headscale.GenerateGivenName() = %v is larger than allowed DNS segment %d", + got, + labelHostnameLength, + ) + } + }) + } +} diff --git a/utils.go b/utils.go index af267eb3..615ca46c 100644 --- a/utils.go +++ b/utils.go @@ -317,3 +317,16 @@ func GenerateRandomStringURLSafe(n int) (string, error) { return base64.RawURLEncoding.EncodeToString(b), err } + +// GenerateRandomStringDNSSafe returns a DNS-safe +// securely generated random string. +// It will return an error if the system's secure random +// number generator fails to function correctly, in which +// case the caller should not continue. +func GenerateRandomStringDNSSafe(n int) (string, error) { + str, err := GenerateRandomStringURLSafe(n) + + str = strings.ReplaceAll(str, "_", "-") + + return str[:n], err +}