ipn/ipnlocal: support serving files/directories too

Updates tailscale/corp#7515

Change-Id: I7b4c924005274ba57763264313d70d2a0c55da30
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2022-11-10 14:16:37 -08:00 committed by Brad Fitzpatrick
parent 446057d613
commit 7b5866ac0a
3 changed files with 309 additions and 11 deletions

View File

@ -1980,7 +1980,6 @@ func (b *LocalBackend) loadStateLocked(key ipn.StateKey, prefs *ipn.Prefs) (err
func (b *LocalBackend) setTCPPortsIntercepted(ports []uint16) { func (b *LocalBackend) setTCPPortsIntercepted(ports []uint16) {
slices.Sort(ports) slices.Sort(ports)
uniq.ModifySlice(&ports) uniq.ModifySlice(&ports)
b.logf("localbackend: handling TCP ports = %v", ports)
var f func(uint16) bool var f func(uint16) bool
switch len(ports) { switch len(ports) {
case 0: case 0:

View File

@ -15,9 +15,12 @@
"net/http/httputil" "net/http/httputil"
"net/netip" "net/netip"
"net/url" "net/url"
"os"
"path"
pathpkg "path" pathpkg "path"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"tailscale.com/ipn" "tailscale.com/ipn"
@ -151,37 +154,44 @@ func (b *LocalBackend) HandleInterceptedTCPConn(dport uint16, srcAddr netip.Addr
sendRST() sendRST()
} }
func (b *LocalBackend) getServeHandler(r *http.Request) (_ ipn.HTTPHandlerView, ok bool) { func (b *LocalBackend) getServeHandler(r *http.Request) (_ ipn.HTTPHandlerView, at string, ok bool) {
var z ipn.HTTPHandlerView // zero value var z ipn.HTTPHandlerView // zero value
if r.TLS == nil { if r.TLS == nil {
return z, false return z, "", false
} }
sctx, ok := r.Context().Value(serveHTTPContextKey{}).(*serveHTTPContext) sctx, ok := r.Context().Value(serveHTTPContextKey{}).(*serveHTTPContext)
if !ok { if !ok {
b.logf("[unexpected] localbackend: no serveHTTPContext in request") b.logf("[unexpected] localbackend: no serveHTTPContext in request")
return z, false return z, "", false
} }
wsc, ok := b.webServerConfig(r.TLS.ServerName, sctx.DestPort) wsc, ok := b.webServerConfig(r.TLS.ServerName, sctx.DestPort)
if !ok { if !ok {
return z, false return z, "", false
} }
path := r.URL.Path if h, ok := wsc.Handlers().GetOk(r.URL.Path); ok {
return h, r.URL.Path, true
}
path := path.Clean(r.URL.Path)
for { for {
withSlash := path + "/"
if h, ok := wsc.Handlers().GetOk(withSlash); ok {
return h, withSlash, true
}
if h, ok := wsc.Handlers().GetOk(path); ok { if h, ok := wsc.Handlers().GetOk(path); ok {
return h, true return h, path, true
} }
if path == "/" { if path == "/" {
return z, false return z, "", false
} }
path = pathpkg.Dir(path) path = pathpkg.Dir(path)
} }
} }
func (b *LocalBackend) serveWebHandler(w http.ResponseWriter, r *http.Request) { func (b *LocalBackend) serveWebHandler(w http.ResponseWriter, r *http.Request) {
h, ok := b.getServeHandler(r) h, mountPoint, ok := b.getServeHandler(r)
if !ok { if !ok {
http.NotFound(w, r) http.NotFound(w, r)
return return
@ -192,7 +202,7 @@ func (b *LocalBackend) serveWebHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
if v := h.Path(); v != "" { if v := h.Path(); v != "" {
io.WriteString(w, "TODO(bradfitz): serve file") b.serveFileOrDirectory(w, r, v, mountPoint)
return return
} }
if v := h.Proxy(); v != "" { if v := h.Proxy(); v != "" {
@ -219,6 +229,74 @@ func (b *LocalBackend) serveWebHandler(w http.ResponseWriter, r *http.Request) {
http.Error(w, "empty handler", 500) http.Error(w, "empty handler", 500)
} }
func (b *LocalBackend) serveFileOrDirectory(w http.ResponseWriter, r *http.Request, fileOrDir, mountPoint string) {
fi, err := os.Stat(fileOrDir)
if err != nil {
if os.IsNotExist(err) {
http.NotFound(w, r)
return
}
http.Error(w, err.Error(), 500)
return
}
if fi.Mode().IsRegular() {
if mountPoint != r.URL.Path {
http.NotFound(w, r)
return
}
f, err := os.Open(fileOrDir)
if err != nil {
http.Error(w, err.Error(), 500)
return
}
defer f.Close()
http.ServeContent(w, r, path.Base(mountPoint), fi.ModTime(), f)
return
}
if !fi.IsDir() {
http.Error(w, "not a file or directory", 500)
return
}
if len(r.URL.Path) < len(mountPoint) && r.URL.Path+"/" == mountPoint {
http.Redirect(w, r, mountPoint, http.StatusFound)
return
}
var fs http.Handler = http.FileServer(http.Dir(fileOrDir))
if mountPoint != "/" {
fs = http.StripPrefix(strings.TrimSuffix(mountPoint, "/"), fs)
}
fs.ServeHTTP(&fixLocationHeaderResponseWriter{
ResponseWriter: w,
mountPoint: mountPoint,
}, r)
}
// fixLocationHeaderResponseWriter is an http.ResponseWriter wrapper that, upon
// flushing HTTP headers, prefixes any Location header with the mount point.
type fixLocationHeaderResponseWriter struct {
http.ResponseWriter
mountPoint string
fixOnce sync.Once // guards call to fix
}
func (w *fixLocationHeaderResponseWriter) fix() {
h := w.ResponseWriter.Header()
if v := h.Get("Location"); v != "" {
h.Set("Location", w.mountPoint+v)
}
}
func (w *fixLocationHeaderResponseWriter) WriteHeader(code int) {
w.fixOnce.Do(w.fix)
w.ResponseWriter.WriteHeader(code)
}
func (w *fixLocationHeaderResponseWriter) Write(p []byte) (int, error) {
w.fixOnce.Do(w.fix)
return w.ResponseWriter.Write(p)
}
// expandProxyArg returns a URL from s, where s can be of form: // expandProxyArg returns a URL from s, where s can be of form:
// //
// * port number ("8080") // * port number ("8080")

View File

@ -4,7 +4,20 @@
package ipnlocal package ipnlocal
import "testing" import (
"bytes"
"context"
"crypto/tls"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"testing"
"tailscale.com/ipn"
)
func TestExpandProxyArg(t *testing.T) { func TestExpandProxyArg(t *testing.T) {
type res struct { type res struct {
@ -31,3 +44,211 @@ type res struct {
} }
} }
} }
func TestGetServeHandler(t *testing.T) {
const serverName = "example.ts.net"
conf1 := &ipn.ServeConfig{
Web: map[ipn.HostPort]*ipn.WebServerConfig{
serverName + ":443": {
Handlers: map[string]*ipn.HTTPHandler{
"/": {},
"/bar": {},
"/foo/": {},
"/foo/bar": {},
"/foo/bar/": {},
},
},
},
}
tests := []struct {
name string
port uint16 // or 443 is zero
path string // http.Request.URL.Path
conf *ipn.ServeConfig
want string // mountPoint
}{
{
name: "nothing",
path: "/",
conf: nil,
want: "",
},
{
name: "root",
conf: conf1,
path: "/",
want: "/",
},
{
name: "root-other",
conf: conf1,
path: "/other",
want: "/",
},
{
name: "bar",
conf: conf1,
path: "/bar",
want: "/bar",
},
{
name: "foo-bar",
conf: conf1,
path: "/foo/bar",
want: "/foo/bar",
},
{
name: "foo-bar-slash",
conf: conf1,
path: "/foo/bar/",
want: "/foo/bar/",
},
{
name: "foo-bar-other",
conf: conf1,
path: "/foo/bar/other",
want: "/foo/bar/",
},
{
name: "foo-other",
conf: conf1,
path: "/foo/other",
want: "/foo/",
},
{
name: "foo-no-trailing-slash",
conf: conf1,
path: "/foo",
want: "/foo/",
},
{
name: "dot-dots",
conf: conf1,
path: "/foo/../../../../../../../../etc/passwd",
want: "/",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
b := &LocalBackend{
serveConfig: tt.conf.View(),
logf: t.Logf,
}
req := &http.Request{
URL: &url.URL{
Path: tt.path,
},
TLS: &tls.ConnectionState{ServerName: serverName},
}
port := tt.port
if port == 0 {
port = 443
}
req = req.WithContext(context.WithValue(req.Context(), serveHTTPContextKey{}, &serveHTTPContext{
DestPort: port,
}))
h, got, ok := b.getServeHandler(req)
if (got != "") != ok {
t.Fatalf("got ok=%v, but got mountPoint=%q", ok, got)
}
if h.Valid() != ok {
t.Fatalf("got ok=%v, but valid=%v", ok, h.Valid())
}
if got != tt.want {
t.Errorf("got handler at mount %q, want %q", got, tt.want)
}
})
}
}
func TestServeFileOrDirectory(t *testing.T) {
td := t.TempDir()
writeFile := func(suffix, contents string) {
if err := os.WriteFile(filepath.Join(td, suffix), []byte(contents), 0600); err != nil {
t.Fatal(err)
}
}
writeFile("foo", "this is foo")
writeFile("bar", "this is bar")
os.MkdirAll(filepath.Join(td, "subdir"), 0700)
writeFile("subdir/file-a", "this is A")
writeFile("subdir/file-b", "this is B")
writeFile("subdir/file-c", "this is C")
contains := func(subs ...string) func([]byte, *http.Response) error {
return func(resBody []byte, res *http.Response) error {
for _, sub := range subs {
if !bytes.Contains(resBody, []byte(sub)) {
return fmt.Errorf("response body does not contain %q: %s", sub, resBody)
}
}
return nil
}
}
isStatus := func(wantCode int) func([]byte, *http.Response) error {
return func(resBody []byte, res *http.Response) error {
if res.StatusCode != wantCode {
return fmt.Errorf("response status = %d; want %d", res.StatusCode, wantCode)
}
return nil
}
}
isRedirect := func(wantLocation string) func([]byte, *http.Response) error {
return func(resBody []byte, res *http.Response) error {
switch res.StatusCode {
case 301, 302, 303, 307, 308:
if got := res.Header.Get("Location"); got != wantLocation {
return fmt.Errorf("got Location = %q; want %q", got, wantLocation)
}
default:
return fmt.Errorf("response status = %d; want redirect. body: %s", res.StatusCode, resBody)
}
return nil
}
}
b := &LocalBackend{}
tests := []struct {
req string
mount string
want func(resBody []byte, res *http.Response) error
}{
// Mounted at /
{"/", "/", contains("foo", "bar", "subdir")},
{"/../../.../../../../../../../etc/passwd", "/", isStatus(404)},
{"/foo", "/", contains("this is foo")},
{"/bar", "/", contains("this is bar")},
{"/bar/inside-file", "/", isStatus(404)},
{"/subdir", "/", isRedirect("/subdir/")},
{"/subdir/", "/", contains("file-a", "file-b", "file-c")},
{"/subdir/file-a", "/", contains("this is A")},
{"/subdir/file-z", "/", isStatus(404)},
{"/doc", "/doc/", isRedirect("/doc/")},
{"/doc/", "/doc/", contains("foo", "bar", "subdir")},
{"/doc/../../.../../../../../../../etc/passwd", "/doc/", isStatus(404)},
{"/doc/foo", "/doc/", contains("this is foo")},
{"/doc/bar", "/doc/", contains("this is bar")},
{"/doc/bar/inside-file", "/doc/", isStatus(404)},
{"/doc/subdir", "/doc/", isRedirect("/doc/subdir/")},
{"/doc/subdir/", "/doc/", contains("file-a", "file-b", "file-c")},
{"/doc/subdir/file-a", "/doc/", contains("this is A")},
{"/doc/subdir/file-z", "/doc/", isStatus(404)},
}
for _, tt := range tests {
rec := httptest.NewRecorder()
req := httptest.NewRequest("GET", tt.req, nil)
b.serveFileOrDirectory(rec, req, td, tt.mount)
if tt.want == nil {
t.Errorf("no want for path %q", tt.req)
return
}
if err := tt.want(rec.Body.Bytes(), rec.Result()); err != nil {
t.Errorf("error for req %q (mount %v): %v", tt.req, tt.mount, err)
}
}
}