// Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause package controlclient import ( "context" "crypto/tls" "errors" "flag" "fmt" "io" "net" "net/http" "net/netip" "net/url" "reflect" "slices" "sync/atomic" "testing" "time" "tailscale.com/control/controlknobs" "tailscale.com/health" "tailscale.com/net/bakedroots" "tailscale.com/net/connectproxy" "tailscale.com/net/netmon" "tailscale.com/net/tsdial" "tailscale.com/tailcfg" "tailscale.com/tstest" "tailscale.com/tstest/integration/testcontrol" "tailscale.com/tstest/tlstest" "tailscale.com/tstime" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/netmap" "tailscale.com/types/persist" ) func fieldsOf(t reflect.Type) (fields []string) { for i := range t.NumField() { if name := t.Field(i).Name; name != "_" { fields = append(fields, name) } } return } func TestStatusEqual(t *testing.T) { // Verify that the Equal method stays in sync with reality equalHandles := []string{"Err", "URL", "NetMap", "Persist", "state"} if have := fieldsOf(reflect.TypeFor[Status]()); !reflect.DeepEqual(have, equalHandles) { t.Errorf("Status.Equal check might be out of sync\nfields: %q\nhandled: %q\n", have, equalHandles) } tests := []struct { a, b *Status want bool }{ { &Status{}, nil, false, }, { nil, &Status{}, false, }, { nil, nil, true, }, { &Status{}, &Status{}, true, }, { &Status{}, &Status{state: StateAuthenticated}, false, }, } for i, tt := range tests { got := tt.a.Equal(tt.b) if got != tt.want { t.Errorf("%d. Equal = %v; want %v", i, got, tt.want) } } } // tests [canSkipStatus]. func TestCanSkipStatus(t *testing.T) { st := new(Status) nm1 := &netmap.NetworkMap{} nm2 := &netmap.NetworkMap{} tests := []struct { name string s1, s2 *Status want bool }{ { name: "nil-s2", s1: st, s2: nil, want: false, }, { name: "equal", s1: st, s2: st, want: false, }, { name: "s1-error", s1: &Status{Err: io.EOF, NetMap: nm1}, s2: &Status{NetMap: nm2}, want: false, }, { name: "s1-url", s1: &Status{URL: "foo", NetMap: nm1}, s2: &Status{NetMap: nm2}, want: false, }, { name: "s1-persist-diff", s1: &Status{Persist: new(persist.Persist).View(), NetMap: nm1}, s2: &Status{NetMap: nm2}, want: false, }, { name: "s1-state-diff", s1: &Status{state: 123, NetMap: nm1}, s2: &Status{NetMap: nm2}, want: false, }, { name: "s1-no-netmap1", s1: &Status{NetMap: nil}, s2: &Status{NetMap: nm2}, want: false, }, { name: "s1-no-netmap2", s1: &Status{NetMap: nm1}, s2: &Status{NetMap: nil}, want: false, }, { name: "skip", s1: &Status{NetMap: nm1}, s2: &Status{NetMap: nm2}, want: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := canSkipStatus(tt.s1, tt.s2); got != tt.want { t.Errorf("canSkipStatus = %v, want %v", got, tt.want) } }) } want := []string{"Err", "URL", "NetMap", "Persist", "state"} if f := fieldsOf(reflect.TypeFor[Status]()); !slices.Equal(f, want) { t.Errorf("Status fields = %q; this code was only written to handle fields %q", f, want) } } func TestRetryableErrors(t *testing.T) { errorTests := []struct { err error want bool }{ {errNoNoiseClient, true}, {errNoNodeKey, true}, {fmt.Errorf("%w: %w", errNoNoiseClient, errors.New("no noise")), true}, {fmt.Errorf("%w: %w", errHTTPPostFailure, errors.New("bad post")), true}, {fmt.Errorf("%w: %w", errNoNodeKey, errors.New("not node key")), true}, {errBadHTTPResponse(429, "too may requests"), true}, {errBadHTTPResponse(500, "internal server eror"), true}, {errBadHTTPResponse(502, "bad gateway"), true}, {errBadHTTPResponse(503, "service unavailable"), true}, {errBadHTTPResponse(504, "gateway timeout"), true}, {errBadHTTPResponse(1234, "random error"), false}, } for _, tt := range errorTests { t.Run(tt.err.Error(), func(t *testing.T) { if isRetryableErrorForTest(tt.err) != tt.want { t.Fatalf("retriable: got %v, want %v", tt.err, tt.want) } }) } } type retryableForTest interface { Retryable() bool } func isRetryableErrorForTest(err error) bool { var ae retryableForTest if errors.As(err, &ae) { return ae.Retryable() } return false } var liveNetworkTest = flag.Bool("live-network-test", false, "run live network tests") func TestDirectProxyManual(t *testing.T) { if !*liveNetworkTest { t.Skip("skipping without --live-network-test") } dialer := &tsdial.Dialer{} dialer.SetNetMon(netmon.NewStatic()) opts := Options{ Persist: persist.Persist{}, GetMachinePrivateKey: func() (key.MachinePrivate, error) { return key.NewMachine(), nil }, ServerURL: "https://controlplane.tailscale.com", Clock: tstime.StdClock{}, Hostinfo: &tailcfg.Hostinfo{ BackendLogID: "test-backend-log-id", }, DiscoPublicKey: key.NewDisco().Public(), Logf: t.Logf, HealthTracker: &health.Tracker{}, PopBrowserURL: func(url string) { t.Logf("PopBrowserURL: %q", url) }, Dialer: dialer, ControlKnobs: &controlknobs.Knobs{}, } d, err := NewDirect(opts) if err != nil { t.Fatalf("NewDirect: %v", err) } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() url, err := d.TryLogin(ctx, LoginEphemeral) if err != nil { t.Fatalf("TryLogin: %v", err) } t.Logf("URL: %q", url) } func TestHTTPSNoProxy(t *testing.T) { testHTTPS(t, false) } // TestTLSWithProxy verifies we can connect to the control plane via // an HTTPS proxy. func TestHTTPSWithProxy(t *testing.T) { testHTTPS(t, true) } func testHTTPS(t *testing.T, withProxy bool) { bakedroots.ResetForTest(t, tlstest.TestRootCA()) controlLn, err := tls.Listen("tcp", "127.0.0.1:0", tlstest.ControlPlaneKeyPair.ServerTLSConfig()) if err != nil { t.Fatal(err) } defer controlLn.Close() proxyLn, err := tls.Listen("tcp", "127.0.0.1:0", tlstest.ProxyServerKeyPair.ServerTLSConfig()) if err != nil { t.Fatal(err) } defer proxyLn.Close() const requiredAuthKey = "hunter2" const someUsername = "testuser" const somePassword = "testpass" testControl := &testcontrol.Server{ Logf: tstest.WhileTestRunningLogger(t), RequireAuthKey: requiredAuthKey, } controlSrv := &http.Server{ Handler: testControl, ErrorLog: logger.StdLogger(t.Logf), } go controlSrv.Serve(controlLn) const fakeControlIP = "1.2.3.4" const fakeProxyIP = "5.6.7.8" dialer := &tsdial.Dialer{} dialer.SetNetMon(netmon.NewStatic()) dialer.SetSystemDialerForTest(func(ctx context.Context, network, addr string) (net.Conn, error) { host, _, err := net.SplitHostPort(addr) if err != nil { return nil, fmt.Errorf("SplitHostPort(%q): %v", addr, err) } var d net.Dialer if host == fakeControlIP { return d.DialContext(ctx, network, controlLn.Addr().String()) } if host == fakeProxyIP { return d.DialContext(ctx, network, proxyLn.Addr().String()) } return nil, fmt.Errorf("unexpected dial to %q", addr) }) opts := Options{ Persist: persist.Persist{}, GetMachinePrivateKey: func() (key.MachinePrivate, error) { return key.NewMachine(), nil }, AuthKey: requiredAuthKey, ServerURL: "https://controlplane.tstest", Clock: tstime.StdClock{}, Hostinfo: &tailcfg.Hostinfo{ BackendLogID: "test-backend-log-id", }, DiscoPublicKey: key.NewDisco().Public(), Logf: t.Logf, HealthTracker: &health.Tracker{}, PopBrowserURL: func(url string) { t.Logf("PopBrowserURL: %q", url) }, Dialer: dialer, } d, err := NewDirect(opts) if err != nil { t.Fatalf("NewDirect: %v", err) } d.dnsCache.LookupIPForTest = func(ctx context.Context, host string) ([]netip.Addr, error) { switch host { case "controlplane.tstest": return []netip.Addr{netip.MustParseAddr(fakeControlIP)}, nil case "proxy.tstest": if !withProxy { t.Errorf("unexpected DNS lookup for %q with proxy disabled", host) return nil, fmt.Errorf("unexpected DNS lookup for %q", host) } return []netip.Addr{netip.MustParseAddr(fakeProxyIP)}, nil } t.Errorf("unexpected DNS query for %q", host) return []netip.Addr{}, nil } var proxyReqs atomic.Int64 if withProxy { d.httpc.Transport.(*http.Transport).Proxy = func(req *http.Request) (*url.URL, error) { t.Logf("using proxy for %q", req.URL) u := &url.URL{ Scheme: "https", Host: "proxy.tstest:443", User: url.UserPassword(someUsername, somePassword), } return u, nil } connectProxy := &http.Server{ Handler: connectProxyTo(t, "controlplane.tstest:443", controlLn.Addr().String(), &proxyReqs), } go connectProxy.Serve(proxyLn) } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() url, err := d.TryLogin(ctx, LoginEphemeral) if err != nil { t.Fatalf("TryLogin: %v", err) } if url != "" { t.Errorf("got URL %q, want empty", url) } if withProxy { if got, want := proxyReqs.Load(), int64(1); got != want { t.Errorf("proxy CONNECT requests = %d; want %d", got, want) } } } func connectProxyTo(t testing.TB, target, backendAddrPort string, reqs *atomic.Int64) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.RequestURI != target { t.Errorf("invalid CONNECT request to %q; want %q", r.RequestURI, target) http.Error(w, "bad target", http.StatusBadRequest) return } r.Header.Set("Authorization", r.Header.Get("Proxy-Authorization")) // for the BasicAuth method. kinda trashy. user, pass, ok := r.BasicAuth() if !ok || user != "testuser" || pass != "testpass" { t.Errorf("invalid CONNECT auth %q:%q; want %q:%q", user, pass, "testuser", "testpass") http.Error(w, "bad auth", http.StatusUnauthorized) return } (&connectproxy.Handler{ Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { var d net.Dialer c, err := d.DialContext(ctx, network, backendAddrPort) if err == nil { reqs.Add(1) } return c, err }, Logf: t.Logf, }).ServeHTTP(w, r) }) }