tailscale/net/udprelay/server_test.go
Jordan Whited f701d39ba4
net/udprelay: change Server.AllocateEndpoint existing alloc strategy (#15792)
The previous strategy assumed clients maintained adequate state to
understand the relationship between endpoint allocation and the server
it was allocated on.

magicsock will not have awareness of the server's disco key
pre-allocation, it only understands peerAPI address at this point. The
second client to allocate on the same server could trigger
re-allocation, breaking a functional relay server endpoint.

If magicsock needs to force reallocation we can add opt-in behaviors
for this later.

Updates tailscale/corp#27502

Signed-off-by: Jordan Whited <jordan@tailscale.com>
2025-04-25 13:09:09 -07:00

309 lines
8.9 KiB
Go

// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package udprelay
import (
"bytes"
"encoding/json"
"math"
"net"
"net/netip"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"go4.org/mem"
"tailscale.com/disco"
"tailscale.com/net/packet"
"tailscale.com/tstime"
"tailscale.com/types/key"
)
type testClient struct {
vni uint32
local key.DiscoPrivate
server key.DiscoPublic
uc *net.UDPConn
}
func newTestClient(t *testing.T, vni uint32, serverEndpoint netip.AddrPort, local key.DiscoPrivate, server key.DiscoPublic) *testClient {
rAddr := &net.UDPAddr{IP: serverEndpoint.Addr().AsSlice(), Port: int(serverEndpoint.Port())}
uc, err := net.DialUDP("udp4", nil, rAddr)
if err != nil {
t.Fatal(err)
}
return &testClient{
vni: vni,
local: local,
server: server,
uc: uc,
}
}
func (c *testClient) write(t *testing.T, b []byte) {
_, err := c.uc.Write(b)
if err != nil {
t.Fatal(err)
}
}
func (c *testClient) read(t *testing.T) []byte {
c.uc.SetReadDeadline(time.Now().Add(time.Second))
b := make([]byte, 1<<16-1)
n, err := c.uc.Read(b)
if err != nil {
t.Fatal(err)
}
return b[:n]
}
func (c *testClient) writeDataPkt(t *testing.T, b []byte) {
pkt := make([]byte, packet.GeneveFixedHeaderLength, packet.GeneveFixedHeaderLength+len(b))
gh := packet.GeneveHeader{Control: false, VNI: c.vni, Protocol: packet.GeneveProtocolWireGuard}
err := gh.Encode(pkt)
if err != nil {
t.Fatal(err)
}
pkt = append(pkt, b...)
c.write(t, pkt)
}
func (c *testClient) readDataPkt(t *testing.T) []byte {
b := c.read(t)
gh := packet.GeneveHeader{}
err := gh.Decode(b)
if err != nil {
t.Fatal(err)
}
if gh.Protocol != packet.GeneveProtocolWireGuard {
t.Fatal("unexpected geneve protocol")
}
if gh.Control {
t.Fatal("unexpected control")
}
if gh.VNI != c.vni {
t.Fatal("unexpected vni")
}
return b[packet.GeneveFixedHeaderLength:]
}
func (c *testClient) writeControlDiscoMsg(t *testing.T, msg disco.Message) {
pkt := make([]byte, packet.GeneveFixedHeaderLength, 512)
gh := packet.GeneveHeader{Control: true, VNI: c.vni, Protocol: packet.GeneveProtocolDisco}
err := gh.Encode(pkt)
if err != nil {
t.Fatal(err)
}
pkt = append(pkt, disco.Magic...)
pkt = c.local.Public().AppendTo(pkt)
box := c.local.Shared(c.server).Seal(msg.AppendMarshal(nil))
pkt = append(pkt, box...)
c.write(t, pkt)
}
func (c *testClient) readControlDiscoMsg(t *testing.T) disco.Message {
b := c.read(t)
gh := packet.GeneveHeader{}
err := gh.Decode(b)
if err != nil {
t.Fatal(err)
}
if gh.Protocol != packet.GeneveProtocolDisco {
t.Fatal("unexpected geneve protocol")
}
if !gh.Control {
t.Fatal("unexpected non-control")
}
if gh.VNI != c.vni {
t.Fatal("unexpected vni")
}
b = b[packet.GeneveFixedHeaderLength:]
headerLen := len(disco.Magic) + key.DiscoPublicRawLen
if len(b) < headerLen {
t.Fatal("disco message too short")
}
sender := key.DiscoPublicFromRaw32(mem.B(b[len(disco.Magic):headerLen]))
if sender.Compare(c.server) != 0 {
t.Fatal("unknown disco public key")
}
payload, ok := c.local.Shared(c.server).Open(b[headerLen:])
if !ok {
t.Fatal("failed to open sealed disco msg")
}
msg, err := disco.Parse(payload)
if err != nil {
t.Fatal("failed to parse disco payload")
}
return msg
}
func (c *testClient) handshake(t *testing.T) {
c.writeControlDiscoMsg(t, &disco.BindUDPRelayEndpoint{})
msg := c.readControlDiscoMsg(t)
challenge, ok := msg.(*disco.BindUDPRelayEndpointChallenge)
if !ok {
t.Fatal("unexepcted disco message type")
}
c.writeControlDiscoMsg(t, &disco.BindUDPRelayEndpointAnswer{Answer: challenge.Challenge})
}
func (c *testClient) close() {
c.uc.Close()
}
func TestServer(t *testing.T) {
discoA := key.NewDisco()
discoB := key.NewDisco()
ipv4LoopbackAddr := netip.MustParseAddr("127.0.0.1")
server, _, err := NewServer(0, []netip.Addr{ipv4LoopbackAddr})
if err != nil {
t.Fatal(err)
}
defer server.Close()
endpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public())
if err != nil {
t.Fatal(err)
}
dupEndpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public())
if err != nil {
t.Fatal(err)
}
// We expect the same endpoint details pre-handshake.
if diff := cmp.Diff(dupEndpoint, endpoint, cmpopts.EquateComparable(netip.AddrPort{}, key.DiscoPublic{})); diff != "" {
t.Fatalf("wrong dupEndpoint (-got +want)\n%s", diff)
}
if len(endpoint.AddrPorts) != 1 {
t.Fatalf("unexpected endpoint.AddrPorts: %v", endpoint.AddrPorts)
}
tcA := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoA, endpoint.ServerDisco)
defer tcA.close()
tcB := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoB, endpoint.ServerDisco)
defer tcB.close()
tcA.handshake(t)
tcB.handshake(t)
dupEndpoint, err = server.AllocateEndpoint(discoA.Public(), discoB.Public())
if err != nil {
t.Fatal(err)
}
// We expect the same endpoint details post-handshake.
if diff := cmp.Diff(dupEndpoint, endpoint, cmpopts.EquateComparable(netip.AddrPort{}, key.DiscoPublic{})); diff != "" {
t.Fatalf("wrong dupEndpoint (-got +want)\n%s", diff)
}
txToB := []byte{1, 2, 3}
tcA.writeDataPkt(t, txToB)
rxFromA := tcB.readDataPkt(t)
if !bytes.Equal(txToB, rxFromA) {
t.Fatal("unexpected msg A->B")
}
txToA := []byte{4, 5, 6}
tcB.writeDataPkt(t, txToA)
rxFromB := tcA.readDataPkt(t)
if !bytes.Equal(txToA, rxFromB) {
t.Fatal("unexpected msg B->A")
}
}
func TestServerEndpointJSONUnmarshal(t *testing.T) {
tests := []struct {
name string
json []byte
wantErr bool
}{
{
name: "valid",
json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"30s","SteadyStateLifetime":"5m0s"}`),
wantErr: false,
},
{
name: "invalid ServerDisco",
json: []byte(`{"ServerDisco":"1","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"30s","SteadyStateLifetime":"5m0s"}`),
wantErr: true,
},
{
name: "invalid LamportID",
json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":1.1,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"30s","SteadyStateLifetime":"5m0s"}`),
wantErr: true,
},
{
name: "invalid AddrPorts",
json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"30s","SteadyStateLifetime":"5m0s"}`),
wantErr: true,
},
{
name: "invalid VNI",
json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":18446744073709551615,"BindLifetime":"30s","SteadyStateLifetime":"5m0s"}`),
wantErr: true,
},
{
name: "invalid BindLifetime",
json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"5","SteadyStateLifetime":"5m0s"}`),
wantErr: true,
},
{
name: "invalid SteadyStateLifetime",
json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"30s","SteadyStateLifetime":"5"}`),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var out ServerEndpoint
err := json.Unmarshal(tt.json, &out)
if tt.wantErr != (err != nil) {
t.Fatalf("wantErr: %v (err == nil): %v", tt.wantErr, err == nil)
}
if tt.wantErr {
return
}
})
}
}
func TestServerEndpointJSONMarshal(t *testing.T) {
tests := []struct {
name string
serverEndpoint ServerEndpoint
}{
{
name: "valid roundtrip",
serverEndpoint: ServerEndpoint{
ServerDisco: key.NewDisco().Public(),
LamportID: uint64(math.MaxUint64),
AddrPorts: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:1"), netip.MustParseAddrPort("127.0.0.2:2")},
VNI: 1<<24 - 1,
BindLifetime: tstime.GoDuration{Duration: defaultBindLifetime},
SteadyStateLifetime: tstime.GoDuration{Duration: defaultSteadyStateLifetime},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
b, err := json.Marshal(&tt.serverEndpoint)
if err != nil {
t.Fatal(err)
}
var got ServerEndpoint
err = json.Unmarshal(b, &got)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(got, tt.serverEndpoint, cmpopts.EquateComparable(netip.AddrPort{}, key.DiscoPublic{})); diff != "" {
t.Fatalf("ServerEndpoint unequal (-got +want)\n%s", diff)
}
})
}
}