diff --git a/clientupdate/distsign/distsign.go b/clientupdate/distsign/distsign.go index 3dad1f4c4..b48321f8f 100644 --- a/clientupdate/distsign/distsign.go +++ b/clientupdate/distsign/distsign.go @@ -247,6 +247,48 @@ func (c *Client) Download(ctx context.Context, srcPath, dstPath string) error { return nil } +// ValidateLocalBinary fetches the latest signature associated with the binary +// at srcURLPath and uses it to validate the file located on disk via +// localFilePath. ValidateLocalBinary returns an error if anything goes wrong +// with the signature download or with signature validation. +func (c *Client) ValidateLocalBinary(srcURLPath, localFilePath string) error { + // Always fetch a fresh signing key. + sigPub, err := c.signingKeys() + if err != nil { + return err + } + + srcURL := c.url(srcURLPath) + sigURL := srcURL + ".sig" + + localFile, err := os.Open(localFilePath) + if err != nil { + return err + } + defer localFile.Close() + + h := NewPackageHash() + _, err = io.Copy(h, localFile) + if err != nil { + return err + } + hash, hashLen := h.Sum(nil), h.Len() + + c.logf("Downloading %q", sigURL) + sig, err := fetch(sigURL, signatureSizeLimit) + if err != nil { + return err + } + + msg := binary.LittleEndian.AppendUint64(hash, uint64(hashLen)) + if !VerifyAny(sigPub, msg, sig) { + return fmt.Errorf("signature %q for file %q does not validate with the current release signing key; either you are under attack, or attempting to download an old version of Tailscale which was signed with an older signing key", sigURL, localFilePath) + } + c.logf("Signature OK") + + return nil +} + // signingKeys fetches current signing keys from the server and validates them // against the roots. Should be called before validation of any downloaded file // to get the fresh keys. diff --git a/clientupdate/distsign/distsign_test.go b/clientupdate/distsign/distsign_test.go index c7f5f023d..c00cc6c15 100644 --- a/clientupdate/distsign/distsign_test.go +++ b/clientupdate/distsign/distsign_test.go @@ -119,6 +119,124 @@ func TestDownload(t *testing.T) { } } +func TestValidateLocalBinary(t *testing.T) { + srv := newTestServer(t) + c := srv.client(t) + + tests := []struct { + desc string + before func(*testing.T) + src string + wantErr bool + }{ + { + desc: "missing file", + before: func(*testing.T) {}, + src: "hello", + wantErr: true, + }, + { + desc: "success", + before: func(*testing.T) { + srv.addSigned("hello", []byte("world")) + }, + src: "hello", + }, + { + desc: "contents changed", + before: func(*testing.T) { + srv.addSigned("hello", []byte("new world")) + }, + src: "hello", + wantErr: true, + }, + { + desc: "no signature", + before: func(*testing.T) { + srv.add("hello", []byte("world")) + }, + src: "hello", + wantErr: true, + }, + { + desc: "bad signature", + before: func(*testing.T) { + srv.add("hello", []byte("world")) + srv.add("hello.sig", []byte("potato")) + }, + src: "hello", + wantErr: true, + }, + { + desc: "signed with untrusted key", + before: func(t *testing.T) { + srv.add("hello", []byte("world")) + srv.add("hello.sig", newSigningKeyPair(t).sign([]byte("world"))) + }, + src: "hello", + wantErr: true, + }, + { + desc: "signed with root key", + before: func(t *testing.T) { + srv.add("hello", []byte("world")) + srv.add("hello.sig", ed25519.Sign(srv.roots[0].k, []byte("world"))) + }, + src: "hello", + wantErr: true, + }, + { + desc: "bad signing key signature", + before: func(t *testing.T) { + srv.add("distsign.pub.sig", []byte("potato")) + srv.addSigned("hello", []byte("world")) + }, + src: "hello", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + srv.reset() + + // First just do a successful Download. + want := []byte("world") + srv.addSigned("hello", want) + dst := filepath.Join(t.TempDir(), tt.src) + t.Cleanup(func() { + os.Remove(dst) + }) + err := c.Download(context.Background(), tt.src, dst) + if err != nil { + t.Fatalf("unexpected error from Download(%q): %v", tt.src, err) + } + got, err := os.ReadFile(dst) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(want, got) { + t.Errorf("Download(%q): got %q, want %q", tt.src, got, want) + } + + // Now we reset srv with the test case and validate against the local dst. + srv.reset() + tt.before(t) + + err = c.ValidateLocalBinary(tt.src, dst) + if err != nil { + if tt.wantErr { + return + } + t.Fatalf("unexpected error from ValidateLocalBinary(%q): %v", tt.src, err) + } + if tt.wantErr { + t.Fatalf("ValidateLocalBinary(%q) succeeded, expected an error", tt.src) + } + }) + } +} + func TestRotateRoot(t *testing.T) { srv := newTestServer(t) c1 := srv.client(t)