tsnet: add test for Funnel connections

For the logic added in b797f77.

Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
Maisem Ali 2023-03-09 12:52:18 -08:00 committed by Maisem Ali
parent c6d96a2b61
commit f34590d9ed
3 changed files with 282 additions and 23 deletions

View File

@ -100,6 +100,8 @@ type Server struct {
// If empty, the Tailscale default is used. // If empty, the Tailscale default is used.
ControlURL string ControlURL string
getCertForTesting func(*tls.ClientHelloInfo) (*tls.Certificate, error)
initOnce sync.Once initOnce sync.Once
initErr error initErr error
lb *ipnlocal.LocalBackend lb *ipnlocal.LocalBackend
@ -842,20 +844,30 @@ func (s *Server) ListenTLS(network, addr string) (net.Listener, error) {
return nil, errors.New("tsnet: you must enable HTTPS in the admin panel to proceed. See https://tailscale.com/s/https") return nil, errors.New("tsnet: you must enable HTTPS in the admin panel to proceed. See https://tailscale.com/s/https")
} }
lc, err := s.LocalClient() // do local client first before listening.
if err != nil {
return nil, err
}
ln, err := s.listen(network, addr, listenOnTailnet) ln, err := s.listen(network, addr, listenOnTailnet)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return tls.NewListener(ln, &tls.Config{ return tls.NewListener(ln, &tls.Config{
GetCertificate: lc.GetCertificate, GetCertificate: s.getCert,
}), nil }), nil
} }
// getCert is the GetCertificate function used by ListenTLS.
//
// It calls GetCertificate on the localClient, passing in the ClientHelloInfo.
// For testing, if s.getCertForTesting is set, it will call that instead.
func (s *Server) getCert(hi *tls.ClientHelloInfo) (*tls.Certificate, error) {
if s.getCertForTesting != nil {
return s.getCertForTesting(hi)
}
lc, err := s.LocalClient()
if err != nil {
return nil, err
}
return lc.GetCertificate(hi)
}
// FunnelOption is an option passed to ListenFunnel to configure the listener. // FunnelOption is an option passed to ListenFunnel to configure the listener.
type FunnelOption interface { type FunnelOption interface {
funnelOption() funnelOption()
@ -909,10 +921,7 @@ func (s *Server) ListenFunnel(network, addr string, opts ...FunnelOption) (net.L
return nil, err return nil, err
} }
lc, err := s.LocalClient() lc := s.localClient
if err != nil {
return nil, err
}
// May not have funnel enabled. Enable it. // May not have funnel enabled. Enable it.
srvConfig, err := lc.GetServeConfig(ctx) srvConfig, err := lc.GetServeConfig(ctx)
@ -944,7 +953,7 @@ func (s *Server) ListenFunnel(network, addr string, opts ...FunnelOption) (net.L
return nil, err return nil, err
} }
return tls.NewListener(ln, &tls.Config{ return tls.NewListener(ln, &tls.Config{
GetCertificate: lc.GetCertificate, GetCertificate: s.getCert,
}), nil }), nil
} }

View File

