clientupdate/distsign: resume partial downloads

Fixes #12573

Change-Id: If2684a2987ec95b893ba9b7c71dc0cb850765b18
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick
2024-06-21 09:24:08 -07:00
parent 5ec01bf3ce
commit dc38eaf551
2 changed files with 106 additions and 19 deletions

View File

@@ -229,8 +229,6 @@ func (c *Client) Download(ctx context.Context, srcPath, dstPath string) error {
c.logf("Downloading %q", sigURL)
sig, err := fetch(sigURL, signatureSizeLimit)
if err != nil {
// Best-effort clean up of downloaded package.
os.Remove(dstPathUnverified)
return err
}
msg := binary.LittleEndian.AppendUint64(hash, uint64(len))
@@ -326,6 +324,8 @@ func fetch(url string, limit int64) ([]byte, error) {
return io.ReadAll(io.LimitReader(resp.Body, limit))
}
var onResponseForTest = func(*http.Response) {}
// download writes the response body of url into a local file at dst, up to
// limit bytes. On success, the returned value is a BLAKE2s hash of the file.
func (c *Client) download(ctx context.Context, url, dst string, limit int64) ([]byte, int64, error) {
@@ -349,31 +349,61 @@ func (c *Client) download(ctx context.Context, url, dst string, limit int64) ([]
return nil, 0, fmt.Errorf("HEAD %q: unexpected Content-Length %v", url, res.ContentLength)
}
c.logf("Download size: %v", res.ContentLength)
h := NewPackageHash()
dlReq := must.Get(http.NewRequestWithContext(ctx, httpm.GET, url, nil))
var skip int64
if fi, err := os.Stat(dst); err == nil {
if fi.Size() == res.ContentLength {
// Assume it got corrupted previously and the earlier attempt failed
// the checksum. Delete it and start over.
if err := os.Remove(dst); err != nil {
return nil, 0, fmt.Errorf("error deleting previous assumed-bad download: %w", err)
}
} else if fi.Size() > 0 && fi.Size() < res.ContentLength {
c.logf("Existing file size: %v", fi.Size())
skip = fi.Size()
dlReq.Header.Add("Range", fmt.Sprintf("bytes=%d-", skip))
}
}
dlRes, err := hc.Do(dlReq)
if err != nil {
return nil, 0, err
}
onResponseForTest(dlRes)
defer dlRes.Body.Close()
// TODO(bradfitz): resume from existing partial file on disk
if dlRes.StatusCode != http.StatusOK {
var of *os.File
wantResponseLength := res.ContentLength
switch dlRes.StatusCode {
case http.StatusOK:
if skip > 0 {
os.Remove(dst) // best effort; the Create will fail anyway if this would
}
of, err = os.Create(dst)
case http.StatusPartialContent:
wantResponseLength = res.ContentLength - skip
of, err = os.OpenFile(dst, os.O_CREATE|os.O_APPEND|os.O_RDWR, 0644)
if err == nil {
// Re-hash the previously downloaded chunk.
_, err = io.Copy(h, io.NewSectionReader(of, 0, skip))
}
default:
return nil, 0, fmt.Errorf("GET %q: %v", url, dlRes.Status)
}
of, err := os.Create(dst)
if err != nil {
return nil, 0, err
}
defer of.Close()
pw := &progressWriter{total: res.ContentLength, logf: c.logf}
h := NewPackageHash()
pw := &progressWriter{total: res.ContentLength, done: skip, logf: c.logf}
n, err := io.Copy(io.MultiWriter(of, h, pw), io.LimitReader(dlRes.Body, limit))
if err != nil {
return nil, n, err
}
if n != res.ContentLength {
return nil, n, fmt.Errorf("GET %q: downloaded %v, want %v", url, n, res.ContentLength)
if n != wantResponseLength {
return nil, n, fmt.Errorf("GET %q: downloaded %v, want %v", url, n, wantResponseLength)
}
if err := dlRes.Body.Close(); err != nil {
return nil, n, err

View File

@@ -5,6 +5,7 @@ package distsign
import (
"bytes"
"cmp"
"context"
"crypto/ed25519"
"net/http"
@@ -12,10 +13,13 @@ import (
"net/url"
"os"
"path/filepath"
"reflect"
"strings"
"testing"
"time"
"golang.org/x/crypto/blake2s"
"tailscale.com/tstest"
)
func TestDownload(t *testing.T) {
@@ -23,11 +27,13 @@ func TestDownload(t *testing.T) {
c := srv.client(t)
tests := []struct {
desc string
before func(*testing.T)
src string
want []byte
wantErr bool
desc string
before func(*testing.T)
existing []byte // optional existing data on disk to resume from
src string
want []byte
wantErr bool
wantCode int // HTTP status code of download to expect; 0 means http.StatusOK
}{
{
desc: "missing file",
@@ -43,6 +49,45 @@ func TestDownload(t *testing.T) {
src: "hello",
want: []byte("world"),
},
{
desc: "success-resume",
before: func(*testing.T) {
srv.addSigned("hello", []byte("world"))
},
src: "hello",
existing: []byte("wo"),
want: []byte("world"),
wantCode: http.StatusPartialContent,
},
{
desc: "success-resume-ignore-matching-size",
before: func(*testing.T) {
srv.addSigned("hello", []byte("world"))
},
src: "hello",
existing: []byte("WORLD"), // same size as world
want: []byte("world"),
wantCode: http.StatusOK,
},
{
desc: "success-resume-ignore-existing-too-big",
before: func(*testing.T) {
srv.addSigned("hello", []byte("world"))
},
src: "hello",
existing: []byte("longer-than-world"), // len greater than len("world")
want: []byte("world"),
wantCode: http.StatusOK,
},
{
desc: "resume-corrupt",
before: func(*testing.T) {
srv.addSigned("hello", []byte("world"))
},
src: "hello",
existing: []byte("WO"), // previous download was bad
wantErr: true,
},
{
desc: "no signature",
before: func(*testing.T) {
@@ -94,10 +139,17 @@ func TestDownload(t *testing.T) {
srv.reset()
tt.before(t)
dst := filepath.Join(t.TempDir(), tt.src)
t.Cleanup(func() {
os.Remove(dst)
var gotCodes []int
tstest.Replace(t, &onResponseForTest, func(res *http.Response) {
gotCodes = append(gotCodes, res.StatusCode)
})
dst := filepath.Join(t.TempDir(), tt.src)
if len(tt.existing) > 0 {
if err := os.WriteFile(dst+".unverified", tt.existing, 0644); err != nil {
t.Fatal(err)
}
}
err := c.Download(context.Background(), tt.src, dst)
if err != nil {
if tt.wantErr {
@@ -107,6 +159,11 @@ func TestDownload(t *testing.T) {
}
if tt.wantErr {
t.Fatalf("Download(%q) succeeded, expected an error", tt.src)
} else {
wantCodes := []int{cmp.Or(tt.wantCode, http.StatusOK)}
if !reflect.DeepEqual(gotCodes, wantCodes) {
t.Errorf("HTTP response status code = %v; want %v", gotCodes, wantCodes)
}
}
got, err := os.ReadFile(dst)
if err != nil {
@@ -486,7 +543,7 @@ func (s *testServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
return
}
w.Write(data)
http.ServeContent(w, r, path, time.Time{}, bytes.NewReader(data))
}
func (s *testServer) addSigned(name string, data []byte) {