diff --git a/tstest/controll/controll.go b/tstest/controll/controll.go new file mode 100644 index 000000000..9afb1979f --- /dev/null +++ b/tstest/controll/controll.go @@ -0,0 +1,290 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The controll program trolls tailscaleds, simulating huge and busy tailnets. +package main + +import ( + "bytes" + "context" + "encoding/json" + "flag" + "fmt" + "log" + "math/rand/v2" + "net/http" + "net/netip" + "os" + "path/filepath" + "testing" + "time" + + "golang.org/x/crypto/acme/autocert" + "tailscale.com/net/tsaddr" + "tailscale.com/tailcfg" + "tailscale.com/tstest/integration" + "tailscale.com/tstest/integration/testcontrol" + "tailscale.com/types/key" + "tailscale.com/types/logger" + "tailscale.com/types/ptr" + "tailscale.com/util/must" +) + +var ( + flagNFake = flag.Int("nfake", 0, "number of fake nodes to add to network") + certHost = flag.String("certhost", "controll.fitz.dev", "hostname to use in TLS certificate") +) + +type state struct { + Legacy key.ControlPrivate + Machine key.MachinePrivate +} + +func loadState() *state { + st := &state{} + path := filepath.Join(must.Get(os.UserCacheDir()), "controll.state") + f, _ := os.ReadFile(path) + f = bytes.TrimSpace(f) + if err := json.Unmarshal(f, st); err == nil { + return st + } + st.Legacy = key.NewControl() + st.Machine = key.NewMachine() + f = must.Get(json.Marshal(st)) + must.Do(os.WriteFile(path, f, 0600)) + return st +} + +func main() { + flag.Parse() + + var t fakeTB + derpMap := integration.RunDERPAndSTUN(t, logger.Discard, "127.0.0.1") + + certManager := &autocert.Manager{ + Prompt: autocert.AcceptTOS, + HostPolicy: autocert.HostWhitelist(*certHost), + Cache: autocert.DirCache(filepath.Join(must.Get(os.UserCacheDir()), "controll-cert")), + } + + control := &testcontrol.Server{ + DERPMap: derpMap, + ExplicitBaseURL: "http://127.0.0.1:9911", + TolerateUnknownPaths: true, + AltMapStream: sendClientChaos, + } + + st := loadState() + control.SetPrivateKeys(st.Machine, st.Legacy) + for range *flagNFake { + control.AddFakeNode() + } + mux := http.NewServeMux() + mux.HandleFunc("/{$}", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("controll")) + }) + mux.Handle("/", control) + + go func() { + addr := "127.0.0.1:9911" + log.Printf("listening on %s", addr) + err := http.ListenAndServe(addr, mux) + log.Fatal(err) + }() + + if *certHost != "" { + go func() { + srv := &http.Server{ + Addr: ":https", + Handler: mux, + TLSConfig: certManager.TLSConfig(), + } + log.Fatalf("TLS: %v", srv.ListenAndServeTLS("", "")) + }() + } + + select {} +} + +func node4(nid tailcfg.NodeID) netip.Prefix { + return netip.PrefixFrom( + netip.AddrFrom4([4]byte{100, 100 + byte(nid>>16), byte(nid >> 8), byte(nid)}), + 32) +} + +func node6(nid tailcfg.NodeID) netip.Prefix { + a := tsaddr.TailscaleULARange().Addr().As16() + a[13] = byte(nid >> 16) + a[14] = byte(nid >> 8) + a[15] = byte(nid) + v6 := netip.AddrFrom16(a) + return netip.PrefixFrom(v6, 128) +} + +func sendClientChaos(ctx context.Context, w testcontrol.MapStreamWriter, r *tailcfg.MapRequest) { + selfPub := r.NodeKey + + nodeID := tailcfg.NodeID(0) + newNodeID := func() tailcfg.NodeID { + nodeID++ + return nodeID + } + + selfNodeID := newNodeID() + selfIP4 := node4(nodeID) + selfIP6 := node6(nodeID) + + selfUserID := tailcfg.UserID(1_000_000) + + var peers []*tailcfg.Node + for range *flagNFake { + nid := newNodeID() + v4, v6 := node4(nid), node6(nid) + user := selfUserID + if rand.IntN(2) == 0 { + // Randomly assign a different user to the peer. + // ... + } + peers = append(peers, &tailcfg.Node{ + ID: nid, + StableID: tailcfg.StableNodeID(fmt.Sprintf("peer-%d", nid)), + Name: fmt.Sprintf("peer-%d.troll.ts.net.", nid), + Key: key.NewNode().Public(), + MachineAuthorized: true, + DiscoKey: key.NewDisco().Public(), + Addresses: []netip.Prefix{v4, v6}, + AllowedIPs: []netip.Prefix{v4, v6}, + User: user, + }) + } + + w.SendMapMessage(&tailcfg.MapResponse{ + Node: &tailcfg.Node{ + ID: selfNodeID, + StableID: "self", + Name: "test-mctestfast.troll.ts.net.", + User: selfUserID, + Key: selfPub, + KeyExpiry: time.Now().Add(5000 * time.Hour), + Machine: key.NewMachine().Public(), // fake; client shouldn't care + DiscoKey: r.DiscoKey, + MachineAuthorized: true, + Addresses: []netip.Prefix{selfIP4, selfIP6}, + AllowedIPs: []netip.Prefix{selfIP4, selfIP6}, + Capabilities: []tailcfg.NodeCapability{}, + CapMap: map[tailcfg.NodeCapability][]tailcfg.RawMessage{}, + }, + DERPMap: &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 1: {RegionID: 1, + Nodes: []*tailcfg.DERPNode{{ + RegionID: 1, + Name: "1i", + IPv4: "199.38.181.103", + IPv6: "2607:f740:f::e19", + HostName: "derp1i.tailscale.com", + CanPort80: true, + }}}, + }, + }, + Peers: peers, + }) + + sendChange := func() error { + const ( + actionToggleOnline = iota + numActions + ) + action := rand.IntN(numActions) + switch action { + case actionToggleOnline: + peer := peers[rand.IntN(len(peers))] + online := peer.Online != nil && *peer.Online + peer.Online = ptr.To(!online) + var lastSeen *time.Time + if !online { + lastSeen = ptr.To(time.Now().UTC().Round(time.Second)) + } + w.SendMapMessage(&tailcfg.MapResponse{ + PeersChangedPatch: []*tailcfg.PeerChange{ + { + NodeID: peer.ID, + Online: peer.Online, + LastSeen: lastSeen, + }, + }, + }) + } + return nil + } + + ticker := time.NewTicker(250 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-ticker.C: + if err := sendChange(); err != nil { + log.Printf("sendChange: %v", err) + return + } + case <-ctx.Done(): + return + } + } +} + +type fakeTB struct { + *testing.T +} + +func (t fakeTB) Cleanup(_ func()) {} +func (t fakeTB) Error(args ...any) { + t.Fatal(args...) +} +func (t fakeTB) Errorf(format string, args ...any) { + t.Fatalf(format, args...) +} +func (t fakeTB) Fail() { + t.Fatal("failed") +} +func (t fakeTB) FailNow() { + t.Fatal("failed") +} +func (t fakeTB) Failed() bool { + return false +} +func (t fakeTB) Fatal(args ...any) { + log.Fatal(args...) +} +func (t fakeTB) Fatalf(format string, args ...any) { + log.Fatalf(format, args...) +} +func (t fakeTB) Helper() {} +func (t fakeTB) Log(args ...any) { + log.Print(args...) +} +func (t fakeTB) Logf(format string, args ...any) { + log.Printf(format, args...) +} +func (t fakeTB) Name() string { + return "faketest" +} +func (t fakeTB) Setenv(key string, value string) { + panic("not implemented") +} +func (t fakeTB) Skip(args ...any) { + t.Fatal("skipped") +} +func (t fakeTB) SkipNow() { + t.Fatal("skipnow") +} +func (t fakeTB) Skipf(format string, args ...any) { + t.Logf(format, args...) + t.Fatal("skipped") +} +func (t fakeTB) Skipped() bool { + return false +} +func (t fakeTB) TempDir() string { + panic("not implemented") +} diff --git a/tstest/integration/testcontrol/testcontrol.go b/tstest/integration/testcontrol/testcontrol.go index e127087a6..249d6f6e0 100644 --- a/tstest/integration/testcontrol/testcontrol.go +++ b/tstest/integration/testcontrol/testcontrol.go @@ -46,19 +46,22 @@ const msgLimit = 1 << 20 // encrypted message length limit // Server is a control plane server. Its zero value is ready for use. // Everything is stored in-memory in one tailnet. type Server struct { - Logf logger.Logf // nil means to use the log package - DERPMap *tailcfg.DERPMap // nil means to use prod DERP map - RequireAuth bool - RequireAuthKey string // required authkey for all nodes - Verbose bool - DNSConfig *tailcfg.DNSConfig // nil means no DNS config - MagicDNSDomain string - HandleC2N http.Handler // if non-nil, used for /some-c2n-path/ in tests + Logf logger.Logf // nil means to use the log package + DERPMap *tailcfg.DERPMap // nil means to use prod DERP map + RequireAuth bool + RequireAuthKey string // required authkey for all nodes + Verbose bool + DNSConfig *tailcfg.DNSConfig // nil means no DNS config + MagicDNSDomain string + HandleC2N http.Handler // if non-nil, used for /some-c2n-path/ in tests + TolerateUnknownPaths bool // if true, serve 404 instead of panicking on unknown URLs paths // ExplicitBaseURL or HTTPTestServer must be set. ExplicitBaseURL string // e.g. "http://127.0.0.1:1234" with no trailing URL HTTPTestServer *httptest.Server // if non-nil, used to get BaseURL + AltMapStream func(context.Context, MapStreamWriter, *tailcfg.MapRequest) + initMuxOnce sync.Once mux *http.ServeMux @@ -268,10 +271,15 @@ func (s *Server) initMux() { func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.initMuxOnce.Do(s.initMux) + log.Printf("control: %s %s", r.Method, r.URL.Path) s.mux.ServeHTTP(w, r) } func (s *Server) serveUnhandled(w http.ResponseWriter, r *http.Request) { + if s.TolerateUnknownPaths { + http.Error(w, "unknown control URL", http.StatusNotFound) + return + } var got bytes.Buffer r.Write(&got) go panic(fmt.Sprintf("testcontrol.Server received unhandled request: %s", got.Bytes())) @@ -324,6 +332,13 @@ func (s *Server) ensureKeyPairLocked() { s.pubKey = s.privKey.Public() } +func (s *Server) SetPrivateKeys(noise key.MachinePrivate, legacy key.ControlPrivate) { + s.noisePrivKey = noise + s.noisePubKey = noise.Public() + s.privKey = legacy + s.pubKey = legacy.Public() +} + func (s *Server) serveKey(w http.ResponseWriter, r *http.Request) { noiseKey, legacyKey := s.publicKeys() if r.FormValue("v") == "" { @@ -460,19 +475,20 @@ func (s *Server) AddFakeNode() { mk := key.NewMachine().Public() dk := key.NewDisco().Public() r := nk.Raw32() - id := int64(binary.LittleEndian.Uint64(r[:])) + id := int64(binary.LittleEndian.Uint64(r[:]) >> 11) ip := netaddr.IPv4(r[0], r[1], r[2], r[3]) addr := netip.PrefixFrom(ip, 32) s.nodes[nk] = &tailcfg.Node{ ID: tailcfg.NodeID(id), StableID: tailcfg.StableNodeID(fmt.Sprintf("TESTCTRL%08x", id)), - User: tailcfg.UserID(id), + User: 123, Machine: mk, Key: nk, MachineAuthorized: true, DiscoKey: dk, Addresses: []netip.Prefix{addr}, AllowedIPs: []netip.Prefix{addr}, + Name: fmt.Sprintf("node-%d.big-troll.ts.net.", id), } // TODO: send updates to other (non-fake?) nodes } @@ -613,7 +629,7 @@ func (s *Server) serveRegister(w http.ResponseWriter, r *http.Request, mkey key. log.Printf("Got %T: %s", req, j) } if s.RequireAuthKey != "" && (req.Auth == nil || req.Auth.AuthKey != s.RequireAuthKey) { - res := must.Get(s.encode(false, tailcfg.RegisterResponse{ + res := must.Get(encode(false, tailcfg.RegisterResponse{ Error: "invalid authkey", })) w.WriteHeader(200) @@ -687,7 +703,7 @@ func (s *Server) serveRegister(w http.ResponseWriter, r *http.Request, mkey key. authURL = s.BaseURL() + authPath } - res, err := s.encode(false, tailcfg.RegisterResponse{ + res, err := encode(false, tailcfg.RegisterResponse{ User: *user, Login: *login, NodeKeyExpired: allExpired, @@ -765,6 +781,37 @@ func (s *Server) InServeMap() int { return s.inServeMap } +type MapStreamWriter interface { + SendMapMessage(*tailcfg.MapResponse) error +} + +type mapStreamSender struct { + w io.Writer + compress bool +} + +func (s mapStreamSender) SendMapMessage(msg *tailcfg.MapResponse) error { + resBytes, err := encode(s.compress, msg) + if err != nil { + return err + } + if len(resBytes) > 16<<20 { + return fmt.Errorf("map message too big: %d", len(resBytes)) + } + var siz [4]byte + binary.LittleEndian.PutUint32(siz[:], uint32(len(resBytes))) + if _, err := s.w.Write(siz[:]); err != nil { + return err + } + if _, err := s.w.Write(resBytes); err != nil { + return err + } + if f, ok := s.w.(http.Flusher); ok { + f.Flush() + } + return nil +} + func (s *Server) serveMap(w http.ResponseWriter, r *http.Request, mkey key.MachinePublic) { s.incrInServeMap(1) defer s.incrInServeMap(-1) @@ -783,6 +830,19 @@ func (s *Server) serveMap(w http.ResponseWriter, r *http.Request, mkey key.Machi go panic(fmt.Sprintf("bad map request: %v", err)) } + // ReadOnly implies no streaming, as it doesn't + // register an updatesCh to get updates. + streaming := req.Stream && !req.ReadOnly + compress := req.Compress != "" + + if s.AltMapStream != nil { + s.AltMapStream(ctx, mapStreamSender{ + w: w, + compress: compress, + }, req) + return + } + jitter := rand.N(8 * time.Second) keepAlive := 50*time.Second + jitter @@ -832,11 +892,6 @@ func (s *Server) serveMap(w http.ResponseWriter, r *http.Request, mkey key.Machi s.condLocked().Broadcast() s.mu.Unlock() - // ReadOnly implies no streaming, as it doesn't - // register an updatesCh to get updates. - streaming := req.Stream && !req.ReadOnly - compress := req.Compress != "" - w.WriteHeader(200) for { if resBytes, ok := s.takeRawMapMessage(req.NodeKey); ok { @@ -1063,7 +1118,7 @@ func (s *Server) takeRawMapMessage(nk key.NodePublic) (mapResJSON []byte, ok boo } func (s *Server) sendMapMsg(w http.ResponseWriter, compress bool, msg any) error { - resBytes, err := s.encode(compress, msg) + resBytes, err := encode(compress, msg) if err != nil { return err } @@ -1093,7 +1148,7 @@ func (s *Server) decode(msg []byte, v any) error { return json.Unmarshal(msg, v) } -func (s *Server) encode(compress bool, v any) (b []byte, err error) { +func encode(compress bool, v any) (b []byte, err error) { var isBytes bool if b, isBytes = v.([]byte); !isBytes { b, err = json.Marshal(v)