@ -4,27 +4,40 @@
package tsnet package tsnet
import ( import (
"bufio"
"context" "context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"math/big"
"net" "net"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/netip" "net/netip"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"sync"
"testing" "testing"
"time" "time"
"golang.org/x/net/proxy" "golang.org/x/net/proxy"
"tailscale.com/ipn"
"tailscale.com/ipn/store/mem" "tailscale.com/ipn/store/mem"
"tailscale.com/net/netns" "tailscale.com/net/netns"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/tstest/integration" "tailscale.com/tstest/integration"
"tailscale.com/tstest/integration/testcontrol" "tailscale.com/tstest/integration/testcontrol"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/util/must"
) )
// TestListener_Server ensures that the listener type always keeps the Server // TestListener_Server ensures that the listener type always keeps the Server
@ -93,6 +106,10 @@ func startControl(t *testing.T) (controlURL string) {
derpMap := integration.RunDERPAndSTUN(t, derpLogf, "127.0.0.1") derpMap := integration.RunDERPAndSTUN(t, derpLogf, "127.0.0.1")
control := &testcontrol.Server{ control := &testcontrol.Server{
DERPMap: derpMap, DERPMap: derpMap,
DNSConfig: &tailcfg.DNSConfig{
Proxied: true,
},
MagicDNSDomain: "tail-scale.ts.net",
} }
control.HTTPTestServer = httptest.NewUnstartedServer(control) control.HTTPTestServer = httptest.NewUnstartedServer(control)
control.HTTPTestServer.Start() control.HTTPTestServer.Start()
@ -102,17 +119,96 @@ func startControl(t *testing.T) (controlURL string) {
return controlURL return controlURL
} }
type testCertIssuer struct {
mu sync.Mutex
certs map[string]*tls.Certificate
root *x509.Certificate
rootKey *ecdsa.PrivateKey
}
func newCertIssuer() *testCertIssuer {
rootKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
panic(err)
}
t := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: "root",
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
rootDER, err := x509.CreateCertificate(rand.Reader, t, t, &rootKey.PublicKey, rootKey)
if err != nil {
panic(err)
}
rootCA, err := x509.ParseCertificate(rootDER)
if err != nil {
panic(err)
}
return &testCertIssuer{
certs: make(map[string]*tls.Certificate),
root: rootCA,
rootKey: rootKey,
}
}
func (tci *testCertIssuer) getCert(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
tci.mu.Lock()
defer tci.mu.Unlock()
cert, ok := tci.certs[chi.ServerName]
if ok {
return cert, nil
}
certPrivKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, err
}
certTmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
DNSNames: []string{chi.ServerName},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour),
}
certDER, err := x509.CreateCertificate(rand.Reader, certTmpl, tci.root, &certPrivKey.PublicKey, tci.rootKey)
if err != nil {
return nil, err
}
cert = &tls.Certificate{
Certificate: [][]byte{certDER, tci.root.Raw},
PrivateKey: certPrivKey,
}
tci.certs[chi.ServerName] = cert
return cert, nil
}
func (tci *testCertIssuer) Pool() *x509.CertPool {
p := x509.NewCertPool()
p.AddCert(tci.root)
return p
}
var testCertRoot = newCertIssuer()
func startServer(t *testing.T, ctx context.Context, controlURL, hostname string) (*Server, netip.Addr) { func startServer(t *testing.T, ctx context.Context, controlURL, hostname string) (*Server, netip.Addr) {
t.Helper() t.Helper()
tmp := filepath.Join(t.TempDir(), hostname) tmp := filepath.Join(t.TempDir(), hostname)
os.MkdirAll(tmp, 0755) os.MkdirAll(tmp, 0755)
s := &Server{ s := &Server{
Dir: tmp, Dir: tmp,
ControlURL: controlURL, ControlURL: controlURL,
Hostname: hostname, Hostname: hostname,
Store: new(mem.Store), Store: new(mem.Store),
Ephemeral: true, Ephemeral: true,
getCertForTesting: testCertRoot.getCert,
} }
if !*verboseNodes { if !*verboseNodes {
s.Logf = logger.Discard s.Logf = logger.Discard
@ -368,3 +464,112 @@ func TestListenerCleanup(t *testing.T) {
t.Fatalf("second ln.Close error: %v, want net.ErrClosed", err) t.Fatalf("second ln.Close error: %v, want net.ErrClosed", err)
} }
} }
func TestFunnel(t *testing.T) {
ctx, dialCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer dialCancel()
controlURL := startControl(t)
s1, _ := startServer(t, ctx, controlURL, "s1")
s2, _ := startServer(t, ctx, controlURL, "s2")
ln := must.Get(s1.ListenFunnel("tcp", ":443"))
defer ln.Close()
wantSrcAddrPort := netip.MustParseAddrPort("127.0.0.1:1234")
wantTarget := ipn.HostPort("s1.tail-scale.ts.net:443")
srv := &http.Server{
ConnContext: func(ctx context.Context, c net.Conn) context.Context {
tc, ok := c.(*tls.Conn)
if !ok {
t.Errorf("ConnContext called with non-TLS conn: %T", c)
}
if fc, ok := tc.NetConn().(*ipn.FunnelConn); !ok {
t.Errorf("ConnContext called with non-FunnelConn: %T", c)
} else if fc.Src != wantSrcAddrPort {
t.Errorf("ConnContext called with wrong SrcAddrPort; got %v, want %v", fc.Src, wantSrcAddrPort)
} else if fc.Target != wantTarget {
t.Errorf("ConnContext called with wrong Target; got %q, want %q", fc.Target, wantTarget)
}
return ctx
},
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "hello")
}),
}
go srv.Serve(ln)
c := &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialIngressConn(s2, s1, addr)
},
TLSClientConfig: &tls.Config{
RootCAs: testCertRoot.Pool(),
},
},
}
resp, err := c.Get("https://s1.tail-scale.ts.net:443")
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("unexpected status code: %v", resp.StatusCode)
return
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if string(body) != "hello" {
t.Errorf("unexpected body: %q", body)
}
}
func dialIngressConn(from, to *Server, target string) (net.Conn, error) {
toLC := must.Get(to.LocalClient())
toStatus := must.Get(toLC.StatusWithoutPeers(context.Background()))
peer6 := toStatus.Self.PeerAPIURL[1] // IPv6
toPeerAPI, ok := strings.CutPrefix(peer6, "http://")
if !ok {
return nil, fmt.Errorf("unexpected PeerAPIURL %q", peer6)
}
dialCtx, dialCancel := context.WithTimeout(context.Background(), 30*time.Second)
outConn, err := from.Dial(dialCtx, "tcp", toPeerAPI)
dialCancel()
if err != nil {
return nil, err
}
req, err := http.NewRequest("POST", "/v0/ingress", nil)
if err != nil {
return nil, err
}
req.Host = toPeerAPI
req.Header.Set("Tailscale-Ingress-Src", "127.0.0.1:1234")
req.Header.Set("Tailscale-Ingress-Target", target)
if err := req.Write(outConn); err != nil {
return nil, err
}
br := bufio.NewReader(outConn)
res, err := http.ReadResponse(br, req)
if err != nil {
return nil, err
}
defer res.Body.Close() // just to appease vet
if res.StatusCode != 101 {
return nil, fmt.Errorf("unexpected status code: %v", res.StatusCode)
}
return &bufferedConn{outConn, br}, nil
}
type bufferedConn struct {
net.Conn
reader *bufio.Reader
}
func (c *bufferedConn) Read(b []byte) (int, error) {
return c.reader.Read(b)
}

