mirror of
https://github.com/tailscale/tailscale.git
synced 2025-07-01 04:28:46 +00:00
204 lines
5.3 KiB
Go
204 lines
5.3 KiB
Go
![]() |
// Copyright (c) Tailscale Inc & AUTHORS
|
||
|
// SPDX-License-Identifier: BSD-3-Clause
|
||
|
|
||
|
//go:build !plan9
|
||
|
|
||
|
package main
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"fmt"
|
||
|
"net/netip"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"github.com/google/go-cmp/cmp"
|
||
|
"tailscale.com/ipn"
|
||
|
"tailscale.com/ipn/store"
|
||
|
"tailscale.com/tailcfg"
|
||
|
"tailscale.com/types/logger"
|
||
|
"tailscale.com/types/netmap"
|
||
|
)
|
||
|
|
||
|
func TestSetInitialStateKeys(t *testing.T) {
|
||
|
var (
|
||
|
podUID = []byte("test-pod-uid")
|
||
|
expectedCapVer = fmt.Appendf(nil, "%d", tailcfg.CurrentCapabilityVersion)
|
||
|
)
|
||
|
for name, tc := range map[string]struct {
|
||
|
initial map[ipn.StateKey][]byte
|
||
|
expected map[ipn.StateKey][]byte
|
||
|
}{
|
||
|
"empty_initial": {
|
||
|
initial: map[ipn.StateKey][]byte{},
|
||
|
expected: map[ipn.StateKey][]byte{
|
||
|
keyPodUID: podUID,
|
||
|
keyCapVer: expectedCapVer,
|
||
|
},
|
||
|
},
|
||
|
"existing_pod_uid_and_capver": {
|
||
|
initial: map[ipn.StateKey][]byte{
|
||
|
keyPodUID: podUID,
|
||
|
keyCapVer: expectedCapVer,
|
||
|
},
|
||
|
expected: map[ipn.StateKey][]byte{
|
||
|
keyPodUID: podUID,
|
||
|
keyCapVer: expectedCapVer,
|
||
|
},
|
||
|
},
|
||
|
"all_keys_preexisting": {
|
||
|
initial: map[ipn.StateKey][]byte{
|
||
|
keyPodUID: podUID,
|
||
|
keyCapVer: expectedCapVer,
|
||
|
keyDeviceID: []byte("existing-device-id"),
|
||
|
keyDeviceFQDN: []byte("existing-device-fqdn"),
|
||
|
keyDeviceIPs: []byte(`["1.2.3.4"]`),
|
||
|
},
|
||
|
expected: map[ipn.StateKey][]byte{
|
||
|
keyPodUID: podUID,
|
||
|
keyCapVer: expectedCapVer,
|
||
|
keyDeviceID: nil,
|
||
|
keyDeviceFQDN: nil,
|
||
|
keyDeviceIPs: nil,
|
||
|
},
|
||
|
},
|
||
|
} {
|
||
|
t.Run(name, func(t *testing.T) {
|
||
|
store, err := store.New(logger.Discard, "mem:")
|
||
|
if err != nil {
|
||
|
t.Fatalf("error creating in-memory store: %v", err)
|
||
|
}
|
||
|
|
||
|
for key, value := range tc.initial {
|
||
|
if err := store.WriteState(key, value); err != nil {
|
||
|
t.Fatalf("error writing initial state key %q: %v", key, err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if err := setInitialStateKeys(store, string(podUID)); err != nil {
|
||
|
t.Fatalf("setInitialStateKeys failed: %v", err)
|
||
|
}
|
||
|
|
||
|
actual := make(map[ipn.StateKey][]byte)
|
||
|
for expectedKey, expectedValue := range tc.expected {
|
||
|
actualValue, err := store.ReadState(expectedKey)
|
||
|
if err != nil {
|
||
|
t.Errorf("error reading state key %q: %v", expectedKey, err)
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
actual[expectedKey] = actualValue
|
||
|
if !bytes.Equal(actualValue, expectedValue) {
|
||
|
t.Errorf("state key %q mismatch: expected %q, got %q", expectedKey, expectedValue, actualValue)
|
||
|
}
|
||
|
}
|
||
|
if diff := cmp.Diff(actual, tc.expected); diff != "" {
|
||
|
t.Errorf("state keys mismatch (-got +want):\n%s", diff)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestKeepStateKeysUpdated(t *testing.T) {
|
||
|
store, err := store.New(logger.Discard, "mem:")
|
||
|
if err != nil {
|
||
|
t.Fatalf("error creating in-memory store: %v", err)
|
||
|
}
|
||
|
|
||
|
nextWaiting := make(chan struct{})
|
||
|
go func() {
|
||
|
<-nextWaiting // Acknowledge the initial signal.
|
||
|
}()
|
||
|
notifyCh := make(chan ipn.Notify)
|
||
|
next := func() (ipn.Notify, error) {
|
||
|
nextWaiting <- struct{}{} // Send signal to test that state is consistent.
|
||
|
return <-notifyCh, nil // Wait for test input.
|
||
|
}
|
||
|
|
||
|
errs := make(chan error, 1)
|
||
|
go func() {
|
||
|
err := keepStateKeysUpdated(store, next)
|
||
|
if err != nil {
|
||
|
errs <- fmt.Errorf("keepStateKeysUpdated returned with error: %w", err)
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
for _, tc := range []struct {
|
||
|
name string
|
||
|
notify ipn.Notify
|
||
|
expected map[ipn.StateKey][]byte
|
||
|
}{
|
||
|
{
|
||
|
name: "initial_not_authed",
|
||
|
notify: ipn.Notify{},
|
||
|
expected: map[ipn.StateKey][]byte{
|
||
|
keyDeviceID: nil,
|
||
|
keyDeviceFQDN: nil,
|
||
|
keyDeviceIPs: nil,
|
||
|
},
|
||
|
},
|
||
|
{
|
||
|
name: "authed",
|
||
|
notify: ipn.Notify{
|
||
|
NetMap: &netmap.NetworkMap{
|
||
|
SelfNode: (&tailcfg.Node{
|
||
|
StableID: "TESTCTRL00000001",
|
||
|
Name: "test-node.test.ts.net",
|
||
|
Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32"), netip.MustParsePrefix("fd7a:115c:a1e0:ab12:4843:cd96:0:1/128")},
|
||
|
}).View(),
|
||
|
},
|
||
|
},
|
||
|
expected: map[ipn.StateKey][]byte{
|
||
|
keyDeviceID: []byte("TESTCTRL00000001"),
|
||
|
keyDeviceFQDN: []byte("test-node.test.ts.net"),
|
||
|
keyDeviceIPs: []byte(`["100.64.0.1","fd7a:115c:a1e0:ab12:4843:cd96:0:1"]`),
|
||
|
},
|
||
|
},
|
||
|
{
|
||
|
name: "updated_fields",
|
||
|
notify: ipn.Notify{
|
||
|
NetMap: &netmap.NetworkMap{
|
||
|
SelfNode: (&tailcfg.Node{
|
||
|
StableID: "TESTCTRL00000001",
|
||
|
Name: "updated.test.ts.net",
|
||
|
Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.250/32")},
|
||
|
}).View(),
|
||
|
},
|
||
|
},
|
||
|
expected: map[ipn.StateKey][]byte{
|
||
|
keyDeviceID: []byte("TESTCTRL00000001"),
|
||
|
keyDeviceFQDN: []byte("updated.test.ts.net"),
|
||
|
keyDeviceIPs: []byte(`["100.64.0.250"]`),
|
||
|
},
|
||
|
},
|
||
|
} {
|
||
|
t.Run(tc.name, func(t *testing.T) {
|
||
|
// Send test input.
|
||
|
select {
|
||
|
case notifyCh <- tc.notify:
|
||
|
case <-errs:
|
||
|
t.Fatal("keepStateKeysUpdated returned before test input")
|
||
|
case <-time.After(5 * time.Second):
|
||
|
t.Fatal("timed out waiting for next() to be called again")
|
||
|
}
|
||
|
|
||
|
// Wait for next() to be called again so we know the goroutine has
|
||
|
// processed the event.
|
||
|
select {
|
||
|
case <-nextWaiting:
|
||
|
case <-errs:
|
||
|
t.Fatal("keepStateKeysUpdated returned before test input")
|
||
|
case <-time.After(5 * time.Second):
|
||
|
t.Fatal("timed out waiting for next() to be called again")
|
||
|
}
|
||
|
|
||
|
for key, value := range tc.expected {
|
||
|
got, _ := store.ReadState(key)
|
||
|
if !bytes.Equal(got, value) {
|
||
|
t.Errorf("state key %q mismatch: expected %q, got %q", key, value, got)
|
||
|
}
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|