tailscale/net/udprelay/server_default_test.go
Jordan Whited 559e548e8b
ipn/ipnlocal: add peerapi endpoint for relay server endpoint allocation
Relay server initialization is performed as part of handling peerapi
endpoint allocation requests in similar fashion to SSH server
init. The relay server is not supported on iOS for now.

Updates tailscale/corp#27502

Signed-off-by: Jordan Whited <jordan@tailscale.com>
2025-04-09 15:44:26 -07:00

207 lines
5.1 KiB
Go

// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build !ios
package udprelay
import (
"bytes"
"net"
"net/netip"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"go4.org/mem"
"tailscale.com/disco"
"tailscale.com/net/packet"
"tailscale.com/types/key"
)
type testClient struct {
vni uint32
local key.DiscoPrivate
server key.DiscoPublic
uc *net.UDPConn
}
func newTestClient(t *testing.T, vni uint32, serverEndpoint netip.AddrPort, local key.DiscoPrivate, server key.DiscoPublic) *testClient {
rAddr := &net.UDPAddr{IP: serverEndpoint.Addr().AsSlice(), Port: int(serverEndpoint.Port())}
uc, err := net.DialUDP("udp4", nil, rAddr)
if err != nil {
t.Fatal(err)
}
return &testClient{
vni: vni,
local: local,
server: server,
uc: uc,
}
}
func (c *testClient) write(t *testing.T, b []byte) {
_, err := c.uc.Write(b)
if err != nil {
t.Fatal(err)
}
}
func (c *testClient) read(t *testing.T) []byte {
c.uc.SetReadDeadline(time.Now().Add(time.Second))
b := make([]byte, 1<<16-1)
n, err := c.uc.Read(b)
if err != nil {
t.Fatal(err)
}
return b[:n]
}
func (c *testClient) writeDataPkt(t *testing.T, b []byte) {
pkt := make([]byte, packet.GeneveFixedHeaderLength, packet.GeneveFixedHeaderLength+len(b))
gh := packet.GeneveHeader{Control: false, VNI: c.vni, Protocol: packet.GeneveProtocolWireGuard}
err := gh.Encode(pkt)
if err != nil {
t.Fatal(err)
}
pkt = append(pkt, b...)
c.write(t, pkt)
}
func (c *testClient) readDataPkt(t *testing.T) []byte {
b := c.read(t)
gh := packet.GeneveHeader{}
err := gh.Decode(b)
if err != nil {
t.Fatal(err)
}
if gh.Protocol != packet.GeneveProtocolWireGuard {
t.Fatal("unexpected geneve protocol")
}
if gh.Control {
t.Fatal("unexpected control")
}
if gh.VNI != c.vni {
t.Fatal("unexpected vni")
}
return b[packet.GeneveFixedHeaderLength:]
}
func (c *testClient) writeControlDiscoMsg(t *testing.T, msg disco.Message) {
pkt := make([]byte, packet.GeneveFixedHeaderLength, 512)
gh := packet.GeneveHeader{Control: true, VNI: c.vni, Protocol: packet.GeneveProtocolDisco}
err := gh.Encode(pkt)
if err != nil {
t.Fatal(err)
}
pkt = append(pkt, disco.Magic...)
pkt = c.local.Public().AppendTo(pkt)
box := c.local.Shared(c.server).Seal(msg.AppendMarshal(nil))
pkt = append(pkt, box...)
c.write(t, pkt)
}
func (c *testClient) readControlDiscoMsg(t *testing.T) disco.Message {
b := c.read(t)
gh := packet.GeneveHeader{}
err := gh.Decode(b)
if err != nil {
t.Fatal(err)
}
if gh.Protocol != packet.GeneveProtocolDisco {
t.Fatal("unexpected geneve protocol")
}
if !gh.Control {
t.Fatal("unexpected non-control")
}
if gh.VNI != c.vni {
t.Fatal("unexpected vni")
}
b = b[packet.GeneveFixedHeaderLength:]
headerLen := len(disco.Magic) + key.DiscoPublicRawLen
if len(b) < headerLen {
t.Fatal("disco message too short")
}
sender := key.DiscoPublicFromRaw32(mem.B(b[len(disco.Magic):headerLen]))
if sender.Compare(c.server) != 0 {
t.Fatal("unknown disco public key")
}
payload, ok := c.local.Shared(c.server).Open(b[headerLen:])
if !ok {
t.Fatal("failed to open sealed disco msg")
}
msg, err := disco.Parse(payload)
if err != nil {
t.Fatal("failed to parse disco payload")
}
return msg
}
func (c *testClient) handshake(t *testing.T) {
c.writeControlDiscoMsg(t, &disco.BindUDPRelayEndpoint{})
msg := c.readControlDiscoMsg(t)
challenge, ok := msg.(*disco.BindUDPRelayEndpointChallenge)
if !ok {
t.Fatal("unexepcted disco message type")
}
c.writeControlDiscoMsg(t, &disco.BindUDPRelayEndpointAnswer{Answer: challenge.Challenge})
}
func (c *testClient) close() {
c.uc.Close()
}
func TestServer(t *testing.T) {
discoA := key.NewDisco()
discoB := key.NewDisco()
ipv4LoopbackAddr := netip.MustParseAddr("127.0.0.1")
server, _, err := NewServer(0, []netip.Addr{ipv4LoopbackAddr})
if err != nil {
t.Fatal(err)
}
defer server.Close()
endpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public())
if err != nil {
t.Fatal(err)
}
dupEndpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public())
if err != nil {
t.Fatal(err)
}
// We expect the same endpoint details as the 3-way bind handshake has not
// yet been completed for both relay client parties.
if diff := cmp.Diff(dupEndpoint, endpoint, cmpopts.EquateComparable(netip.AddrPort{}, key.DiscoPublic{})); diff != "" {
t.Fatalf("wrong dupEndpoint (-got +want)\n%s", diff)
}
if len(endpoint.AddrPorts) != 1 {
t.Fatalf("unexpected endpoint.AddrPorts: %v", endpoint.AddrPorts)
}
tcA := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoA, endpoint.ServerDisco)
defer tcA.close()
tcB := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoB, endpoint.ServerDisco)
defer tcB.close()
tcA.handshake(t)
tcB.handshake(t)
txToB := []byte{1, 2, 3}
tcA.writeDataPkt(t, txToB)
rxFromA := tcB.readDataPkt(t)
if !bytes.Equal(txToB, rxFromA) {
t.Fatal("unexpected msg A->B")
}
txToA := []byte{4, 5, 6}
tcB.writeDataPkt(t, txToA)
rxFromB := tcA.readDataPkt(t)
if !bytes.Equal(txToA, rxFromB) {
t.Fatal("unexpected msg B->A")
}
}