mirror of
https://github.com/tailscale/tailscale.git
synced 2025-04-29 12:11:02 +00:00

The previous strategy assumed clients maintained adequate state to understand the relationship between endpoint allocation and the server it was allocated on. magicsock will not have awareness of the server's disco key pre-allocation, it only understands peerAPI address at this point. The second client to allocate on the same server could trigger re-allocation, breaking a functional relay server endpoint. If magicsock needs to force reallocation we can add opt-in behaviors for this later. Updates tailscale/corp#27502 Signed-off-by: Jordan Whited <jordan@tailscale.com>
309 lines
8.9 KiB
Go
309 lines
8.9 KiB
Go
// Copyright (c) Tailscale Inc & AUTHORS
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
package udprelay
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"math"
|
|
"net"
|
|
"net/netip"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/go-cmp/cmp"
|
|
"github.com/google/go-cmp/cmp/cmpopts"
|
|
"go4.org/mem"
|
|
"tailscale.com/disco"
|
|
"tailscale.com/net/packet"
|
|
"tailscale.com/tstime"
|
|
"tailscale.com/types/key"
|
|
)
|
|
|
|
type testClient struct {
|
|
vni uint32
|
|
local key.DiscoPrivate
|
|
server key.DiscoPublic
|
|
uc *net.UDPConn
|
|
}
|
|
|
|
func newTestClient(t *testing.T, vni uint32, serverEndpoint netip.AddrPort, local key.DiscoPrivate, server key.DiscoPublic) *testClient {
|
|
rAddr := &net.UDPAddr{IP: serverEndpoint.Addr().AsSlice(), Port: int(serverEndpoint.Port())}
|
|
uc, err := net.DialUDP("udp4", nil, rAddr)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return &testClient{
|
|
vni: vni,
|
|
local: local,
|
|
server: server,
|
|
uc: uc,
|
|
}
|
|
}
|
|
|
|
func (c *testClient) write(t *testing.T, b []byte) {
|
|
_, err := c.uc.Write(b)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func (c *testClient) read(t *testing.T) []byte {
|
|
c.uc.SetReadDeadline(time.Now().Add(time.Second))
|
|
b := make([]byte, 1<<16-1)
|
|
n, err := c.uc.Read(b)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return b[:n]
|
|
}
|
|
|
|
func (c *testClient) writeDataPkt(t *testing.T, b []byte) {
|
|
pkt := make([]byte, packet.GeneveFixedHeaderLength, packet.GeneveFixedHeaderLength+len(b))
|
|
gh := packet.GeneveHeader{Control: false, VNI: c.vni, Protocol: packet.GeneveProtocolWireGuard}
|
|
err := gh.Encode(pkt)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
pkt = append(pkt, b...)
|
|
c.write(t, pkt)
|
|
}
|
|
|
|
func (c *testClient) readDataPkt(t *testing.T) []byte {
|
|
b := c.read(t)
|
|
gh := packet.GeneveHeader{}
|
|
err := gh.Decode(b)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if gh.Protocol != packet.GeneveProtocolWireGuard {
|
|
t.Fatal("unexpected geneve protocol")
|
|
}
|
|
if gh.Control {
|
|
t.Fatal("unexpected control")
|
|
}
|
|
if gh.VNI != c.vni {
|
|
t.Fatal("unexpected vni")
|
|
}
|
|
return b[packet.GeneveFixedHeaderLength:]
|
|
}
|
|
|
|
func (c *testClient) writeControlDiscoMsg(t *testing.T, msg disco.Message) {
|
|
pkt := make([]byte, packet.GeneveFixedHeaderLength, 512)
|
|
gh := packet.GeneveHeader{Control: true, VNI: c.vni, Protocol: packet.GeneveProtocolDisco}
|
|
err := gh.Encode(pkt)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
pkt = append(pkt, disco.Magic...)
|
|
pkt = c.local.Public().AppendTo(pkt)
|
|
box := c.local.Shared(c.server).Seal(msg.AppendMarshal(nil))
|
|
pkt = append(pkt, box...)
|
|
c.write(t, pkt)
|
|
}
|
|
|
|
func (c *testClient) readControlDiscoMsg(t *testing.T) disco.Message {
|
|
b := c.read(t)
|
|
gh := packet.GeneveHeader{}
|
|
err := gh.Decode(b)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if gh.Protocol != packet.GeneveProtocolDisco {
|
|
t.Fatal("unexpected geneve protocol")
|
|
}
|
|
if !gh.Control {
|
|
t.Fatal("unexpected non-control")
|
|
}
|
|
if gh.VNI != c.vni {
|
|
t.Fatal("unexpected vni")
|
|
}
|
|
b = b[packet.GeneveFixedHeaderLength:]
|
|
headerLen := len(disco.Magic) + key.DiscoPublicRawLen
|
|
if len(b) < headerLen {
|
|
t.Fatal("disco message too short")
|
|
}
|
|
sender := key.DiscoPublicFromRaw32(mem.B(b[len(disco.Magic):headerLen]))
|
|
if sender.Compare(c.server) != 0 {
|
|
t.Fatal("unknown disco public key")
|
|
}
|
|
payload, ok := c.local.Shared(c.server).Open(b[headerLen:])
|
|
if !ok {
|
|
t.Fatal("failed to open sealed disco msg")
|
|
}
|
|
msg, err := disco.Parse(payload)
|
|
if err != nil {
|
|
t.Fatal("failed to parse disco payload")
|
|
}
|
|
return msg
|
|
}
|
|
|
|
func (c *testClient) handshake(t *testing.T) {
|
|
c.writeControlDiscoMsg(t, &disco.BindUDPRelayEndpoint{})
|
|
msg := c.readControlDiscoMsg(t)
|
|
challenge, ok := msg.(*disco.BindUDPRelayEndpointChallenge)
|
|
if !ok {
|
|
t.Fatal("unexepcted disco message type")
|
|
}
|
|
c.writeControlDiscoMsg(t, &disco.BindUDPRelayEndpointAnswer{Answer: challenge.Challenge})
|
|
}
|
|
|
|
func (c *testClient) close() {
|
|
c.uc.Close()
|
|
}
|
|
|
|
func TestServer(t *testing.T) {
|
|
discoA := key.NewDisco()
|
|
discoB := key.NewDisco()
|
|
|
|
ipv4LoopbackAddr := netip.MustParseAddr("127.0.0.1")
|
|
|
|
server, _, err := NewServer(0, []netip.Addr{ipv4LoopbackAddr})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer server.Close()
|
|
|
|
endpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
dupEndpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// We expect the same endpoint details pre-handshake.
|
|
if diff := cmp.Diff(dupEndpoint, endpoint, cmpopts.EquateComparable(netip.AddrPort{}, key.DiscoPublic{})); diff != "" {
|
|
t.Fatalf("wrong dupEndpoint (-got +want)\n%s", diff)
|
|
}
|
|
|
|
if len(endpoint.AddrPorts) != 1 {
|
|
t.Fatalf("unexpected endpoint.AddrPorts: %v", endpoint.AddrPorts)
|
|
}
|
|
tcA := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoA, endpoint.ServerDisco)
|
|
defer tcA.close()
|
|
tcB := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoB, endpoint.ServerDisco)
|
|
defer tcB.close()
|
|
|
|
tcA.handshake(t)
|
|
tcB.handshake(t)
|
|
|
|
dupEndpoint, err = server.AllocateEndpoint(discoA.Public(), discoB.Public())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
// We expect the same endpoint details post-handshake.
|
|
if diff := cmp.Diff(dupEndpoint, endpoint, cmpopts.EquateComparable(netip.AddrPort{}, key.DiscoPublic{})); diff != "" {
|
|
t.Fatalf("wrong dupEndpoint (-got +want)\n%s", diff)
|
|
}
|
|
|
|
txToB := []byte{1, 2, 3}
|
|
tcA.writeDataPkt(t, txToB)
|
|
rxFromA := tcB.readDataPkt(t)
|
|
if !bytes.Equal(txToB, rxFromA) {
|
|
t.Fatal("unexpected msg A->B")
|
|
}
|
|
|
|
txToA := []byte{4, 5, 6}
|
|
tcB.writeDataPkt(t, txToA)
|
|
rxFromB := tcA.readDataPkt(t)
|
|
if !bytes.Equal(txToA, rxFromB) {
|
|
t.Fatal("unexpected msg B->A")
|
|
}
|
|
}
|
|
|
|
func TestServerEndpointJSONUnmarshal(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
json []byte
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "valid",
|
|
json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"30s","SteadyStateLifetime":"5m0s"}`),
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "invalid ServerDisco",
|
|
json: []byte(`{"ServerDisco":"1","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"30s","SteadyStateLifetime":"5m0s"}`),
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "invalid LamportID",
|
|
json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":1.1,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"30s","SteadyStateLifetime":"5m0s"}`),
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "invalid AddrPorts",
|
|
json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"30s","SteadyStateLifetime":"5m0s"}`),
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "invalid VNI",
|
|
json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":18446744073709551615,"BindLifetime":"30s","SteadyStateLifetime":"5m0s"}`),
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "invalid BindLifetime",
|
|
json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"5","SteadyStateLifetime":"5m0s"}`),
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "invalid SteadyStateLifetime",
|
|
json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"30s","SteadyStateLifetime":"5"}`),
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
var out ServerEndpoint
|
|
err := json.Unmarshal(tt.json, &out)
|
|
if tt.wantErr != (err != nil) {
|
|
t.Fatalf("wantErr: %v (err == nil): %v", tt.wantErr, err == nil)
|
|
}
|
|
if tt.wantErr {
|
|
return
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestServerEndpointJSONMarshal(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
serverEndpoint ServerEndpoint
|
|
}{
|
|
{
|
|
name: "valid roundtrip",
|
|
serverEndpoint: ServerEndpoint{
|
|
ServerDisco: key.NewDisco().Public(),
|
|
LamportID: uint64(math.MaxUint64),
|
|
AddrPorts: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:1"), netip.MustParseAddrPort("127.0.0.2:2")},
|
|
VNI: 1<<24 - 1,
|
|
BindLifetime: tstime.GoDuration{Duration: defaultBindLifetime},
|
|
SteadyStateLifetime: tstime.GoDuration{Duration: defaultSteadyStateLifetime},
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
b, err := json.Marshal(&tt.serverEndpoint)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
var got ServerEndpoint
|
|
err = json.Unmarshal(b, &got)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if diff := cmp.Diff(got, tt.serverEndpoint, cmpopts.EquateComparable(netip.AddrPort{}, key.DiscoPublic{})); diff != "" {
|
|
t.Fatalf("ServerEndpoint unequal (-got +want)\n%s", diff)
|
|
}
|
|
})
|
|
}
|
|
}
|