mirror of
https://github.com/tailscale/tailscale.git
synced 2025-12-01 17:49:02 +00:00
feature/tpm: add swtpm-based integration tests
Add swtpm-based integration tests for the TPM functionality. The test suite checks whether swtpm v0.10.1 is available and if it's being run as root and if so runs a set of tests against swtpm instances of different configurations (1.2, 2.0, uninitialized, etc...) to validate the client behavior. Updates tailscale/corp#34174 Signed-off-by: Patrick O'Doherty <patrick@tailscale.com>
This commit is contained in:
462
feature/tpm/swtpm_test.go
Normal file
462
feature/tpm/swtpm_test.go
Normal file
@@ -0,0 +1,462 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build linux && ts_swtpm
|
||||
|
||||
package tpm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
crand "crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-tpm/tpm2"
|
||||
"github.com/google/go-tpm/tpm2/transport/linuxtpm"
|
||||
)
|
||||
|
||||
// swtpmBinary is the name of the swtpm executable to use
|
||||
const swtpmBinary = "swtpm"
|
||||
const example32ByteKey = "12345678901234567890123456789012"
|
||||
|
||||
type swtpm struct {
|
||||
dataDir string
|
||||
deviceName string
|
||||
devicePath string
|
||||
pidFile string
|
||||
opts *swtpmOptions
|
||||
t testing.TB
|
||||
}
|
||||
|
||||
func newSWTPM(t testing.TB, opts ...swtpmOption) *swtpm {
|
||||
t.Helper()
|
||||
options := &swtpmOptions{
|
||||
version: "2.0",
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
dataDir := t.TempDir()
|
||||
|
||||
suffix := make([]byte, 8)
|
||||
if _, err := crand.Read(suffix); err != nil {
|
||||
t.Fatalf("failed to generate random suffix: %v", err)
|
||||
}
|
||||
|
||||
// use unique per-test vtpm device names to avoid conflicts
|
||||
deviceName := fmt.Sprintf("vtpm-%d-%s", time.Now().UnixNano(), hex.EncodeToString(suffix))
|
||||
devicePath := filepath.Join("/dev", deviceName)
|
||||
pidFile := filepath.Join(dataDir, "swtpm.pid")
|
||||
|
||||
s := &swtpm{
|
||||
dataDir: dataDir,
|
||||
deviceName: deviceName,
|
||||
devicePath: devicePath,
|
||||
pidFile: pidFile,
|
||||
opts: options,
|
||||
t: t,
|
||||
}
|
||||
s.t.Logf("created swtpm with device %s, data dir %s", deviceName, dataDir)
|
||||
|
||||
if err := s.start(); err != nil {
|
||||
t.Fatalf("failed to start swtpm: %v", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
s.stop()
|
||||
})
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// runSetup initializes the TPM state using swtpm_setup
|
||||
func (s *swtpm) runSetup() error {
|
||||
args := []string{
|
||||
"--tpmstate", s.dataDir,
|
||||
}
|
||||
|
||||
switch s.opts.version {
|
||||
case "1.2":
|
||||
// TPM 1.2 is the default for swtpm_setup, no flag needed
|
||||
case "2.0":
|
||||
args = append(args, "--tpm2")
|
||||
default:
|
||||
s.t.Fatalf("unsupported swtpm version for setup: %q", s.opts.version)
|
||||
}
|
||||
|
||||
fullArgs := append([]string{"swtpm_setup"}, args...)
|
||||
cmd := exec.Command("sudo", fullArgs...)
|
||||
s.t.Logf("running swtpm_setup with args: %v", args)
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("swtpm_setup failed: %w: %s", err, output)
|
||||
} else {
|
||||
s.t.Logf("swtpm_setup output: %s", output)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// start launches the swtpm process with the configured options
|
||||
func (s *swtpm) start() error {
|
||||
// init state if requested
|
||||
if s.opts.withSetup {
|
||||
if err := s.runSetup(); err != nil {
|
||||
return fmt.Errorf("swtpm_setup failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
args := []string{
|
||||
"cuse",
|
||||
"--tpmstate", fmt.Sprintf("dir=%s", s.dataDir),
|
||||
"--name", s.deviceName,
|
||||
"--pid", fmt.Sprintf("file=%s", s.pidFile),
|
||||
}
|
||||
|
||||
switch s.opts.version {
|
||||
case "1.2":
|
||||
case "2.0":
|
||||
args = append(args, "--tpm2")
|
||||
default:
|
||||
return fmt.Errorf("unsupported swtpm version: %s", s.opts.version)
|
||||
}
|
||||
|
||||
// when using swtpm_setup, we need to tell swtpm to send TPM2_Startup(CLEAR)
|
||||
if s.opts.withSetup {
|
||||
args = append(args, "--flags", "startup-clear")
|
||||
}
|
||||
|
||||
if s.opts.flags != nil {
|
||||
args = append(args, s.opts.flags...)
|
||||
}
|
||||
|
||||
fullArgs := append([]string{swtpmBinary}, args...)
|
||||
cmd := exec.Command("sudo", fullArgs...)
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start swtpm with args %v: %w: %s", args, err, out)
|
||||
} else {
|
||||
s.t.Logf("swtpm started with args %v output: %s", args, out)
|
||||
}
|
||||
|
||||
s.t.Logf("waiting for swtpm device at %s", s.devicePath)
|
||||
if err := s.waitForDevice(); err != nil {
|
||||
s.stop()
|
||||
return fmt.Errorf("swtpm device not available: %w", err)
|
||||
} else {
|
||||
s.t.Logf("swtpm device available at %s", s.devicePath)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// waitForDevice waits for the swtpm character device to be created
|
||||
func (s *swtpm) waitForDevice() error {
|
||||
timeout := time.After(5 * time.Second)
|
||||
ticker := time.NewTicker(50 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timeout:
|
||||
return fmt.Errorf("timeout waiting for device at %s", s.devicePath)
|
||||
case <-ticker.C:
|
||||
if _, err := os.Stat(s.devicePath); err == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// stop terminates the swtpm process and cleans up resources
|
||||
func (s *swtpm) stop() {
|
||||
pidBytes, err := os.ReadFile(s.pidFile)
|
||||
if err != nil {
|
||||
s.t.Logf("failed to read swtpm pid file %s: %v", s.pidFile, err)
|
||||
return
|
||||
}
|
||||
|
||||
var pid int
|
||||
pid, err = strconv.Atoi(string(bytes.TrimSpace(pidBytes)))
|
||||
if err != nil {
|
||||
s.t.Logf("failed to parse swtpm pid %q: %v", string(pidBytes), err)
|
||||
return
|
||||
}
|
||||
var process *os.Process
|
||||
process, err = os.FindProcess(pid)
|
||||
if err != nil {
|
||||
s.t.Logf("failed to find swtpm process with pid %d: %v", pid, err)
|
||||
return
|
||||
}
|
||||
if err := process.Signal(syscall.SIGTERM); err != nil {
|
||||
s.t.Logf("failed to send SIGTERM to swtpm PID %d: %v", pid, err)
|
||||
return
|
||||
}
|
||||
s.t.Logf("sent SIGTERM to swtpm PID %d", pid)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := process.Wait()
|
||||
done <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
s.t.Logf("swtpm PID %d exited with error: %v", pid, err)
|
||||
} else {
|
||||
s.t.Logf("swtpm PID %d exited after SIGTERM", pid)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
s.t.Fatalf("timed out waiting for process %d to terminate", pid)
|
||||
}
|
||||
}
|
||||
|
||||
// DevicePath returns the path to the swtpm character device
|
||||
func (s *swtpm) DevicePath() string {
|
||||
return s.devicePath
|
||||
}
|
||||
|
||||
type swtpmOptions struct {
|
||||
version string
|
||||
flags []string
|
||||
withSetup bool
|
||||
}
|
||||
|
||||
type swtpmOption func(*swtpmOptions)
|
||||
|
||||
// withSetup enables TPM initialization via swtpm_setup
|
||||
func withSetup() swtpmOption {
|
||||
return func(o *swtpmOptions) {
|
||||
o.withSetup = true
|
||||
}
|
||||
}
|
||||
|
||||
// withSwtpmVersion sets the TPM version (either "1.2" or "2.0")
|
||||
func withSwtpmVersion(version string) swtpmOption {
|
||||
return func(o *swtpmOptions) {
|
||||
o.version = version
|
||||
}
|
||||
}
|
||||
|
||||
// withTPM12 configures swtpm to use TPM 1.2
|
||||
func withTPM12() swtpmOption {
|
||||
return withSwtpmVersion("1.2")
|
||||
}
|
||||
|
||||
// withTPM20 configures swtpm to use TPM 2.0 (default)
|
||||
func withTPM20() swtpmOption {
|
||||
return withSwtpmVersion("2.0")
|
||||
}
|
||||
|
||||
// checkSWTPMAvailable checks if swtpm is available and errors if not
|
||||
func checkSWTPMAvailable(t testing.TB) {
|
||||
t.Helper()
|
||||
p, err := exec.LookPath(swtpmBinary)
|
||||
if err != nil {
|
||||
t.Fatalf("swtpm binary not found in PATH: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// ensure version 0.10.1
|
||||
cmd := exec.Command(p, "--version")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to execute swtpm --version: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !bytes.HasPrefix(output, []byte("TPM emulator version 0.10.1")) {
|
||||
t.Fatalf("swtpm version is not compatible: %s", output)
|
||||
return
|
||||
}
|
||||
|
||||
// ensure that we can run swtpm with sudo non-interactive (necessary to create CUSE devices)
|
||||
cmd = exec.Command("sudo", "-n", p, "cuse", "--help")
|
||||
if err := cmd.Run(); err != nil {
|
||||
t.Fatalf("swtpm cannot be run with sudo without password prompt: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSWTPM_Integration(t *testing.T) {
|
||||
checkSWTPMAvailable(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
opts []swtpmOption
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "broken-1.2-no-setup",
|
||||
opts: []swtpmOption{withTPM12()},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "broken-2.0-no-setup",
|
||||
opts: []swtpmOption{withTPM20()},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "working-2.0",
|
||||
opts: []swtpmOption{withTPM20(), withSetup()},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
swtpm := newSWTPM(t, tt.opts...)
|
||||
devicePath := swtpm.DevicePath()
|
||||
|
||||
if _, err := os.Stat(devicePath); err != nil {
|
||||
t.Fatalf("swtpm device does not exist at %s: %v", devicePath, err)
|
||||
}
|
||||
|
||||
tpmDev, err := linuxtpm.Open(devicePath)
|
||||
if err != nil {
|
||||
t.Fatalf("linuxtpm.Open(%s) failed: %v", devicePath, err)
|
||||
}
|
||||
defer tpmDev.Close()
|
||||
|
||||
err = withSRK(t.Logf, tpmDev, func(srk tpm2.AuthHandle) error {
|
||||
t.Logf("Successfully loaded SRK with handle: %v", srk.Handle)
|
||||
return nil
|
||||
})
|
||||
|
||||
if tt.wantErr != (err != nil) {
|
||||
t.Errorf("withSRK() error = %v, wantErr = %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSWTPM_SealUnseal(t *testing.T) {
|
||||
checkSWTPMAvailable(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
opts []swtpmOption
|
||||
data []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "1.2-fail-no-setup",
|
||||
opts: []swtpmOption{withTPM12()},
|
||||
data: []byte(example32ByteKey),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "1.2-fail-32-byte-key",
|
||||
opts: []swtpmOption{withTPM12(), withSetup()},
|
||||
data: []byte(example32ByteKey),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "2.0-seal-fail-no-setup",
|
||||
opts: []swtpmOption{withTPM20()},
|
||||
data: []byte("test data"),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "2.0-seal-unseal-32-byte-key",
|
||||
opts: []swtpmOption{withTPM20(), withSetup()},
|
||||
data: []byte(example32ByteKey),
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
swtpm := newSWTPM(t, tt.opts...)
|
||||
devicePath := swtpm.DevicePath()
|
||||
|
||||
tpmDev, err := linuxtpm.Open(devicePath)
|
||||
if err != nil {
|
||||
t.Fatalf("linuxtpm.Open(%s) failed: %v", devicePath, err)
|
||||
}
|
||||
defer tpmDev.Close()
|
||||
|
||||
sealed, err := tpmSealWithTPM(t.Logf, tpmDev, tt.data)
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("tpmSealWithTPM() expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("tpmSealWithTPM() failed: %v", err)
|
||||
}
|
||||
|
||||
if sealed == nil {
|
||||
t.Fatal("tpmSealWithTPM() returned nil sealed data")
|
||||
}
|
||||
if len(sealed.Private) == 0 {
|
||||
t.Error("sealed.Private is empty")
|
||||
}
|
||||
if len(sealed.Public) == 0 {
|
||||
t.Error("sealed.Public is empty")
|
||||
}
|
||||
|
||||
unsealed, err := tpmUnsealWithTPM(t.Logf, tpmDev, sealed)
|
||||
if err != nil {
|
||||
t.Fatalf("tpmUnsealWithTPM() failed: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(unsealed, tt.data) {
|
||||
t.Errorf("unsealed data mismatch:\ngot: %q\nwant: %q", unsealed, tt.data)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSWTPM_SealUnsealCrossDevice(t *testing.T) {
|
||||
checkSWTPMAvailable(t)
|
||||
|
||||
swtpm1 := newSWTPM(t, withTPM20(), withSetup())
|
||||
tpmDev1, err := linuxtpm.Open(swtpm1.DevicePath())
|
||||
if err != nil {
|
||||
t.Fatalf("linuxtpm.Open(%s) failed: %v", swtpm1.DevicePath(), err)
|
||||
}
|
||||
defer tpmDev1.Close()
|
||||
|
||||
logf := func(format string, args ...any) {
|
||||
t.Logf(format, args...)
|
||||
}
|
||||
|
||||
testData := []byte("TPM1 secret data")
|
||||
sealed, err := tpmSealWithTPM(logf, tpmDev1, testData)
|
||||
if err != nil {
|
||||
t.Fatalf("tpmSealWithTPM() on first device failed: %v", err)
|
||||
}
|
||||
|
||||
// round trip on the same TPM
|
||||
unsealed, err := tpmUnsealWithTPM(logf, tpmDev1, sealed)
|
||||
if err != nil {
|
||||
t.Fatalf("tpmUnsealWithTPM() on first device failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(unsealed, testData) {
|
||||
t.Errorf("unsealed data mismatch on first device:\ngot: %q\nwant: %q", unsealed, testData)
|
||||
}
|
||||
|
||||
// create a second device
|
||||
swtpm2 := newSWTPM(t, withTPM20(), withSetup())
|
||||
tpmDev2, err := linuxtpm.Open(swtpm2.DevicePath())
|
||||
if err != nil {
|
||||
t.Fatalf("linuxtpm.Open(%s) failed: %v", swtpm2.DevicePath(), err)
|
||||
}
|
||||
defer tpmDev2.Close()
|
||||
|
||||
// confirm we cannot unseal with the second TPM
|
||||
_, err = tpmUnsealWithTPM(logf, tpmDev2, sealed)
|
||||
if err == nil {
|
||||
t.Error("tpmUnsealWithTPM() on second device should have failed but succeeded")
|
||||
}
|
||||
}
|
||||
@@ -393,8 +393,13 @@ func tpmSeal(logf logger.Logf, data []byte) (*tpmSealedData, error) {
|
||||
}
|
||||
defer tpm.Close()
|
||||
|
||||
return tpmSealWithTPM(logf, tpm, data)
|
||||
}
|
||||
|
||||
// tpmSealWithTPM seals the data using SRK of the provided TPM.
|
||||
func tpmSealWithTPM(logf logger.Logf, tpm transport.TPM, data []byte) (*tpmSealedData, error) {
|
||||
var res *tpmSealedData
|
||||
err = withSRK(logf, tpm, func(srk tpm2.AuthHandle) error {
|
||||
err := withSRK(logf, tpm, func(srk tpm2.AuthHandle) error {
|
||||
sealCmd := tpm2.Create{
|
||||
ParentHandle: srk,
|
||||
InSensitive: tpm2.TPM2BSensitiveCreate{
|
||||
@@ -436,8 +441,13 @@ func tpmUnseal(logf logger.Logf, data *tpmSealedData) ([]byte, error) {
|
||||
}
|
||||
defer tpm.Close()
|
||||
|
||||
return tpmUnsealWithTPM(logf, tpm, data)
|
||||
}
|
||||
|
||||
// tpmUnsealWithTPM unseals the data using SRK of the provided TPM.
|
||||
func tpmUnsealWithTPM(logf logger.Logf, tpm transport.TPM, data *tpmSealedData) ([]byte, error) {
|
||||
var res []byte
|
||||
err = withSRK(logf, tpm, func(srk tpm2.AuthHandle) error {
|
||||
err := withSRK(logf, tpm, func(srk tpm2.AuthHandle) error {
|
||||
// Load the sealed object into the TPM first under SRK.
|
||||
loadCmd := tpm2.Load{
|
||||
ParentHandle: srk,
|
||||
|
||||
Reference in New Issue
Block a user