diff --git a/hostinfo/hostinfo.go b/hostinfo/hostinfo.go index afb465ece..3e8f2f994 100644 --- a/hostinfo/hostinfo.go +++ b/hostinfo/hostinfo.go @@ -43,7 +43,7 @@ func RegisterHostinfoNewHook(f func(*tailcfg.Hostinfo)) { // New returns a partially populated Hostinfo for the current host. func New() *tailcfg.Hostinfo { - hostname, _ := os.Hostname() + hostname, _ := Hostname() hostname = dnsname.FirstLabel(hostname) hi := &tailcfg.Hostinfo{ IPNVersion: version.Long(), @@ -509,3 +509,21 @@ func IsInVM86() bool { return New().DeviceModel == copyV86DeviceModel }) } + +type hostnameQuery func() (string, error) + +var hostnameFn atomic.Value // of func() (string, error) + +// SetHostNameFn sets a custom function for querying the system hostname. +func SetHostnameFn(fn hostnameQuery) { + hostnameFn.Store(fn) +} + +// Hostname returns the system hostname using the function +// set by SetHostNameFn. We will fallback to os.Hostname. +func Hostname() (string, error) { + if fn, ok := hostnameFn.Load().(hostnameQuery); ok && fn != nil { + return fn() + } + return os.Hostname() +} diff --git a/hostinfo/hostinfo_test.go b/hostinfo/hostinfo_test.go index 9fe32e044..15b6971b6 100644 --- a/hostinfo/hostinfo_test.go +++ b/hostinfo/hostinfo_test.go @@ -5,6 +5,7 @@ package hostinfo import ( "encoding/json" + "os" "strings" "testing" ) @@ -49,3 +50,31 @@ func TestEtcAptSourceFileIsDisabled(t *testing.T) { }) } } + +func TestCustomHostnameFunc(t *testing.T) { + want := "custom-hostname" + SetHostnameFn(func() (string, error) { + return want, nil + }) + + got, err := Hostname() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if got != want { + t.Errorf("got %q, want %q", got, want) + } + + SetHostnameFn(os.Hostname) + got, err = Hostname() + want, _ = os.Hostname() + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != want { + t.Errorf("got %q, want %q", got, want) + } + +} diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 9ec4b4767..e21403fbe 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -1245,7 +1245,7 @@ func (b *LocalBackend) UpdateStatus(sb *ipnstate.StatusBuilder) { } } else { - ss.HostName, _ = os.Hostname() + ss.HostName, _ = hostinfo.Hostname() } for _, pln := range b.peerAPIListeners { ss.PeerAPIURL = append(ss.PeerAPIURL, pln.urlStr)