diff --git a/tstest/natlab/natlab.go b/tstest/natlab/natlab.go index 92a4ccb68..ffa02eee4 100644 --- a/tstest/natlab/natlab.go +++ b/tstest/natlab/natlab.go @@ -684,10 +684,11 @@ func (m *Machine) ListenPacket(ctx context.Context, network, address string) (ne ipp := netip.AddrPortFrom(ip, port) c := &conn{ - m: m, - fam: fam, - ipp: ipp, - in: make(chan *Packet, 100), // arbitrary + m: m, + fam: fam, + ipp: ipp, + closedCh: make(chan struct{}), + in: make(chan *Packet, 100), // arbitrary } switch c.fam { case 0: @@ -716,70 +717,28 @@ type conn struct { fam uint8 // 0, 4, or 6 ipp netip.AddrPort - mu sync.Mutex - closed bool - readDeadline time.Time - activeReads map[*activeRead]bool - in chan *Packet -} + closeOnce sync.Once + closedCh chan struct{} // closed by Close -type activeRead struct { - cancel context.CancelFunc -} - -// canRead reports whether we can do a read. -func (c *conn) canRead() error { - c.mu.Lock() - defer c.mu.Unlock() - if c.closed { - return net.ErrClosed - } - if !c.readDeadline.IsZero() && c.readDeadline.Before(time.Now()) { - return errors.New("read deadline exceeded") - } - return nil -} - -func (c *conn) registerActiveRead(ar *activeRead, active bool) { - c.mu.Lock() - defer c.mu.Unlock() - if c.activeReads == nil { - c.activeReads = make(map[*activeRead]bool) - } - if active { - c.activeReads[ar] = true - } else { - delete(c.activeReads, ar) - } + in chan *Packet } func (c *conn) Close() error { - c.mu.Lock() - defer c.mu.Unlock() - if c.closed { - return nil - } - c.closed = true - switch c.fam { - case 0: - c.m.unregisterConn4(c) - c.m.unregisterConn6(c) - case 4: - c.m.unregisterConn4(c) - case 6: - c.m.unregisterConn6(c) - } - c.breakActiveReadsLocked() + c.closeOnce.Do(func() { + switch c.fam { + case 0: + c.m.unregisterConn4(c) + c.m.unregisterConn6(c) + case 4: + c.m.unregisterConn4(c) + case 6: + c.m.unregisterConn6(c) + } + close(c.closedCh) + }) return nil } -func (c *conn) breakActiveReadsLocked() { - for ar := range c.activeReads { - ar.cancel() - } - c.activeReads = nil -} - func (c *conn) LocalAddr() net.Addr { return &net.UDPAddr{ IP: c.ipp.Addr().AsSlice(), @@ -809,25 +768,13 @@ func (c *conn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { } func (c *conn) ReadFromUDPAddrPort(p []byte) (n int, addr netip.AddrPort, err error) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - ar := &activeRead{cancel: cancel} - - if err := c.canRead(); err != nil { - return 0, netip.AddrPort{}, err - } - - c.registerActiveRead(ar, true) - defer c.registerActiveRead(ar, false) - select { + case <-c.closedCh: + return 0, netip.AddrPort{}, net.ErrClosed case pkt := <-c.in: n = copy(p, pkt.Payload) pkt.Trace("PacketConn.ReadFrom") return n, pkt.Src, nil - case <-ctx.Done(): - return 0, netip.AddrPort{}, context.DeadlineExceeded } } @@ -857,18 +804,5 @@ func (c *conn) SetWriteDeadline(t time.Time) error { panic("SetWriteDeadline unsupported; TODO when needed") } func (c *conn) SetReadDeadline(t time.Time) error { - c.mu.Lock() - defer c.mu.Unlock() - - now := time.Now() - if t.After(now) { - panic("SetReadDeadline in the future not yet supported; TODO?") - } - - if !t.IsZero() && t.Before(now) { - c.breakActiveReadsLocked() - } - c.readDeadline = t - - return nil + panic("SetReadDeadline unsupported; TODO when needed") }