clientupdate: download SPK and MSI packages with distsign (#9115)

Reimplement `downloadURLToFile` using `distsign.Download` and move all
of the progress reporting logic over there.

Updates #6995
Updates #755

Signed-off-by: Andrew Lytvynov <awly@tailscale.com>
This commit is contained in:
Andrew Lytvynov
2023-08-28 14:48:33 -06:00
committed by GitHub
parent 0c6fe94cf4
commit 8d2eaa1956
4 changed files with 117 additions and 140 deletions

View File

@@ -38,6 +38,7 @@
package distsign
import (
"context"
"crypto/ed25519"
"crypto/rand"
"encoding/binary"
@@ -46,12 +47,17 @@ import (
"fmt"
"hash"
"io"
"log"
"net/http"
"net/url"
"os"
"time"
"github.com/hdevalence/ed25519consensus"
"golang.org/x/crypto/blake2s"
"tailscale.com/net/tshttpproxy"
"tailscale.com/types/logger"
"tailscale.com/util/must"
)
const (
@@ -177,18 +183,22 @@ func (ph *PackageHash) Len() int64 { return ph.len }
// Client downloads and validates files from a distribution server.
type Client struct {
logf logger.Logf
roots []ed25519.PublicKey
pkgsAddr *url.URL
}
// NewClient returns a new client for distribution server located at pkgsAddr,
// and uses embedded root keys from the roots/ subdirectory of this package.
func NewClient(pkgsAddr string) (*Client, error) {
func NewClient(logf logger.Logf, pkgsAddr string) (*Client, error) {
if logf == nil {
logf = log.Printf
}
u, err := url.Parse(pkgsAddr)
if err != nil {
return nil, fmt.Errorf("invalid pkgsAddr %q: %w", pkgsAddr, err)
}
return &Client{roots: roots(), pkgsAddr: u}, nil
return &Client{logf: logf, roots: roots(), pkgsAddr: u}, nil
}
func (c *Client) url(path string) string {
@@ -199,7 +209,7 @@ func (c *Client) url(path string) string {
// The file is downloaded to dstPath and its signature is validated using the
// embedded root keys. Download returns an error if anything goes wrong with
// the actual file download or with signature validation.
func (c *Client) Download(srcPath, dstPath string) error {
func (c *Client) Download(ctx context.Context, srcPath, dstPath string) error {
// Always fetch a fresh signing key.
sigPub, err := c.signingKeys()
if err != nil {
@@ -209,11 +219,13 @@ func (c *Client) Download(srcPath, dstPath string) error {
srcURL := c.url(srcPath)
sigURL := srcURL + ".sig"
c.logf("Downloading %q", srcURL)
dstPathUnverified := dstPath + ".unverified"
hash, len, err := download(srcURL, dstPathUnverified, downloadSizeLimit)
hash, len, err := c.download(ctx, srcURL, dstPathUnverified, downloadSizeLimit)
if err != nil {
return err
}
c.logf("Downloading %q", sigURL)
sig, err := fetch(sigURL, signatureSizeLimit)
if err != nil {
// Best-effort clean up of downloaded package.
@@ -226,6 +238,7 @@ func (c *Client) Download(srcPath, dstPath string) error {
os.Remove(dstPathUnverified)
return fmt.Errorf("signature %q for file %q does not validate with the current release signing key; either you are under attack, or attempting to download an old version of Tailscale which was signed with an older signing key", sigURL, srcURL)
}
c.logf("Signature OK")
if err := os.Rename(dstPathUnverified, dstPath); err != nil {
return fmt.Errorf("failed to move %q to %q after signature validation", dstPathUnverified, dstPath)
@@ -272,32 +285,84 @@ func fetch(url string, limit int64) ([]byte, error) {
// download writes the response body of url into a local file at dst, up to
// limit bytes. On success, the returned value is a BLAKE2s hash of the file.
func download(url, dst string, limit int64) ([]byte, int64, error) {
resp, err := http.Get(url)
func (c *Client) download(ctx context.Context, url, dst string, limit int64) ([]byte, int64, error) {
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.Proxy = tshttpproxy.ProxyFromEnvironment
defer tr.CloseIdleConnections()
hc := &http.Client{Transport: tr}
quickCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
headReq := must.Get(http.NewRequestWithContext(quickCtx, http.MethodHead, url, nil))
res, err := hc.Do(headReq)
if err != nil {
return nil, 0, err
}
defer resp.Body.Close()
if res.StatusCode != http.StatusOK {
return nil, 0, fmt.Errorf("HEAD %q: %v", url, res.Status)
}
if res.ContentLength <= 0 {
return nil, 0, fmt.Errorf("HEAD %q: unexpected Content-Length %v", url, res.ContentLength)
}
c.logf("Download size: %v", res.ContentLength)
dlReq := must.Get(http.NewRequestWithContext(ctx, http.MethodGet, url, nil))
dlRes, err := hc.Do(dlReq)
if err != nil {
return nil, 0, err
}
defer dlRes.Body.Close()
// TODO(bradfitz): resume from existing partial file on disk
if dlRes.StatusCode != http.StatusOK {
return nil, 0, fmt.Errorf("GET %q: %v", url, dlRes.Status)
}
of, err := os.Create(dst)
if err != nil {
return nil, 0, err
}
defer of.Close()
pw := &progressWriter{total: res.ContentLength, logf: c.logf}
h := NewPackageHash()
r := io.TeeReader(io.LimitReader(resp.Body, limit), h)
f, err := os.Create(dst)
n, err := io.Copy(io.MultiWriter(of, h, pw), io.LimitReader(dlRes.Body, limit))
if err != nil {
return nil, 0, err
return nil, n, err
}
defer f.Close()
if _, err := io.Copy(f, r); err != nil {
return nil, 0, err
if n != res.ContentLength {
return nil, n, fmt.Errorf("GET %q: downloaded %v, want %v", url, n, res.ContentLength)
}
if err := f.Close(); err != nil {
return nil, 0, err
if err := dlRes.Body.Close(); err != nil {
return nil, n, err
}
if err := of.Close(); err != nil {
return nil, n, err
}
pw.print()
return h.Sum(nil), h.Len(), nil
}
type progressWriter struct {
done int64
total int64
lastPrint time.Time
logf logger.Logf
}
func (pw *progressWriter) Write(p []byte) (n int, err error) {
pw.done += int64(len(p))
if time.Since(pw.lastPrint) > 2*time.Second {
pw.print()
}
return len(p), nil
}
func (pw *progressWriter) print() {
pw.lastPrint = time.Now()
pw.logf("Downloaded %v/%v (%.1f%%)", pw.done, pw.total, float64(pw.done)/float64(pw.total)*100)
}
func parsePrivateKey(data []byte, typeTag string) (ed25519.PrivateKey, error) {
b, rest := pem.Decode(data)
if b == nil {