wgengine/magicsock: plumb logf throughout, and expose in Options.

Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:
David Anderson 2020-03-07 13:11:52 -08:00 committed by Dave Anderson
parent f42b9b6c9a
commit bb93d7aaba
2 changed files with 30 additions and 23 deletions

View File

@ -49,7 +49,7 @@ type Conn struct {
stunServers []string stunServers []string
startEpUpdate chan struct{} // send to trigger endpoint update startEpUpdate chan struct{} // send to trigger endpoint update
epFunc func(endpoints []string) epFunc func(endpoints []string)
logf func(format string, args ...interface{}) logf logger.Logf
sendLogLimit *rate.Limiter sendLogLimit *rate.Limiter
// bufferedIPv4From and bufferedIPv4Packet are owned by // bufferedIPv4From and bufferedIPv4Packet are owned by
@ -121,6 +121,8 @@ type udpAddr struct {
// Options contains options for Listen. // Options contains options for Listen.
type Options struct { type Options struct {
Logf logger.Logf
// Port is the port to listen on. // Port is the port to listen on.
// Zero means to pick one automatically. // Zero means to pick one automatically.
Port uint16 Port uint16
@ -147,15 +149,21 @@ func (o *Options) endpointsFunc() func([]string) {
func Listen(opts Options) (*Conn, error) { func Listen(opts Options) (*Conn, error) {
var packetConn net.PacketConn var packetConn net.PacketConn
var err error var err error
logf := log.Printf
if opts.Logf != nil {
logf = opts.Logf
}
if opts.Port == 0 { if opts.Port == 0 {
// Our choice of port. Start with DefaultPort. // Our choice of port. Start with DefaultPort.
// If unavailable, pick any port. // If unavailable, pick any port.
want := fmt.Sprintf(":%d", DefaultPort) want := fmt.Sprintf(":%d", DefaultPort)
log.Printf("magicsock: bind: trying %v\n", want) logf("magicsock: bind: trying %v\n", want)
packetConn, err = net.ListenPacket("udp4", want) packetConn, err = net.ListenPacket("udp4", want)
if err != nil { if err != nil {
want = ":0" want = ":0"
log.Printf("magicsock: bind: falling back to %v (%v)\n", want, err) logf("magicsock: bind: falling back to %v (%v)\n", want, err)
packetConn, err = net.ListenPacket("udp4", want) packetConn, err = net.ListenPacket("udp4", want)
} }
} else { } else {
@ -175,7 +183,7 @@ func Listen(opts Options) (*Conn, error) {
connCtx: connCtx, connCtx: connCtx,
connCtxCancel: connCtxCancel, connCtxCancel: connCtxCancel,
epFunc: opts.endpointsFunc(), epFunc: opts.endpointsFunc(),
logf: log.Printf, logf: logf,
addrsByUDP: make(map[udpAddr]*AddrSet), addrsByUDP: make(map[udpAddr]*AddrSet),
addrsByKey: make(map[key.Public]*AddrSet), addrsByKey: make(map[key.Public]*AddrSet),
wantDerp: true, wantDerp: true,
@ -398,7 +406,7 @@ func (c *Conn) determineEndpoints(ctx context.Context) (ipPorts []string, err er
var eps []string // unique endpoints var eps []string // unique endpoints
addAddr := func(s, reason string) { addAddr := func(s, reason string) {
log.Printf("magicsock: found local %s (%s)\n", s, reason) c.logf("magicsock: found local %s (%s)\n", s, reason)
alreadyMu.Lock() alreadyMu.Lock()
defer alreadyMu.Unlock() defer alreadyMu.Unlock()
@ -608,7 +616,7 @@ func (c *Conn) Send(b []byte, ep conn.Endpoint) error {
ret = err ret = err
} }
if err != nil && addr != roamAddr && c.sendLogLimit.Allow() { if err != nil && addr != roamAddr && c.sendLogLimit.Allow() {
log.Printf("magicsock: Conn.Send(%v): %v", addr, err) c.logf("magicsock: Conn.Send(%v): %v", addr, err)
} }
} }
if success { if success {
@ -684,7 +692,7 @@ func (c *Conn) derpWriteChanOfAddr(addr *net.UDPAddr) chan<- derpWriteRequest {
return nil return nil
} }
// TODO(bradfitz): don't hold derpMu here. It's slow. Release first and use singleflight to dial+re-lock to add. // TODO(bradfitz): don't hold derpMu here. It's slow. Release first and use singleflight to dial+re-lock to add.
dc, err := derphttp.NewClient(c.privateKey, "https://"+host+"/derp", log.Printf) dc, err := derphttp.NewClient(c.privateKey, "https://"+host+"/derp", c.logf)
if err != nil { if err != nil {
c.logf("derphttp.NewClient: port %d, host %q invalid? err: %v", addr.Port, host, err) c.logf("derphttp.NewClient: port %d, host %q invalid? err: %v", addr.Port, host, err)
return nil return nil
@ -765,7 +773,7 @@ func (c *Conn) runDerpReader(ctx context.Context, derpFakeAddr *net.UDPAddr, dc
res.n = len(m.Data) res.n = len(m.Data)
res.src = m.Source res.src = m.Source
if logDerpVerbose { if logDerpVerbose {
log.Printf("got derp %v packet: %q", derpFakeAddr, m.Data) c.logf("got derp %v packet: %q", derpFakeAddr, m.Data)
} }
default: default:
// Ignore. // Ignore.
@ -800,7 +808,7 @@ func (c *Conn) runDerpWriter(ctx context.Context, derpFakeAddr *net.UDPAddr, dc
case wr := <-ch: case wr := <-ch:
err := dc.Send(wr.pubKey, wr.b) err := dc.Send(wr.pubKey, wr.b)
if err != nil { if err != nil {
log.Printf("magicsock: derp.Send(%v): %v", wr.addr, err) c.logf("magicsock: derp.Send(%v): %v", wr.addr, err)
} }
select { select {
case wr.errc <- err: case wr.errc <- err:
@ -911,7 +919,7 @@ func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr
ncopy := dm.copyBuf(b) ncopy := dm.copyBuf(b)
if ncopy != n { if ncopy != n {
err = fmt.Errorf("received DERP packet of length %d that's too big for WireGuard ReceiveIPv4 buf size %d", n, ncopy) err = fmt.Errorf("received DERP packet of length %d that's too big for WireGuard ReceiveIPv4 buf size %d", n, ncopy)
log.Printf("magicsock: %v", err) c.logf("magicsock: %v", err)
return 0, nil, nil, err return 0, nil, nil, err
} }
@ -921,7 +929,7 @@ func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr
if addrSet == nil { if addrSet == nil {
key := wgcfg.Key(dm.src) key := wgcfg.Key(dm.src)
log.Printf("magicsock: DERP packet from unknown key: %s", key.ShortString()) c.logf("magicsock: DERP packet from unknown key: %s", key.ShortString())
} }
case um := <-c.udpRecvCh: case um := <-c.udpRecvCh:
@ -1061,23 +1069,23 @@ func (c *Conn) LinkChange() {
if c.pconnPort != 0 { if c.pconnPort != 0 {
c.pconn.mu.Lock() c.pconn.mu.Lock()
if err := c.pconn.pconn.Close(); err != nil { if err := c.pconn.pconn.Close(); err != nil {
log.Printf("magicsock: link change close failed: %v", err) c.logf("magicsock: link change close failed: %v", err)
} }
packetConn, err := net.ListenPacket("udp4", fmt.Sprintf(":%d", c.pconnPort)) packetConn, err := net.ListenPacket("udp4", fmt.Sprintf(":%d", c.pconnPort))
if err == nil { if err == nil {
log.Printf("magicsock: link change rebound port: %d", c.pconnPort) c.logf("magicsock: link change rebound port: %d", c.pconnPort)
c.pconn.pconn = packetConn.(*net.UDPConn) c.pconn.pconn = packetConn.(*net.UDPConn)
c.pconn.mu.Unlock() c.pconn.mu.Unlock()
return return
} }
log.Printf("magicsock: link change unable to bind fixed port %d: %v, falling back to random port", c.pconnPort, err) c.logf("magicsock: link change unable to bind fixed port %d: %v, falling back to random port", c.pconnPort, err)
c.pconn.mu.Unlock() c.pconn.mu.Unlock()
} }
log.Printf("magicsock: link change, binding new port") c.logf("magicsock: link change, binding new port")
packetConn, err := net.ListenPacket("udp4", ":0") packetConn, err := net.ListenPacket("udp4", ":0")
if err != nil { if err != nil {
log.Printf("magicsock: link change failed to bind new port: %v", err) c.logf("magicsock: link change failed to bind new port: %v", err)
return return
} }
c.pconn.Reset(packetConn.(*net.UDPConn)) c.pconn.Reset(packetConn.(*net.UDPConn))
@ -1291,7 +1299,7 @@ func (c *Conn) CreateBind(uint16) (conn.Bind, uint16, error) {
// comma-separated list of UDP ip:ports. // comma-separated list of UDP ip:ports.
func (c *Conn) CreateEndpoint(key [32]byte, addrs string) (conn.Endpoint, error) { func (c *Conn) CreateEndpoint(key [32]byte, addrs string) (conn.Endpoint, error) {
pk := wgcfg.Key(key) pk := wgcfg.Key(key)
log.Printf("magicsock: CreateEndpoint: key=%s: %s", pk.ShortString(), addrs) c.logf("magicsock: CreateEndpoint: key=%s: %s", pk.ShortString(), addrs)
a := &AddrSet{ a := &AddrSet{
publicKey: key, publicKey: key,
curAddr: -1, curAddr: -1,

View File

@ -9,7 +9,6 @@
crand "crypto/rand" crand "crypto/rand"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"log"
"net" "net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -156,17 +155,17 @@ func serveSTUN(t *testing.T) (addr net.Addr, cleanupFn func()) {
t.Fatalf("failed to open STUN listener: %v", err) t.Fatalf("failed to open STUN listener: %v", err)
} }
go runSTUN(pc, &stats) go runSTUN(t, pc, &stats)
return pc.LocalAddr(), func() { pc.Close() } return pc.LocalAddr(), func() { pc.Close() }
} }
func runSTUN(pc net.PacketConn, stats *stunStats) { func runSTUN(t *testing.T, pc net.PacketConn, stats *stunStats) {
var buf [64 << 10]byte var buf [64 << 10]byte
for { for {
n, addr, err := pc.ReadFrom(buf[:]) n, addr, err := pc.ReadFrom(buf[:])
if err != nil { if err != nil {
if strings.Contains(err.Error(), "closed network connection") { if strings.Contains(err.Error(), "closed network connection") {
log.Printf("STUN server shutdown") t.Logf("STUN server shutdown")
return return
} }
continue continue
@ -191,7 +190,7 @@ func runSTUN(pc net.PacketConn, stats *stunStats) {
res := stun.Response(txid, ua.IP, uint16(ua.Port)) res := stun.Response(txid, ua.IP, uint16(ua.Port))
if _, err := pc.WriteTo(res, addr); err != nil { if _, err := pc.WriteTo(res, addr); err != nil {
log.Printf("STUN server write failed: %v", err) t.Logf("STUN server write failed: %v", err)
} }
} }
} }
@ -510,7 +509,7 @@ func TestTwoDevicePing(t *testing.T) {
} }
case <-time.After(3 * time.Second): case <-time.After(3 * time.Second):
if strict { if strict {
t.Fatalf("return ping %d did not transit", i) t.Errorf("return ping %d did not transit", i)
} }
} }
} }