// Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause package compositedav import ( "fmt" "log" "net/http" "path" "strings" "testing" "time" "github.com/google/go-cmp/cmp" "tailscale.com/tstest" ) var parentPath = "/parent" var childPath = "/parent/child.txt" var parentResponse = `<D:response> <D:href>/parent/</D:href> <D:propstat> <D:prop> <D:getlastmodified>Mon, 29 Apr 2024 19:52:23 GMT</D:getlastmodified> <D:creationdate>Fri, 19 Apr 2024 04:13:34 GMT</D:creationdate> <D:resourcetype> <D:collection xmlns:D="DAV:" /> </D:resourcetype> </D:prop> <D:status>HTTP/1.1 200 OK</D:status> </D:propstat> </D:response>` var childResponse = ` <D:response> <D:href>/parent/child.txt</D:href> <D:propstat> <D:prop> <D:getlastmodified>Mon, 29 Apr 2024 19:52:23 GMT</D:getlastmodified> <D:creationdate>Fri, 19 Apr 2024 04:13:34 GMT</D:creationdate> <D:resourcetype> <D:collection xmlns:D="DAV:" /> </D:resourcetype> </D:prop> <D:status>HTTP/1.1 200 OK</D:status> </D:propstat> </D:response>` var fullParent = []byte( strings.ReplaceAll( fmt.Sprintf(`<?xml version="1.0" encoding="UTF-8"?><D:multistatus xmlns:D="DAV:">%s%s</D:multistatus>`, parentResponse, childResponse), "\n", "")) var partialParent = []byte( strings.ReplaceAll( fmt.Sprintf(`<?xml version="1.0" encoding="UTF-8"?><D:multistatus xmlns:D="DAV:">%s</D:multistatus>`, parentResponse), "\n", "")) var fullChild = []byte( strings.ReplaceAll( fmt.Sprintf(`<?xml version="1.0" encoding="UTF-8"?><D:multistatus xmlns:D="DAV:">%s</D:multistatus>`, childResponse), "\n", "")) func TestStatCacheNoTimeout(t *testing.T) { // Make sure we don't leak goroutines tstest.ResourceCheck(t) c := &StatCache{TTL: 5 * time.Second} defer c.stop() // check get before set fetched := c.get(childPath, 0) if fetched != nil { t.Errorf("got %v, want nil", fetched) } // set new stat ce := newCacheEntry(http.StatusMultiStatus, fullChild) c.set(childPath, 0, ce) fetched = c.get(childPath, 0) if diff := cmp.Diff(fetched, ce); diff != "" { t.Errorf("should have gotten cached value; (-got+want):%v", diff) } // fetch stat again, should still be cached fetched = c.get(childPath, 0) if diff := cmp.Diff(fetched, ce); diff != "" { t.Errorf("should still have gotten cached value; (-got+want):%v", diff) } } func TestStatCacheTimeout(t *testing.T) { // Make sure we don't leak goroutines tstest.ResourceCheck(t) c := &StatCache{TTL: 250 * time.Millisecond} defer c.stop() // set new stat ce := newCacheEntry(http.StatusMultiStatus, fullChild) c.set(childPath, 0, ce) fetched := c.get(childPath, 0) if diff := cmp.Diff(fetched, ce); diff != "" { t.Errorf("should have gotten cached value; (-got+want):%v", diff) } // wait for cache to expire and refetch stat, should be empty now time.Sleep(c.TTL * 2) fetched = c.get(childPath, 0) if fetched != nil { t.Errorf("cached value should have expired") } c.set(childPath, 0, ce) // invalidate the cache and make sure nothing is returned c.invalidate() fetched = c.get(childPath, 0) if fetched != nil { t.Errorf("invalidate should have cleared cached value") } } func TestParentChildRelationship(t *testing.T) { // Make sure we don't leak goroutines tstest.ResourceCheck(t) c := &StatCache{TTL: 24 * time.Hour} // don't expire defer c.stop() missingParentPath := "/missingparent" unparseableParentPath := "/unparseable" c.set(parentPath, 1, newCacheEntry(http.StatusMultiStatus, fullParent)) c.set(missingParentPath, 1, newCacheEntry(http.StatusNotFound, nil)) c.set(unparseableParentPath, 1, newCacheEntry(http.StatusMultiStatus, []byte("<this will not parse"))) tests := []struct { path string depth int want *cacheEntry }{ { path: parentPath, depth: 1, want: newCacheEntry(http.StatusMultiStatus, fullParent), }, { path: parentPath, depth: 0, want: newCacheEntry(http.StatusMultiStatus, partialParent), }, { path: childPath, depth: 0, want: newCacheEntry(http.StatusMultiStatus, fullChild), }, { path: path.Join(parentPath, "nonexistent.txt"), depth: 0, want: notFound, }, { path: missingParentPath, depth: 1, want: notFound, }, { path: missingParentPath, depth: 0, want: notFound, }, { path: path.Join(missingParentPath, "filename.txt"), depth: 0, want: notFound, }, { path: unparseableParentPath, depth: 1, want: nil, }, { path: unparseableParentPath, depth: 0, want: nil, }, { path: path.Join(unparseableParentPath, "filename.txt"), depth: 0, want: nil, }, { path: "/unknown", depth: 1, want: nil, }, } for _, test := range tests { t.Run(fmt.Sprintf("%d%s", test.depth, test.path), func(t *testing.T) { got := c.get(test.path, test.depth) if diff := cmp.Diff(got, test.want); diff != "" { t.Errorf("unexpected cached value; (-got+want):%v", diff) log.Printf("want\n%s", test.want.Raw) log.Printf("got\n%s", got.Raw) } }) } }