diff --git a/clientupdate/clientupdate.go b/clientupdate/clientupdate.go index 47627ddf8..50bcf12dd 100644 --- a/clientupdate/clientupdate.go +++ b/clientupdate/clientupdate.go @@ -37,11 +37,18 @@ ) const ( - CurrentTrack = "" StableTrack = "stable" UnstableTrack = "unstable" ) +var CurrentTrack = func() string { + if version.IsUnstableBuild() { + return UnstableTrack + } else { + return StableTrack + } +}() + func versionToTrack(v string) (string, error) { _, rest, ok := strings.Cut(v, ".") if !ok { @@ -106,7 +113,7 @@ func (args Arguments) validate() error { return fmt.Errorf("only one of Version(%q) or Track(%q) can be set", args.Version, args.Track) } switch args.Track { - case StableTrack, UnstableTrack, CurrentTrack: + case StableTrack, UnstableTrack, "": // All valid values. default: return fmt.Errorf("unsupported track %q", args.Track) @@ -119,11 +126,17 @@ type Updater struct { // Update is a platform-specific method that updates the installation. May be // nil (not all platforms support updates from within Tailscale). Update func() error + + // currentVersion is the short form of the current client version as + // returned by version.Short(), typically "x.y.z". Used for tests to + // override the actual current version. + currentVersion string } func NewUpdater(args Arguments) (*Updater, error) { up := Updater{ - Arguments: args, + Arguments: args, + currentVersion: version.Short(), } if up.Stdout == nil { up.Stdout = os.Stdout @@ -139,18 +152,15 @@ func NewUpdater(args Arguments) (*Updater, error) { if args.ForAutoUpdate && !canAutoUpdate { return nil, errors.ErrUnsupported } - if up.Track == CurrentTrack { - switch { - case up.Version != "": + if up.Track == "" { + if up.Version != "" { var err error up.Track, err = versionToTrack(args.Version) if err != nil { return nil, err } - case version.IsUnstableBuild(): - up.Track = UnstableTrack - default: - up.Track = StableTrack + } else { + up.Track = CurrentTrack } } if up.Arguments.PkgsAddr == "" { @@ -259,13 +269,16 @@ func Update(args Arguments) error { } func (up *Updater) confirm(ver string) bool { - switch cmpver.Compare(version.Short(), ver) { - case 0: - up.Logf("already running %v version %v; no update needed", up.Track, ver) - return false - case 1: - up.Logf("installed %v version %v is newer than the latest available version %v; no update needed", up.Track, version.Short(), ver) - return false + // Only check version when we're not switching tracks. + if up.Track == "" || up.Track == CurrentTrack { + switch c := cmpver.Compare(up.currentVersion, ver); { + case c == 0: + up.Logf("already running %v version %v; no update needed", up.Track, ver) + return false + case c > 0: + up.Logf("installed %v version %v is newer than the latest available version %v; no update needed", up.Track, up.currentVersion, ver) + return false + } } if up.Confirm != nil { return up.Confirm(ver) @@ -681,7 +694,7 @@ func parseAlpinePackageVersion(out []byte) (string, error) { return "", fmt.Errorf("malformed info line: %q", line) } ver := parts[1] - if cmpver.Compare(ver, maxVer) == 1 { + if cmpver.Compare(ver, maxVer) > 0 { maxVer = ver } } @@ -880,7 +893,7 @@ func (up *Updater) installMSI(msi string) error { break } up.Logf("Install attempt failed: %v", err) - uninstallVersion := version.Short() + uninstallVersion := up.currentVersion if v := os.Getenv("TS_DEBUG_UNINSTALL_VERSION"); v != "" { uninstallVersion = v } @@ -1331,12 +1344,8 @@ func requestedTailscaleVersion(ver, track string) (string, error) { // LatestTailscaleVersion returns the latest released version for the given // track from pkgs.tailscale.com. func LatestTailscaleVersion(track string) (string, error) { - if track == CurrentTrack { - if version.IsUnstableBuild() { - track = UnstableTrack - } else { - track = StableTrack - } + if track == "" { + track = CurrentTrack } latest, err := latestPackages(track) diff --git a/clientupdate/clientupdate_test.go b/clientupdate/clientupdate_test.go index dc8f66fd6..b265d5641 100644 --- a/clientupdate/clientupdate_test.go +++ b/clientupdate/clientupdate_test.go @@ -846,3 +846,107 @@ func TestParseUnraidPluginVersion(t *testing.T) { }) } } + +func TestConfirm(t *testing.T) { + curTrack := CurrentTrack + defer func() { CurrentTrack = curTrack }() + + tests := []struct { + desc string + fromTrack string + toTrack string + fromVer string + toVer string + confirm func(string) bool + want bool + }{ + { + desc: "on latest stable", + fromTrack: StableTrack, + toTrack: StableTrack, + fromVer: "1.66.0", + toVer: "1.66.0", + want: false, + }, + { + desc: "stable upgrade", + fromTrack: StableTrack, + toTrack: StableTrack, + fromVer: "1.66.0", + toVer: "1.68.0", + want: true, + }, + { + desc: "unstable upgrade", + fromTrack: UnstableTrack, + toTrack: UnstableTrack, + fromVer: "1.67.1", + toVer: "1.67.2", + want: true, + }, + { + desc: "from stable to unstable", + fromTrack: StableTrack, + toTrack: UnstableTrack, + fromVer: "1.66.0", + toVer: "1.67.1", + want: true, + }, + { + desc: "from unstable to stable", + fromTrack: UnstableTrack, + toTrack: StableTrack, + fromVer: "1.67.1", + toVer: "1.66.0", + want: true, + }, + { + desc: "confirm callback rejects", + fromTrack: StableTrack, + toTrack: StableTrack, + fromVer: "1.66.0", + toVer: "1.66.1", + confirm: func(string) bool { + return false + }, + want: false, + }, + { + desc: "confirm callback allows", + fromTrack: StableTrack, + toTrack: StableTrack, + fromVer: "1.66.0", + toVer: "1.66.1", + confirm: func(string) bool { + return true + }, + want: true, + }, + { + desc: "downgrade", + fromTrack: StableTrack, + toTrack: StableTrack, + fromVer: "1.66.1", + toVer: "1.66.0", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + CurrentTrack = tt.fromTrack + up := Updater{ + currentVersion: tt.fromVer, + Arguments: Arguments{ + Track: tt.toTrack, + Confirm: tt.confirm, + Logf: t.Logf, + }, + } + + if got := up.confirm(tt.toVer); got != tt.want { + t.Errorf("got %v, want %v", got, tt.want) + } + }) + } +}