diff --git a/clientupdate/distsign/distsign.go b/clientupdate/distsign/distsign.go index eba4b9267..2b3f2a97f 100644 --- a/clientupdate/distsign/distsign.go +++ b/clientupdate/distsign/distsign.go @@ -229,8 +229,6 @@ func (c *Client) Download(ctx context.Context, srcPath, dstPath string) error { c.logf("Downloading %q", sigURL) sig, err := fetch(sigURL, signatureSizeLimit) if err != nil { - // Best-effort clean up of downloaded package. - os.Remove(dstPathUnverified) return err } msg := binary.LittleEndian.AppendUint64(hash, uint64(len)) @@ -326,6 +324,8 @@ func fetch(url string, limit int64) ([]byte, error) { return io.ReadAll(io.LimitReader(resp.Body, limit)) } +var onResponseForTest = func(*http.Response) {} + // download writes the response body of url into a local file at dst, up to // limit bytes. On success, the returned value is a BLAKE2s hash of the file. func (c *Client) download(ctx context.Context, url, dst string, limit int64) ([]byte, int64, error) { @@ -349,31 +349,61 @@ func (c *Client) download(ctx context.Context, url, dst string, limit int64) ([] return nil, 0, fmt.Errorf("HEAD %q: unexpected Content-Length %v", url, res.ContentLength) } c.logf("Download size: %v", res.ContentLength) + h := NewPackageHash() dlReq := must.Get(http.NewRequestWithContext(ctx, httpm.GET, url, nil)) + + var skip int64 + if fi, err := os.Stat(dst); err == nil { + if fi.Size() == res.ContentLength { + // Assume it got corrupted previously and the earlier attempt failed + // the checksum. Delete it and start over. + if err := os.Remove(dst); err != nil { + return nil, 0, fmt.Errorf("error deleting previous assumed-bad download: %w", err) + } + } else if fi.Size() > 0 && fi.Size() < res.ContentLength { + c.logf("Existing file size: %v", fi.Size()) + skip = fi.Size() + dlReq.Header.Add("Range", fmt.Sprintf("bytes=%d-", skip)) + } + } + dlRes, err := hc.Do(dlReq) if err != nil { return nil, 0, err } + onResponseForTest(dlRes) defer dlRes.Body.Close() - // TODO(bradfitz): resume from existing partial file on disk - if dlRes.StatusCode != http.StatusOK { + + var of *os.File + wantResponseLength := res.ContentLength + switch dlRes.StatusCode { + case http.StatusOK: + if skip > 0 { + os.Remove(dst) // best effort; the Create will fail anyway if this would + } + of, err = os.Create(dst) + case http.StatusPartialContent: + wantResponseLength = res.ContentLength - skip + of, err = os.OpenFile(dst, os.O_CREATE|os.O_APPEND|os.O_RDWR, 0644) + if err == nil { + // Re-hash the previously downloaded chunk. + _, err = io.Copy(h, io.NewSectionReader(of, 0, skip)) + } + default: return nil, 0, fmt.Errorf("GET %q: %v", url, dlRes.Status) } - - of, err := os.Create(dst) if err != nil { return nil, 0, err } defer of.Close() - pw := &progressWriter{total: res.ContentLength, logf: c.logf} - h := NewPackageHash() + pw := &progressWriter{total: res.ContentLength, done: skip, logf: c.logf} n, err := io.Copy(io.MultiWriter(of, h, pw), io.LimitReader(dlRes.Body, limit)) if err != nil { return nil, n, err } - if n != res.ContentLength { - return nil, n, fmt.Errorf("GET %q: downloaded %v, want %v", url, n, res.ContentLength) + if n != wantResponseLength { + return nil, n, fmt.Errorf("GET %q: downloaded %v, want %v", url, n, wantResponseLength) } if err := dlRes.Body.Close(); err != nil { return nil, n, err diff --git a/clientupdate/distsign/distsign_test.go b/clientupdate/distsign/distsign_test.go index 09a701f49..8439f471b 100644 --- a/clientupdate/distsign/distsign_test.go +++ b/clientupdate/distsign/distsign_test.go @@ -5,6 +5,7 @@ package distsign import ( "bytes" + "cmp" "context" "crypto/ed25519" "net/http" @@ -12,10 +13,13 @@ import ( "net/url" "os" "path/filepath" + "reflect" "strings" "testing" + "time" "golang.org/x/crypto/blake2s" + "tailscale.com/tstest" ) func TestDownload(t *testing.T) { @@ -23,11 +27,13 @@ func TestDownload(t *testing.T) { c := srv.client(t) tests := []struct { - desc string - before func(*testing.T) - src string - want []byte - wantErr bool + desc string + before func(*testing.T) + existing []byte // optional existing data on disk to resume from + src string + want []byte + wantErr bool + wantCode int // HTTP status code of download to expect; 0 means http.StatusOK }{ { desc: "missing file", @@ -43,6 +49,45 @@ func TestDownload(t *testing.T) { src: "hello", want: []byte("world"), }, + { + desc: "success-resume", + before: func(*testing.T) { + srv.addSigned("hello", []byte("world")) + }, + src: "hello", + existing: []byte("wo"), + want: []byte("world"), + wantCode: http.StatusPartialContent, + }, + { + desc: "success-resume-ignore-matching-size", + before: func(*testing.T) { + srv.addSigned("hello", []byte("world")) + }, + src: "hello", + existing: []byte("WORLD"), // same size as world + want: []byte("world"), + wantCode: http.StatusOK, + }, + { + desc: "success-resume-ignore-existing-too-big", + before: func(*testing.T) { + srv.addSigned("hello", []byte("world")) + }, + src: "hello", + existing: []byte("longer-than-world"), // len greater than len("world") + want: []byte("world"), + wantCode: http.StatusOK, + }, + { + desc: "resume-corrupt", + before: func(*testing.T) { + srv.addSigned("hello", []byte("world")) + }, + src: "hello", + existing: []byte("WO"), // previous download was bad + wantErr: true, + }, { desc: "no signature", before: func(*testing.T) { @@ -94,10 +139,17 @@ func TestDownload(t *testing.T) { srv.reset() tt.before(t) - dst := filepath.Join(t.TempDir(), tt.src) - t.Cleanup(func() { - os.Remove(dst) + var gotCodes []int + tstest.Replace(t, &onResponseForTest, func(res *http.Response) { + gotCodes = append(gotCodes, res.StatusCode) }) + + dst := filepath.Join(t.TempDir(), tt.src) + if len(tt.existing) > 0 { + if err := os.WriteFile(dst+".unverified", tt.existing, 0644); err != nil { + t.Fatal(err) + } + } err := c.Download(context.Background(), tt.src, dst) if err != nil { if tt.wantErr { @@ -107,6 +159,11 @@ func TestDownload(t *testing.T) { } if tt.wantErr { t.Fatalf("Download(%q) succeeded, expected an error", tt.src) + } else { + wantCodes := []int{cmp.Or(tt.wantCode, http.StatusOK)} + if !reflect.DeepEqual(gotCodes, wantCodes) { + t.Errorf("HTTP response status code = %v; want %v", gotCodes, wantCodes) + } } got, err := os.ReadFile(dst) if err != nil { @@ -486,7 +543,7 @@ func (s *testServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.NotFound(w, r) return } - w.Write(data) + http.ServeContent(w, r, path, time.Time{}, bytes.NewReader(data)) } func (s *testServer) addSigned(name string, data []byte) {