// Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause package ipnlocal import ( "bytes" "context" "crypto/sha256" "crypto/tls" "encoding/hex" "encoding/json" "errors" "fmt" "net/http" "net/http/httptest" "net/netip" "net/url" "os" "path/filepath" "strings" "testing" "time" "tailscale.com/ipn" "tailscale.com/ipn/store/mem" "tailscale.com/tailcfg" "tailscale.com/tsd" "tailscale.com/types/logid" "tailscale.com/types/netmap" "tailscale.com/util/cmpx" "tailscale.com/util/mak" "tailscale.com/util/must" "tailscale.com/wgengine" ) func TestExpandProxyArg(t *testing.T) { type res struct { target string insecure bool } tests := []struct { in string want res }{ {"", res{}}, {"3030", res{"http://127.0.0.1:3030", false}}, {"localhost:3030", res{"http://localhost:3030", false}}, {"10.2.3.5:3030", res{"http://10.2.3.5:3030", false}}, {"http://foo.com", res{"http://foo.com", false}}, {"https://foo.com", res{"https://foo.com", false}}, {"https+insecure://10.2.3.4", res{"https://10.2.3.4", true}}, } for _, tt := range tests { target, insecure := expandProxyArg(tt.in) got := res{target, insecure} if got != tt.want { t.Errorf("expandProxyArg(%q) = %v, want %v", tt.in, got, tt.want) } } } func TestGetServeHandler(t *testing.T) { const serverName = "example.ts.net" conf1 := &ipn.ServeConfig{ Web: map[ipn.HostPort]*ipn.WebServerConfig{ serverName + ":443": { Handlers: map[string]*ipn.HTTPHandler{ "/": {}, "/bar": {}, "/foo/": {}, "/foo/bar": {}, "/foo/bar/": {}, }, }, }, } tests := []struct { name string port uint16 // or 443 is zero path string // http.Request.URL.Path conf *ipn.ServeConfig want string // mountPoint }{ { name: "nothing", path: "/", conf: nil, want: "", }, { name: "root", conf: conf1, path: "/", want: "/", }, { name: "root-other", conf: conf1, path: "/other", want: "/", }, { name: "bar", conf: conf1, path: "/bar", want: "/bar", }, { name: "foo-bar", conf: conf1, path: "/foo/bar", want: "/foo/bar", }, { name: "foo-bar-slash", conf: conf1, path: "/foo/bar/", want: "/foo/bar/", }, { name: "foo-bar-other", conf: conf1, path: "/foo/bar/other", want: "/foo/bar/", }, { name: "foo-other", conf: conf1, path: "/foo/other", want: "/foo/", }, { name: "foo-no-trailing-slash", conf: conf1, path: "/foo", want: "/foo/", }, { name: "dot-dots", conf: conf1, path: "/foo/../../../../../../../../etc/passwd", want: "/", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { b := &LocalBackend{ serveConfig: tt.conf.View(), logf: t.Logf, } req := &http.Request{ URL: &url.URL{ Path: tt.path, }, TLS: &tls.ConnectionState{ServerName: serverName}, } port := cmpx.Or(tt.port, 443) req = req.WithContext(context.WithValue(req.Context(), serveHTTPContextKey{}, &serveHTTPContext{ DestPort: port, })) h, got, ok := b.getServeHandler(req) if (got != "") != ok { t.Fatalf("got ok=%v, but got mountPoint=%q", ok, got) } if h.Valid() != ok { t.Fatalf("got ok=%v, but valid=%v", ok, h.Valid()) } if got != tt.want { t.Errorf("got handler at mount %q, want %q", got, tt.want) } }) } } func getEtag(t *testing.T, b any) string { t.Helper() bts, err := json.Marshal(b) if err != nil { t.Fatal(err) } sum := sha256.Sum256(bts) return hex.EncodeToString(sum[:]) } // TestServeConfigForeground tests the inter-dependency // between a ServeConfig and a WatchIPNBus: // 1. Creating a WatchIPNBus returns a sessionID, that // 2. ServeConfig sets it as the key of the Foreground field. // 3. ServeConfig expects the WatchIPNBus to clean up the Foreground // config when the session is done. // 4. WatchIPNBus expects the ServeConfig to send a signal (close the channel) // if an incoming SetServeConfig removes previous foregrounds. func TestServeConfigForeground(t *testing.T) { b := newTestBackend(t) ch1 := make(chan string, 1) go func() { defer close(ch1) b.WatchNotifications(context.Background(), ipn.NotifyInitialState, nil, func(roNotify *ipn.Notify) (keepGoing bool) { if roNotify.SessionID != "" { ch1 <- roNotify.SessionID } return true }) }() ch2 := make(chan string, 1) go func() { b.WatchNotifications(context.Background(), ipn.NotifyInitialState, nil, func(roNotify *ipn.Notify) (keepGoing bool) { if roNotify.SessionID != "" { ch2 <- roNotify.SessionID return true } ch2 <- "again" // let channel know fn was called again return true }) }() var session1 string select { case session1 = <-ch1: case <-time.After(time.Second): t.Fatal("timed out waiting on watch notifications session id") } var session2 string select { case session2 = <-ch2: case <-time.After(time.Second): t.Fatal("timed out waiting on watch notifications session id") } err := b.SetServeConfig(&ipn.ServeConfig{ Foreground: map[string]*ipn.ServeConfig{ session1: {TCP: map[uint16]*ipn.TCPPortHandler{ 443: {TCPForward: "http://localhost:3000"}}, }, session2: {TCP: map[uint16]*ipn.TCPPortHandler{ 999: {TCPForward: "http://localhost:4000"}}, }, }, }, "") if err != nil { t.Fatal(err) } // Setting a new serve config should shut down WatchNotifications // whose session IDs are no longer found: session1 goes, session2 stays. err = b.SetServeConfig(&ipn.ServeConfig{ TCP: map[uint16]*ipn.TCPPortHandler{ 5000: {TCPForward: "http://localhost:5000"}, }, Foreground: map[string]*ipn.ServeConfig{ session2: {TCP: map[uint16]*ipn.TCPPortHandler{ 999: {TCPForward: "http://localhost:4000"}}, }, }, }, "") if err != nil { t.Fatal(err) } select { case _, ok := <-ch1: if ok { t.Fatal("expected channel to be closed") } case <-time.After(time.Second): t.Fatal("timed out waiting on watch notifications closing") } // check that the second session is still running b.send(ipn.Notify{}) select { case _, ok := <-ch2: if !ok { t.Fatal("expected second session to remain open") } case <-time.After(time.Second): t.Fatal("timed out waiting on second session") } } func TestServeConfigETag(t *testing.T) { b := newTestBackend(t) // a nil config with initial etag should succeed err := b.SetServeConfig(nil, getEtag(t, nil)) if err != nil { t.Fatal(err) } // a nil config with an invalid etag should fail err = b.SetServeConfig(nil, "abc") if !errors.Is(err, ErrETagMismatch) { t.Fatal("expected an error but got nil") } // a new config with no etag should succeed conf := &ipn.ServeConfig{ Web: map[ipn.HostPort]*ipn.WebServerConfig{ "example.ts.net:443": {Handlers: map[string]*ipn.HTTPHandler{ "/": {Proxy: "http://127.0.0.1:3000"}, }}, }, } err = b.SetServeConfig(conf, getEtag(t, nil)) if err != nil { t.Fatal(err) } confView := b.ServeConfig() etag := getEtag(t, confView) if etag == "" { t.Fatal("expected to get an etag but got an empty string") } conf = confView.AsStruct() mak.Set(&conf.AllowFunnel, "example.ts.net:443", true) // replacing an existing config with an invalid etag should fail err = b.SetServeConfig(conf, "invalid etag") if !errors.Is(err, ErrETagMismatch) { t.Fatalf("expected an etag mismatch error but got %v", err) } // replacing an existing config with a valid etag should succeed err = b.SetServeConfig(conf, etag) if err != nil { t.Fatal(err) } // replacing an existing config with a previous etag should fail err = b.SetServeConfig(nil, etag) if !errors.Is(err, ErrETagMismatch) { t.Fatalf("expected an etag mismatch error but got %v", err) } // replacing an existing config with the new etag should succeed newCfg := b.ServeConfig() etag = getEtag(t, newCfg) err = b.SetServeConfig(nil, etag) if err != nil { t.Fatal(err) } } func TestServeHTTPProxy(t *testing.T) { b := newTestBackend(t) // Start test serve endpoint. testServ := httptest.NewServer(http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { // Piping all the headers through the response writer // so we can check their values in tests below. for key, val := range r.Header { w.Header().Add(key, strings.Join(val, ",")) } }, )) defer testServ.Close() conf := &ipn.ServeConfig{ Web: map[ipn.HostPort]*ipn.WebServerConfig{ "example.ts.net:443": {Handlers: map[string]*ipn.HTTPHandler{ "/": {Proxy: testServ.URL}, }}, }, } if err := b.SetServeConfig(conf, ""); err != nil { t.Fatal(err) } type headerCheck struct { header string want string } tests := []struct { name string srcIP string wantHeaders []headerCheck }{ { name: "request-from-user-within-tailnet", srcIP: "100.150.151.152", wantHeaders: []headerCheck{ {"X-Forwarded-Proto", "https"}, {"X-Forwarded-For", "100.150.151.152"}, {"Tailscale-User-Login", "someone@example.com"}, {"Tailscale-User-Name", "Some One"}, {"Tailscale-User-Profile-Pic", "https://example.com/photo.jpg"}, {"Tailscale-Headers-Info", "https://tailscale.com/s/serve-headers"}, }, }, { name: "request-from-tagged-node-within-tailnet", srcIP: "100.150.151.153", wantHeaders: []headerCheck{ {"X-Forwarded-Proto", "https"}, {"X-Forwarded-For", "100.150.151.153"}, {"Tailscale-User-Login", ""}, {"Tailscale-User-Name", ""}, {"Tailscale-User-Profile-Pic", ""}, {"Tailscale-Headers-Info", ""}, }, }, { name: "request-from-outside-tailnet", srcIP: "100.160.161.162", wantHeaders: []headerCheck{ {"X-Forwarded-Proto", "https"}, {"X-Forwarded-For", "100.160.161.162"}, {"Tailscale-User-Login", ""}, {"Tailscale-User-Name", ""}, {"Tailscale-User-Profile-Pic", ""}, {"Tailscale-Headers-Info", ""}, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := &http.Request{ URL: &url.URL{Path: "/"}, TLS: &tls.ConnectionState{ServerName: "example.ts.net"}, } req = req.WithContext(context.WithValue(req.Context(), serveHTTPContextKey{}, &serveHTTPContext{ DestPort: 443, SrcAddr: netip.MustParseAddrPort(tt.srcIP + ":1234"), // random src port for tests })) w := httptest.NewRecorder() b.serveWebHandler(w, req) // Verify the headers. h := w.Result().Header for _, c := range tt.wantHeaders { if got := h.Get(c.header); got != c.want { t.Errorf("invalid %q header; want=%q, got=%q", c.header, c.want, got) } } }) } } func newTestBackend(t *testing.T) *LocalBackend { sys := &tsd.System{} e, err := wgengine.NewUserspaceEngine(t.Logf, wgengine.Config{SetSubsystem: sys.Set}) if err != nil { t.Fatal(err) } sys.Set(e) sys.Set(new(mem.Store)) b, err := NewLocalBackend(t.Logf, logid.PublicID{}, sys, 0) if err != nil { t.Fatal(err) } t.Cleanup(b.Shutdown) dir := t.TempDir() b.SetVarRoot(dir) pm := must.Get(newProfileManager(new(mem.Store), t.Logf)) pm.currentProfile = &ipn.LoginProfile{ID: "id0"} b.pm = pm b.netMap = &netmap.NetworkMap{ SelfNode: (&tailcfg.Node{ Name: "example.ts.net", }).View(), UserProfiles: map[tailcfg.UserID]tailcfg.UserProfile{ tailcfg.UserID(1): { LoginName: "someone@example.com", DisplayName: "Some One", ProfilePicURL: "https://example.com/photo.jpg", }, }, } b.peers = map[tailcfg.NodeID]tailcfg.NodeView{ 152: (&tailcfg.Node{ ID: 152, ComputedName: "some-peer", User: tailcfg.UserID(1), }).View(), 153: (&tailcfg.Node{ ID: 153, ComputedName: "some-tagged-peer", Tags: []string{"tag:server", "tag:test"}, User: tailcfg.UserID(1), }).View(), } b.nodeByAddr = map[netip.Addr]tailcfg.NodeID{ netip.MustParseAddr("100.150.151.152"): 152, netip.MustParseAddr("100.150.151.153"): 153, } return b } func TestServeFileOrDirectory(t *testing.T) { td := t.TempDir() writeFile := func(suffix, contents string) { if err := os.WriteFile(filepath.Join(td, suffix), []byte(contents), 0600); err != nil { t.Fatal(err) } } writeFile("foo", "this is foo") writeFile("bar", "this is bar") os.MkdirAll(filepath.Join(td, "subdir"), 0700) writeFile("subdir/file-a", "this is A") writeFile("subdir/file-b", "this is B") writeFile("subdir/file-c", "this is C") contains := func(subs ...string) func([]byte, *http.Response) error { return func(resBody []byte, res *http.Response) error { for _, sub := range subs { if !bytes.Contains(resBody, []byte(sub)) { return fmt.Errorf("response body does not contain %q: %s", sub, resBody) } } return nil } } isStatus := func(wantCode int) func([]byte, *http.Response) error { return func(resBody []byte, res *http.Response) error { if res.StatusCode != wantCode { return fmt.Errorf("response status = %d; want %d", res.StatusCode, wantCode) } return nil } } isRedirect := func(wantLocation string) func([]byte, *http.Response) error { return func(resBody []byte, res *http.Response) error { switch res.StatusCode { case 301, 302, 303, 307, 308: if got := res.Header.Get("Location"); got != wantLocation { return fmt.Errorf("got Location = %q; want %q", got, wantLocation) } default: return fmt.Errorf("response status = %d; want redirect. body: %s", res.StatusCode, resBody) } return nil } } b := &LocalBackend{} tests := []struct { req string mount string want func(resBody []byte, res *http.Response) error }{ // Mounted at / {"/", "/", contains("foo", "bar", "subdir")}, {"/../../.../../../../../../../etc/passwd", "/", isStatus(404)}, {"/foo", "/", contains("this is foo")}, {"/bar", "/", contains("this is bar")}, {"/bar/inside-file", "/", isStatus(404)}, {"/subdir", "/", isRedirect("/subdir/")}, {"/subdir/", "/", contains("file-a", "file-b", "file-c")}, {"/subdir/file-a", "/", contains("this is A")}, {"/subdir/file-z", "/", isStatus(404)}, {"/doc", "/doc/", isRedirect("/doc/")}, {"/doc/", "/doc/", contains("foo", "bar", "subdir")}, {"/doc/../../.../../../../../../../etc/passwd", "/doc/", isStatus(404)}, {"/doc/foo", "/doc/", contains("this is foo")}, {"/doc/bar", "/doc/", contains("this is bar")}, {"/doc/bar/inside-file", "/doc/", isStatus(404)}, {"/doc/subdir", "/doc/", isRedirect("/doc/subdir/")}, {"/doc/subdir/", "/doc/", contains("file-a", "file-b", "file-c")}, {"/doc/subdir/file-a", "/doc/", contains("this is A")}, {"/doc/subdir/file-z", "/doc/", isStatus(404)}, } for _, tt := range tests { rec := httptest.NewRecorder() req := httptest.NewRequest("GET", tt.req, nil) b.serveFileOrDirectory(rec, req, td, tt.mount) if tt.want == nil { t.Errorf("no want for path %q", tt.req) return } if err := tt.want(rec.Body.Bytes(), rec.Result()); err != nil { t.Errorf("error for req %q (mount %v): %v", tt.req, tt.mount, err) } } } func Test_isGRPCContentType(t *testing.T) { tests := map[string]struct { contentType string want bool }{ "application/grpc": { contentType: "application/grpc", want: true, }, "application/grpc;": { contentType: "application/grpc;", want: true, }, "application/grpc+": { contentType: "application/grpc+", want: true, }, "application/grpcfoobar": { contentType: "application/grpcfoobar", }, "application/text": { contentType: "application/text", }, "foobar": { contentType: "foobar", }, } for name, scenario := range tests { t.Run(name, func(t *testing.T) { if got := isGRPCContentType(scenario.contentType); got != scenario.want { t.Errorf("test case %s failed, isGRPCContentType() = %v, want %v", name, got, scenario.want) } }) } }