From 7dc7078d9609b03c4f834ee1b1496f7219c30166 Mon Sep 17 00:00:00 2001 From: Josh Bleecher Snyder Date: Wed, 28 Apr 2021 10:28:44 -0700 Subject: [PATCH] wgengine/magicsock: use netaddr.IP in listenPacket It must be an IP address; enforce that at the type level. Suggested-by: Brad Fitzpatrick Signed-off-by: Josh Bleecher Snyder --- wgengine/magicsock/magicsock.go | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index e6e722ef2..d7cbfb7ea 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -2596,10 +2596,17 @@ func (c *Conn) initialBind() error { return nil } -// listenPacket opens a listener. Host must be "" or an IP address. -func (c *Conn) listenPacket(network, host string, port uint16) (net.PacketConn, error) { +// listenPacket opens a packet listener. +// The network must be "udp4" or "udp6". +// Host is the (local) IP address to listen on; use the zero IP to leave unspecified. +func (c *Conn) listenPacket(network string, host netaddr.IP, port uint16) (net.PacketConn, error) { ctx := context.Background() // unused without DNS name to resolve - addr := net.JoinHostPort(host, fmt.Sprint(port)) + // Translate host to package net: "" for the zero value, the IP address string otherwise. + var s string + if !host.IsZero() { + s = host.String() + } + addr := net.JoinHostPort(s, fmt.Sprint(port)) if c.packetListener != nil { return c.packetListener.ListenPacket(ctx, network, addr) } @@ -2611,11 +2618,15 @@ func (c *Conn) listenPacket(network, host string, port uint16) (net.PacketConn, // If rucPtr had an existing UDP socket bound, it closes that socket. // The caller is responsible for informing the portMapper of any changes. func (c *Conn) bindSocket(rucPtr **RebindingUDPConn, network string) error { - host := "" + var host netaddr.IP if inTest() && !c.simulatedNetwork { - host = "127.0.0.1" - if network == "udp6" { - host = "::1" + switch network { + case "udp4": + host = netaddr.MustParseIP("127.0.0.1") + case "udp6": + host = netaddr.MustParseIP("::1") + default: + panic("unrecognized network in bindSocket: " + network) } }