mirror of
				https://github.com/tailscale/tailscale.git
				synced 2025-10-25 10:09:17 +00:00 
			
		
		
		
	clientupdate/distsign: resume partial downloads
Fixes #12573 Change-Id: If2684a2987ec95b893ba9b7c71dc0cb850765b18 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
		| @@ -229,8 +229,6 @@ func (c *Client) Download(ctx context.Context, srcPath, dstPath string) error { | ||||
| 	c.logf("Downloading %q", sigURL) | ||||
| 	sig, err := fetch(sigURL, signatureSizeLimit) | ||||
| 	if err != nil { | ||||
| 		// Best-effort clean up of downloaded package. | ||||
| 		os.Remove(dstPathUnverified) | ||||
| 		return err | ||||
| 	} | ||||
| 	msg := binary.LittleEndian.AppendUint64(hash, uint64(len)) | ||||
| @@ -326,6 +324,8 @@ func fetch(url string, limit int64) ([]byte, error) { | ||||
| 	return io.ReadAll(io.LimitReader(resp.Body, limit)) | ||||
| } | ||||
| 
 | ||||
| var onResponseForTest = func(*http.Response) {} | ||||
| 
 | ||||
| // download writes the response body of url into a local file at dst, up to | ||||
| // limit bytes. On success, the returned value is a BLAKE2s hash of the file. | ||||
| func (c *Client) download(ctx context.Context, url, dst string, limit int64) ([]byte, int64, error) { | ||||
| @@ -349,31 +349,61 @@ func (c *Client) download(ctx context.Context, url, dst string, limit int64) ([] | ||||
| 		return nil, 0, fmt.Errorf("HEAD %q: unexpected Content-Length %v", url, res.ContentLength) | ||||
| 	} | ||||
| 	c.logf("Download size: %v", res.ContentLength) | ||||
| 	h := NewPackageHash() | ||||
| 
 | ||||
| 	dlReq := must.Get(http.NewRequestWithContext(ctx, httpm.GET, url, nil)) | ||||
| 
 | ||||
| 	var skip int64 | ||||
| 	if fi, err := os.Stat(dst); err == nil { | ||||
| 		if fi.Size() == res.ContentLength { | ||||
| 			// Assume it got corrupted previously and the earlier attempt failed | ||||
| 			// the checksum. Delete it and start over. | ||||
| 			if err := os.Remove(dst); err != nil { | ||||
| 				return nil, 0, fmt.Errorf("error deleting previous assumed-bad download: %w", err) | ||||
| 			} | ||||
| 		} else if fi.Size() > 0 && fi.Size() < res.ContentLength { | ||||
| 			c.logf("Existing file size: %v", fi.Size()) | ||||
| 			skip = fi.Size() | ||||
| 			dlReq.Header.Add("Range", fmt.Sprintf("bytes=%d-", skip)) | ||||
| 		} | ||||
| 	} | ||||
| 
 | ||||
| 	dlRes, err := hc.Do(dlReq) | ||||
| 	if err != nil { | ||||
| 		return nil, 0, err | ||||
| 	} | ||||
| 	onResponseForTest(dlRes) | ||||
| 	defer dlRes.Body.Close() | ||||
| 	// TODO(bradfitz): resume from existing partial file on disk | ||||
| 	if dlRes.StatusCode != http.StatusOK { | ||||
| 
 | ||||
| 	var of *os.File | ||||
| 	wantResponseLength := res.ContentLength | ||||
| 	switch dlRes.StatusCode { | ||||
| 	case http.StatusOK: | ||||
| 		if skip > 0 { | ||||
| 			os.Remove(dst) // best effort; the Create will fail anyway if this would | ||||
| 		} | ||||
| 		of, err = os.Create(dst) | ||||
| 	case http.StatusPartialContent: | ||||
| 		wantResponseLength = res.ContentLength - skip | ||||
| 		of, err = os.OpenFile(dst, os.O_CREATE|os.O_APPEND|os.O_RDWR, 0644) | ||||
| 		if err == nil { | ||||
| 			// Re-hash the previously downloaded chunk. | ||||
| 			_, err = io.Copy(h, io.NewSectionReader(of, 0, skip)) | ||||
| 		} | ||||
| 	default: | ||||
| 		return nil, 0, fmt.Errorf("GET %q: %v", url, dlRes.Status) | ||||
| 	} | ||||
| 
 | ||||
| 	of, err := os.Create(dst) | ||||
| 	if err != nil { | ||||
| 		return nil, 0, err | ||||
| 	} | ||||
| 	defer of.Close() | ||||
| 	pw := &progressWriter{total: res.ContentLength, logf: c.logf} | ||||
| 	h := NewPackageHash() | ||||
| 	pw := &progressWriter{total: res.ContentLength, done: skip, logf: c.logf} | ||||
| 	n, err := io.Copy(io.MultiWriter(of, h, pw), io.LimitReader(dlRes.Body, limit)) | ||||
| 	if err != nil { | ||||
| 		return nil, n, err | ||||
| 	} | ||||
| 	if n != res.ContentLength { | ||||
| 		return nil, n, fmt.Errorf("GET %q: downloaded %v, want %v", url, n, res.ContentLength) | ||||
| 	if n != wantResponseLength { | ||||
| 		return nil, n, fmt.Errorf("GET %q: downloaded %v, want %v", url, n, wantResponseLength) | ||||
| 	} | ||||
| 	if err := dlRes.Body.Close(); err != nil { | ||||
| 		return nil, n, err | ||||
|   | ||||
| @@ -5,6 +5,7 @@ package distsign | ||||
| 
 | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"cmp" | ||||
| 	"context" | ||||
| 	"crypto/ed25519" | ||||
| 	"net/http" | ||||
| @@ -12,10 +13,13 @@ import ( | ||||
| 	"net/url" | ||||
| 	"os" | ||||
| 	"path/filepath" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 
 | ||||
| 	"golang.org/x/crypto/blake2s" | ||||
| 	"tailscale.com/tstest" | ||||
| ) | ||||
| 
 | ||||
| func TestDownload(t *testing.T) { | ||||
| @@ -25,9 +29,11 @@ func TestDownload(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		desc     string | ||||
| 		before   func(*testing.T) | ||||
| 		existing []byte // optional existing data on disk to resume from | ||||
| 		src      string | ||||
| 		want     []byte | ||||
| 		wantErr  bool | ||||
| 		wantCode int // HTTP status code of download to expect; 0 means http.StatusOK | ||||
| 	}{ | ||||
| 		{ | ||||
| 			desc:    "missing file", | ||||
| @@ -43,6 +49,45 @@ func TestDownload(t *testing.T) { | ||||
| 			src:  "hello", | ||||
| 			want: []byte("world"), | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "success-resume", | ||||
| 			before: func(*testing.T) { | ||||
| 				srv.addSigned("hello", []byte("world")) | ||||
| 			}, | ||||
| 			src:      "hello", | ||||
| 			existing: []byte("wo"), | ||||
| 			want:     []byte("world"), | ||||
| 			wantCode: http.StatusPartialContent, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "success-resume-ignore-matching-size", | ||||
| 			before: func(*testing.T) { | ||||
| 				srv.addSigned("hello", []byte("world")) | ||||
| 			}, | ||||
| 			src:      "hello", | ||||
| 			existing: []byte("WORLD"), // same size as world | ||||
| 			want:     []byte("world"), | ||||
| 			wantCode: http.StatusOK, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "success-resume-ignore-existing-too-big", | ||||
| 			before: func(*testing.T) { | ||||
| 				srv.addSigned("hello", []byte("world")) | ||||
| 			}, | ||||
| 			src:      "hello", | ||||
| 			existing: []byte("longer-than-world"), // len greater than len("world") | ||||
| 			want:     []byte("world"), | ||||
| 			wantCode: http.StatusOK, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "resume-corrupt", | ||||
| 			before: func(*testing.T) { | ||||
| 				srv.addSigned("hello", []byte("world")) | ||||
| 			}, | ||||
| 			src:      "hello", | ||||
| 			existing: []byte("WO"), // previous download was bad | ||||
| 			wantErr:  true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			desc: "no signature", | ||||
| 			before: func(*testing.T) { | ||||
| @@ -94,10 +139,17 @@ func TestDownload(t *testing.T) { | ||||
| 			srv.reset() | ||||
| 			tt.before(t) | ||||
| 
 | ||||
| 			dst := filepath.Join(t.TempDir(), tt.src) | ||||
| 			t.Cleanup(func() { | ||||
| 				os.Remove(dst) | ||||
| 			var gotCodes []int | ||||
| 			tstest.Replace(t, &onResponseForTest, func(res *http.Response) { | ||||
| 				gotCodes = append(gotCodes, res.StatusCode) | ||||
| 			}) | ||||
| 
 | ||||
| 			dst := filepath.Join(t.TempDir(), tt.src) | ||||
| 			if len(tt.existing) > 0 { | ||||
| 				if err := os.WriteFile(dst+".unverified", tt.existing, 0644); err != nil { | ||||
| 					t.Fatal(err) | ||||
| 				} | ||||
| 			} | ||||
| 			err := c.Download(context.Background(), tt.src, dst) | ||||
| 			if err != nil { | ||||
| 				if tt.wantErr { | ||||
| @@ -107,6 +159,11 @@ func TestDownload(t *testing.T) { | ||||
| 			} | ||||
| 			if tt.wantErr { | ||||
| 				t.Fatalf("Download(%q) succeeded, expected an error", tt.src) | ||||
| 			} else { | ||||
| 				wantCodes := []int{cmp.Or(tt.wantCode, http.StatusOK)} | ||||
| 				if !reflect.DeepEqual(gotCodes, wantCodes) { | ||||
| 					t.Errorf("HTTP response status code = %v; want %v", gotCodes, wantCodes) | ||||
| 				} | ||||
| 			} | ||||
| 			got, err := os.ReadFile(dst) | ||||
| 			if err != nil { | ||||
| @@ -486,7 +543,7 @@ func (s *testServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { | ||||
| 		http.NotFound(w, r) | ||||
| 		return | ||||
| 	} | ||||
| 	w.Write(data) | ||||
| 	http.ServeContent(w, r, path, time.Time{}, bytes.NewReader(data)) | ||||
| } | ||||
| 
 | ||||
| func (s *testServer) addSigned(name string, data []byte) { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Brad Fitzpatrick
					Brad Fitzpatrick