diff --git a/net/udprelay/server.go b/net/udprelay/server.go index d2661e59f..979ccf717 100644 --- a/net/udprelay/server.go +++ b/net/udprelay/server.go @@ -112,7 +112,7 @@ type serverEndpoint struct { allocatedAt time.Time } -func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex int, discoMsg disco.Message, uw udpWriter, serverDisco key.DiscoPublic) { +func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex int, discoMsg disco.Message, conn *net.UDPConn, serverDisco key.DiscoPublic) { if senderIndex != 0 && senderIndex != 1 { return } @@ -165,7 +165,7 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex reply = serverDisco.AppendTo(reply) box := e.discoSharedSecrets[senderIndex].Seal(m.AppendMarshal(nil)) reply = append(reply, box...) - uw.WriteMsgUDPAddrPort(reply, nil, from) + conn.WriteMsgUDPAddrPort(reply, nil, from) return case *disco.BindUDPRelayEndpointAnswer: err := validateVNIAndRemoteKey(discoMsg.BindUDPRelayEndpointCommon) @@ -191,7 +191,7 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex } } -func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []byte, uw udpWriter, serverDisco key.DiscoPublic) { +func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []byte, conn *net.UDPConn, serverDisco key.DiscoPublic) { senderRaw, isDiscoMsg := disco.Source(b) if !isDiscoMsg { // Not a Disco message @@ -222,14 +222,10 @@ func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []by return } - e.handleDiscoControlMsg(from, senderIndex, discoMsg, uw, serverDisco) + e.handleDiscoControlMsg(from, senderIndex, discoMsg, conn, serverDisco) } -type udpWriter interface { - WriteMsgUDPAddrPort(b []byte, oob []byte, addr netip.AddrPort) (n, oobn int, err error) -} - -func (e *serverEndpoint) handlePacket(from netip.AddrPort, gh packet.GeneveHeader, b []byte, uw udpWriter, serverDisco key.DiscoPublic) { +func (e *serverEndpoint) handlePacket(from netip.AddrPort, gh packet.GeneveHeader, b []byte, rxSocket, otherAFSocket *net.UDPConn, serverDisco key.DiscoPublic) { if !gh.Control { if !e.isBound() { // not a control packet, but serverEndpoint isn't bound @@ -247,8 +243,16 @@ func (e *serverEndpoint) handlePacket(from netip.AddrPort, gh packet.GeneveHeade // unrecognized source return } - // relay packet - uw.WriteMsgUDPAddrPort(b, nil, to) + // Relay the packet towards the other party via the socket associated + // with the destination's address family. If source and destination + // address families are matching we tx on the same socket the packet + // was received (rxSocket), otherwise we use the "other" socket + // (otherAFSocket). [Server] makes no use of dual-stack sockets. + if from.Addr().Is4() == to.Addr().Is4() { + rxSocket.WriteMsgUDPAddrPort(b, nil, to) + } else if otherAFSocket != nil { + otherAFSocket.WriteMsgUDPAddrPort(b, nil, to) + } return } @@ -258,7 +262,7 @@ func (e *serverEndpoint) handlePacket(from netip.AddrPort, gh packet.GeneveHeade } msg := b[packet.GeneveFixedHeaderLength:] - e.handleSealedDiscoControlMsg(from, msg, uw, serverDisco) + e.handleSealedDiscoControlMsg(from, msg, rxSocket, serverDisco) } func (e *serverEndpoint) isExpired(now time.Time, bindLifetime, steadyStateLifetime time.Duration) bool { @@ -346,10 +350,10 @@ func NewServer(logf logger.Logf, port int, overrideAddrs []netip.Addr) (s *Serve } s.wg.Add(1) - go s.packetReadLoop(s.uc4) + go s.packetReadLoop(s.uc4, s.uc6) if s.uc6 != nil { s.wg.Add(1) - go s.packetReadLoop(s.uc6) + go s.packetReadLoop(s.uc6, s.uc4) } s.wg.Add(1) go s.endpointGCLoop() @@ -531,7 +535,7 @@ func (s *Server) endpointGCLoop() { } } -func (s *Server) handlePacket(from netip.AddrPort, b []byte, uw udpWriter) { +func (s *Server) handlePacket(from netip.AddrPort, b []byte, rxSocket, otherAFSocket *net.UDPConn) { if stun.Is(b) && b[1] == 0x01 { // A b[1] value of 0x01 (STUN method binding) is sufficiently // non-overlapping with the Geneve header where the LSB is always 0 @@ -555,10 +559,10 @@ func (s *Server) handlePacket(from netip.AddrPort, b []byte, uw udpWriter) { return } - e.handlePacket(from, gh, b, uw, s.discoPublic) + e.handlePacket(from, gh, b, rxSocket, otherAFSocket, s.discoPublic) } -func (s *Server) packetReadLoop(uc *net.UDPConn) { +func (s *Server) packetReadLoop(readFromSocket, otherSocket *net.UDPConn) { defer func() { s.wg.Done() s.Close() @@ -566,11 +570,11 @@ func (s *Server) packetReadLoop(uc *net.UDPConn) { b := make([]byte, 1<<16-1) for { // TODO: extract laddr from IP_PKTINFO for use in reply - n, from, err := uc.ReadFromUDPAddrPort(b) + n, from, err := readFromSocket.ReadFromUDPAddrPort(b) if err != nil { return } - s.handlePacket(from, b[:n], uc) + s.handlePacket(from, b[:n], readFromSocket, otherSocket) } } diff --git a/net/udprelay/server_test.go b/net/udprelay/server_test.go index 8c0c5aff6..de1c29364 100644 --- a/net/udprelay/server_test.go +++ b/net/udprelay/server_test.go @@ -181,8 +181,9 @@ func TestServer(t *testing.T) { discoB := key.NewDisco() cases := []struct { - name string - overrideAddrs []netip.Addr + name string + overrideAddrs []netip.Addr + forceClientsMixedAF bool }{ { name: "over ipv4", @@ -192,6 +193,11 @@ func TestServer(t *testing.T) { name: "over ipv6", overrideAddrs: []netip.Addr{netip.MustParseAddr("::1")}, }, + { + name: "mixed address families", + overrideAddrs: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("::1")}, + forceClientsMixedAF: true, + }, } for _, tt := range cases { @@ -216,16 +222,47 @@ func TestServer(t *testing.T) { t.Fatalf("wrong dupEndpoint (-got +want)\n%s", diff) } - if len(endpoint.AddrPorts) != 1 { + if len(endpoint.AddrPorts) < 1 { t.Fatalf("unexpected endpoint.AddrPorts: %v", endpoint.AddrPorts) } - tcA := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoA, discoB.Public(), endpoint.ServerDisco) + tcAServerEndpointAddr := endpoint.AddrPorts[0] + tcA := newTestClient(t, endpoint.VNI, tcAServerEndpointAddr, discoA, discoB.Public(), endpoint.ServerDisco) defer tcA.close() - tcB := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoB, discoA.Public(), endpoint.ServerDisco) + tcBServerEndpointAddr := tcAServerEndpointAddr + if tt.forceClientsMixedAF { + foundMixedAF := false + for _, addr := range endpoint.AddrPorts { + if addr.Addr().Is4() != tcBServerEndpointAddr.Addr().Is4() { + tcBServerEndpointAddr = addr + foundMixedAF = true + } + } + if !foundMixedAF { + t.Fatal("force clients to mixed address families is set, but relay server lacks address family diversity") + } + } + tcB := newTestClient(t, endpoint.VNI, tcBServerEndpointAddr, discoB, discoA.Public(), endpoint.ServerDisco) defer tcB.close() - tcA.handshake(t) - tcB.handshake(t) + for i := 0; i < 2; i++ { + // We handshake both clients twice to guarantee server-side + // packet reading goroutines, which are independent across + // address families, have seen an answer from both clients + // before proceeding. This is needed because the test assumes + // that B's handshake is complete (the first send is A->B below), + // but the server may not have handled B's handshake answer + // before it handles A's data pkt towards B. + // + // Data transmissions following "re-handshakes" orient so that + // the sender is the same as the party that performed the + // handshake, for the same reasons. + // + // [magicsock.relayManager] is not prone to this issue as both + // parties transmit data packets immediately following their + // handshake answer. + tcA.handshake(t) + tcB.handshake(t) + } dupEndpoint, err = server.AllocateEndpoint(discoA.Public(), discoB.Public()) if err != nil { @@ -250,30 +287,32 @@ func TestServer(t *testing.T) { t.Fatal("unexpected msg B->A") } - tcAOnNewPort := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoA, discoB.Public(), endpoint.ServerDisco) + tcAOnNewPort := newTestClient(t, endpoint.VNI, tcAServerEndpointAddr, discoA, discoB.Public(), endpoint.ServerDisco) tcAOnNewPort.handshakeGeneration = tcA.handshakeGeneration + 1 defer tcAOnNewPort.close() - // Handshake client A on a new source IP:port, verify we receive packets on the new binding + // Handshake client A on a new source IP:port, verify we can send packets on the new binding tcAOnNewPort.handshake(t) - txToAOnNewPort := []byte{7, 8, 9} - tcB.writeDataPkt(t, txToAOnNewPort) - rxFromB = tcAOnNewPort.readDataPkt(t) - if !bytes.Equal(txToAOnNewPort, rxFromB) { - t.Fatal("unexpected msg B->A") + + fromAOnNewPort := []byte{7, 8, 9} + tcAOnNewPort.writeDataPkt(t, fromAOnNewPort) + rxFromA = tcB.readDataPkt(t) + if !bytes.Equal(fromAOnNewPort, rxFromA) { + t.Fatal("unexpected msg A->B") } - tcBOnNewPort := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoB, discoA.Public(), endpoint.ServerDisco) + tcBOnNewPort := newTestClient(t, endpoint.VNI, tcBServerEndpointAddr, discoB, discoA.Public(), endpoint.ServerDisco) tcBOnNewPort.handshakeGeneration = tcB.handshakeGeneration + 1 defer tcBOnNewPort.close() - // Handshake client B on a new source IP:port, verify we receive packets on the new binding + // Handshake client B on a new source IP:port, verify we can send packets on the new binding tcBOnNewPort.handshake(t) - txToBOnNewPort := []byte{7, 8, 9} - tcAOnNewPort.writeDataPkt(t, txToBOnNewPort) - rxFromA = tcBOnNewPort.readDataPkt(t) - if !bytes.Equal(txToBOnNewPort, rxFromA) { - t.Fatal("unexpected msg A->B") + + fromBOnNewPort := []byte{7, 8, 9} + tcBOnNewPort.writeDataPkt(t, fromBOnNewPort) + rxFromB = tcAOnNewPort.readDataPkt(t) + if !bytes.Equal(fromBOnNewPort, rxFromB) { + t.Fatal("unexpected msg B->A") } }) }