diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 58192c50a..c6b8f6962 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -45,8 +45,9 @@ import ( // A Conn routes UDP packets and actively manages a list of its endpoints. // It implements wireguard/device.Bind. type Conn struct { + pconnPort uint16 // the preferred port from opts.Port; 0 means auto pconn4 *RebindingUDPConn - pconnPort uint16 + pconn6 *RebindingUDPConn // non-nil if IPv6 available epFunc func(endpoints []string) logf logger.Logf sendLogLimit *rate.Limiter @@ -137,6 +138,8 @@ var DisableSTUNForTesting bool // Options contains options for Listen. type Options struct { + // Logf optionally provides a log function to use. + // If nil, log.Printf is used. Logf logger.Logf // Port is the port to listen on. @@ -153,6 +156,13 @@ type Options struct { derpTLSConfig *tls.Config // normally nil; used by tests } +func (o *Options) logf() logger.Logf { + if o.Logf != nil { + return o.Logf + } + return log.Printf +} + func (o *Options) endpointsFunc() func([]string) { if o == nil || o.EndpointsFunc == nil { return func([]string) {} @@ -164,41 +174,11 @@ func (o *Options) endpointsFunc() func([]string) { // As the set of possible endpoints for a Conn changes, the // callback opts.EndpointsFunc is called. func Listen(opts Options) (*Conn, error) { - var packetConn net.PacketConn - var err error - - logf := log.Printf - if opts.Logf != nil { - logf = opts.Logf - } - - if opts.Port == 0 { - // Our choice of port. Start with DefaultPort. - // If unavailable, pick any port. - want := fmt.Sprintf(":%d", DefaultPort) - logf("magicsock: bind: trying %v\n", want) - packetConn, err = net.ListenPacket("udp4", want) - if err != nil { - want = ":0" - logf("magicsock: bind: falling back to %v (%v)\n", want, err) - packetConn, err = net.ListenPacket("udp4", want) - } - } else { - packetConn, err = net.ListenPacket("udp4", fmt.Sprintf(":%d", opts.Port)) - } - if err != nil { - return nil, fmt.Errorf("magicsock.Listen: %v", err) - } - - connCtx, connCtxCancel := context.WithCancel(context.Background()) c := &Conn{ - pconn4: new(RebindingUDPConn), pconnPort: opts.Port, - sendLogLimit: rate.NewLimiter(rate.Every(1*time.Minute), 1), - connCtx: connCtx, - connCtxCancel: connCtxCancel, + logf: opts.logf(), epFunc: opts.endpointsFunc(), - logf: logf, + sendLogLimit: rate.NewLimiter(rate.Every(1*time.Minute), 1), addrsByUDP: make(map[udpAddr]*AddrSet), addrsByKey: make(map[key.Public]*AddrSet), wantDerp: true, @@ -207,6 +187,12 @@ func Listen(opts Options) (*Conn, error) { derpTLSConfig: opts.derpTLSConfig, derps: opts.DERPs, } + + if err := c.initialBind(); err != nil { + return nil, err + } + + c.connCtx, c.connCtxCancel = context.WithCancel(context.Background()) if c.derps == nil { c.derps = derpmap.Prod() } @@ -214,11 +200,12 @@ func Listen(opts Options) (*Conn, error) { DERP: c.derps, Logf: logger.WithPrefix(c.logf, "netcheck: "), GetSTUNConn4: func() netcheck.STUNConn { return c.pconn4 }, - // TODO: add GetSTUNConn6 once Conn has a pconn6 + } + if c.pconn6 != nil { + c.netChecker.GetSTUNConn6 = func() netcheck.STUNConn { return c.pconn6 } } c.ignoreSTUNPackets() - c.pconn4.Reset(packetConn.(*net.UDPConn)) c.ReSTUN("initial") // We assume that LinkChange notifications are plumbed through well @@ -438,8 +425,7 @@ func (c *Conn) determineEndpoints(ctx context.Context) (ipPorts []string, err er if nr.GlobalV4 != "" { addAddr(nr.GlobalV4, "stun") } - const tailControlDoesIPv6 = false // TODO: when IPv6 filtering/splitting is enabled in tailcontrol - if nr.GlobalV6 != "" && tailControlDoesIPv6 { + if nr.GlobalV6 != "" { addAddr(nr.GlobalV6, "stun") } @@ -1005,9 +991,23 @@ func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr return n, ep, addr, nil } -func (c *Conn) ReceiveIPv6(buff []byte) (int, conn.Endpoint, *net.UDPAddr, error) { - // TODO(crawshaw): IPv6 support - return 0, nil, nil, syscall.EAFNOSUPPORT +func (c *Conn) ReceiveIPv6(b []byte) (int, conn.Endpoint, *net.UDPAddr, error) { + if c.pconn6 == nil { + return 0, nil, nil, syscall.EAFNOSUPPORT + } + for { + n, pAddr, err := c.pconn6.ReadFrom(b) + if err != nil { + return 0, nil, nil, err + } + addr := pAddr.(*net.UDPAddr) + if stun.Is(b[:n]) { + c.stunReceiveFunc.Load().(func([]byte, *net.UDPAddr))(b, addr) + continue + } + // TODO(bradfitz): finish. look up addrset, return etc. + // For now we're only using this for STUN. + } } // SetPrivateKey sets the connection's private key. @@ -1107,6 +1107,9 @@ func (c *Conn) Close() error { c.closed = true c.connCtxCancel() c.closeAllDerpLocked() + if c.pconn6 != nil { + c.pconn6.Close() + } return c.pconn4.Close() } @@ -1152,6 +1155,40 @@ func (c *Conn) ReSTUN(why string) { } } +func (c *Conn) initialBind() error { + if err := c.bind1(&c.pconn4, "udp4"); err != nil { + return err + } + if err := c.bind1(&c.pconn6, "udp6"); err != nil { + c.logf("ignoring IPv6 bind failure: %v", err) + } + return nil +} + +func (c *Conn) bind1(ruc **RebindingUDPConn, which string) error { + var pc net.PacketConn + var err error + if c.pconnPort == 0 && DefaultPort != 0 { + pc, err = net.ListenPacket(which, fmt.Sprintf(":%d", DefaultPort)) + if err != nil { + c.logf("magicsock: bind: default port %s/%v unavailable; picking random", which, DefaultPort) + } + } + if pc == nil { + // If unavailable, pick any port. + pc, err = net.ListenPacket(which, fmt.Sprintf(":%d", c.pconnPort)) + } + if err != nil { + c.logf("magicsock: bind(%s/%v): %v", which, c.pconnPort, err) + return fmt.Errorf("magicsock: bind: %s/%d: %v", which, c.pconnPort, err) + } + if *ruc == nil { + *ruc = new(RebindingUDPConn) + } + (*ruc).Reset(pc.(*net.UDPConn)) + return nil +} + // Rebind closes and re-binds the UDP sockets. // It should be followed by a call to ReSTUN. func (c *Conn) Rebind() {