View File

@ -26,6 +26,7 @@
"github.com/klauspost/compress/zstd" "github.com/klauspost/compress/zstd"
"go4.org/mem" "go4.org/mem"
"golang.org/x/exp/slices"
"tailscale.com/net/netaddr" "tailscale.com/net/netaddr"
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
"tailscale.com/smallzstd" "tailscale.com/smallzstd"
@ -39,11 +40,12 @@
// Server is a control plane server. Its zero value is ready for use. // Server is a control plane server. Its zero value is ready for use.
// Everything is stored in-memory in one tailnet. // Everything is stored in-memory in one tailnet.
type Server struct { type Server struct {
Logf logger.Logf // nil means to use the log package Logf logger.Logf // nil means to use the log package
DERPMap *tailcfg.DERPMap // nil means to use prod DERP map DERPMap *tailcfg.DERPMap // nil means to use prod DERP map
RequireAuth bool RequireAuth bool
Verbose bool Verbose bool
DNSConfig *tailcfg.DNSConfig // nil means no DNS config DNSConfig *tailcfg.DNSConfig // nil means no DNS config
MagicDNSDomain string
// ExplicitBaseURL or HTTPTestServer must be set. // ExplicitBaseURL or HTTPTestServer must be set.
ExplicitBaseURL string // e.g. "http://127.0.0.1:1234" with no trailing URL ExplicitBaseURL string // e.g. "http://127.0.0.1:1234" with no trailing URL
@ -328,6 +330,15 @@ func (s *Server) AddFakeNode() {
// TODO: send updates to other (non-fake?) nodes // TODO: send updates to other (non-fake?) nodes
} }
func (s *Server) AllUsers() (users []*tailcfg.User) {
s.mu.Lock()
defer s.mu.Unlock()
for _, u := range s.users {
users = append(users, u.Clone())
}
return users
}
func (s *Server) AllNodes() (nodes []*tailcfg.Node) { func (s *Server) AllNodes() (nodes []*tailcfg.Node) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@ -494,6 +505,11 @@ func (s *Server) serveRegister(w http.ResponseWriter, r *http.Request, mkey key.
Addresses: allowedIPs, Addresses: allowedIPs,
AllowedIPs: allowedIPs, AllowedIPs: allowedIPs,
Hostinfo: req.Hostinfo.View(), Hostinfo: req.Hostinfo.View(),
Name: req.Hostinfo.Hostname,
Capabilities: []string{
tailcfg.NodeAttrFunnel,
tailcfg.CapabilityFunnelPorts + "?ports=8080,443",
},
} }
requireAuth := s.RequireAuth requireAuth := s.RequireAuth
if requireAuth && s.nodeKeyAuthed[nk] { if requireAuth && s.nodeKeyAuthed[nk] {
@ -729,6 +745,20 @@ func (s *Server) serveMap(w http.ResponseWriter, r *http.Request, mkey key.Machi
KeepAlive: true, KeepAlive: true,
} }
func packetFilterWithIngressCaps() []tailcfg.FilterRule {
out := slices.Clone(tailcfg.FilterAllowAll)
out = append(out, tailcfg.FilterRule{
SrcIPs: []string{"*"},
CapGrant: []tailcfg.CapGrant{
{
Dsts: []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()},
Caps: []string{tailcfg.CapabilityIngress},
},
},
})
return out
}
// MapResponse generates a MapResponse for a MapRequest. // MapResponse generates a MapResponse for a MapRequest.
// //
// No updates to s are done here. // No updates to s are done here.
@ -741,16 +771,24 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse,
} }
user, _ := s.getUser(nk) user, _ := s.getUser(nk)
t := time.Date(2020, 8, 3, 0, 0, 0, 1, time.UTC) t := time.Date(2020, 8, 3, 0, 0, 0, 1, time.UTC)
dns := s.DNSConfig
if dns != nil && s.MagicDNSDomain != "" {
dns = dns.Clone()
dns.CertDomains = []string{
fmt.Sprintf(node.Hostinfo.Hostname() + "." + s.MagicDNSDomain),
}
}
res = &tailcfg.MapResponse{ res = &tailcfg.MapResponse{
Node: node, Node: node,
DERPMap: s.DERPMap, DERPMap: s.DERPMap,
Domain: string(user.Domain), Domain: string(user.Domain),
CollectServices: "true", CollectServices: "true",
PacketFilter: tailcfg.FilterAllowAll, PacketFilter: packetFilterWithIngressCaps(),
Debug: &tailcfg.Debug{ Debug: &tailcfg.Debug{
DisableUPnP: "true", DisableUPnP: "true",
}, },
DNSConfig: s.DNSConfig, DNSConfig: dns,
ControlTime: &t, ControlTime: &t,
} }
for _, p := range s.AllNodes() { for _, p := range s.AllNodes() {
@ -761,6 +799,13 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse,
sort.Slice(res.Peers, func(i, j int) bool { sort.Slice(res.Peers, func(i, j int) bool {
return res.Peers[i].ID < res.Peers[j].ID return res.Peers[i].ID < res.Peers[j].ID
}) })
for _, u := range s.AllUsers() {
res.UserProfiles = append(res.UserProfiles, tailcfg.UserProfile{
ID: u.ID,
LoginName: u.LoginName,
DisplayName: u.DisplayName,
})
}
v4Prefix := netip.PrefixFrom(netaddr.IPv4(100, 64, uint8(tailcfg.NodeID(user.ID)>>8), uint8(tailcfg.NodeID(user.ID))), 32) v4Prefix := netip.PrefixFrom(netaddr.IPv4(100, 64, uint8(tailcfg.NodeID(user.ID)>>8), uint8(tailcfg.NodeID(user.ID))), 32)
v6Prefix := netip.PrefixFrom(tsaddr.Tailscale4To6(v4Prefix.Addr()), 128) v6Prefix := netip.PrefixFrom(tsaddr.Tailscale4To6(v4Prefix.Addr()), 128)