diff --git a/tstest/integration/integration_test.go b/tstest/integration/integration_test.go index 442633856..d2f3e6189 100644 --- a/tstest/integration/integration_test.go +++ b/tstest/integration/integration_test.go @@ -224,6 +224,49 @@ func TestNodeAddressIPFields(t *testing.T) { d1.MustCleanShutdown(t) } +func TestAddPingRequest(t *testing.T) { + t.Parallel() + bins := BuildTestBinaries(t) + + env := newTestEnv(t, bins) + defer env.Close() + + n1 := newTestNode(t, env) + d1 := n1.StartDaemon(t) + defer d1.Kill() + + n1.AwaitListening(t) + n1.MustUp() + n1.AwaitRunning(t) + + gotPing := make(chan bool, 1) + waitPing := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPing <- true + })) + defer waitPing.Close() + + nodes := env.Control.AllNodes() + if len(nodes) != 1 { + t.Fatalf("expected 1 node, got %d nodes", len(nodes)) + } + + nodeKey := nodes[0].Key + pr := &tailcfg.PingRequest{URL: waitPing.URL, Log: true} + ok := env.Control.AddPingRequest(nodeKey, pr) + if !ok { + t.Fatalf("no node found with NodeKey %v in AddPingRequest", nodeKey) + } + + // Wait for PingRequest to come back + pingTimeout := time.NewTimer(10 * time.Second) + select { + case <-gotPing: + pingTimeout.Stop() + case <-pingTimeout.C: + t.Error("didn't get PingRequest from tailscaled") + } +} + // testEnv contains the test environment (set of servers) used by one // or more nodes. type testEnv struct { diff --git a/tstest/integration/testcontrol/testcontrol.go b/tstest/integration/testcontrol/testcontrol.go index a7ab87d8d..2e8ce82b9 100644 --- a/tstest/integration/testcontrol/testcontrol.go +++ b/tstest/integration/testcontrol/testcontrol.go @@ -54,6 +54,7 @@ type Server struct { updates map[tailcfg.NodeID]chan updateType authPath map[string]*AuthPath nodeKeyAuthed map[tailcfg.NodeKey]bool // key => true once authenticated + pingReqsToAdd map[tailcfg.NodeKey]*tailcfg.PingRequest } // NumNodes returns the number of nodes in the testcontrol server. @@ -67,6 +68,27 @@ func (s *Server) NumNodes() int { return len(s.nodes) } +// AddPingRequest sends the ping pr to nodeKeyDst. It reports whether it did so. That is, +// it reports whether nodeKeyDst was connected. +func (s *Server) AddPingRequest(nodeKeyDst tailcfg.NodeKey, pr *tailcfg.PingRequest) bool { + s.mu.Lock() + defer s.mu.Unlock() + if s.pingReqsToAdd == nil { + s.pingReqsToAdd = map[tailcfg.NodeKey]*tailcfg.PingRequest{} + } + // Now send the update to the channel + node := s.nodeLocked(nodeKeyDst) + if node == nil { + return false + } + + s.pingReqsToAdd[nodeKeyDst] = pr + nodeID := node.ID + oldUpdatesCh := s.updates[nodeID] + sendUpdate(oldUpdatesCh, updateDebugInjection) + return true +} + type AuthPath struct { nodeKey tailcfg.NodeKey @@ -380,6 +402,9 @@ func (s *Server) serveRegister(w http.ResponseWriter, r *http.Request, mkey tail // via a lite endpoint update. These ones are never dup-suppressed, // as the client is expecting an answer regardless. updateSelfChanged + + // updateDebugInjection is an update used for PingRequests + updateDebugInjection ) func (s *Server) updateLocked(source string, peers []tailcfg.NodeID) { @@ -561,6 +586,14 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse, netaddr.MustParseIPPrefix(fmt.Sprintf("100.64.%d.%d/32", uint8(node.ID>>8), uint8(node.ID))), } res.Node.AllowedIPs = res.Node.Addresses + + // Consume the PingRequest while protected by mutex if it exists + s.mu.Lock() + if pr, ok := s.pingReqsToAdd[node.Key]; ok { + res.PingRequest = pr + delete(s.pingReqsToAdd, node.Key) + } + s.mu.Unlock() return res, nil }