diff --git a/tstest/integration/integration_test.go b/tstest/integration/integration_test.go index 770abd506..da7581070 100644 --- a/tstest/integration/integration_test.go +++ b/tstest/integration/integration_test.go @@ -52,6 +52,7 @@ import ( "tailscale.com/types/opt" "tailscale.com/types/ptr" "tailscale.com/util/dnsname" + "tailscale.com/util/mak" "tailscale.com/util/must" "tailscale.com/util/rands" "tailscale.com/version" @@ -521,6 +522,164 @@ func TestIncrementalMapUpdatePeersRemoved(t *testing.T) { d2.MustCleanShutdown(t) } +func TestCapMapPacketFilter(t *testing.T) { + tstest.Shard(t) + tstest.Parallel(t) + env := newTestEnv(t) + + n1 := newTestNode(t, env) + d1 := n1.StartDaemon() + n1.AwaitListening() + n1.MustUp() + n1.AwaitRunning() + + all := env.Control.AllNodes() + if len(all) != 1 { + t.Fatalf("expected 1 node, got %d nodes", len(all)) + } + tnode1 := all[0] + + n2 := newTestNode(t, env) + d2 := n2.StartDaemon() + n2.AwaitListening() + n2.MustUp() + n2.AwaitRunning() + + all = env.Control.AllNodes() + if len(all) != 2 { + t.Fatalf("expected 2 nodes, got %d nodes", len(all)) + } + var tnode2 *tailcfg.Node + for _, n := range all { + if n.ID != tnode1.ID { + tnode2 = n + } + } + if tnode2 == nil { + t.Fatalf("failed to find second node ID (two dups?)") + } + + t.Logf("node1=%v, node2=%v", tnode1.ID, tnode2.ID) + + n1.AwaitStatus(func(st *ipnstate.Status) error { + if len(st.Peer) != 1 { + return fmt.Errorf("got %d peers; want 1", len(st.Peer)) + } + peer := st.Peer[st.Peers()[0]] + if peer.ID == st.Self.ID { + return errors.New("peer is self") + } + return nil + }) + + // Check that n2 is reachable from n1 with the default packet filter. + if err := n1.PingICMP(n2); err != nil { + t.Fatalf("ping: %v", err) + } + + t.Logf("setting packet filter with a cap") + if !env.Control.AddRawMapResponse(tnode2.Key, &tailcfg.MapResponse{ + PacketFilter: []tailcfg.FilterRule{ + { + SrcIPs: []string{"cap:foobar"}, + DstPorts: []tailcfg.NetPortRange{{ + IP: "*", + Ports: tailcfg.PortRange{First: 0, Last: 65535}, + }}, + }, + }, + }) { + t.Fatalf("failed to add map response") + } + + // Wait until n2 is no longer reachable from n1. + if err := tstest.WaitFor(5*time.Second, func() error { + if err := n1.PingICMP(n2); err == nil { + return errors.New("ping successful, wanted an error") + } + return nil + }); err != nil { + t.Fatal(err) + } + + t.Logf("setting cap on node1") + peer := env.Control.PeerForNode(tnode2.Key, tnode1.Key) + mak.Set(&peer.CapMap, tailcfg.NodeCapability("foobar"), []tailcfg.RawMessage{}) + if !env.Control.AddRawMapResponse(tnode2.Key, &tailcfg.MapResponse{ + PeersChanged: []*tailcfg.Node{peer}, + }) { + t.Fatalf("failed to add map response") + } + n2.AwaitPeerHasCap(tnode1.Key, tailcfg.NodeCapability("foobar")) + + t.Logf("confirming that n1 can ping n2") + if err := n1.PingICMP(n2); err != nil { + t.Fatalf("ping error %s", err) + } + + t.Logf("removing cap from node1") + peer = env.Control.PeerForNode(tnode2.Key, tnode1.Key) + if _, ok := peer.CapMap[tailcfg.NodeCapability("foobar")]; ok { + t.Fatal("unexpected cap") + } + if !env.Control.AddRawMapResponse(tnode2.Key, &tailcfg.MapResponse{ + PeersChanged: []*tailcfg.Node{peer}, + }) { + t.Fatalf("failed to add map response") + } + n2.AwaitPeerHasNoCap(tnode1.Key, tailcfg.NodeCapability("foobar")) + + t.Logf("confirming that n1 cannot ping n2") + if err := n1.PingICMP(n2); err == nil { + t.Fatal("ping successful, wanted an error") + } + + t.Logf("adding third node") + n3 := newTestNode(t, env) + d3 := n3.StartDaemon() + n3.AwaitListening() + n3.MustUp() + n3.AwaitRunning() + + all = env.Control.AllNodes() + if len(all) != 3 { + t.Fatalf("expected 3 nodes, got %d nodes", len(all)) + } + var tnode3 *tailcfg.Node + for _, n := range all { + if n.ID != tnode1.ID && n.ID != tnode2.ID { + tnode3 = n + } + } + if tnode3 == nil { + t.Fatalf("failed to find third node ID") + } + + t.Logf("confirming that n3 cannot ping n2") + if err := n3.PingICMP(n2); err == nil { + t.Fatal("ping successful, wanted an error") + } + + t.Logf("setting cap on the node3") + peer = env.Control.PeerForNode(tnode2.Key, tnode3.Key) + mak.Set(&peer.CapMap, tailcfg.NodeCapability("foobar"), []tailcfg.RawMessage{}) + if !env.Control.AddRawMapResponse(tnode2.Key, &tailcfg.MapResponse{ + PeersChanged: []*tailcfg.Node{peer}, + }) { + t.Fatalf("failed to add map response") + } + n2.AwaitPeerHasCap(tnode3.Key, tailcfg.NodeCapability("foobar")) + + t.Logf("confirming that n3 can ping n2") + if err := n3.PingICMP(n2); err != nil { + t.Fatalf("ping error %s", err) + } + + d1.MustCleanShutdown(t) + d2.MustCleanShutdown(t) + d3.MustCleanShutdown(t) +} + func TestNodeAddressIPFields(t *testing.T) { tstest.Shard(t) flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/7008") @@ -1876,6 +2035,14 @@ func (n *testNode) Ping(otherNode *testNode) error { return n.Tailscale("ping", ip).Run() } +func (n *testNode) PingICMP(otherNode *testNode) error { + t := n.env.t + t.Helper() + ip := otherNode.AwaitIP4().String() + t.Logf("Running ping --icmp %v (from %v)...", ip, n.AwaitIP4()) + return n.Tailscale("ping", "--timeout", "1s", "--icmp", ip).Run() +} + // AwaitListening waits for the tailscaled to be serving local clients // over its localhost IPC mechanism. (Unix socket, etc) func (n *testNode) AwaitListening() { @@ -1945,25 +2112,52 @@ func (n *testNode) AwaitRunning() { n.AwaitBackendState("Running") } -func (n *testNode) AwaitBackendState(state string) { - t := n.env.t - t.Helper() - if err := tstest.WaitFor(20*time.Second, func() error { - st, err := n.Status() - if err != nil { - return err +// AwaitPeerHasCap waits until peer has a cap. +func (n *testNode) AwaitPeerHasCap(peerKey key.NodePublic, cap tailcfg.NodeCapability) { + n.awaitPeerCapPresence(peerKey, cap, true) +} + +// AwaitPeerHasNoCap waits until peer does not have a cap. +func (n *testNode) AwaitPeerHasNoCap(peerKey key.NodePublic, cap tailcfg.NodeCapability) { + n.awaitPeerCapPresence(peerKey, cap, false) +} + +func (n *testNode) awaitPeerCapPresence(peerKey key.NodePublic, cap tailcfg.NodeCapability, wantCap bool) { + n.AwaitStatus(func(st *ipnstate.Status) error { + for pk, peer := range st.Peer { + if pk != peerKey { + continue + } + if _, ok := peer.CapMap[cap]; ok != wantCap { + return fmt.Errorf("peer cap=%v want=%v", ok, wantCap) + } + return nil } + return fmt.Errorf("peer not found") + }) +} + +func (n *testNode) AwaitBackendState(state string) { + n.AwaitStatus(func(st *ipnstate.Status) error { if st.BackendState != state { return fmt.Errorf("in state %q; want %q", st.BackendState, state) } return nil - }); err != nil { - t.Fatalf("failure/timeout waiting for transition to Running status: %v", err) - } + }) } // AwaitNeedsLogin waits for n to reach the IPN state "NeedsLogin". func (n *testNode) AwaitNeedsLogin() { + n.AwaitStatus(func(st *ipnstate.Status) error { + if st.BackendState != "NeedsLogin" { + return fmt.Errorf("in state %q", st.BackendState) + } + return nil + }) +} + +// AwaitStatus waits until the ready function returns no error. +func (n *testNode) AwaitStatus(ready func(*ipnstate.Status) error) { t := n.env.t t.Helper() if err := tstest.WaitFor(20*time.Second, func() error { @@ -1971,12 +2165,9 @@ func (n *testNode) AwaitNeedsLogin() { if err != nil { return err } - if st.BackendState != "NeedsLogin" { - return fmt.Errorf("in state %q", st.BackendState) - } - return nil + return ready(st) }); err != nil { - t.Fatalf("failure/timeout waiting for transition to NeedsLogin status: %v", err) + t.Fatalf("failure/timeout waiting for callback function to return no error: %v", err) } } diff --git a/tstest/integration/testcontrol/testcontrol.go b/tstest/integration/testcontrol/testcontrol.go index e127087a6..559e083e2 100644 --- a/tstest/integration/testcontrol/testcontrol.go +++ b/tstest/integration/testcontrol/testcontrol.go @@ -931,6 +931,47 @@ func packetFilterWithIngressCaps() []tailcfg.FilterRule { return out } +// PeerForNode returns a Node object for a specific peer of a given node. +func (s *Server) PeerForNode(nodeKey, peerKey key.NodePublic) *tailcfg.Node { + s.mu.Lock() + defer s.mu.Unlock() + node := s.nodes[nodeKey].Clone() + peer := s.nodes[peerKey].Clone() + s.fillPeerDetailsLocked(node, peer) + return peer +} + +func (s *Server) fillPeerDetailsLocked(node, p *tailcfg.Node) { + nodeMasqs := s.masquerades[node.Key] + if masqIP := nodeMasqs[p.Key]; masqIP.IsValid() { + if masqIP.Is6() { + p.SelfNodeV6MasqAddrForThisPeer = ptr.To(masqIP) + } else { + p.SelfNodeV4MasqAddrForThisPeer = ptr.To(masqIP) + } + } + + jailed := maps.Clone(s.peerIsJailed[node.Key]) + p.IsJailed = jailed[p.Key] + + peerAddress := s.masquerades[p.Key][node.Key] + if peerAddress.IsValid() { + if peerAddress.Is6() { + p.Addresses[1] = netip.PrefixFrom(peerAddress, peerAddress.BitLen()) + p.AllowedIPs[1] = netip.PrefixFrom(peerAddress, peerAddress.BitLen()) + } else { + p.Addresses[0] = netip.PrefixFrom(peerAddress, peerAddress.BitLen()) + p.AllowedIPs[0] = netip.PrefixFrom(peerAddress, peerAddress.BitLen()) + } + } + + routes := s.nodeSubnetRoutes[p.Key] + if len(routes) > 0 { + p.PrimaryRoutes = routes + p.AllowedIPs = append(p.AllowedIPs, routes...) + } +} + // MapResponse generates a MapResponse for a MapRequest. // // No updates to s are done here. @@ -969,40 +1010,13 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse, ControlTime: &t, } - s.mu.Lock() - nodeMasqs := s.masquerades[node.Key] - jailed := maps.Clone(s.peerIsJailed[node.Key]) - s.mu.Unlock() for _, p := range s.AllNodes() { if p.StableID == node.StableID { continue } - if masqIP := nodeMasqs[p.Key]; masqIP.IsValid() { - if masqIP.Is6() { - p.SelfNodeV6MasqAddrForThisPeer = ptr.To(masqIP) - } else { - p.SelfNodeV4MasqAddrForThisPeer = ptr.To(masqIP) - } - } - p.IsJailed = jailed[p.Key] - s.mu.Lock() - peerAddress := s.masquerades[p.Key][node.Key] - routes := s.nodeSubnetRoutes[p.Key] + s.fillPeerDetailsLocked(node, p) s.mu.Unlock() - if peerAddress.IsValid() { - if peerAddress.Is6() { - p.Addresses[1] = netip.PrefixFrom(peerAddress, peerAddress.BitLen()) - p.AllowedIPs[1] = netip.PrefixFrom(peerAddress, peerAddress.BitLen()) - } else { - p.Addresses[0] = netip.PrefixFrom(peerAddress, peerAddress.BitLen()) - p.AllowedIPs[0] = netip.PrefixFrom(peerAddress, peerAddress.BitLen()) - } - } - if len(routes) > 0 { - p.PrimaryRoutes = routes - p.AllowedIPs = append(p.AllowedIPs, routes...) - } res.Peers = append(res.Peers, p) }