diff --git a/net/udprelay/server.go b/net/udprelay/server.go index 7138cec7a..b260955e0 100644 --- a/net/udprelay/server.go +++ b/net/udprelay/server.go @@ -10,6 +10,7 @@ import ( "bytes" "context" "crypto/rand" + "encoding/binary" "errors" "fmt" "net" @@ -20,6 +21,7 @@ import ( "time" "go4.org/mem" + "golang.org/x/crypto/blake2s" "golang.org/x/net/ipv6" "tailscale.com/disco" "tailscale.com/net/batching" @@ -73,7 +75,9 @@ type Server struct { closeCh chan struct{} netChecker *netcheck.Client - mu sync.Mutex // guards the following fields + mu sync.Mutex // guards the following fields + macSecrets [][blake2s.Size]byte // [0] is most recent, max 2 elements + macSecretRotatedAt time.Time derpMap *tailcfg.DERPMap onlyStaticAddrPorts bool // no dynamic addr port discovery when set staticAddrPorts views.Slice[netip.AddrPort] // static ip:port pairs set with [Server.SetStaticAddrPorts] @@ -85,6 +89,8 @@ type Server struct { byDisco map[key.SortedPairOfDiscoPublic]*serverEndpoint } +const macSecretRotationInterval = time.Minute * 2 + const ( minVNI = uint32(1) maxVNI = uint32(1<<24 - 1) @@ -98,22 +104,42 @@ type serverEndpoint struct { // indexing of this array aligns with the following fields, e.g. // discoSharedSecrets[0] is the shared secret to use when sealing // Disco protocol messages for transmission towards discoPubKeys[0]. - discoPubKeys key.SortedPairOfDiscoPublic - discoSharedSecrets [2]key.DiscoShared - handshakeGeneration [2]uint32 // or zero if a handshake has never started for that relay leg - handshakeAddrPorts [2]netip.AddrPort // or zero value if a handshake has never started for that relay leg - boundAddrPorts [2]netip.AddrPort // or zero value if a handshake has never completed for that relay leg - lastSeen [2]time.Time // TODO(jwhited): consider using mono.Time - challenge [2][disco.BindUDPRelayChallengeLen]byte - packetsRx [2]uint64 // num packets received from/sent by each client after they are bound - bytesRx [2]uint64 // num bytes received from/sent by each client after they are bound + discoPubKeys key.SortedPairOfDiscoPublic + discoSharedSecrets [2]key.DiscoShared + inProgressGeneration [2]uint32 // or zero if a handshake has never started, or has just completed + boundAddrPorts [2]netip.AddrPort // or zero value if a handshake has never completed for that relay leg + lastSeen [2]time.Time // TODO(jwhited): consider using mono.Time + packetsRx [2]uint64 // num packets received from/sent by each client after they are bound + bytesRx [2]uint64 // num bytes received from/sent by each client after they are bound lamportID uint64 vni uint32 allocatedAt time.Time } -func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex int, discoMsg disco.Message, serverDisco key.DiscoPublic) (write []byte, to netip.AddrPort) { +func blakeMACFromBindMsg(blakeKey [blake2s.Size]byte, src netip.AddrPort, msg disco.BindUDPRelayEndpointCommon) ([blake2s.Size]byte, error) { + input := make([]byte, 8, 4+4+32+18) // vni + generation + invited party disco key + addr:port + binary.BigEndian.PutUint32(input[0:4], msg.VNI) + binary.BigEndian.PutUint32(input[4:8], msg.Generation) + input = msg.RemoteKey.AppendTo(input) + input, err := src.AppendBinary(input) + if err != nil { + return [blake2s.Size]byte{}, err + } + h, err := blake2s.New256(blakeKey[:]) + if err != nil { + return [blake2s.Size]byte{}, err + } + _, err = h.Write(input) + if err != nil { + return [blake2s.Size]byte{}, err + } + var out [blake2s.Size]byte + h.Sum(out[:0]) + return out, nil +} + +func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex int, discoMsg disco.Message, serverDisco key.DiscoPublic, macSecrets [][blake2s.Size]byte) (write []byte, to netip.AddrPort) { if senderIndex != 0 && senderIndex != 1 { return nil, netip.AddrPort{} } @@ -144,18 +170,11 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex // Generation must be nonzero, silently drop return nil, netip.AddrPort{} } - if e.handshakeGeneration[senderIndex] == discoMsg.Generation { - // we've seen this generation before, silently drop - return nil, netip.AddrPort{} - } - e.handshakeGeneration[senderIndex] = discoMsg.Generation - e.handshakeAddrPorts[senderIndex] = from + e.inProgressGeneration[senderIndex] = discoMsg.Generation m := new(disco.BindUDPRelayEndpointChallenge) m.VNI = e.vni m.Generation = discoMsg.Generation m.RemoteKey = e.discoPubKeys.Get()[otherSender] - rand.Read(e.challenge[senderIndex][:]) - copy(m.Challenge[:], e.challenge[senderIndex][:]) reply := make([]byte, packet.GeneveFixedHeaderLength, 512) gh := packet.GeneveHeader{Control: true, Protocol: packet.GeneveProtocolDisco} gh.VNI.Set(e.vni) @@ -165,6 +184,11 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex } reply = append(reply, disco.Magic...) reply = serverDisco.AppendTo(reply) + mac, err := blakeMACFromBindMsg(macSecrets[0], from, m.BindUDPRelayEndpointCommon) + if err != nil { + return nil, netip.AddrPort{} + } + m.Challenge = mac box := e.discoSharedSecrets[senderIndex].Seal(m.AppendMarshal(nil)) reply = append(reply, box...) return reply, from @@ -174,17 +198,29 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex // silently drop return nil, netip.AddrPort{} } - generation := e.handshakeGeneration[senderIndex] - if generation == 0 || // we have no active handshake - generation != discoMsg.Generation || // mismatching generation for the active handshake - e.handshakeAddrPorts[senderIndex] != from || // mismatching source for the active handshake - !bytes.Equal(e.challenge[senderIndex][:], discoMsg.Challenge[:]) { // mismatching answer for the active handshake + generation := e.inProgressGeneration[senderIndex] + if generation == 0 || // we have no in-progress handshake + generation != discoMsg.Generation { // mismatching generation for the in-progress handshake // silently drop return nil, netip.AddrPort{} } - // Handshake complete. Update the binding for this sender. - e.boundAddrPorts[senderIndex] = from - e.lastSeen[senderIndex] = time.Now() // record last seen as bound time + for _, macSecret := range macSecrets { + mac, err := blakeMACFromBindMsg(macSecret, from, discoMsg.BindUDPRelayEndpointCommon) + if err != nil { + // silently drop + return nil, netip.AddrPort{} + } + // Speed is favored over constant-time comparison here. The sender is + // already authenticated via disco. + if bytes.Equal(mac[:], discoMsg.Challenge[:]) { + // Handshake complete. Update the binding for this sender. + e.boundAddrPorts[senderIndex] = from + e.lastSeen[senderIndex] = time.Now() // record last seen as bound time + e.inProgressGeneration[senderIndex] = 0 // reset to zero, which indicates there is no in-progress handshake + return nil, netip.AddrPort{} + } + } + // MAC does not match, silently drop return nil, netip.AddrPort{} default: // unexpected message types, silently drop @@ -192,7 +228,7 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex } } -func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []byte, serverDisco key.DiscoPublic) (write []byte, to netip.AddrPort) { +func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []byte, serverDisco key.DiscoPublic, macSecrets [][blake2s.Size]byte) (write []byte, to netip.AddrPort) { senderRaw, isDiscoMsg := disco.Source(b) if !isDiscoMsg { // Not a Disco message @@ -223,39 +259,29 @@ func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []by return nil, netip.AddrPort{} } - return e.handleDiscoControlMsg(from, senderIndex, discoMsg, serverDisco) + return e.handleDiscoControlMsg(from, senderIndex, discoMsg, serverDisco, macSecrets) } -func (e *serverEndpoint) handlePacket(from netip.AddrPort, gh packet.GeneveHeader, b []byte, serverDisco key.DiscoPublic) (write []byte, to netip.AddrPort) { - if !gh.Control { - if !e.isBound() { - // not a control packet, but serverEndpoint isn't bound - return nil, netip.AddrPort{} - } - switch { - case from == e.boundAddrPorts[0]: - e.lastSeen[0] = time.Now() - e.packetsRx[0]++ - e.bytesRx[0] += uint64(len(b)) - return b, e.boundAddrPorts[1] - case from == e.boundAddrPorts[1]: - e.lastSeen[1] = time.Now() - e.packetsRx[1]++ - e.bytesRx[1] += uint64(len(b)) - return b, e.boundAddrPorts[0] - default: - // unrecognized source - return nil, netip.AddrPort{} - } - } - - if gh.Protocol != packet.GeneveProtocolDisco { - // control packet, but not Disco +func (e *serverEndpoint) handleDataPacket(from netip.AddrPort, b []byte, now time.Time) (write []byte, to netip.AddrPort) { + if !e.isBound() { + // not a control packet, but serverEndpoint isn't bound + return nil, netip.AddrPort{} + } + switch { + case from == e.boundAddrPorts[0]: + e.lastSeen[0] = now + e.packetsRx[0]++ + e.bytesRx[0] += uint64(len(b)) + return b, e.boundAddrPorts[1] + case from == e.boundAddrPorts[1]: + e.lastSeen[1] = now + e.packetsRx[1]++ + e.bytesRx[1] += uint64(len(b)) + return b, e.boundAddrPorts[0] + default: + // unrecognized source return nil, netip.AddrPort{} } - - msg := b[packet.GeneveFixedHeaderLength:] - return e.handleSealedDiscoControlMsg(from, msg, serverDisco) } func (e *serverEndpoint) isExpired(now time.Time, bindLifetime, steadyStateLifetime time.Duration) bool { @@ -621,7 +647,35 @@ func (s *Server) handlePacket(from netip.AddrPort, b []byte) (write []byte, to n return nil, netip.AddrPort{} } - return e.handlePacket(from, gh, b, s.discoPublic) + now := time.Now() + if gh.Control { + if gh.Protocol != packet.GeneveProtocolDisco { + // control packet, but not Disco + return nil, netip.AddrPort{} + } + msg := b[packet.GeneveFixedHeaderLength:] + s.maybeRotateMACSecretLocked(now) + return e.handleSealedDiscoControlMsg(from, msg, s.discoPublic, s.macSecrets) + } + return e.handleDataPacket(from, b, now) +} + +func (s *Server) maybeRotateMACSecretLocked(now time.Time) { + if !s.macSecretRotatedAt.IsZero() && now.Sub(s.macSecretRotatedAt) < macSecretRotationInterval { + return + } + switch len(s.macSecrets) { + case 0: + s.macSecrets = make([][blake2s.Size]byte, 1, 2) + case 1: + s.macSecrets = append(s.macSecrets, [blake2s.Size]byte{}) + fallthrough + case 2: + s.macSecrets[1] = s.macSecrets[0] + } + rand.Read(s.macSecrets[0][:]) + s.macSecretRotatedAt = now + return } func (s *Server) packetReadLoop(readFromSocket, otherSocket batching.Conn, readFromSocketIsIPv4 bool) { diff --git a/net/udprelay/server_test.go b/net/udprelay/server_test.go index 6c3d61658..582d4cf67 100644 --- a/net/udprelay/server_test.go +++ b/net/udprelay/server_test.go @@ -5,6 +5,7 @@ package udprelay import ( "bytes" + "crypto/rand" "net" "net/netip" "testing" @@ -14,6 +15,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "go4.org/mem" + "golang.org/x/crypto/blake2s" "tailscale.com/disco" "tailscale.com/net/packet" "tailscale.com/types/key" @@ -352,3 +354,117 @@ func TestServer_getNextVNILocked(t *testing.T) { _, err = s.getNextVNILocked() c.Assert(err, qt.IsNil) } + +func Test_blakeMACFromBindMsg(t *testing.T) { + var macSecret [blake2s.Size]byte + rand.Read(macSecret[:]) + src := netip.MustParseAddrPort("[2001:db8::1]:7") + + msgA := disco.BindUDPRelayEndpointCommon{ + VNI: 1, + Generation: 1, + RemoteKey: key.NewDisco().Public(), + Challenge: [32]byte{}, + } + macA, err := blakeMACFromBindMsg(macSecret, src, msgA) + if err != nil { + t.Fatal(err) + } + + msgB := msgA + msgB.VNI++ + macB, err := blakeMACFromBindMsg(macSecret, src, msgB) + if err != nil { + t.Fatal(err) + } + if macA == macB { + t.Fatalf("varying VNI input produced identical mac: %v", macA) + } + + msgC := msgA + msgC.Generation++ + macC, err := blakeMACFromBindMsg(macSecret, src, msgC) + if err != nil { + t.Fatal(err) + } + if macA == macC { + t.Fatalf("varying Generation input produced identical mac: %v", macA) + } + + msgD := msgA + msgD.RemoteKey = key.NewDisco().Public() + macD, err := blakeMACFromBindMsg(macSecret, src, msgD) + if err != nil { + t.Fatal(err) + } + if macA == macD { + t.Fatalf("varying RemoteKey input produced identical mac: %v", macA) + } + + msgE := msgA + msgE.Challenge = [32]byte{0x01} // challenge is not part of the MAC and should be ignored + macE, err := blakeMACFromBindMsg(macSecret, src, msgE) + if err != nil { + t.Fatal(err) + } + if macA != macE { + t.Fatalf("varying Challenge input produced varying mac: %v", macA) + } + + macSecretB := macSecret + macSecretB[0] ^= 0xFF + macF, err := blakeMACFromBindMsg(macSecretB, src, msgA) + if err != nil { + t.Fatal(err) + } + if macA == macF { + t.Fatalf("varying macSecret input produced identical mac: %v", macA) + } + + srcB := netip.AddrPortFrom(src.Addr(), src.Port()+1) + macG, err := blakeMACFromBindMsg(macSecret, srcB, msgA) + if err != nil { + t.Fatal(err) + } + if macA == macG { + t.Fatalf("varying src input produced identical mac: %v", macA) + } +} + +func Benchmark_blakeMACFromBindMsg(b *testing.B) { + var macSecret [blake2s.Size]byte + rand.Read(macSecret[:]) + src := netip.MustParseAddrPort("[2001:db8::1]:7") + msg := disco.BindUDPRelayEndpointCommon{ + VNI: 1, + Generation: 1, + RemoteKey: key.NewDisco().Public(), + Challenge: [32]byte{}, + } + b.ReportAllocs() + for b.Loop() { + _, err := blakeMACFromBindMsg(macSecret, src, msg) + if err != nil { + b.Fatal(err) + } + } +} + +func TestServer_maybeRotateMACSecretLocked(t *testing.T) { + s := &Server{} + start := time.Now() + s.maybeRotateMACSecretLocked(start) + qt.Assert(t, len(s.macSecrets), qt.Equals, 1) + macSecret := s.macSecrets[0] + s.maybeRotateMACSecretLocked(start.Add(macSecretRotationInterval - time.Nanosecond)) + qt.Assert(t, len(s.macSecrets), qt.Equals, 1) + qt.Assert(t, s.macSecrets[0], qt.Equals, macSecret) + s.maybeRotateMACSecretLocked(start.Add(macSecretRotationInterval)) + qt.Assert(t, len(s.macSecrets), qt.Equals, 2) + qt.Assert(t, s.macSecrets[1], qt.Equals, macSecret) + qt.Assert(t, s.macSecrets[0], qt.Not(qt.Equals), s.macSecrets[1]) + s.maybeRotateMACSecretLocked(s.macSecretRotatedAt.Add(macSecretRotationInterval)) + qt.Assert(t, macSecret, qt.Not(qt.Equals), s.macSecrets[0]) + qt.Assert(t, macSecret, qt.Not(qt.Equals), s.macSecrets[1]) + qt.Assert(t, s.macSecrets[0], qt.Not(qt.Equals), s.macSecrets[1]) +}