diff --git a/tstest/integration/integration_test.go b/tstest/integration/integration_test.go index b30ea2f33..df35482d5 100644 --- a/tstest/integration/integration_test.go +++ b/tstest/integration/integration_test.go @@ -27,6 +27,7 @@ "go4.org/mem" "tailscale.com/smallzstd" "tailscale.com/tstest" + "tailscale.com/tstest/integration/testcontrol" ) func TestIntegration(t *testing.T) { @@ -41,14 +42,20 @@ func TestIntegration(t *testing.T) { ts := httptest.NewServer(logc) defer ts.Close() - httpProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // catchBadTrafficProxy explodes if it gets any traffic. + // It's here to catch anything that would otherwise try to leave localhost. + catchBadTrafficProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var got bytes.Buffer r.Write(&got) err := fmt.Errorf("unexpected HTTP proxy via proxy: %s", got.Bytes()) t.Error(err) go panic(err) })) - defer httpProxy.Close() + defer catchBadTrafficProxy.Close() + + controlServer := new(testcontrol.Server) + controlHTTPServer := httptest.NewServer(controlServer) + defer controlHTTPServer.Close() socketPath := filepath.Join(td, "tailscale.sock") dcmd := exec.Command(daemonExe, @@ -58,8 +65,8 @@ func TestIntegration(t *testing.T) { ) dcmd.Env = append(os.Environ(), "TS_LOG_TARGET="+ts.URL, - "HTTP_PROXY="+httpProxy.URL, - "HTTPS_PROXY="+httpProxy.URL, + "HTTP_PROXY="+catchBadTrafficProxy.URL, + "HTTPS_PROXY="+catchBadTrafficProxy.URL, ) if err := dcmd.Start(); err != nil { t.Fatalf("starting tailscaled: %v", err) @@ -77,11 +84,6 @@ func TestIntegration(t *testing.T) { t.Fatal(err) } - if os.Getenv("TS_RUN_TEST") == "failing_up" { - // Force a connection through the HTTP proxy to panic and fail. - exec.Command(cliExe, "--socket="+socketPath, "up").Run() - } - if err := tstest.WaitFor(20*time.Second, func() error { const sub = `Program starting: ` if !logc.logsContains(mem.S(sub)) { @@ -92,6 +94,23 @@ func TestIntegration(t *testing.T) { t.Error(err) } + if err := exec.Command(cliExe, "--socket="+socketPath, "up", "--login-server="+controlHTTPServer.URL).Run(); err != nil { + t.Fatalf("up: %v", err) + } + + var ip string + if err := tstest.WaitFor(20*time.Second, func() error { + out, err := exec.Command(cliExe, "--socket="+socketPath, "ip").Output() + if err != nil { + return err + } + ip = string(out) + return nil + }); err != nil { + t.Error(err) + } + t.Logf("Got IP: %v", ip) + dcmd.Process.Signal(os.Interrupt) ps, err := dcmd.Process.Wait() diff --git a/tstest/integration/testcontrol/testcontrol.go b/tstest/integration/testcontrol/testcontrol.go new file mode 100644 index 000000000..fe3a12911 --- /dev/null +++ b/tstest/integration/testcontrol/testcontrol.go @@ -0,0 +1,545 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package testcontrol contains a minimal control plane server for testing purposes. +package testcontrol + +import ( + "bytes" + crand "crypto/rand" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "io" + "io/ioutil" + "log" + "math/rand" + "net/http" + "strings" + "sync" + "time" + + "github.com/klauspost/compress/zstd" + "golang.org/x/crypto/nacl/box" + "inet.af/netaddr" + "tailscale.com/derp/derpmap" + "tailscale.com/smallzstd" + "tailscale.com/tailcfg" + "tailscale.com/types/logger" + "tailscale.com/types/wgkey" +) + +// Server is a control plane server. Its zero value is ready for use. +// Everything is stored in-memory in one tailnet. +type Server struct { + Logf logger.Logf // nil means to use the log package + DERPMap *tailcfg.DERPMap // nil means to use prod DERP map + + initMuxOnce sync.Once + mux *http.ServeMux + + mu sync.Mutex + pubKey wgkey.Key + privKey wgkey.Private + nodes map[tailcfg.NodeKey]*tailcfg.Node + users map[tailcfg.NodeKey]*tailcfg.User + logins map[tailcfg.NodeKey]*tailcfg.Login + updates map[tailcfg.NodeID]chan updateType +} + +func (s *Server) logf(format string, a ...interface{}) { + if s.Logf != nil { + s.Logf(format, a...) + } else { + log.Printf(format, a...) + } +} + +func (s *Server) initMux() { + s.mux = http.NewServeMux() + s.mux.HandleFunc("/", s.serveUnhandled) + s.mux.HandleFunc("/key", s.serveKey) + s.mux.HandleFunc("/machine/", s.serveMachine) +} + +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + s.initMuxOnce.Do(s.initMux) + s.mux.ServeHTTP(w, r) +} + +func (s *Server) serveUnhandled(w http.ResponseWriter, r *http.Request) { + var got bytes.Buffer + r.Write(&got) + go panic(fmt.Sprintf("testcontrol.Server received unhandled request: %s", got.Bytes())) +} + +func (s *Server) publicKey() wgkey.Key { + pub, _ := s.keyPair() + return pub +} + +func (s *Server) privateKey() wgkey.Private { + _, priv := s.keyPair() + return priv +} + +func (s *Server) keyPair() (pub wgkey.Key, priv wgkey.Private) { + s.mu.Lock() + defer s.mu.Unlock() + if s.pubKey.IsZero() { + var err error + s.privKey, err = wgkey.NewPrivate() + if err != nil { + go panic(err) // bring down test, even if in http.Handler + } + s.pubKey = s.privKey.Public() + } + return s.pubKey, s.privKey +} + +func (s *Server) serveKey(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(200) + io.WriteString(w, s.publicKey().HexString()) +} + +func (s *Server) serveMachine(w http.ResponseWriter, r *http.Request) { + mkeyStr := strings.TrimPrefix(r.URL.Path, "/machine/") + rem := "" + if i := strings.IndexByte(mkeyStr, '/'); i != -1 { + rem = mkeyStr[i:] + mkeyStr = mkeyStr[:i] + } + + key, err := wgkey.ParseHex(mkeyStr) + if err != nil { + http.Error(w, "bad machine key hex", 400) + return + } + mkey := tailcfg.MachineKey(key) + + if r.Method != "POST" { + http.Error(w, "POST required", 400) + return + } + + switch rem { + case "": + s.serveRegister(w, r, mkey) + case "/map": + s.serveMap(w, r, mkey) + default: + s.serveUnhandled(w, r) + } +} + +// Node returns the node for nodeKey. It's always nil or cloned memory. +func (s *Server) Node(nodeKey tailcfg.NodeKey) *tailcfg.Node { + s.mu.Lock() + defer s.mu.Unlock() + return s.nodes[nodeKey].Clone() +} + +func (s *Server) getUser(nodeKey tailcfg.NodeKey) (*tailcfg.User, *tailcfg.Login) { + s.mu.Lock() + defer s.mu.Unlock() + if s.users == nil { + s.users = map[tailcfg.NodeKey]*tailcfg.User{} + } + if s.logins == nil { + s.logins = map[tailcfg.NodeKey]*tailcfg.Login{} + } + if u, ok := s.users[nodeKey]; ok { + return u, s.logins[nodeKey] + } + id := tailcfg.UserID(len(s.users) + 1) + domain := "fake-control.example.net" + loginName := fmt.Sprintf("user-%d@%s", id, domain) + displayName := fmt.Sprintf("User %d", id) + login := &tailcfg.Login{ + ID: tailcfg.LoginID(id), + Provider: "testcontrol", + LoginName: loginName, + DisplayName: displayName, + ProfilePicURL: "https://tailscale.com/static/images/marketing/team-carney.jpg", + Domain: domain, + } + user := &tailcfg.User{ + ID: id, + LoginName: loginName, + DisplayName: displayName, + Domain: domain, + Logins: []tailcfg.LoginID{login.ID}, + } + s.users[nodeKey] = user + s.logins[nodeKey] = login + return user, login +} + +func (s *Server) serveRegister(w http.ResponseWriter, r *http.Request, mkey tailcfg.MachineKey) { + var req tailcfg.RegisterRequest + if err := s.decode(mkey, r.Body, &req); err != nil { + panic(fmt.Sprintf("serveRegister: decode: %v", err)) + } + if req.Version != 1 { + panic(fmt.Sprintf("serveRegister: unsupported version: %d", req.Version)) + } + if req.NodeKey.IsZero() { + panic("serveRegister: request has zero node key") + } + + user, login := s.getUser(req.NodeKey) + s.mu.Lock() + if s.nodes == nil { + s.nodes = map[tailcfg.NodeKey]*tailcfg.Node{} + } + s.nodes[req.NodeKey] = &tailcfg.Node{ + ID: tailcfg.NodeID(user.ID), + StableID: tailcfg.StableNodeID(fmt.Sprintf("TESTCTRL%08x", int(user.ID))), + User: user.ID, + Machine: mkey, + Key: req.NodeKey, + MachineAuthorized: true, + } + s.mu.Unlock() + + res, err := s.encode(mkey, false, tailcfg.RegisterResponse{ + User: *user, + Login: *login, + NodeKeyExpired: false, + MachineAuthorized: true, + AuthURL: "", // all good; TODO(bradfitz): add ways to not start all good. + }) + if err != nil { + go panic(fmt.Sprintf("serveRegister: encode: %v", err)) + } + w.WriteHeader(200) + w.Write(res) +} + +// updateType indicates why a long-polling map request is being woken +// up for an update. +type updateType int + +const ( + // updatePeerChanged is an update that a peer has changed. + updatePeerChanged updateType = iota + 1 + + // updateSelfChanged is an update that the node changed itself + // via a lite endpoint update. These ones are never dup-suppressed, + // as the client is expecting an answer regardless. + updateSelfChanged +) + +func (s *Server) updateLocked(source string, peers []tailcfg.NodeID) { + for _, peer := range peers { + sendUpdate(s.updates[peer], updatePeerChanged) + } +} + +// sendUpdate sends updateType to dst if dst is non-nil and +// has capacity. +func sendUpdate(dst chan<- updateType, updateType updateType) { + if dst == nil { + return + } + // The dst channel has a buffer size of 1. + // If we fail to insert an update into the buffer that + // means there is already an update pending. + select { + case dst <- updateType: + default: + } +} + +func (s *Server) serveMap(w http.ResponseWriter, r *http.Request, mkey tailcfg.MachineKey) { + ctx := r.Context() + + req := new(tailcfg.MapRequest) + if err := s.decode(mkey, r.Body, req); err != nil { + go panic(fmt.Sprintf("bad map request: %v", err)) + } + + jitter := time.Duration(rand.Intn(8000)) * time.Millisecond + keepAlive := 50*time.Second + jitter + + node := s.Node(req.NodeKey) + if node == nil { + http.Error(w, "node not found", 400) + return + } + if node.Machine != mkey { + http.Error(w, "node doesn't match machine key", 400) + return + } + + var peersToUpdate []tailcfg.NodeID + if !req.ReadOnly { + endpoints := filterInvalidIPv6Endpoints(req.Endpoints) + node.Endpoints = endpoints + // TODO: more + // TODO: register node, + //s.UpdateEndpoint(mkey, req.NodeKey, + // XXX + } + + nodeID := node.ID + + s.mu.Lock() + updatesCh := make(chan updateType, 1) + oldUpdatesCh := s.updates[nodeID] + if breakSameNodeMapResponseStreams(req) { + if oldUpdatesCh != nil { + close(oldUpdatesCh) + } + if s.updates == nil { + s.updates = map[tailcfg.NodeID]chan updateType{} + } + s.updates[nodeID] = updatesCh + } else { + sendUpdate(oldUpdatesCh, updateSelfChanged) + } + s.updateLocked("serveMap", peersToUpdate) + s.mu.Unlock() + + // ReadOnly implies no streaming, as it doesn't + // register an updatesCh to get updates. + streaming := req.Stream && !req.ReadOnly + compress := req.Compress != "" + + w.WriteHeader(200) + for { + res, err := s.MapResponse(req) + if err != nil { + // TODO: log + return + } + if res == nil { + return // done + } + // TODO: add minner if/when needed + resBytes, err := json.Marshal(res) + if err != nil { + s.logf("json.Marshal: %v", err) + return + } + if err := s.sendMapMsg(w, mkey, compress, resBytes); err != nil { + return + } + if !streaming { + return + } + keepAliveLoop: + for { + var keepAliveTimer *time.Timer + var keepAliveTimerCh <-chan time.Time + if keepAlive > 0 { + keepAliveTimer = time.NewTimer(keepAlive) + keepAliveTimerCh = keepAliveTimer.C + } + select { + case <-ctx.Done(): + if keepAliveTimer != nil { + keepAliveTimer.Stop() + } + return + case _, ok := <-updatesCh: + if !ok { + // replaced by new poll request + return + } + break keepAliveLoop + case <-keepAliveTimerCh: + if err := s.sendMapMsg(w, mkey, compress, keepAliveMsg); err != nil { + return + } + } + } + } +} + +var keepAliveMsg = &struct { + KeepAlive bool +}{ + KeepAlive: true, +} + +var prodDERPMap = derpmap.Prod() + +// MapResponse generates a MapResponse for a MapRequest. +// +// No updates to s are done here. +func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse, err error) { + node := s.Node(req.NodeKey) + if node == nil { + // node key rotated away (once test server supports that) + return nil, nil + } + derpMap := s.DERPMap + if derpMap == nil { + derpMap = prodDERPMap + } + user, _ := s.getUser(req.NodeKey) + res = &tailcfg.MapResponse{ + Node: node, + DERPMap: derpMap, + Domain: string(user.Domain), + CollectServices: "true", + PacketFilter: tailcfg.FilterAllowAll, + } + res.Node.Addresses = []netaddr.IPPrefix{ + netaddr.MustParseIPPrefix(fmt.Sprintf("100.64.%d.%d/32", uint8(node.ID>>8), uint8(node.ID))), + } + res.Node.AllowedIPs = res.Node.Addresses + return res, nil +} + +func (s *Server) sendMapMsg(w http.ResponseWriter, mkey tailcfg.MachineKey, compress bool, msg interface{}) error { + resBytes, err := s.encode(mkey, compress, msg) + if err != nil { + return err + } + if len(resBytes) > 16<<20 { + return fmt.Errorf("map message too big: %d", len(resBytes)) + } + var siz [4]byte + binary.LittleEndian.PutUint32(siz[:], uint32(len(resBytes))) + if _, err := w.Write(siz[:]); err != nil { + return err + } + if _, err := w.Write(resBytes); err != nil { + return err + } + if f, ok := w.(http.Flusher); ok { + f.Flush() + } else { + s.logf("[unexpected] ResponseWriter %T is not a Flusher", w) + } + return nil +} + +func (s *Server) decode(mkey tailcfg.MachineKey, r io.Reader, v interface{}) error { + if c, _ := r.(io.Closer); c != nil { + defer c.Close() + } + const msgLimit = 1 << 20 + msg, err := ioutil.ReadAll(io.LimitReader(r, msgLimit)) + if err != nil { + return err + } + if len(msg) == msgLimit { + return errors.New("encrypted message too long") + } + + var nonce [24]byte + if len(msg) < len(nonce)+1 { + return errors.New("missing nonce") + } + copy(nonce[:], msg) + msg = msg[len(nonce):] + + priv := s.privateKey() + pub, pri := (*[32]byte)(&mkey), (*[32]byte)(&priv) + decrypted, ok := box.Open(nil, msg, &nonce, pub, pri) + if !ok { + return errors.New("can't decrypt request") + } + return json.Unmarshal(decrypted, v) +} + +var zstdEncoderPool = &sync.Pool{ + New: func() interface{} { + encoder, err := smallzstd.NewEncoder(nil, zstd.WithEncoderLevel(zstd.SpeedFastest)) + if err != nil { + panic(err) + } + return encoder + }, +} + +func (s *Server) encode(mkey tailcfg.MachineKey, compress bool, v interface{}) (b []byte, err error) { + var isBytes bool + if b, isBytes = v.([]byte); !isBytes { + b, err = json.Marshal(v) + if err != nil { + return nil, err + } + } + if compress { + encoder := zstdEncoderPool.Get().(*zstd.Encoder) + b = encoder.EncodeAll(b, nil) + encoder.Close() + zstdEncoderPool.Put(encoder) + } + var nonce [24]byte + if _, err := io.ReadFull(crand.Reader, nonce[:]); err != nil { + panic(err) + } + priv := s.privateKey() + pub, pri := (*[32]byte)(&mkey), (*[32]byte)(&priv) + msgData := box.Seal(nonce[:], b, &nonce, pub, pri) + return msgData, nil +} + +// filterInvalidIPv6Endpoints removes invalid IPv6 endpoints from eps, +// modify the slice in place, returning the potentially smaller subset (aliasing +// the original memory). +// +// Two types of IPv6 endpoints are considered invalid: link-local +// addresses, and anything with a zone. +func filterInvalidIPv6Endpoints(eps []string) []string { + clean := eps[:0] + for _, ep := range eps { + if keepClientEndpoint(ep) { + clean = append(clean, ep) + } + } + return clean +} + +func keepClientEndpoint(ep string) bool { + ipp, err := netaddr.ParseIPPort(ep) + if err != nil { + // Shouldn't have made it this far if we unmarshalled + // the incoming JSON response. + return false + } + ip := ipp.IP + if ip.Zone() != "" { + return false + } + if ip.Is6() && ip.IsLinkLocalUnicast() { + // We let clients send these for now, but + // tailscaled doesn't know how to use them yet + // so we filter them out for now. A future + // MapRequest.Version might signal that + // clients know how to use them (e.g. try all + // local scopes). + return false + } + return true +} + +// breakSameNodeMapResponseStreams reports whether req should break a +// prior long-polling MapResponse stream (if active) from the same +// node ID. +func breakSameNodeMapResponseStreams(req *tailcfg.MapRequest) bool { + if req.ReadOnly { + // Don't register our updatesCh for closability + // nor close another peer's if we're a read-only request. + return false + } + if !req.Stream && req.OmitPeers { + // Likewise, if we're not streaming and not asking for peers, + // (but still mutable, without Readonly set), consider this an endpoint + // update request only, and don't close any existing map response + // for this nodeID. It's likely the same client with a built-up + // compression context. We want to let them update their + // new endpoints with us without breaking that other long-running + // map response. + return false + } + return true +}