mirror of
				https://github.com/tailscale/tailscale.git
				synced 2025-10-24 17:48:57 +00:00 
			
		
		
		
	clientupdate/distsign: resume partial downloads
Fixes #12573 Change-Id: If2684a2987ec95b893ba9b7c71dc0cb850765b18 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
		| @@ -229,8 +229,6 @@ func (c *Client) Download(ctx context.Context, srcPath, dstPath string) error { | |||||||
| 	c.logf("Downloading %q", sigURL) | 	c.logf("Downloading %q", sigURL) | ||||||
| 	sig, err := fetch(sigURL, signatureSizeLimit) | 	sig, err := fetch(sigURL, signatureSizeLimit) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		// Best-effort clean up of downloaded package. |  | ||||||
| 		os.Remove(dstPathUnverified) |  | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	msg := binary.LittleEndian.AppendUint64(hash, uint64(len)) | 	msg := binary.LittleEndian.AppendUint64(hash, uint64(len)) | ||||||
| @@ -326,6 +324,8 @@ func fetch(url string, limit int64) ([]byte, error) { | |||||||
| 	return io.ReadAll(io.LimitReader(resp.Body, limit)) | 	return io.ReadAll(io.LimitReader(resp.Body, limit)) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | var onResponseForTest = func(*http.Response) {} | ||||||
|  | 
 | ||||||
| // download writes the response body of url into a local file at dst, up to | // download writes the response body of url into a local file at dst, up to | ||||||
| // limit bytes. On success, the returned value is a BLAKE2s hash of the file. | // limit bytes. On success, the returned value is a BLAKE2s hash of the file. | ||||||
| func (c *Client) download(ctx context.Context, url, dst string, limit int64) ([]byte, int64, error) { | func (c *Client) download(ctx context.Context, url, dst string, limit int64) ([]byte, int64, error) { | ||||||
| @@ -349,31 +349,61 @@ func (c *Client) download(ctx context.Context, url, dst string, limit int64) ([] | |||||||
| 		return nil, 0, fmt.Errorf("HEAD %q: unexpected Content-Length %v", url, res.ContentLength) | 		return nil, 0, fmt.Errorf("HEAD %q: unexpected Content-Length %v", url, res.ContentLength) | ||||||
| 	} | 	} | ||||||
| 	c.logf("Download size: %v", res.ContentLength) | 	c.logf("Download size: %v", res.ContentLength) | ||||||
|  | 	h := NewPackageHash() | ||||||
| 
 | 
 | ||||||
| 	dlReq := must.Get(http.NewRequestWithContext(ctx, httpm.GET, url, nil)) | 	dlReq := must.Get(http.NewRequestWithContext(ctx, httpm.GET, url, nil)) | ||||||
|  | 
 | ||||||
|  | 	var skip int64 | ||||||
|  | 	if fi, err := os.Stat(dst); err == nil { | ||||||
|  | 		if fi.Size() == res.ContentLength { | ||||||
|  | 			// Assume it got corrupted previously and the earlier attempt failed | ||||||
|  | 			// the checksum. Delete it and start over. | ||||||
|  | 			if err := os.Remove(dst); err != nil { | ||||||
|  | 				return nil, 0, fmt.Errorf("error deleting previous assumed-bad download: %w", err) | ||||||
|  | 			} | ||||||
|  | 		} else if fi.Size() > 0 && fi.Size() < res.ContentLength { | ||||||
|  | 			c.logf("Existing file size: %v", fi.Size()) | ||||||
|  | 			skip = fi.Size() | ||||||
|  | 			dlReq.Header.Add("Range", fmt.Sprintf("bytes=%d-", skip)) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	dlRes, err := hc.Do(dlReq) | 	dlRes, err := hc.Do(dlReq) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, 0, err | 		return nil, 0, err | ||||||
| 	} | 	} | ||||||
|  | 	onResponseForTest(dlRes) | ||||||
| 	defer dlRes.Body.Close() | 	defer dlRes.Body.Close() | ||||||
| 	// TODO(bradfitz): resume from existing partial file on disk | 
 | ||||||
| 	if dlRes.StatusCode != http.StatusOK { | 	var of *os.File | ||||||
|  | 	wantResponseLength := res.ContentLength | ||||||
|  | 	switch dlRes.StatusCode { | ||||||
|  | 	case http.StatusOK: | ||||||
|  | 		if skip > 0 { | ||||||
|  | 			os.Remove(dst) // best effort; the Create will fail anyway if this would | ||||||
|  | 		} | ||||||
|  | 		of, err = os.Create(dst) | ||||||
|  | 	case http.StatusPartialContent: | ||||||
|  | 		wantResponseLength = res.ContentLength - skip | ||||||
|  | 		of, err = os.OpenFile(dst, os.O_CREATE|os.O_APPEND|os.O_RDWR, 0644) | ||||||
|  | 		if err == nil { | ||||||
|  | 			// Re-hash the previously downloaded chunk. | ||||||
|  | 			_, err = io.Copy(h, io.NewSectionReader(of, 0, skip)) | ||||||
|  | 		} | ||||||
|  | 	default: | ||||||
| 		return nil, 0, fmt.Errorf("GET %q: %v", url, dlRes.Status) | 		return nil, 0, fmt.Errorf("GET %q: %v", url, dlRes.Status) | ||||||
| 	} | 	} | ||||||
| 
 |  | ||||||
| 	of, err := os.Create(dst) |  | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, 0, err | 		return nil, 0, err | ||||||
| 	} | 	} | ||||||
| 	defer of.Close() | 	defer of.Close() | ||||||
| 	pw := &progressWriter{total: res.ContentLength, logf: c.logf} | 	pw := &progressWriter{total: res.ContentLength, done: skip, logf: c.logf} | ||||||
| 	h := NewPackageHash() |  | ||||||
| 	n, err := io.Copy(io.MultiWriter(of, h, pw), io.LimitReader(dlRes.Body, limit)) | 	n, err := io.Copy(io.MultiWriter(of, h, pw), io.LimitReader(dlRes.Body, limit)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, n, err | 		return nil, n, err | ||||||
| 	} | 	} | ||||||
| 	if n != res.ContentLength { | 	if n != wantResponseLength { | ||||||
| 		return nil, n, fmt.Errorf("GET %q: downloaded %v, want %v", url, n, res.ContentLength) | 		return nil, n, fmt.Errorf("GET %q: downloaded %v, want %v", url, n, wantResponseLength) | ||||||
| 	} | 	} | ||||||
| 	if err := dlRes.Body.Close(); err != nil { | 	if err := dlRes.Body.Close(); err != nil { | ||||||
| 		return nil, n, err | 		return nil, n, err | ||||||
|   | |||||||
| @@ -5,6 +5,7 @@ package distsign | |||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"bytes" | 	"bytes" | ||||||
|  | 	"cmp" | ||||||
| 	"context" | 	"context" | ||||||
| 	"crypto/ed25519" | 	"crypto/ed25519" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| @@ -12,10 +13,13 @@ import ( | |||||||
| 	"net/url" | 	"net/url" | ||||||
| 	"os" | 	"os" | ||||||
| 	"path/filepath" | 	"path/filepath" | ||||||
|  | 	"reflect" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"testing" | 	"testing" | ||||||
|  | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"golang.org/x/crypto/blake2s" | 	"golang.org/x/crypto/blake2s" | ||||||
|  | 	"tailscale.com/tstest" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestDownload(t *testing.T) { | func TestDownload(t *testing.T) { | ||||||
| @@ -25,9 +29,11 @@ func TestDownload(t *testing.T) { | |||||||
| 	tests := []struct { | 	tests := []struct { | ||||||
| 		desc     string | 		desc     string | ||||||
| 		before   func(*testing.T) | 		before   func(*testing.T) | ||||||
|  | 		existing []byte // optional existing data on disk to resume from | ||||||
| 		src      string | 		src      string | ||||||
| 		want     []byte | 		want     []byte | ||||||
| 		wantErr  bool | 		wantErr  bool | ||||||
|  | 		wantCode int // HTTP status code of download to expect; 0 means http.StatusOK | ||||||
| 	}{ | 	}{ | ||||||
| 		{ | 		{ | ||||||
| 			desc:    "missing file", | 			desc:    "missing file", | ||||||
| @@ -43,6 +49,45 @@ func TestDownload(t *testing.T) { | |||||||
| 			src:  "hello", | 			src:  "hello", | ||||||
| 			want: []byte("world"), | 			want: []byte("world"), | ||||||
| 		}, | 		}, | ||||||
|  | 		{ | ||||||
|  | 			desc: "success-resume", | ||||||
|  | 			before: func(*testing.T) { | ||||||
|  | 				srv.addSigned("hello", []byte("world")) | ||||||
|  | 			}, | ||||||
|  | 			src:      "hello", | ||||||
|  | 			existing: []byte("wo"), | ||||||
|  | 			want:     []byte("world"), | ||||||
|  | 			wantCode: http.StatusPartialContent, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			desc: "success-resume-ignore-matching-size", | ||||||
|  | 			before: func(*testing.T) { | ||||||
|  | 				srv.addSigned("hello", []byte("world")) | ||||||
|  | 			}, | ||||||
|  | 			src:      "hello", | ||||||
|  | 			existing: []byte("WORLD"), // same size as world | ||||||
|  | 			want:     []byte("world"), | ||||||
|  | 			wantCode: http.StatusOK, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			desc: "success-resume-ignore-existing-too-big", | ||||||
|  | 			before: func(*testing.T) { | ||||||
|  | 				srv.addSigned("hello", []byte("world")) | ||||||
|  | 			}, | ||||||
|  | 			src:      "hello", | ||||||
|  | 			existing: []byte("longer-than-world"), // len greater than len("world") | ||||||
|  | 			want:     []byte("world"), | ||||||
|  | 			wantCode: http.StatusOK, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			desc: "resume-corrupt", | ||||||
|  | 			before: func(*testing.T) { | ||||||
|  | 				srv.addSigned("hello", []byte("world")) | ||||||
|  | 			}, | ||||||
|  | 			src:      "hello", | ||||||
|  | 			existing: []byte("WO"), // previous download was bad | ||||||
|  | 			wantErr:  true, | ||||||
|  | 		}, | ||||||
| 		{ | 		{ | ||||||
| 			desc: "no signature", | 			desc: "no signature", | ||||||
| 			before: func(*testing.T) { | 			before: func(*testing.T) { | ||||||
| @@ -94,10 +139,17 @@ func TestDownload(t *testing.T) { | |||||||
| 			srv.reset() | 			srv.reset() | ||||||
| 			tt.before(t) | 			tt.before(t) | ||||||
| 
 | 
 | ||||||
| 			dst := filepath.Join(t.TempDir(), tt.src) | 			var gotCodes []int | ||||||
| 			t.Cleanup(func() { | 			tstest.Replace(t, &onResponseForTest, func(res *http.Response) { | ||||||
| 				os.Remove(dst) | 				gotCodes = append(gotCodes, res.StatusCode) | ||||||
| 			}) | 			}) | ||||||
|  | 
 | ||||||
|  | 			dst := filepath.Join(t.TempDir(), tt.src) | ||||||
|  | 			if len(tt.existing) > 0 { | ||||||
|  | 				if err := os.WriteFile(dst+".unverified", tt.existing, 0644); err != nil { | ||||||
|  | 					t.Fatal(err) | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
| 			err := c.Download(context.Background(), tt.src, dst) | 			err := c.Download(context.Background(), tt.src, dst) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| 				if tt.wantErr { | 				if tt.wantErr { | ||||||
| @@ -107,6 +159,11 @@ func TestDownload(t *testing.T) { | |||||||
| 			} | 			} | ||||||
| 			if tt.wantErr { | 			if tt.wantErr { | ||||||
| 				t.Fatalf("Download(%q) succeeded, expected an error", tt.src) | 				t.Fatalf("Download(%q) succeeded, expected an error", tt.src) | ||||||
|  | 			} else { | ||||||
|  | 				wantCodes := []int{cmp.Or(tt.wantCode, http.StatusOK)} | ||||||
|  | 				if !reflect.DeepEqual(gotCodes, wantCodes) { | ||||||
|  | 					t.Errorf("HTTP response status code = %v; want %v", gotCodes, wantCodes) | ||||||
|  | 				} | ||||||
| 			} | 			} | ||||||
| 			got, err := os.ReadFile(dst) | 			got, err := os.ReadFile(dst) | ||||||
| 			if err != nil { | 			if err != nil { | ||||||
| @@ -486,7 +543,7 @@ func (s *testServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { | |||||||
| 		http.NotFound(w, r) | 		http.NotFound(w, r) | ||||||
| 		return | 		return | ||||||
| 	} | 	} | ||||||
| 	w.Write(data) | 	http.ServeContent(w, r, path, time.Time{}, bytes.NewReader(data)) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (s *testServer) addSigned(name string, data []byte) { | func (s *testServer) addSigned(name string, data []byte) { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Brad Fitzpatrick
					Brad Fitzpatrick