mirror of
https://github.com/tailscale/tailscale.git
synced 2025-04-21 06:01:42 +00:00
clientupdate/distsign: resume partial downloads
Fixes #12573 Change-Id: If2684a2987ec95b893ba9b7c71dc0cb850765b18 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
parent
5ec01bf3ce
commit
dc38eaf551
@ -229,8 +229,6 @@ func (c *Client) Download(ctx context.Context, srcPath, dstPath string) error {
|
|||||||
c.logf("Downloading %q", sigURL)
|
c.logf("Downloading %q", sigURL)
|
||||||
sig, err := fetch(sigURL, signatureSizeLimit)
|
sig, err := fetch(sigURL, signatureSizeLimit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Best-effort clean up of downloaded package.
|
|
||||||
os.Remove(dstPathUnverified)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
msg := binary.LittleEndian.AppendUint64(hash, uint64(len))
|
msg := binary.LittleEndian.AppendUint64(hash, uint64(len))
|
||||||
@ -326,6 +324,8 @@ func fetch(url string, limit int64) ([]byte, error) {
|
|||||||
return io.ReadAll(io.LimitReader(resp.Body, limit))
|
return io.ReadAll(io.LimitReader(resp.Body, limit))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var onResponseForTest = func(*http.Response) {}
|
||||||
|
|
||||||
// download writes the response body of url into a local file at dst, up to
|
// download writes the response body of url into a local file at dst, up to
|
||||||
// limit bytes. On success, the returned value is a BLAKE2s hash of the file.
|
// limit bytes. On success, the returned value is a BLAKE2s hash of the file.
|
||||||
func (c *Client) download(ctx context.Context, url, dst string, limit int64) ([]byte, int64, error) {
|
func (c *Client) download(ctx context.Context, url, dst string, limit int64) ([]byte, int64, error) {
|
||||||
@ -349,31 +349,61 @@ func (c *Client) download(ctx context.Context, url, dst string, limit int64) ([]
|
|||||||
return nil, 0, fmt.Errorf("HEAD %q: unexpected Content-Length %v", url, res.ContentLength)
|
return nil, 0, fmt.Errorf("HEAD %q: unexpected Content-Length %v", url, res.ContentLength)
|
||||||
}
|
}
|
||||||
c.logf("Download size: %v", res.ContentLength)
|
c.logf("Download size: %v", res.ContentLength)
|
||||||
|
h := NewPackageHash()
|
||||||
|
|
||||||
dlReq := must.Get(http.NewRequestWithContext(ctx, httpm.GET, url, nil))
|
dlReq := must.Get(http.NewRequestWithContext(ctx, httpm.GET, url, nil))
|
||||||
|
|
||||||
|
var skip int64
|
||||||
|
if fi, err := os.Stat(dst); err == nil {
|
||||||
|
if fi.Size() == res.ContentLength {
|
||||||
|
// Assume it got corrupted previously and the earlier attempt failed
|
||||||
|
// the checksum. Delete it and start over.
|
||||||
|
if err := os.Remove(dst); err != nil {
|
||||||
|
return nil, 0, fmt.Errorf("error deleting previous assumed-bad download: %w", err)
|
||||||
|
}
|
||||||
|
} else if fi.Size() > 0 && fi.Size() < res.ContentLength {
|
||||||
|
c.logf("Existing file size: %v", fi.Size())
|
||||||
|
skip = fi.Size()
|
||||||
|
dlReq.Header.Add("Range", fmt.Sprintf("bytes=%d-", skip))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
dlRes, err := hc.Do(dlReq)
|
dlRes, err := hc.Do(dlReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
onResponseForTest(dlRes)
|
||||||
defer dlRes.Body.Close()
|
defer dlRes.Body.Close()
|
||||||
// TODO(bradfitz): resume from existing partial file on disk
|
|
||||||
if dlRes.StatusCode != http.StatusOK {
|
var of *os.File
|
||||||
|
wantResponseLength := res.ContentLength
|
||||||
|
switch dlRes.StatusCode {
|
||||||
|
case http.StatusOK:
|
||||||
|
if skip > 0 {
|
||||||
|
os.Remove(dst) // best effort; the Create will fail anyway if this would
|
||||||
|
}
|
||||||
|
of, err = os.Create(dst)
|
||||||
|
case http.StatusPartialContent:
|
||||||
|
wantResponseLength = res.ContentLength - skip
|
||||||
|
of, err = os.OpenFile(dst, os.O_CREATE|os.O_APPEND|os.O_RDWR, 0644)
|
||||||
|
if err == nil {
|
||||||
|
// Re-hash the previously downloaded chunk.
|
||||||
|
_, err = io.Copy(h, io.NewSectionReader(of, 0, skip))
|
||||||
|
}
|
||||||
|
default:
|
||||||
return nil, 0, fmt.Errorf("GET %q: %v", url, dlRes.Status)
|
return nil, 0, fmt.Errorf("GET %q: %v", url, dlRes.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
of, err := os.Create(dst)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
defer of.Close()
|
defer of.Close()
|
||||||
pw := &progressWriter{total: res.ContentLength, logf: c.logf}
|
pw := &progressWriter{total: res.ContentLength, done: skip, logf: c.logf}
|
||||||
h := NewPackageHash()
|
|
||||||
n, err := io.Copy(io.MultiWriter(of, h, pw), io.LimitReader(dlRes.Body, limit))
|
n, err := io.Copy(io.MultiWriter(of, h, pw), io.LimitReader(dlRes.Body, limit))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, n, err
|
return nil, n, err
|
||||||
}
|
}
|
||||||
if n != res.ContentLength {
|
if n != wantResponseLength {
|
||||||
return nil, n, fmt.Errorf("GET %q: downloaded %v, want %v", url, n, res.ContentLength)
|
return nil, n, fmt.Errorf("GET %q: downloaded %v, want %v", url, n, wantResponseLength)
|
||||||
}
|
}
|
||||||
if err := dlRes.Body.Close(); err != nil {
|
if err := dlRes.Body.Close(); err != nil {
|
||||||
return nil, n, err
|
return nil, n, err
|
||||||
|
@ -5,6 +5,7 @@ package distsign
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"cmp"
|
||||||
"context"
|
"context"
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -12,10 +13,13 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"golang.org/x/crypto/blake2s"
|
"golang.org/x/crypto/blake2s"
|
||||||
|
"tailscale.com/tstest"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDownload(t *testing.T) {
|
func TestDownload(t *testing.T) {
|
||||||
@ -25,9 +29,11 @@ func TestDownload(t *testing.T) {
|
|||||||
tests := []struct {
|
tests := []struct {
|
||||||
desc string
|
desc string
|
||||||
before func(*testing.T)
|
before func(*testing.T)
|
||||||
|
existing []byte // optional existing data on disk to resume from
|
||||||
src string
|
src string
|
||||||
want []byte
|
want []byte
|
||||||
wantErr bool
|
wantErr bool
|
||||||
|
wantCode int // HTTP status code of download to expect; 0 means http.StatusOK
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
desc: "missing file",
|
desc: "missing file",
|
||||||
@ -43,6 +49,45 @@ func TestDownload(t *testing.T) {
|
|||||||
src: "hello",
|
src: "hello",
|
||||||
want: []byte("world"),
|
want: []byte("world"),
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
desc: "success-resume",
|
||||||
|
before: func(*testing.T) {
|
||||||
|
srv.addSigned("hello", []byte("world"))
|
||||||
|
},
|
||||||
|
src: "hello",
|
||||||
|
existing: []byte("wo"),
|
||||||
|
want: []byte("world"),
|
||||||
|
wantCode: http.StatusPartialContent,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "success-resume-ignore-matching-size",
|
||||||
|
before: func(*testing.T) {
|
||||||
|
srv.addSigned("hello", []byte("world"))
|
||||||
|
},
|
||||||
|
src: "hello",
|
||||||
|
existing: []byte("WORLD"), // same size as world
|
||||||
|
want: []byte("world"),
|
||||||
|
wantCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "success-resume-ignore-existing-too-big",
|
||||||
|
before: func(*testing.T) {
|
||||||
|
srv.addSigned("hello", []byte("world"))
|
||||||
|
},
|
||||||
|
src: "hello",
|
||||||
|
existing: []byte("longer-than-world"), // len greater than len("world")
|
||||||
|
want: []byte("world"),
|
||||||
|
wantCode: http.StatusOK,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "resume-corrupt",
|
||||||
|
before: func(*testing.T) {
|
||||||
|
srv.addSigned("hello", []byte("world"))
|
||||||
|
},
|
||||||
|
src: "hello",
|
||||||
|
existing: []byte("WO"), // previous download was bad
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
desc: "no signature",
|
desc: "no signature",
|
||||||
before: func(*testing.T) {
|
before: func(*testing.T) {
|
||||||
@ -94,10 +139,17 @@ func TestDownload(t *testing.T) {
|
|||||||
srv.reset()
|
srv.reset()
|
||||||
tt.before(t)
|
tt.before(t)
|
||||||
|
|
||||||
dst := filepath.Join(t.TempDir(), tt.src)
|
var gotCodes []int
|
||||||
t.Cleanup(func() {
|
tstest.Replace(t, &onResponseForTest, func(res *http.Response) {
|
||||||
os.Remove(dst)
|
gotCodes = append(gotCodes, res.StatusCode)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
dst := filepath.Join(t.TempDir(), tt.src)
|
||||||
|
if len(tt.existing) > 0 {
|
||||||
|
if err := os.WriteFile(dst+".unverified", tt.existing, 0644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
err := c.Download(context.Background(), tt.src, dst)
|
err := c.Download(context.Background(), tt.src, dst)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if tt.wantErr {
|
if tt.wantErr {
|
||||||
@ -107,6 +159,11 @@ func TestDownload(t *testing.T) {
|
|||||||
}
|
}
|
||||||
if tt.wantErr {
|
if tt.wantErr {
|
||||||
t.Fatalf("Download(%q) succeeded, expected an error", tt.src)
|
t.Fatalf("Download(%q) succeeded, expected an error", tt.src)
|
||||||
|
} else {
|
||||||
|
wantCodes := []int{cmp.Or(tt.wantCode, http.StatusOK)}
|
||||||
|
if !reflect.DeepEqual(gotCodes, wantCodes) {
|
||||||
|
t.Errorf("HTTP response status code = %v; want %v", gotCodes, wantCodes)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
got, err := os.ReadFile(dst)
|
got, err := os.ReadFile(dst)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -486,7 +543,7 @@ func (s *testServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
http.NotFound(w, r)
|
http.NotFound(w, r)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
w.Write(data)
|
http.ServeContent(w, r, path, time.Time{}, bytes.NewReader(data))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *testServer) addSigned(name string, data []byte) {
|
func (s *testServer) addSigned(name string, data []byte) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user