diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 685fff4da..c57086201 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -4015,3 +4015,84 @@ func TestConn_receiveIP(t *testing.T) { }) } } + +func Test_lazyEndpoint_InitiationMessagePublicKey(t *testing.T) { + tests := []struct { + name string + callWithPeerMapKey bool + maybeEPMatchingKey bool + wantNoteRecvActivityCalled bool + }{ + { + name: "noteRecvActivity called", + callWithPeerMapKey: true, + maybeEPMatchingKey: false, + wantNoteRecvActivityCalled: true, + }, + { + name: "maybeEP early return", + callWithPeerMapKey: true, + maybeEPMatchingKey: true, + wantNoteRecvActivityCalled: false, + }, + { + name: "not in peerMap early return", + callWithPeerMapKey: false, + maybeEPMatchingKey: false, + wantNoteRecvActivityCalled: false, + }, + { + name: "not in peerMap maybeEP early return", + callWithPeerMapKey: false, + maybeEPMatchingKey: true, + wantNoteRecvActivityCalled: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ep := &endpoint{ + nodeID: 1, + publicKey: key.NewNode().Public(), + } + ep.disco.Store(&endpointDisco{ + key: key.NewDisco().Public(), + }) + + var noteRecvActivityCalledFor key.NodePublic + conn := newConn(t.Logf) + conn.noteRecvActivity = func(public key.NodePublic) { + // wireguard-go will call into ParseEndpoint if the "real" + // noteRecvActivity ends up JIT configuring the peer. Mimic that + // to ensure there are no deadlocks around conn.mu. + // See tailscale/tailscale#16651 & http://go/corp#30836 + _, err := conn.ParseEndpoint(ep.publicKey.UntypedHexString()) + if err != nil { + t.Fatalf("ParseEndpoint() err: %v", err) + } + noteRecvActivityCalledFor = public + } + ep.c = conn + + var pubKey [32]byte + if tt.callWithPeerMapKey { + copy(pubKey[:], ep.publicKey.AppendTo(nil)) + } + conn.peerMap.upsertEndpoint(ep, key.DiscoPublic{}) + + le := &lazyEndpoint{ + c: conn, + } + if tt.maybeEPMatchingKey { + le.maybeEP = ep + } + le.InitiationMessagePublicKey(pubKey) + want := key.NodePublic{} + if tt.wantNoteRecvActivityCalled { + want = ep.publicKey + } + if noteRecvActivityCalledFor.Compare(want) != 0 { + t.Fatalf("noteRecvActivityCalledFor = %v, want %v", noteRecvActivityCalledFor, want) + } + }) + } +}