vnet: add control/derps to test, stateful firewall

Updates #13038

Change-Id: Icd65b34c5f03498b5a7109785bb44692bce8911a
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick
2024-08-06 17:33:45 -07:00
committed by Maisem Ali
parent 20691894f5
commit 8594292aa4
8 changed files with 599 additions and 48 deletions

View File

@@ -27,6 +27,10 @@ type Config struct {
networks []*Network
}
func (c *Config) NumNodes() int {
return len(c.nodes)
}
// AddNode creates a new node in the world.
//
// The opts may be of the following types:
@@ -110,6 +114,11 @@ type Node struct {
nets []*Network
}
// MAC returns the MAC address of the node.
func (n *Node) MAC() MAC {
return n.mac
}
// Network returns the first network this node is connected to,
// or nil if none.
func (n *Node) Network() *Network {

View File

@@ -5,6 +5,7 @@ package vnet
import (
"errors"
"log"
"math/rand/v2"
"net/netip"
"time"
@@ -111,9 +112,9 @@ func (n *oneToOneNAT) PickIncomingDst(src, dst netip.AddrPort, at time.Time) (la
return netip.AddrPortFrom(n.lanIP, dst.Port())
}
type hardKeyOut struct {
lanIP netip.Addr
dst netip.AddrPort
type srcDstTuple struct {
src netip.AddrPort
dst netip.AddrPort
}
type hardKeyIn struct {
@@ -137,7 +138,7 @@ type lanAddrAndTime struct {
type hardNAT struct {
wanIP netip.Addr
out map[hardKeyOut]portMappingAndTime
out map[srcDstTuple]portMappingAndTime
in map[hardKeyIn]lanAddrAndTime
}
@@ -148,7 +149,7 @@ func init() {
}
func (n *hardNAT) PickOutgoingSrc(src, dst netip.AddrPort, at time.Time) (wanSrc netip.AddrPort) {
ko := hardKeyOut{src.Addr(), dst}
ko := srcDstTuple{src, dst}
if pm, ok := n.out[ko]; ok {
// Existing flow.
// TODO: bump timestamp
@@ -196,9 +197,10 @@ func (n *hardNAT) PickIncomingDst(src, dst netip.AddrPort, at time.Time) (lanDst
// Unlike Linux, this implementation is capped at 32k entries and doesn't resort
// to other allocation strategies when all 32k WAN ports are taken.
type easyNAT struct {
wanIP netip.Addr
out map[netip.AddrPort]portMappingAndTime
in map[uint16]lanAddrAndTime
wanIP netip.Addr
out map[netip.AddrPort]portMappingAndTime
in map[uint16]lanAddrAndTime
lastOut map[srcDstTuple]time.Time // (lan:port, wan:port) => last packet out time
}
func init() {
@@ -208,6 +210,7 @@ func init() {
}
func (n *easyNAT) PickOutgoingSrc(src, dst netip.AddrPort, at time.Time) (wanSrc netip.AddrPort) {
mak.Set(&n.lastOut, srcDstTuple{src, dst}, at)
if pm, ok := n.out[src]; ok {
// Existing flow.
// TODO: bump timestamp
@@ -235,5 +238,14 @@ func (n *easyNAT) PickIncomingDst(src, dst netip.AddrPort, at time.Time) (lanDst
if dst.Addr() != n.wanIP {
return netip.AddrPort{} // drop; not for us. shouldn't happen if natlabd routing isn't broken.
}
return n.in[dst.Port()].lanAddr
lanDst = n.in[dst.Port()].lanAddr
// Stateful firewall: drop incoming packets that don't have traffic out.
// TODO(bradfitz): verify Linux does this in the router code, not in the NAT code.
if t, ok := n.lastOut[srcDstTuple{lanDst, src}]; !ok || at.Sub(t) > 300*time.Second {
log.Printf("Drop incoming packet from %v to %v; no recent outgoing packet", src, dst)
return netip.AddrPort{}
}
return lanDst
}

View File

@@ -16,6 +16,7 @@ package vnet
import (
"bufio"
"context"
"crypto/tls"
"encoding/binary"
"encoding/json"
"errors"
@@ -24,6 +25,7 @@ import (
"log"
"net"
"net/http"
"net/http/httptest"
"net/netip"
"os/exec"
"strconv"
@@ -44,9 +46,15 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter"
"tailscale.com/derp"
"tailscale.com/derp/derphttp"
"tailscale.com/net/netutil"
"tailscale.com/net/stun"
"tailscale.com/syncs"
"tailscale.com/tailcfg"
"tailscale.com/tstest/integration/testcontrol"
"tailscale.com/types/key"
"tailscale.com/types/logger"
"tailscale.com/util/mak"
"tailscale.com/util/set"
)
@@ -240,6 +248,7 @@ func (n *network) acceptTCP(r *tcp.ForwarderRequest) {
log.Printf("AcceptTCP: %v", stringifyTEI(reqDetails))
clientRemoteIP := netaddrIPFromNetstackIP(reqDetails.RemoteAddress)
destIP := netaddrIPFromNetstackIP(reqDetails.LocalAddress)
destPort := reqDetails.LocalPort
if !clientRemoteIP.IsValid() {
r.Complete(true) // sends a RST
return
@@ -254,7 +263,7 @@ func (n *network) acceptTCP(r *tcp.ForwarderRequest) {
}
ep.SocketOptions().SetKeepAlive(true)
if reqDetails.LocalPort == 123 {
if destPort == 123 {
r.Complete(false)
tc := gonet.NewTCPConn(&wq, ep)
io.WriteString(tc, "Hello from Go\nGoodbye.\n")
@@ -262,7 +271,21 @@ func (n *network) acceptTCP(r *tcp.ForwarderRequest) {
return
}
if reqDetails.LocalPort == 8008 && destIP == fakeTestAgentIP {
if destPort == 124 {
r.Complete(false)
tc := gonet.NewTCPConn(&wq, ep)
go func() {
defer tc.Close()
bs := bufio.NewScanner(tc)
for bs.Scan() {
line := bs.Text()
log.Printf("LOG from guest: %s", line)
}
}()
return
}
if destPort == 8008 && destIP == fakeTestAgentIP {
r.Complete(false)
tc := gonet.NewTCPConn(&wq, ep)
node := n.nodesByIP[clientRemoteIP]
@@ -271,11 +294,40 @@ func (n *network) acceptTCP(r *tcp.ForwarderRequest) {
return
}
if destPort == 80 && destIP == fakeControlIP {
r.Complete(false)
tc := gonet.NewTCPConn(&wq, ep)
hs := &http.Server{Handler: n.s.control}
go hs.Serve(netutil.NewOneConnListener(tc, nil))
return
}
if destPort == 443 && (destIP == fakeDERP1IP || destIP == fakeDERP2IP) {
ds := n.s.derps[0]
if destIP == fakeDERP2IP {
ds = n.s.derps[1]
}
r.Complete(false)
tc := gonet.NewTCPConn(&wq, ep)
tlsConn := tls.Server(tc, ds.tlsConfig)
hs := &http.Server{Handler: ds.handler}
go hs.Serve(netutil.NewOneConnListener(tlsConn, nil))
return
}
if destPort == 80 && (destIP == fakeDERP1IP || destIP == fakeDERP2IP) {
r.Complete(false)
tc := gonet.NewTCPConn(&wq, ep)
hs := &http.Server{Handler: n.s.derps[0].handler}
go hs.Serve(netutil.NewOneConnListener(tc, nil))
return
}
var targetDial string
if n.s.derpIPs.Contains(destIP) {
targetDial = destIP.String() + ":" + strconv.Itoa(int(reqDetails.LocalPort))
} else if destIP == fakeControlplaneIP {
targetDial = "controlplane.tailscale.com:" + strconv.Itoa(int(reqDetails.LocalPort))
targetDial = destIP.String() + ":" + strconv.Itoa(int(destPort))
} else if destIP == fakeProxyControlplaneIP {
targetDial = "controlplane.tailscale.com:" + strconv.Itoa(int(destPort))
}
if targetDial != "" {
c, err := net.Dial("tcp", targetDial)
@@ -298,9 +350,12 @@ func (n *network) acceptTCP(r *tcp.ForwarderRequest) {
}
var (
fakeDNSIP = netip.AddrFrom4([4]byte{4, 11, 4, 11})
fakeControlplaneIP = netip.AddrFrom4([4]byte{52, 52, 0, 1})
fakeTestAgentIP = netip.AddrFrom4([4]byte{52, 52, 0, 2})
fakeDNSIP = netip.AddrFrom4([4]byte{4, 11, 4, 11})
fakeProxyControlplaneIP = netip.AddrFrom4([4]byte{52, 52, 0, 1}) // real controlplane.tailscale.com proxy
fakeTestAgentIP = netip.AddrFrom4([4]byte{52, 52, 0, 2})
fakeControlIP = netip.AddrFrom4([4]byte{52, 52, 0, 3}) // 3=C for "Control"
fakeDERP1IP = netip.AddrFrom4([4]byte{33, 4, 0, 1}) // 3340=DERP; 1=derp 1
fakeDERP2IP = netip.AddrFrom4([4]byte{33, 4, 0, 2}) // 3340=DERP; 1=derp 1
)
type EthernetPacket struct {
@@ -381,9 +436,33 @@ type node struct {
lanIP netip.Addr // must be in net.lanIP prefix + unique in net
}
type derpServer struct {
srv *derp.Server
handler http.Handler
tlsConfig *tls.Config
}
func newDERPServer() *derpServer {
// Just to get a self-signed TLS cert:
ts := httptest.NewTLSServer(nil)
ts.Close()
ds := &derpServer{
srv: derp.NewServer(key.NewNode(), logger.Discard),
tlsConfig: ts.TLS, // self-signed; test client configure to not check
}
var mux http.ServeMux
mux.Handle("/derp", derphttp.Handler(ds.srv))
mux.HandleFunc("/generate_204", derphttp.ServeNoContent)
ds.handler = &mux
return ds
}
type Server struct {
shutdownCtx context.Context
shutdownCancel context.CancelFunc
blendReality bool
derpIPs set.Set[netip.Addr]
@@ -392,10 +471,50 @@ type Server struct {
networks set.Set[*network]
networkByWAN map[netip.Addr]*network
mu sync.Mutex
agentConnWaiter map[*node]chan<- struct{} // signaled after added to set
agentConns set.Set[*agentConn] // not keyed by node; should be small/cheap enough to scan all
agentRoundTripper map[*node]*http.Transport
control *testcontrol.Server
derps []*derpServer
mu sync.Mutex
agentConnWaiter map[*node]chan<- struct{} // signaled after added to set
agentConns set.Set[*agentConn] // not keyed by node; should be small/cheap enough to scan all
agentDialer map[*node]DialFunc
}
type DialFunc func(ctx context.Context, network, address string) (net.Conn, error)
var derpMap = &tailcfg.DERPMap{
Regions: map[int]*tailcfg.DERPRegion{
1: {
RegionID: 1,
RegionCode: "atlantis",
RegionName: "Atlantis",
Nodes: []*tailcfg.DERPNode{
{
Name: "1a",
RegionID: 1,
HostName: "derp1.tailscale",
IPv4: fakeDERP1IP.String(),
InsecureForTests: true,
CanPort80: true,
},
},
},
2: {
RegionID: 2,
RegionCode: "northpole",
RegionName: "North Pole",
Nodes: []*tailcfg.DERPNode{
{
Name: "2a",
RegionID: 2,
HostName: "derp2.tailscale",
IPv4: fakeDERP2IP.String(),
InsecureForTests: true,
CanPort80: true,
},
},
},
},
}
func New(c *Config) (*Server, error) {
@@ -404,12 +523,20 @@ func New(c *Config) (*Server, error) {
shutdownCtx: ctx,
shutdownCancel: cancel,
control: &testcontrol.Server{
DERPMap: derpMap,
ExplicitBaseURL: "http://control.tailscale",
},
derpIPs: set.Of[netip.Addr](),
nodeByMAC: map[MAC]*node{},
networkByWAN: map[netip.Addr]*network{},
networks: set.Of[*network](),
}
for range 2 {
s.derps = append(s.derps, newDERPServer())
}
if err := s.initFromConfig(c); err != nil {
return nil, err
}
@@ -418,9 +545,14 @@ func New(c *Config) (*Server, error) {
return nil, fmt.Errorf("newServer: initStack: %v", err)
}
}
return s, nil
}
func (s *Server) Close() {
s.shutdownCancel()
}
func (s *Server) HWAddr(mac MAC) net.HardwareAddr {
// TODO: cache
return net.HardwareAddr(mac[:])
@@ -435,7 +567,13 @@ func (s *Server) IPv4ForDNS(qname string) (netip.Addr, bool) {
case "test-driver.tailscale":
return fakeTestAgentIP, true
case "controlplane.tailscale.com":
return fakeControlplaneIP, true
return fakeProxyControlplaneIP, true
case "control.tailscale":
return fakeControlIP, true
case "derp1.tailscale":
return fakeDERP1IP, true
case "derp2.tailscale":
return fakeDERP2IP, true
}
return netip.Addr{}, false
}
@@ -538,7 +676,10 @@ func (s *Server) routeUDPPacket(up UDPPacket) {
if up.Dst.Port() == stunPort {
// TODO(bradfitz): fake latency; time.AfterFunc the response
if res, ok := makeSTUNReply(up); ok {
//log.Printf("STUN reply: %+v", res)
s.routeUDPPacket(res)
} else {
log.Printf("weird: STUN packet not handled")
}
return
}
@@ -622,6 +763,7 @@ func (n *network) HandleEthernetPacket(ep EthernetPacket) {
func (n *network) HandleUDPPacket(p UDPPacket) {
dst := n.doNATIn(p.Src, p.Dst)
if !dst.IsValid() {
log.Printf("Warning: NAT dropped packet; no mapping for %v=>%v", p.Src, p.Dst)
return
}
p.Dst = dst
@@ -726,7 +868,10 @@ func (n *network) HandleEthernetIPv4PacketForRouter(ep EthernetPacket) {
if toForward && isUDP {
src := netip.AddrPortFrom(srcIP, uint16(udp.SrcPort))
dst := netip.AddrPortFrom(dstIP, uint16(udp.DstPort))
src0 := src
src = n.doNATOut(src, dst)
_ = src0
//log.Printf("XXX UDP out %v=>%v to %v", src0, src, dst)
n.s.routeUDPPacket(UDPPacket{
Src: src,
@@ -891,12 +1036,19 @@ func (s *Server) shouldInterceptTCP(pkt gopacket.Packet) bool {
if !ok {
return false
}
if tcp.DstPort == 123 {
if tcp.DstPort == 123 || tcp.DstPort == 124 {
return true
}
dstIP, _ := netip.AddrFromSlice(ipv4.DstIP.To4())
if tcp.DstPort == 80 || tcp.DstPort == 443 {
if dstIP == fakeControlplaneIP || s.derpIPs.Contains(dstIP) {
switch dstIP {
case fakeControlIP, fakeDERP1IP, fakeDERP2IP:
return true
}
if dstIP == fakeProxyControlplaneIP {
return s.blendReality
}
if s.derpIPs.Contains(dstIP) {
return true
}
}
@@ -1166,12 +1318,15 @@ func (s *Server) takeAgentConn(ctx context.Context, n *node) (_ *agentConn, ok b
for {
ac, ok := s.takeAgentConnOne(n)
if ok {
log.Printf("got agent conn for %v", n.mac)
return ac, true
}
s.mu.Lock()
ready := make(chan struct{})
mak.Set(&s.agentConnWaiter, n, ready)
s.mu.Unlock()
log.Printf("waiting for agent conn for %v", n.mac)
select {
case <-ctx.Done():
return nil, false
@@ -1190,36 +1345,40 @@ func (s *Server) takeAgentConnOne(n *node) (_ *agentConn, ok bool) {
for ac := range s.agentConns {
if ac.node == n {
s.agentConns.Delete(ac)
log.Printf("XXX takeAgentConnOne HIT for %v", n.mac)
return ac, true
}
}
log.Printf("XXX takeAgentConnOne MISS for %v", n.mac)
return nil, false
}
func (s *Server) NodeAgentRoundTripper(ctx context.Context, n *Node) http.RoundTripper {
func (s *Server) NodeAgentDialer(n *Node) DialFunc {
s.mu.Lock()
defer s.mu.Unlock()
if rt, ok := s.agentRoundTripper[n.n]; ok {
return rt
if d, ok := s.agentDialer[n.n]; ok {
return d
}
var rt = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
ac, ok := s.takeAgentConn(ctx, n.n)
if !ok {
return nil, ctx.Err()
}
return ac.tc, nil
},
d := func(ctx context.Context, network, addr string) (net.Conn, error) {
ac, ok := s.takeAgentConn(ctx, n.n)
if !ok {
return nil, ctx.Err()
}
return ac.tc, nil
}
mak.Set(&s.agentDialer, n.n, d)
return d
}
mak.Set(&s.agentRoundTripper, n.n, rt)
return rt
func (s *Server) NodeAgentRoundTripper(n *Node) http.RoundTripper {
return &http.Transport{
DialContext: s.NodeAgentDialer(n),
}
}
func (s *Server) NodeStatus(ctx context.Context, n *Node) ([]byte, error) {
rt := s.NodeAgentRoundTripper(ctx, n)
rt := s.NodeAgentRoundTripper(n)
req, err := http.NewRequestWithContext(ctx, "GET", "http://node/status", nil)
if err != nil {
return nil, err