diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 2f3c47bc2..88f6467a6 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -49,7 +49,7 @@ type Conn struct { stunServers []string startEpUpdate chan struct{} // send to trigger endpoint update epFunc func(endpoints []string) - logf func(format string, args ...interface{}) + logf logger.Logf sendLogLimit *rate.Limiter // bufferedIPv4From and bufferedIPv4Packet are owned by @@ -121,6 +121,8 @@ type udpAddr struct { // Options contains options for Listen. type Options struct { + Logf logger.Logf + // Port is the port to listen on. // Zero means to pick one automatically. Port uint16 @@ -147,15 +149,21 @@ func (o *Options) endpointsFunc() func([]string) { func Listen(opts Options) (*Conn, error) { var packetConn net.PacketConn var err error + + logf := log.Printf + if opts.Logf != nil { + logf = opts.Logf + } + if opts.Port == 0 { // Our choice of port. Start with DefaultPort. // If unavailable, pick any port. want := fmt.Sprintf(":%d", DefaultPort) - log.Printf("magicsock: bind: trying %v\n", want) + logf("magicsock: bind: trying %v\n", want) packetConn, err = net.ListenPacket("udp4", want) if err != nil { want = ":0" - log.Printf("magicsock: bind: falling back to %v (%v)\n", want, err) + logf("magicsock: bind: falling back to %v (%v)\n", want, err) packetConn, err = net.ListenPacket("udp4", want) } } else { @@ -175,7 +183,7 @@ func Listen(opts Options) (*Conn, error) { connCtx: connCtx, connCtxCancel: connCtxCancel, epFunc: opts.endpointsFunc(), - logf: log.Printf, + logf: logf, addrsByUDP: make(map[udpAddr]*AddrSet), addrsByKey: make(map[key.Public]*AddrSet), wantDerp: true, @@ -398,7 +406,7 @@ func (c *Conn) determineEndpoints(ctx context.Context) (ipPorts []string, err er var eps []string // unique endpoints addAddr := func(s, reason string) { - log.Printf("magicsock: found local %s (%s)\n", s, reason) + c.logf("magicsock: found local %s (%s)\n", s, reason) alreadyMu.Lock() defer alreadyMu.Unlock() @@ -608,7 +616,7 @@ func (c *Conn) Send(b []byte, ep conn.Endpoint) error { ret = err } if err != nil && addr != roamAddr && c.sendLogLimit.Allow() { - log.Printf("magicsock: Conn.Send(%v): %v", addr, err) + c.logf("magicsock: Conn.Send(%v): %v", addr, err) } } if success { @@ -684,7 +692,7 @@ func (c *Conn) derpWriteChanOfAddr(addr *net.UDPAddr) chan<- derpWriteRequest { return nil } // TODO(bradfitz): don't hold derpMu here. It's slow. Release first and use singleflight to dial+re-lock to add. - dc, err := derphttp.NewClient(c.privateKey, "https://"+host+"/derp", log.Printf) + dc, err := derphttp.NewClient(c.privateKey, "https://"+host+"/derp", c.logf) if err != nil { c.logf("derphttp.NewClient: port %d, host %q invalid? err: %v", addr.Port, host, err) return nil @@ -765,7 +773,7 @@ func (c *Conn) runDerpReader(ctx context.Context, derpFakeAddr *net.UDPAddr, dc res.n = len(m.Data) res.src = m.Source if logDerpVerbose { - log.Printf("got derp %v packet: %q", derpFakeAddr, m.Data) + c.logf("got derp %v packet: %q", derpFakeAddr, m.Data) } default: // Ignore. @@ -800,7 +808,7 @@ func (c *Conn) runDerpWriter(ctx context.Context, derpFakeAddr *net.UDPAddr, dc case wr := <-ch: err := dc.Send(wr.pubKey, wr.b) if err != nil { - log.Printf("magicsock: derp.Send(%v): %v", wr.addr, err) + c.logf("magicsock: derp.Send(%v): %v", wr.addr, err) } select { case wr.errc <- err: @@ -911,7 +919,7 @@ func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr ncopy := dm.copyBuf(b) if ncopy != n { err = fmt.Errorf("received DERP packet of length %d that's too big for WireGuard ReceiveIPv4 buf size %d", n, ncopy) - log.Printf("magicsock: %v", err) + c.logf("magicsock: %v", err) return 0, nil, nil, err } @@ -921,7 +929,7 @@ func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr if addrSet == nil { key := wgcfg.Key(dm.src) - log.Printf("magicsock: DERP packet from unknown key: %s", key.ShortString()) + c.logf("magicsock: DERP packet from unknown key: %s", key.ShortString()) } case um := <-c.udpRecvCh: @@ -1061,23 +1069,23 @@ func (c *Conn) LinkChange() { if c.pconnPort != 0 { c.pconn.mu.Lock() if err := c.pconn.pconn.Close(); err != nil { - log.Printf("magicsock: link change close failed: %v", err) + c.logf("magicsock: link change close failed: %v", err) } packetConn, err := net.ListenPacket("udp4", fmt.Sprintf(":%d", c.pconnPort)) if err == nil { - log.Printf("magicsock: link change rebound port: %d", c.pconnPort) + c.logf("magicsock: link change rebound port: %d", c.pconnPort) c.pconn.pconn = packetConn.(*net.UDPConn) c.pconn.mu.Unlock() return } - log.Printf("magicsock: link change unable to bind fixed port %d: %v, falling back to random port", c.pconnPort, err) + c.logf("magicsock: link change unable to bind fixed port %d: %v, falling back to random port", c.pconnPort, err) c.pconn.mu.Unlock() } - log.Printf("magicsock: link change, binding new port") + c.logf("magicsock: link change, binding new port") packetConn, err := net.ListenPacket("udp4", ":0") if err != nil { - log.Printf("magicsock: link change failed to bind new port: %v", err) + c.logf("magicsock: link change failed to bind new port: %v", err) return } c.pconn.Reset(packetConn.(*net.UDPConn)) @@ -1291,7 +1299,7 @@ func (c *Conn) CreateBind(uint16) (conn.Bind, uint16, error) { // comma-separated list of UDP ip:ports. func (c *Conn) CreateEndpoint(key [32]byte, addrs string) (conn.Endpoint, error) { pk := wgcfg.Key(key) - log.Printf("magicsock: CreateEndpoint: key=%s: %s", pk.ShortString(), addrs) + c.logf("magicsock: CreateEndpoint: key=%s: %s", pk.ShortString(), addrs) a := &AddrSet{ publicKey: key, curAddr: -1, diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index b9e37aa43..44572765a 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -9,7 +9,6 @@ crand "crypto/rand" "crypto/tls" "fmt" - "log" "net" "net/http" "net/http/httptest" @@ -156,17 +155,17 @@ func serveSTUN(t *testing.T) (addr net.Addr, cleanupFn func()) { t.Fatalf("failed to open STUN listener: %v", err) } - go runSTUN(pc, &stats) + go runSTUN(t, pc, &stats) return pc.LocalAddr(), func() { pc.Close() } } -func runSTUN(pc net.PacketConn, stats *stunStats) { +func runSTUN(t *testing.T, pc net.PacketConn, stats *stunStats) { var buf [64 << 10]byte for { n, addr, err := pc.ReadFrom(buf[:]) if err != nil { if strings.Contains(err.Error(), "closed network connection") { - log.Printf("STUN server shutdown") + t.Logf("STUN server shutdown") return } continue @@ -191,7 +190,7 @@ func runSTUN(pc net.PacketConn, stats *stunStats) { res := stun.Response(txid, ua.IP, uint16(ua.Port)) if _, err := pc.WriteTo(res, addr); err != nil { - log.Printf("STUN server write failed: %v", err) + t.Logf("STUN server write failed: %v", err) } } } @@ -510,7 +509,7 @@ func TestTwoDevicePing(t *testing.T) { } case <-time.After(3 * time.Second): if strict { - t.Fatalf("return ping %d did not transit", i) + t.Errorf("return ping %d did not transit", i) } } }