clientupdate/distsign: resume partial downloads

Fixes #12573

Change-Id: If2684a2987ec95b893ba9b7c71dc0cb850765b18
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick 2024-06-21 09:24:08 -07:00
parent 5ec01bf3ce
commit dc38eaf551
2 changed files with 106 additions and 19 deletions

View File

@ -229,8 +229,6 @@ func (c *Client) Download(ctx context.Context, srcPath, dstPath string) error {
c.logf("Downloading %q", sigURL) c.logf("Downloading %q", sigURL)
sig, err := fetch(sigURL, signatureSizeLimit) sig, err := fetch(sigURL, signatureSizeLimit)
if err != nil { if err != nil {
// Best-effort clean up of downloaded package.
os.Remove(dstPathUnverified)
return err return err
} }
msg := binary.LittleEndian.AppendUint64(hash, uint64(len)) msg := binary.LittleEndian.AppendUint64(hash, uint64(len))
@ -326,6 +324,8 @@ func fetch(url string, limit int64) ([]byte, error) {
return io.ReadAll(io.LimitReader(resp.Body, limit)) return io.ReadAll(io.LimitReader(resp.Body, limit))
} }
var onResponseForTest = func(*http.Response) {}
// download writes the response body of url into a local file at dst, up to // download writes the response body of url into a local file at dst, up to
// limit bytes. On success, the returned value is a BLAKE2s hash of the file. // limit bytes. On success, the returned value is a BLAKE2s hash of the file.
func (c *Client) download(ctx context.Context, url, dst string, limit int64) ([]byte, int64, error) { func (c *Client) download(ctx context.Context, url, dst string, limit int64) ([]byte, int64, error) {
@ -349,31 +349,61 @@ func (c *Client) download(ctx context.Context, url, dst string, limit int64) ([]
return nil, 0, fmt.Errorf("HEAD %q: unexpected Content-Length %v", url, res.ContentLength) return nil, 0, fmt.Errorf("HEAD %q: unexpected Content-Length %v", url, res.ContentLength)
} }
c.logf("Download size: %v", res.ContentLength) c.logf("Download size: %v", res.ContentLength)
h := NewPackageHash()
dlReq := must.Get(http.NewRequestWithContext(ctx, httpm.GET, url, nil)) dlReq := must.Get(http.NewRequestWithContext(ctx, httpm.GET, url, nil))
var skip int64
if fi, err := os.Stat(dst); err == nil {
if fi.Size() == res.ContentLength {
// Assume it got corrupted previously and the earlier attempt failed
// the checksum. Delete it and start over.
if err := os.Remove(dst); err != nil {
return nil, 0, fmt.Errorf("error deleting previous assumed-bad download: %w", err)
}
} else if fi.Size() > 0 && fi.Size() < res.ContentLength {
c.logf("Existing file size: %v", fi.Size())
skip = fi.Size()
dlReq.Header.Add("Range", fmt.Sprintf("bytes=%d-", skip))
}
}
dlRes, err := hc.Do(dlReq) dlRes, err := hc.Do(dlReq)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
onResponseForTest(dlRes)
defer dlRes.Body.Close() defer dlRes.Body.Close()
// TODO(bradfitz): resume from existing partial file on disk
if dlRes.StatusCode != http.StatusOK { var of *os.File
wantResponseLength := res.ContentLength
switch dlRes.StatusCode {
case http.StatusOK:
if skip > 0 {
os.Remove(dst) // best effort; the Create will fail anyway if this would
}
of, err = os.Create(dst)
case http.StatusPartialContent:
wantResponseLength = res.ContentLength - skip
of, err = os.OpenFile(dst, os.O_CREATE|os.O_APPEND|os.O_RDWR, 0644)
if err == nil {
// Re-hash the previously downloaded chunk.
_, err = io.Copy(h, io.NewSectionReader(of, 0, skip))
}
default:
return nil, 0, fmt.Errorf("GET %q: %v", url, dlRes.Status) return nil, 0, fmt.Errorf("GET %q: %v", url, dlRes.Status)
} }
of, err := os.Create(dst)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
defer of.Close() defer of.Close()
pw := &progressWriter{total: res.ContentLength, logf: c.logf} pw := &progressWriter{total: res.ContentLength, done: skip, logf: c.logf}
h := NewPackageHash()
n, err := io.Copy(io.MultiWriter(of, h, pw), io.LimitReader(dlRes.Body, limit)) n, err := io.Copy(io.MultiWriter(of, h, pw), io.LimitReader(dlRes.Body, limit))
if err != nil { if err != nil {
return nil, n, err return nil, n, err
} }
if n != res.ContentLength { if n != wantResponseLength {
return nil, n, fmt.Errorf("GET %q: downloaded %v, want %v", url, n, res.ContentLength) return nil, n, fmt.Errorf("GET %q: downloaded %v, want %v", url, n, wantResponseLength)
} }
if err := dlRes.Body.Close(); err != nil { if err := dlRes.Body.Close(); err != nil {
return nil, n, err return nil, n, err

View File

@ -5,6 +5,7 @@ package distsign
import ( import (
"bytes" "bytes"
"cmp"
"context" "context"
"crypto/ed25519" "crypto/ed25519"
"net/http" "net/http"
@ -12,10 +13,13 @@ import (
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
"reflect"
"strings" "strings"
"testing" "testing"
"time"
"golang.org/x/crypto/blake2s" "golang.org/x/crypto/blake2s"
"tailscale.com/tstest"
) )
func TestDownload(t *testing.T) { func TestDownload(t *testing.T) {
@ -25,9 +29,11 @@ func TestDownload(t *testing.T) {
tests := []struct { tests := []struct {
desc string desc string
before func(*testing.T) before func(*testing.T)
existing []byte // optional existing data on disk to resume from
src string src string
want []byte want []byte
wantErr bool wantErr bool
wantCode int // HTTP status code of download to expect; 0 means http.StatusOK
}{ }{
{ {
desc: "missing file", desc: "missing file",
@ -43,6 +49,45 @@ func TestDownload(t *testing.T) {
src: "hello", src: "hello",
want: []byte("world"), want: []byte("world"),
}, },
{
desc: "success-resume",
before: func(*testing.T) {
srv.addSigned("hello", []byte("world"))
},
src: "hello",
existing: []byte("wo"),
want: []byte("world"),
wantCode: http.StatusPartialContent,
},
{
desc: "success-resume-ignore-matching-size",
before: func(*testing.T) {
srv.addSigned("hello", []byte("world"))
},
src: "hello",
existing: []byte("WORLD"), // same size as world
want: []byte("world"),
wantCode: http.StatusOK,
},
{
desc: "success-resume-ignore-existing-too-big",
before: func(*testing.T) {
srv.addSigned("hello", []byte("world"))
},
src: "hello",
existing: []byte("longer-than-world"), // len greater than len("world")
want: []byte("world"),
wantCode: http.StatusOK,
},
{
desc: "resume-corrupt",
before: func(*testing.T) {
srv.addSigned("hello", []byte("world"))
},
src: "hello",
existing: []byte("WO"), // previous download was bad
wantErr: true,
},
{ {
desc: "no signature", desc: "no signature",
before: func(*testing.T) { before: func(*testing.T) {
@ -94,10 +139,17 @@ func TestDownload(t *testing.T) {
srv.reset() srv.reset()
tt.before(t) tt.before(t)
dst := filepath.Join(t.TempDir(), tt.src) var gotCodes []int
t.Cleanup(func() { tstest.Replace(t, &onResponseForTest, func(res *http.Response) {
os.Remove(dst) gotCodes = append(gotCodes, res.StatusCode)
}) })
dst := filepath.Join(t.TempDir(), tt.src)
if len(tt.existing) > 0 {
if err := os.WriteFile(dst+".unverified", tt.existing, 0644); err != nil {
t.Fatal(err)
}
}
err := c.Download(context.Background(), tt.src, dst) err := c.Download(context.Background(), tt.src, dst)
if err != nil { if err != nil {
if tt.wantErr { if tt.wantErr {
@ -107,6 +159,11 @@ func TestDownload(t *testing.T) {
} }
if tt.wantErr { if tt.wantErr {
t.Fatalf("Download(%q) succeeded, expected an error", tt.src) t.Fatalf("Download(%q) succeeded, expected an error", tt.src)
} else {
wantCodes := []int{cmp.Or(tt.wantCode, http.StatusOK)}
if !reflect.DeepEqual(gotCodes, wantCodes) {
t.Errorf("HTTP response status code = %v; want %v", gotCodes, wantCodes)
}
} }
got, err := os.ReadFile(dst) got, err := os.ReadFile(dst)
if err != nil { if err != nil {
@ -486,7 +543,7 @@ func (s *testServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r) http.NotFound(w, r)
return return
} }
w.Write(data) http.ServeContent(w, r, path, time.Time{}, bytes.NewReader(data))
} }
func (s *testServer) addSigned(name string, data []byte) { func (s *testServer) addSigned(name string, data []byte) {