drive: parse depth 1 PROPFIND results to include children in cache

Clients often perform a PROPFIND for the parent directory before
performing PROPFIND for specific children within that directory.
The PROPFIND for the parent directory is usually done at depth 1,
meaning that we already have information for all of the children.
By immediately adding that to the cache, we save a roundtrip to
the remote peer on the PROPFIND for the specific child.

Updates tailscale/corp#19779

Signed-off-by: Percy Wegmann <percy@tailscale.com>
This commit is contained in:
Percy Wegmann 2024-05-03 17:20:13 -05:00 committed by Percy Wegmann
parent d86d1e7601
commit 7209c4f91e
4 changed files with 368 additions and 39 deletions

View File

@ -81,6 +81,16 @@ type Handler struct {
staticRoot string staticRoot string
} }
var cacheInvalidatingMethods = map[string]bool{
"PUT": true,
"POST": true,
"COPY": true,
"MKCOL": true,
"MOVE": true,
"PROPPATCH": true,
"DELETE": true,
}
// ServeHTTP implements http.Handler. // ServeHTTP implements http.Handler.
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method == "PROPFIND" { if r.Method == "PROPFIND" {
@ -88,11 +98,12 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
if r.Method != "GET" { _, shouldInvalidate := cacheInvalidatingMethods[r.Method]
// If the user is performing a modification (e.g. PUT, MKDIR, etc), if shouldInvalidate {
// If the user is performing a modification (e.g. PUT, MKDIR, etc.),
// we need to invalidate the StatCache to make sure we're not knowingly // we need to invalidate the StatCache to make sure we're not knowingly
// showing stale stats. // showing stale stats.
// TODO(oxtoacart): maybe be more selective about invalidating cache // TODO(oxtoacart): maybe only invalidate specific paths
h.StatCache.invalidate() h.StatCache.invalidate()
} }

View File

@ -4,11 +4,19 @@
package compositedav package compositedav
import ( import (
"bytes"
"encoding/xml"
"log"
"net/http" "net/http"
"sync" "sync"
"time" "time"
"github.com/jellydator/ttlcache/v3" "github.com/jellydator/ttlcache/v3"
"tailscale.com/drive/driveimpl/shared"
)
var (
notFound = newCacheEntry(http.StatusNotFound, nil)
) )
// StatCache provides a cache for directory listings and file metadata. // StatCache provides a cache for directory listings and file metadata.
@ -18,12 +26,38 @@
// This is similar to the DirectoryCacheLifetime setting of Windows' built-in // This is similar to the DirectoryCacheLifetime setting of Windows' built-in
// SMB client, see // SMB client, see
// https://learn.microsoft.com/en-us/previous-versions/windows/it-pro/windows-7/ff686200(v=ws.10) // https://learn.microsoft.com/en-us/previous-versions/windows/it-pro/windows-7/ff686200(v=ws.10)
//
// StatCache is built specifically to cache the results of PROPFIND requests,
// which come back as MultiStatus XML responses. Typical clients will issue two
// kinds of PROPFIND:
//
// The first kind of PROPFIND is a directory listing performed to depth 1. At
// this depth, the resulting XML will contain stats for the requested folder as
// well as for all children of that folder.
//
// The second kind of PROPFIND is a file listing performed to depth 0. At this
// depth, the resulting XML will contain stats only for the requested file.
//
// In order to avoid round-trips, when a PROPFIND at depth 0 is attempted, and
// the requested file is not in the cache, StatCache will check to see if the
// parent folder of that file is cached. If so, StatCache infers the correct
// MultiStatus for the file according to the following logic:
//
// 1. If the parent folder is NotFound (404), treat the file itself as NotFound
// 2. If the parent folder's XML doesn't contain the file, treat it as
// NotFound.
// 3. If the parent folder's XML contains the file, build a MultiStatus for the
// file based on the parent's XML.
//
// To avoid inconsistencies from the perspective of the client, any operations
// that modify the filesystem (e.g. PUT, MKDIR, etc.) should call invalidate()
// to invalidate the cache.
type StatCache struct { type StatCache struct {
TTL time.Duration TTL time.Duration
// mu guards the below values. // mu guards the below values.
mu sync.Mutex mu sync.Mutex
cachesByDepthAndPath map[int]*ttlcache.Cache[string, []byte] cachesByDepthAndPath map[int]*ttlcache.Cache[string, *cacheEntry]
} }
// getOr checks the cache for the named value at the given depth. If a cached // getOr checks the cache for the named value at the given depth. If a cached
@ -32,25 +66,57 @@ type StatCache struct {
// status and value. If the function returned http.StatusMultiStatus, getOr // status and value. If the function returned http.StatusMultiStatus, getOr
// caches the resulting value at the given name and depth before returning. // caches the resulting value at the given name and depth before returning.
func (c *StatCache) getOr(name string, depth int, or func() (int, []byte)) (int, []byte) { func (c *StatCache) getOr(name string, depth int, or func() (int, []byte)) (int, []byte) {
cached := c.get(name, depth) ce := c.get(name, depth)
if cached != nil { if ce == nil {
return http.StatusMultiStatus, cached // Not cached, fetch value.
status, raw := or()
ce = newCacheEntry(status, raw)
if status == http.StatusMultiStatus || status == http.StatusNotFound {
// Got a legit status, cache value
c.set(name, depth, ce)
}
} }
status, next := or() return ce.Status, ce.Raw
if c != nil && status == http.StatusMultiStatus && next != nil {
c.set(name, depth, next)
}
return status, next
} }
func (c *StatCache) get(name string, depth int) []byte { // get retrieves the entry for the named file at the given depth. If no entry
// is found, and depth == 0, get will check to see if the parent path of name
// is present in the cache at depth 1. If so, it will infer that the child does
// not exist and return notFound (404).
func (c *StatCache) get(name string, depth int) *cacheEntry {
if c == nil { if c == nil {
return nil return nil
} }
name = shared.Normalize(name)
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
ce := c.tryGetLocked(name, depth)
if ce != nil {
// Cache hit.
return ce
}
if depth > 0 {
// Cache miss.
return nil
}
// At depth 0, if child's parent is in the cache, and the child isn't
// cached, we can infer that the child is notFound.
p := c.tryGetLocked(shared.Parent(name), 1)
if p != nil {
return notFound
}
// No parent in cache, cache miss.
return nil
}
// tryGetLocked requires that c.mu be held.
func (c *StatCache) tryGetLocked(name string, depth int) *cacheEntry {
if c.cachesByDepthAndPath == nil { if c.cachesByDepthAndPath == nil {
return nil return nil
} }
@ -65,28 +131,80 @@ func (c *StatCache) get(name string, depth int) []byte {
return item.Value() return item.Value()
} }
func (c *StatCache) set(name string, depth int, value []byte) { // set stores the given cacheEntry in the cache at the given name and depth. If
// the depth is 1, set also populates depth 0 entries in the cache for the bare
// name. If status is StatusMultiStatus, set will parse the PROPFIND result and
// store depth 0 entries for all children. If parsing the result fails, nothing
// is cached.
func (c *StatCache) set(name string, depth int, ce *cacheEntry) {
if c == nil { if c == nil {
return return
} }
name = shared.Normalize(name)
var self *cacheEntry
var children map[string]*cacheEntry
if depth == 1 {
switch ce.Status {
case http.StatusNotFound:
// Record notFound as the self entry.
self = ce
case http.StatusMultiStatus:
// Parse the raw MultiStatus and extract specific responses
// corresponding to the self entry (e.g. the directory, but at depth 0)
// and children (e.g. files within the directory) so that subsequent
// requests for these can be satisfied from the cache.
var ms multiStatus
err := xml.Unmarshal(ce.Raw, &ms)
if err != nil {
// unparseable MultiStatus response, don't cache
log.Printf("statcache.set error: %s", err)
return
}
children = make(map[string]*cacheEntry, len(ms.Responses)-1)
for i := 0; i < len(ms.Responses); i++ {
response := ms.Responses[i]
name := shared.Normalize(response.Href)
raw := marshalMultiStatus(response)
entry := newCacheEntry(ce.Status, raw)
if i == 0 {
self = entry
} else {
children[name] = entry
}
}
}
}
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
c.setLocked(name, depth, ce)
if self != nil {
c.setLocked(name, 0, self)
}
for childName, child := range children {
c.setLocked(childName, 0, child)
}
}
// setLocked requires that c.mu be held.
func (c *StatCache) setLocked(name string, depth int, ce *cacheEntry) {
if c.cachesByDepthAndPath == nil { if c.cachesByDepthAndPath == nil {
c.cachesByDepthAndPath = make(map[int]*ttlcache.Cache[string, []byte]) c.cachesByDepthAndPath = make(map[int]*ttlcache.Cache[string, *cacheEntry])
} }
cache := c.cachesByDepthAndPath[depth] cache := c.cachesByDepthAndPath[depth]
if cache == nil { if cache == nil {
cache = ttlcache.New( cache = ttlcache.New(
ttlcache.WithTTL[string, []byte](c.TTL), ttlcache.WithTTL[string, *cacheEntry](c.TTL),
) )
go cache.Start() go cache.Start()
c.cachesByDepthAndPath[depth] = cache c.cachesByDepthAndPath[depth] = cache
} }
cache.Set(name, value, ttlcache.DefaultTTL) cache.Set(name, ce, ttlcache.DefaultTTL)
} }
// invalidate invalidates the entire cache.
func (c *StatCache) invalidate() { func (c *StatCache) invalidate() {
if c == nil { if c == nil {
return return
@ -108,3 +226,54 @@ func (c *StatCache) stop() {
cache.Stop() cache.Stop()
} }
} }
type cacheEntry struct {
Status int
Raw []byte
}
func newCacheEntry(status int, raw []byte) *cacheEntry {
return &cacheEntry{Status: status, Raw: raw}
}
type propStat struct {
InnerXML []byte `xml:",innerxml"`
}
type response struct {
XMLName xml.Name `xml:"response"`
Href string `xml:"href"`
PropStats []*propStat `xml:"propstat"`
}
type multiStatus struct {
XMLName xml.Name `xml:"multistatus"`
Responses []*response `xml:"response"`
}
// marshalMultiStatus performs custom marshalling of a MultiStatus to preserve
// the original formatting, namespacing, etc. Doing this with Go's XML encoder
// is somewhere between difficult and impossible, which is why we use this more
// manual approach.
func marshalMultiStatus(response *response) []byte {
// TODO(percy): maybe pool these buffers
var buf bytes.Buffer
buf.WriteString(multistatusTemplateStart)
buf.WriteString(response.Href)
buf.WriteString(hrefEnd)
for _, propStat := range response.PropStats {
buf.WriteString(propstatStart)
buf.Write(propStat.InnerXML)
buf.WriteString(propstatEnd)
}
buf.WriteString(multistatusTemplateEnd)
return buf.Bytes()
}
const (
multistatusTemplateStart = `<?xml version="1.0" encoding="UTF-8"?><D:multistatus xmlns:D="DAV:"><D:response><D:href>`
hrefEnd = `</D:href>`
propstatStart = `<D:propstat>`
propstatEnd = `</D:propstat>`
multistatusTemplateEnd = `</D:response></D:multistatus>`
)

View File

@ -4,17 +4,65 @@
package compositedav package compositedav
import ( import (
"bytes" "fmt"
"log"
"net/http"
"path"
"strings"
"testing" "testing"
"time" "time"
"github.com/google/go-cmp/cmp"
"tailscale.com/tstest" "tailscale.com/tstest"
) )
var ( var parentPath = "/parent"
val = []byte("1")
file = "file" var childPath = "/parent/child.txt"
)
var parentResponse = `<D:response>
<D:href>/parent/</D:href>
<D:propstat>
<D:prop>
<D:getlastmodified>Mon, 29 Apr 2024 19:52:23 GMT</D:getlastmodified>
<D:creationdate>Fri, 19 Apr 2024 04:13:34 GMT</D:creationdate>
<D:resourcetype>
<D:collection xmlns:D="DAV:" />
</D:resourcetype>
</D:prop>
<D:status>HTTP/1.1 200 OK</D:status>
</D:propstat>
</D:response>`
var childResponse = `
<D:response>
<D:href>/parent/child.txt</D:href>
<D:propstat>
<D:prop>
<D:getlastmodified>Mon, 29 Apr 2024 19:52:23 GMT</D:getlastmodified>
<D:creationdate>Fri, 19 Apr 2024 04:13:34 GMT</D:creationdate>
<D:resourcetype>
<D:collection xmlns:D="DAV:" />
</D:resourcetype>
</D:prop>
<D:status>HTTP/1.1 200 OK</D:status>
</D:propstat>
</D:response>`
var fullParent = []byte(
strings.ReplaceAll(
fmt.Sprintf(`<?xml version="1.0" encoding="UTF-8"?><D:multistatus xmlns:D="DAV:">%s%s</D:multistatus>`, parentResponse, childResponse),
"\n", ""))
var partialParent = []byte(
strings.ReplaceAll(
fmt.Sprintf(`<?xml version="1.0" encoding="UTF-8"?><D:multistatus xmlns:D="DAV:">%s</D:multistatus>`, parentResponse),
"\n", ""))
var fullChild = []byte(
strings.ReplaceAll(
fmt.Sprintf(`<?xml version="1.0" encoding="UTF-8"?><D:multistatus xmlns:D="DAV:">%s</D:multistatus>`, childResponse),
"\n", ""))
func TestStatCacheNoTimeout(t *testing.T) { func TestStatCacheNoTimeout(t *testing.T) {
// Make sure we don't leak goroutines // Make sure we don't leak goroutines
@ -24,22 +72,23 @@ func TestStatCacheNoTimeout(t *testing.T) {
defer c.stop() defer c.stop()
// check get before set // check get before set
fetched := c.get(file, 1) fetched := c.get(childPath, 0)
if fetched != nil { if fetched != nil {
t.Errorf("got %q, want nil", fetched) t.Errorf("got %v, want nil", fetched)
} }
// set new stat // set new stat
c.set(file, 1, val) ce := newCacheEntry(http.StatusMultiStatus, fullChild)
fetched = c.get(file, 1) c.set(childPath, 0, ce)
if !bytes.Equal(fetched, val) { fetched = c.get(childPath, 0)
t.Errorf("got %q, want %q", fetched, val) if diff := cmp.Diff(fetched, ce); diff != "" {
t.Errorf("should have gotten cached value; (-got+want):%v", diff)
} }
// fetch stat again, should still be cached // fetch stat again, should still be cached
fetched = c.get(file, 1) fetched = c.get(childPath, 0)
if !bytes.Equal(fetched, val) { if diff := cmp.Diff(fetched, ce); diff != "" {
t.Errorf("got %q, want %q", fetched, val) t.Errorf("should still have gotten cached value; (-got+want):%v", diff)
} }
} }
@ -51,25 +100,114 @@ func TestStatCacheTimeout(t *testing.T) {
defer c.stop() defer c.stop()
// set new stat // set new stat
c.set(file, 1, val) ce := newCacheEntry(http.StatusMultiStatus, fullChild)
fetched := c.get(file, 1) c.set(childPath, 0, ce)
if !bytes.Equal(fetched, val) { fetched := c.get(childPath, 0)
t.Errorf("got %q, want %q", fetched, val) if diff := cmp.Diff(fetched, ce); diff != "" {
t.Errorf("should have gotten cached value; (-got+want):%v", diff)
} }
// wait for cache to expire and refetch stat, should be empty now // wait for cache to expire and refetch stat, should be empty now
time.Sleep(c.TTL * 2) time.Sleep(c.TTL * 2)
fetched = c.get(file, 1) fetched = c.get(childPath, 0)
if fetched != nil { if fetched != nil {
t.Errorf("invalidate should have cleared cached value") t.Errorf("cached value should have expired")
} }
c.set(file, 1, val) c.set(childPath, 0, ce)
// invalidate the cache and make sure nothing is returned // invalidate the cache and make sure nothing is returned
c.invalidate() c.invalidate()
fetched = c.get(file, 1) fetched = c.get(childPath, 0)
if fetched != nil { if fetched != nil {
t.Errorf("invalidate should have cleared cached value") t.Errorf("invalidate should have cleared cached value")
} }
} }
func TestParentChildRelationship(t *testing.T) {
// Make sure we don't leak goroutines
tstest.ResourceCheck(t)
c := &StatCache{TTL: 24 * time.Hour} // don't expire
defer c.stop()
missingParentPath := "/missingparent"
unparseableParentPath := "/unparseable"
c.set(parentPath, 1, newCacheEntry(http.StatusMultiStatus, fullParent))
c.set(missingParentPath, 1, newCacheEntry(http.StatusNotFound, nil))
c.set(unparseableParentPath, 1, newCacheEntry(http.StatusMultiStatus, []byte("<this will not parse")))
tests := []struct {
path string
depth int
want *cacheEntry
}{
{
path: parentPath,
depth: 1,
want: newCacheEntry(http.StatusMultiStatus, fullParent),
},
{
path: parentPath,
depth: 0,
want: newCacheEntry(http.StatusMultiStatus, partialParent),
},
{
path: childPath,
depth: 0,
want: newCacheEntry(http.StatusMultiStatus, fullChild),
},
{
path: path.Join(parentPath, "nonexistent.txt"),
depth: 0,
want: notFound,
},
{
path: missingParentPath,
depth: 1,
want: notFound,
},
{
path: missingParentPath,
depth: 0,
want: notFound,
},
{
path: path.Join(missingParentPath, "filename.txt"),
depth: 0,
want: notFound,
},
{
path: unparseableParentPath,
depth: 1,
want: nil,
},
{
path: unparseableParentPath,
depth: 0,
want: nil,
},
{
path: path.Join(unparseableParentPath, "filename.txt"),
depth: 0,
want: nil,
},
{
path: "/unknown",
depth: 1,
want: nil,
},
}
for _, test := range tests {
t.Run(fmt.Sprintf("%d%s", test.depth, test.path), func(t *testing.T) {
got := c.get(test.path, test.depth)
if diff := cmp.Diff(got, test.want); diff != "" {
t.Errorf("unexpected cached value; (-got+want):%v", diff)
log.Printf("want\n%s", test.want.Raw)
log.Printf("got\n%s", got.Raw)
}
})
}
}

View File

@ -26,6 +26,17 @@ func CleanAndSplit(p string) []string {
return strings.Split(strings.Trim(path.Clean(p), sepStringAndDot), sepString) return strings.Split(strings.Trim(path.Clean(p), sepStringAndDot), sepString)
} }
// Normalize normalizes the given path (e.g. dropping trailing slashes).
func Normalize(p string) string {
return Join(CleanAndSplit(p)...)
}
// Parent extracts the parent of the given path.
func Parent(p string) string {
parts := CleanAndSplit(p)
return Join(parts[:len(parts)-1]...)
}
// Join behaves like path.Join() but also includes a leading slash. // Join behaves like path.Join() but also includes a leading slash.
func Join(parts ...string) string { func Join(parts ...string) string {
fullParts := make([]string, 0, len(parts)) fullParts := make([]string, 0, len(parts))