diff --git a/net/stunserver/stunserver.go b/net/stunserver/stunserver.go index b45bb6331..c7068e116 100644 --- a/net/stunserver/stunserver.go +++ b/net/stunserver/stunserver.go @@ -15,6 +15,8 @@ import ( "net/netip" "time" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" "tailscale.com/metrics" "tailscale.com/net/stun" ) @@ -70,13 +72,17 @@ func (s *STUNServer) Listen(listenAddr string) error { // Serve starts serving responses to STUN requests. Listen must be called before Serve. func (s *STUNServer) Serve() error { var buf [64 << 10]byte + var oob [4096]byte var ( - n int - ua *net.UDPAddr - err error + n, oobn int + remote netip.AddrPort + local net.IP + err error + cm4 ipv4.ControlMessage + cm6 ipv6.ControlMessage ) for { - n, ua, err = s.pc.ReadFromUDP(buf[:]) + n, oobn, _, remote, err = s.pc.ReadMsgUDPAddrPort(buf[:], oob[:]) if err != nil { if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) { return nil @@ -86,6 +92,22 @@ func (s *STUNServer) Serve() error { stunReadError.Add(1) continue } + + if remote.Addr().Is4() { + err = cm4.Parse(oob[:oobn]) + } else { + err = cm6.Parse(oob[:oobn]) + } + if err != nil { + log.Printf("parse control msg error: %v", err) + continue + } + if remote.Addr().Is4() { + local = cm4.Dst + } else { + local = cm6.Dst + } + pkt := buf[:n] if !stun.Is(pkt) { stunNotSTUN.Add(1) @@ -96,14 +118,27 @@ func (s *STUNServer) Serve() error { stunNotSTUN.Add(1) continue } - if ua.IP.To4() != nil { + if remote.Addr().Is4() { stunIPv4.Add(1) } else { stunIPv6.Add(1) } - addr, _ := netip.AddrFromSlice(ua.IP) - res := stun.Response(txid, netip.AddrPortFrom(addr, uint16(ua.Port))) - _, err = s.pc.WriteTo(res, ua) + res := stun.Response(txid, remote) + + // TODO(raggi): send upstream patch to provide a way to serialize a + // control message into an existng buffer. + if remote.Addr().Is4() { + cm4 = ipv4.ControlMessage{ + Src: local, + } + oobn = copy(oob[:], cm4.Marshal()) + } else { + cm6 = ipv6.ControlMessage{ + Src: local, + } + oobn = copy(oob[:], cm6.Marshal()) + } + _, _, err = s.pc.WriteMsgUDPAddrPort(res, oob[:oobn], remote) if err != nil { stunWriteError.Add(1) } else {