// Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause package udprelay import ( "bytes" "encoding/json" "math" "net" "net/netip" "testing" "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "go4.org/mem" "tailscale.com/disco" "tailscale.com/net/packet" "tailscale.com/tstime" "tailscale.com/types/key" ) type testClient struct { vni uint32 local key.DiscoPrivate server key.DiscoPublic uc *net.UDPConn } func newTestClient(t *testing.T, vni uint32, serverEndpoint netip.AddrPort, local key.DiscoPrivate, server key.DiscoPublic) *testClient { rAddr := &net.UDPAddr{IP: serverEndpoint.Addr().AsSlice(), Port: int(serverEndpoint.Port())} uc, err := net.DialUDP("udp4", nil, rAddr) if err != nil { t.Fatal(err) } return &testClient{ vni: vni, local: local, server: server, uc: uc, } } func (c *testClient) write(t *testing.T, b []byte) { _, err := c.uc.Write(b) if err != nil { t.Fatal(err) } } func (c *testClient) read(t *testing.T) []byte { c.uc.SetReadDeadline(time.Now().Add(time.Second)) b := make([]byte, 1<<16-1) n, err := c.uc.Read(b) if err != nil { t.Fatal(err) } return b[:n] } func (c *testClient) writeDataPkt(t *testing.T, b []byte) { pkt := make([]byte, packet.GeneveFixedHeaderLength, packet.GeneveFixedHeaderLength+len(b)) gh := packet.GeneveHeader{Control: false, VNI: c.vni, Protocol: packet.GeneveProtocolWireGuard} err := gh.Encode(pkt) if err != nil { t.Fatal(err) } pkt = append(pkt, b...) c.write(t, pkt) } func (c *testClient) readDataPkt(t *testing.T) []byte { b := c.read(t) gh := packet.GeneveHeader{} err := gh.Decode(b) if err != nil { t.Fatal(err) } if gh.Protocol != packet.GeneveProtocolWireGuard { t.Fatal("unexpected geneve protocol") } if gh.Control { t.Fatal("unexpected control") } if gh.VNI != c.vni { t.Fatal("unexpected vni") } return b[packet.GeneveFixedHeaderLength:] } func (c *testClient) writeControlDiscoMsg(t *testing.T, msg disco.Message) { pkt := make([]byte, packet.GeneveFixedHeaderLength, 512) gh := packet.GeneveHeader{Control: true, VNI: c.vni, Protocol: packet.GeneveProtocolDisco} err := gh.Encode(pkt) if err != nil { t.Fatal(err) } pkt = append(pkt, disco.Magic...) pkt = c.local.Public().AppendTo(pkt) box := c.local.Shared(c.server).Seal(msg.AppendMarshal(nil)) pkt = append(pkt, box...) c.write(t, pkt) } func (c *testClient) readControlDiscoMsg(t *testing.T) disco.Message { b := c.read(t) gh := packet.GeneveHeader{} err := gh.Decode(b) if err != nil { t.Fatal(err) } if gh.Protocol != packet.GeneveProtocolDisco { t.Fatal("unexpected geneve protocol") } if !gh.Control { t.Fatal("unexpected non-control") } if gh.VNI != c.vni { t.Fatal("unexpected vni") } b = b[packet.GeneveFixedHeaderLength:] headerLen := len(disco.Magic) + key.DiscoPublicRawLen if len(b) < headerLen { t.Fatal("disco message too short") } sender := key.DiscoPublicFromRaw32(mem.B(b[len(disco.Magic):headerLen])) if sender.Compare(c.server) != 0 { t.Fatal("unknown disco public key") } payload, ok := c.local.Shared(c.server).Open(b[headerLen:]) if !ok { t.Fatal("failed to open sealed disco msg") } msg, err := disco.Parse(payload) if err != nil { t.Fatal("failed to parse disco payload") } return msg } func (c *testClient) handshake(t *testing.T) { c.writeControlDiscoMsg(t, &disco.BindUDPRelayEndpoint{}) msg := c.readControlDiscoMsg(t) challenge, ok := msg.(*disco.BindUDPRelayEndpointChallenge) if !ok { t.Fatal("unexepcted disco message type") } c.writeControlDiscoMsg(t, &disco.BindUDPRelayEndpointAnswer{Answer: challenge.Challenge}) } func (c *testClient) close() { c.uc.Close() } func TestServer(t *testing.T) { discoA := key.NewDisco() discoB := key.NewDisco() ipv4LoopbackAddr := netip.MustParseAddr("127.0.0.1") server, _, err := NewServer(0, []netip.Addr{ipv4LoopbackAddr}) if err != nil { t.Fatal(err) } defer server.Close() endpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public()) if err != nil { t.Fatal(err) } dupEndpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public()) if err != nil { t.Fatal(err) } // We expect the same endpoint details as the 3-way bind handshake has not // yet been completed for both relay client parties. if diff := cmp.Diff(dupEndpoint, endpoint, cmpopts.EquateComparable(netip.AddrPort{}, key.DiscoPublic{})); diff != "" { t.Fatalf("wrong dupEndpoint (-got +want)\n%s", diff) } if len(endpoint.AddrPorts) != 1 { t.Fatalf("unexpected endpoint.AddrPorts: %v", endpoint.AddrPorts) } tcA := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoA, endpoint.ServerDisco) defer tcA.close() tcB := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoB, endpoint.ServerDisco) defer tcB.close() tcA.handshake(t) tcB.handshake(t) txToB := []byte{1, 2, 3} tcA.writeDataPkt(t, txToB) rxFromA := tcB.readDataPkt(t) if !bytes.Equal(txToB, rxFromA) { t.Fatal("unexpected msg A->B") } txToA := []byte{4, 5, 6} tcB.writeDataPkt(t, txToA) rxFromB := tcA.readDataPkt(t) if !bytes.Equal(txToA, rxFromB) { t.Fatal("unexpected msg B->A") } } func TestServerEndpointJSONUnmarshal(t *testing.T) { tests := []struct { name string json []byte wantErr bool }{ { name: "valid", json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"30s","SteadyStateLifetime":"5m0s"}`), wantErr: false, }, { name: "invalid ServerDisco", json: []byte(`{"ServerDisco":"1","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"30s","SteadyStateLifetime":"5m0s"}`), wantErr: true, }, { name: "invalid LamportID", json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":1.1,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"30s","SteadyStateLifetime":"5m0s"}`), wantErr: true, }, { name: "invalid AddrPorts", json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"30s","SteadyStateLifetime":"5m0s"}`), wantErr: true, }, { name: "invalid VNI", json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":18446744073709551615,"BindLifetime":"30s","SteadyStateLifetime":"5m0s"}`), wantErr: true, }, { name: "invalid BindLifetime", json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"5","SteadyStateLifetime":"5m0s"}`), wantErr: true, }, { name: "invalid SteadyStateLifetime", json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"30s","SteadyStateLifetime":"5"}`), wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var out ServerEndpoint err := json.Unmarshal(tt.json, &out) if tt.wantErr != (err != nil) { t.Fatalf("wantErr: %v (err == nil): %v", tt.wantErr, err == nil) } if tt.wantErr { return } }) } } func TestServerEndpointJSONMarshal(t *testing.T) { tests := []struct { name string serverEndpoint ServerEndpoint }{ { name: "valid roundtrip", serverEndpoint: ServerEndpoint{ ServerDisco: key.NewDisco().Public(), LamportID: uint64(math.MaxUint64), AddrPorts: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:1"), netip.MustParseAddrPort("127.0.0.2:2")}, VNI: 1<<24 - 1, BindLifetime: tstime.GoDuration{Duration: defaultBindLifetime}, SteadyStateLifetime: tstime.GoDuration{Duration: defaultSteadyStateLifetime}, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { b, err := json.Marshal(&tt.serverEndpoint) if err != nil { t.Fatal(err) } var got ServerEndpoint err = json.Unmarshal(b, &got) if err != nil { t.Fatal(err) } if diff := cmp.Diff(got, tt.serverEndpoint, cmpopts.EquateComparable(netip.AddrPort{}, key.DiscoPublic{})); diff != "" { t.Fatalf("ServerEndpoint unequal (-got +want)\n%s", diff) } }) } }