proxymap, various: distinguish between different protocols

Previously, we were registering TCP and UDP connections in the same map,
which could result in erroneously removing a mapping if one of the two
connections completes while the other one is still active.

Add a "proto string" argument to these functions to avoid this.
Additionally, take the "proto" argument in LocalAPI, and plumb that
through from the CLI and add a new LocalClient method.

Updates tailscale/corp#20600

Signed-off-by: Andrew Dunham <andrew@du.nham.ca>
Change-Id: I35d5efaefdfbf4721e315b8ca123f0c8af9125fb
This commit is contained in:
Andrew Dunham 2024-06-06 14:48:40 -04:00
parent 2cb408f9b1
commit 45d2f4301f
12 changed files with 89 additions and 30 deletions

View File

@ -285,6 +285,10 @@ func decodeJSON[T any](b []byte) (ret T, err error) {
// WhoIs returns the owner of the remoteAddr, which must be an IP or IP:port. // WhoIs returns the owner of the remoteAddr, which must be an IP or IP:port.
// //
// If not found, the error is ErrPeerNotFound. // If not found, the error is ErrPeerNotFound.
//
// For connections proxied by tailscaled, this looks up the owner of the given
// address as TCP first, falling back to UDP; if you want to only check a
// specific address family, use WhoIsProto.
func (lc *LocalClient) WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) { func (lc *LocalClient) WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) {
body, err := lc.get200(ctx, "/localapi/v0/whois?addr="+url.QueryEscape(remoteAddr)) body, err := lc.get200(ctx, "/localapi/v0/whois?addr="+url.QueryEscape(remoteAddr))
if err != nil { if err != nil {
@ -313,6 +317,21 @@ func (lc *LocalClient) WhoIsNodeKey(ctx context.Context, key key.NodePublic) (*a
return decodeJSON[*apitype.WhoIsResponse](body) return decodeJSON[*apitype.WhoIsResponse](body)
} }
// WhoIsProto returns the owner of the remoteAddr, which must be an IP or
// IP:port, for the given protocol (tcp or udp).
//
// If not found, the error is ErrPeerNotFound.
func (lc *LocalClient) WhoIsProto(ctx context.Context, proto, remoteAddr string) (*apitype.WhoIsResponse, error) {
body, err := lc.get200(ctx, "/localapi/v0/whois?proto="+url.QueryEscape(proto)+"&addr="+url.QueryEscape(remoteAddr))
if err != nil {
if hs, ok := err.(httpStatusError); ok && hs.HTTPStatus == http.StatusNotFound {
return nil, ErrPeerNotFound
}
return nil, err
}
return decodeJSON[*apitype.WhoIsResponse](body)
}
// Goroutines returns a dump of the Tailscale daemon's current goroutines. // Goroutines returns a dump of the Tailscale daemon's current goroutines.
func (lc *LocalClient) Goroutines(ctx context.Context) ([]byte, error) { func (lc *LocalClient) Goroutines(ctx context.Context) ([]byte, error) {
return lc.get200(ctx, "/localapi/v0/goroutines") return lc.get200(ctx, "/localapi/v0/goroutines")

View File

@ -26,12 +26,14 @@
FlagSet: func() *flag.FlagSet { FlagSet: func() *flag.FlagSet {
fs := newFlagSet("whois") fs := newFlagSet("whois")
fs.BoolVar(&whoIsArgs.json, "json", false, "output in JSON format") fs.BoolVar(&whoIsArgs.json, "json", false, "output in JSON format")
fs.StringVar(&whoIsArgs.proto, "proto", "", `protocol; one of "tcp" or "udp"; empty mans both `)
return fs return fs
}(), }(),
} }
var whoIsArgs struct { var whoIsArgs struct {
json bool // output in JSON format json bool // output in JSON format
proto string // "tcp" or "udp"
} }
func runWhoIs(ctx context.Context, args []string) error { func runWhoIs(ctx context.Context, args []string) error {
@ -40,7 +42,7 @@ func runWhoIs(ctx context.Context, args []string) error {
} else if len(args) == 0 { } else if len(args) == 0 {
return errors.New("missing argument, expected one peer") return errors.New("missing argument, expected one peer")
} }
who, err := localClient.WhoIs(ctx, args[0]) who, err := localClient.WhoIsProto(ctx, whoIsArgs.proto, args[0])
if err != nil { if err != nil {
return err return err
} }

View File

@ -995,8 +995,15 @@ func (b *LocalBackend) WhoIsNodeKey(k key.NodePublic) (n tailcfg.NodeView, u tai
// WhoIs reports the node and user who owns the node with the given IP:port. // WhoIs reports the node and user who owns the node with the given IP:port.
// If the IP address is a Tailscale IP, the provided port may be 0. // If the IP address is a Tailscale IP, the provided port may be 0.
//
// The 'proto' is used when looking up the IP:port in our proxy mapper; it
// tracks which local IP:ports correspond to connections proxied by tailscaled,
// and since tailscaled proxies both TCP and UDP, the 'proto' is needed to look
// up the correct IP:port based on the connection's protocol. If not provided,
// the lookup will be done for TCP and then UDP, in that order.
//
// If ok == true, n and u are valid. // If ok == true, n and u are valid.
func (b *LocalBackend) WhoIs(ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) { func (b *LocalBackend) WhoIs(proto string, ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) {
var zero tailcfg.NodeView var zero tailcfg.NodeView
b.mu.Lock() b.mu.Lock()
defer b.mu.Unlock() defer b.mu.Unlock()
@ -1005,7 +1012,20 @@ func (b *LocalBackend) WhoIs(ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.
if !ok { if !ok {
var ip netip.Addr var ip netip.Addr
if ipp.Port() != 0 { if ipp.Port() != 0 {
ip, ok = b.sys.ProxyMapper().WhoIsIPPort(ipp) var protos []string
if proto != "" {
protos = []string{proto}
} else {
// If the user didn't specify a protocol, try all of them
protos = []string{"tcp", "udp"}
}
for _, tryproto := range protos {
ip, ok = b.sys.ProxyMapper().WhoIsIPPort(tryproto, ipp)
if ok {
break
}
}
} }
if !ok { if !ok {
return zero, u, false return zero, u, false
@ -5044,7 +5064,7 @@ func (dt *driveTransport) RoundTrip(req *http.Request) (resp *http.Response, err
dt.b.mu.Lock() dt.b.mu.Lock()
selfNodeKey := dt.b.netMap.SelfNode.Key().ShortString() selfNodeKey := dt.b.netMap.SelfNode.Key().ShortString()
dt.b.mu.Unlock() dt.b.mu.Unlock()
n, _, ok := dt.b.WhoIs(netip.MustParseAddrPort(req.URL.Host)) n, _, ok := dt.b.WhoIs("tcp", netip.MustParseAddrPort(req.URL.Host))
shareNodeKey := "unknown" shareNodeKey := "unknown"
if ok { if ok {
shareNodeKey = string(n.Key().ShortString()) shareNodeKey = string(n.Key().ShortString())

View File

@ -1057,7 +1057,7 @@ func TestWhoIs(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.q, func(t *testing.T) { t.Run(tt.q, func(t *testing.T) {
nv, up, ok := b.WhoIs(netip.MustParseAddrPort(tt.q)) nv, up, ok := b.WhoIs("", netip.MustParseAddrPort(tt.q))
var got tailcfg.NodeID var got tailcfg.NodeID
if ok { if ok {
got = nv.ID() got = nv.ID()

View File

@ -187,7 +187,7 @@ func (pln *peerAPIListener) serve() {
func (pln *peerAPIListener) ServeConn(src netip.AddrPort, c net.Conn) { func (pln *peerAPIListener) ServeConn(src netip.AddrPort, c net.Conn) {
logf := pln.lb.logf logf := pln.lb.logf
peerNode, peerUser, ok := pln.lb.WhoIs(src) peerNode, peerUser, ok := pln.lb.WhoIs("tcp", src)
if !ok { if !ok {
logf("peerapi: unknown peer %v", src) logf("peerapi: unknown peer %v", src)
c.Close() c.Close()

View File

@ -710,7 +710,7 @@ func (b *LocalBackend) addTailscaleIdentityHeaders(r *httputil.ProxyRequest) {
if !ok { if !ok {
return return
} }
node, user, ok := b.WhoIs(c.SrcAddr) node, user, ok := b.WhoIs("tcp", c.SrcAddr)
if !ok { if !ok {
return // traffic from outside of Tailnet (funneled) return // traffic from outside of Tailnet (funneled)
} }

View File

@ -447,7 +447,7 @@ func (h *Handler) serveWhoIs(w http.ResponseWriter, r *http.Request) {
// localBackendWhoIsMethods is the subset of ipn.LocalBackend as needed // localBackendWhoIsMethods is the subset of ipn.LocalBackend as needed
// by the localapi WhoIs method. // by the localapi WhoIs method.
type localBackendWhoIsMethods interface { type localBackendWhoIsMethods interface {
WhoIs(netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) WhoIs(string, netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool)
WhoIsNodeKey(key.NodePublic) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) WhoIsNodeKey(key.NodePublic) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool)
PeerCaps(netip.Addr) tailcfg.PeerCapMap PeerCaps(netip.Addr) tailcfg.PeerCapMap
} }
@ -482,7 +482,7 @@ func (h *Handler) serveWhoIsWithBackend(w http.ResponseWriter, r *http.Request,
} }
} }
if ipp.IsValid() { if ipp.IsValid() {
n, u, ok = b.WhoIs(ipp) n, u, ok = b.WhoIs(r.FormValue("proto"), ipp)
} }
} else { } else {
http.Error(w, "missing 'addr' parameter", http.StatusBadRequest) http.Error(w, "missing 'addr' parameter", http.StatusBadRequest)

View File

@ -100,13 +100,13 @@ func TestSetPushDeviceToken(t *testing.T) {
} }
type whoIsBackend struct { type whoIsBackend struct {
whoIs func(ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) whoIs func(proto string, ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool)
whoIsNodeKey func(key.NodePublic) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) whoIsNodeKey func(key.NodePublic) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool)
peerCaps map[netip.Addr]tailcfg.PeerCapMap peerCaps map[netip.Addr]tailcfg.PeerCapMap
} }
func (b whoIsBackend) WhoIs(ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) { func (b whoIsBackend) WhoIs(proto string, ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) {
return b.whoIs(ipp) return b.whoIs(proto, ipp)
} }
func (b whoIsBackend) WhoIsNodeKey(k key.NodePublic) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) { func (b whoIsBackend) WhoIsNodeKey(k key.NodePublic) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) {
@ -143,7 +143,7 @@ func TestWhoIsArgTypes(t *testing.T) {
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
t.Run(input, func(t *testing.T) { t.Run(input, func(t *testing.T) {
b := whoIsBackend{ b := whoIsBackend{
whoIs: func(ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) { whoIs: func(proto string, ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) {
if !strings.Contains(input, ":") { if !strings.Contains(input, ":") {
want := netip.MustParseAddrPort("100.101.102.103:0") want := netip.MustParseAddrPort("100.101.102.103:0")
if ipp != want { if ipp != want {

View File

@ -9,8 +9,6 @@
"net/netip" "net/netip"
"sync" "sync"
"time" "time"
"tailscale.com/util/mak"
) )
// Mapper tracks which localhost ip:ports correspond to which remote Tailscale // Mapper tracks which localhost ip:ports correspond to which remote Tailscale
@ -21,26 +19,39 @@
// given localhost:port corresponds to. // given localhost:port corresponds to.
type Mapper struct { type Mapper struct {
mu sync.Mutex mu sync.Mutex
m map[netip.AddrPort]netip.Addr m map[string]map[netip.AddrPort]netip.Addr // proto ("tcp", "udp") => ephemeral => tailscale IP
} }
// RegisterIPPortIdentity registers a given node (identified by its // RegisterIPPortIdentity registers a given node (identified by its
// Tailscale IP) as temporarily having the given IP:port for whois lookups. // Tailscale IP) as temporarily having the given IP:port for whois lookups.
//
// The IP:port is generally a localhost IP and an ephemeral port, used // The IP:port is generally a localhost IP and an ephemeral port, used
// while proxying connections to localhost when tailscaled is running // while proxying connections to localhost when tailscaled is running
// in netstack mode. // in netstack mode.
func (m *Mapper) RegisterIPPortIdentity(ipport netip.AddrPort, tsIP netip.Addr) { //
// The proto is the network protocol that is being proxied; it must be "tcp" or
// "udp" (not e.g. "tcp4", "udp6", etc.)
func (m *Mapper) RegisterIPPortIdentity(proto string, ipport netip.AddrPort, tsIP netip.Addr) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
mak.Set(&m.m, ipport, tsIP) if m.m == nil {
m.m = make(map[string]map[netip.AddrPort]netip.Addr)
}
p, ok := m.m[proto]
if !ok {
p = make(map[netip.AddrPort]netip.Addr)
m.m[proto] = p
}
p[ipport] = tsIP
} }
// UnregisterIPPortIdentity removes a temporary IP:port registration // UnregisterIPPortIdentity removes a temporary IP:port registration
// made previously by RegisterIPPortIdentity. // made previously by RegisterIPPortIdentity.
func (m *Mapper) UnregisterIPPortIdentity(ipport netip.AddrPort) { func (m *Mapper) UnregisterIPPortIdentity(proto string, ipport netip.AddrPort) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
delete(m.m, ipport) p := m.m[proto]
delete(p, ipport) // safe to delete from a nil map
} }
var whoIsSleeps = [...]time.Duration{ var whoIsSleeps = [...]time.Duration{
@ -53,7 +64,7 @@ func (m *Mapper) UnregisterIPPortIdentity(ipport netip.AddrPort) {
// WhoIsIPPort looks up an IP:port in the temporary registrations, // WhoIsIPPort looks up an IP:port in the temporary registrations,
// and returns a matching Tailscale IP, if it exists. // and returns a matching Tailscale IP, if it exists.
func (m *Mapper) WhoIsIPPort(ipport netip.AddrPort) (tsIP netip.Addr, ok bool) { func (m *Mapper) WhoIsIPPort(proto string, ipport netip.AddrPort) (tsIP netip.Addr, ok bool) {
// We currently have a registration race, // We currently have a registration race,
// https://github.com/tailscale/tailscale/issues/1616, // https://github.com/tailscale/tailscale/issues/1616,
// so loop a few times for now waiting for the registration // so loop a few times for now waiting for the registration
@ -62,7 +73,10 @@ func (m *Mapper) WhoIsIPPort(ipport netip.AddrPort) (tsIP netip.Addr, ok bool) {
for _, d := range whoIsSleeps { for _, d := range whoIsSleeps {
time.Sleep(d) time.Sleep(d)
m.mu.Lock() m.mu.Lock()
tsIP, ok = m.m[ipport] p, ok := m.m[proto]
if ok {
tsIP, ok = p[ipport]
}
m.mu.Unlock() m.mu.Unlock()
if ok { if ok {
return tsIP, true return tsIP, true

View File

@ -68,7 +68,7 @@ type ipnLocalBackend interface {
GetSSH_HostKeys() ([]gossh.Signer, error) GetSSH_HostKeys() ([]gossh.Signer, error)
ShouldRunSSH() bool ShouldRunSSH() bool
NetMap() *netmap.NetworkMap NetMap() *netmap.NetworkMap
WhoIs(ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) WhoIs(proto string, ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool)
DoNoiseRequest(req *http.Request) (*http.Response, error) DoNoiseRequest(req *http.Request) (*http.Response, error)
Dialer() *tsdial.Dialer Dialer() *tsdial.Dialer
TailscaleVarRoot() string TailscaleVarRoot() string
@ -604,7 +604,7 @@ func (c *conn) setInfo(ctx ssh.Context) error {
if !tsaddr.IsTailscaleIP(ci.src.Addr()) { if !tsaddr.IsTailscaleIP(ci.src.Addr()) {
return fmt.Errorf("tailssh: rejecting non-Tailscale remote address %v", ci.src) return fmt.Errorf("tailssh: rejecting non-Tailscale remote address %v", ci.src)
} }
node, uprof, ok := c.srv.lb.WhoIs(ci.src) node, uprof, ok := c.srv.lb.WhoIs("tcp", ci.src)
if !ok { if !ok {
return fmt.Errorf("unknown Tailscale identity from src %v", ci.src) return fmt.Errorf("unknown Tailscale identity from src %v", ci.src)
} }

View File

@ -281,7 +281,11 @@ func (ts *localState) NetMap() *netmap.NetworkMap {
} }
} }
func (ts *localState) WhoIs(ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) { func (ts *localState) WhoIs(proto string, ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) {
if proto != "tcp" {
return tailcfg.NodeView{}, tailcfg.UserProfile{}, false
}
return (&tailcfg.Node{ return (&tailcfg.Node{
ID: 2, ID: 2,
StableID: "peer-id", StableID: "peer-id",

View File

@ -1328,8 +1328,8 @@ func (ns *Impl) forwardTCP(getClient func(...tcpip.SettableSocketOption) *gonet.
backendLocalAddr := server.LocalAddr().(*net.TCPAddr) backendLocalAddr := server.LocalAddr().(*net.TCPAddr)
backendLocalIPPort := netaddr.Unmap(backendLocalAddr.AddrPort()) backendLocalIPPort := netaddr.Unmap(backendLocalAddr.AddrPort())
ns.pm.RegisterIPPortIdentity(backendLocalIPPort, clientRemoteIP) ns.pm.RegisterIPPortIdentity("tcp", backendLocalIPPort, clientRemoteIP)
defer ns.pm.UnregisterIPPortIdentity(backendLocalIPPort) defer ns.pm.UnregisterIPPortIdentity("tcp", backendLocalIPPort)
connClosed := make(chan error, 2) connClosed := make(chan error, 2)
go func() { go func() {
_, err := io.Copy(server, client) _, err := io.Copy(server, client)
@ -1533,7 +1533,7 @@ func (ns *Impl) forwardUDP(client *gonet.UDPConn, clientAddr, dstAddr netip.Addr
ns.logf("could not get backend local IP:port from %v:%v", backendLocalAddr.IP, backendLocalAddr.Port) ns.logf("could not get backend local IP:port from %v:%v", backendLocalAddr.IP, backendLocalAddr.Port)
} }
if isLocal { if isLocal {
ns.pm.RegisterIPPortIdentity(backendLocalIPPort, dstAddr.Addr()) ns.pm.RegisterIPPortIdentity("udp", backendLocalIPPort, clientAddr.Addr())
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@ -1549,7 +1549,7 @@ func (ns *Impl) forwardUDP(client *gonet.UDPConn, clientAddr, dstAddr netip.Addr
} }
timer := time.AfterFunc(idleTimeout, func() { timer := time.AfterFunc(idleTimeout, func() {
if isLocal { if isLocal {
ns.pm.UnregisterIPPortIdentity(backendLocalIPPort) ns.pm.UnregisterIPPortIdentity("udp", backendLocalIPPort)
} }
ns.logf("netstack: UDP session between %s and %s timed out", backendListenAddr, backendRemoteAddr) ns.logf("netstack: UDP session between %s and %s timed out", backendListenAddr, backendRemoteAddr)
cancel() cancel()