mirror of
https://github.com/tailscale/tailscale.git
synced 2024-11-25 19:15:34 +00:00
a076213f58
Updates tailscale/corp#11620 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
101 lines
2.2 KiB
Go
101 lines
2.2 KiB
Go
// Copyright (c) Tailscale Inc & AUTHORS
|
|
// SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
package memnet
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
)
|
|
|
|
const (
|
|
bufferSize = 256 * 1024
|
|
)
|
|
|
|
// Listener is a net.Listener using NewConn to create pairs of network
|
|
// connections connected in memory using a buffered pipe. It also provides a
|
|
// Dial method to establish new connections.
|
|
type Listener struct {
|
|
addr connAddr
|
|
ch chan Conn
|
|
closeOnce sync.Once
|
|
closed chan struct{}
|
|
|
|
// NewConn, if non-nil, is called to create a new pair of connections
|
|
// when dialing. If nil, NewConn is used.
|
|
NewConn func(network, addr string, maxBuf int) (Conn, Conn)
|
|
}
|
|
|
|
// Listen returns a new Listener for the provided address.
|
|
func Listen(addr string) *Listener {
|
|
return &Listener{
|
|
addr: connAddr(addr),
|
|
ch: make(chan Conn),
|
|
closed: make(chan struct{}),
|
|
}
|
|
}
|
|
|
|
// Addr implements net.Listener.Addr.
|
|
func (l *Listener) Addr() net.Addr {
|
|
return l.addr
|
|
}
|
|
|
|
// Close closes the pipe listener.
|
|
func (l *Listener) Close() error {
|
|
l.closeOnce.Do(func() {
|
|
close(l.closed)
|
|
})
|
|
return nil
|
|
}
|
|
|
|
// Accept blocks until a new connection is available or the listener is closed.
|
|
func (l *Listener) Accept() (net.Conn, error) {
|
|
select {
|
|
case c := <-l.ch:
|
|
return c, nil
|
|
case <-l.closed:
|
|
return nil, net.ErrClosed
|
|
}
|
|
}
|
|
|
|
// Dial connects to the listener using the provided context.
|
|
// The provided Context must be non-nil. If the context expires before the
|
|
// connection is complete, an error is returned. Once successfully connected
|
|
// any expiration of the context will not affect the connection.
|
|
func (l *Listener) Dial(ctx context.Context, network, addr string) (_ net.Conn, err error) {
|
|
if !strings.HasSuffix(network, "tcp") {
|
|
return nil, net.UnknownNetworkError(network)
|
|
}
|
|
if connAddr(addr) != l.addr {
|
|
return nil, &net.AddrError{
|
|
Err: "invalid address",
|
|
Addr: addr,
|
|
}
|
|
}
|
|
|
|
newConn := l.NewConn
|
|
if newConn == nil {
|
|
newConn = func(network, addr string, maxBuf int) (Conn, Conn) {
|
|
return NewConn(addr, maxBuf)
|
|
}
|
|
}
|
|
c, s := newConn(network, addr, bufferSize)
|
|
defer func() {
|
|
if err != nil {
|
|
c.Close()
|
|
s.Close()
|
|
}
|
|
}()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case <-l.closed:
|
|
return nil, net.ErrClosed
|
|
case l.ch <- s:
|
|
return c, nil
|
|
}
|
|
}
|