diff --git a/wgengine/userspace.go b/wgengine/userspace.go index a369fa343..9bc0c6dc8 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -687,7 +687,7 @@ func (e *userspaceEngine) noteRecvActivity(nk key.NodePublic) { // couple minutes (just not on every packet). if e.trimmedNodes[nk] { e.logf("wgengine: idle peer %v now active, reconfiguring WireGuard", nk.ShortString()) - e.maybeReconfigWireguardLocked(nil) + e.maybeReconfigWireguardLocked(false) } } @@ -735,7 +735,7 @@ func (i *maybeReconfigInputs) Clone() *maybeReconfigInputs { // If discoChanged is nil or empty, this extra removal step isn't done. // // e.wgLock must be held. -func (e *userspaceEngine) maybeReconfigWireguardLocked(discoChanged map[key.NodePublic]bool) error { +func (e *userspaceEngine) maybeReconfigWireguardLocked(forceReconfig bool) error { if hook := e.testMaybeReconfigHook; hook != nil { hook() return nil @@ -779,15 +779,11 @@ func (e *userspaceEngine) maybeReconfigWireguardLocked(discoChanged map[key.Node e.trimmedNodes = make(map[key.NodePublic]bool) } - needRemoveStep := false for i := range full.Peers { p := &full.Peers[i] nk := p.PublicKey if !buildfeatures.HasLazyWG || !e.isTrimmablePeer(p, len(full.Peers)) { min.Peers = append(min.Peers, *p) - if discoChanged[nk] { - needRemoveStep = true - } continue } trackNodes = append(trackNodes, nk) @@ -798,9 +794,6 @@ func (e *userspaceEngine) maybeReconfigWireguardLocked(discoChanged map[key.Node } if recentlyActive { min.Peers = append(min.Peers, *p) - if discoChanged[nk] { - needRemoveStep = true - } } else { e.trimmedNodes[nk] = true } @@ -812,7 +805,7 @@ func (e *userspaceEngine) maybeReconfigWireguardLocked(discoChanged map[key.Node TrimmedNodes: e.trimmedNodes, TrackNodes: views.SliceOf(trackNodes), TrackIPs: views.SliceOf(trackIPs), - }); !changed { + }); !changed && !forceReconfig { return nil } @@ -820,26 +813,6 @@ func (e *userspaceEngine) maybeReconfigWireguardLocked(discoChanged map[key.Node e.updateActivityMapsLocked(trackNodes, trackIPs) } - if needRemoveStep { - minner := min - minner.Peers = nil - numRemove := 0 - for _, p := range min.Peers { - if discoChanged[p.PublicKey] { - numRemove++ - continue - } - minner.Peers = append(minner.Peers, p) - } - if numRemove > 0 { - e.logf("wgengine: Reconfig: removing session keys for %d peers", numRemove) - if err := wgcfg.ReconfigDevice(e.wgdev, &minner, e.logf); err != nil { - e.logf("wgdev.Reconfig: %v", err) - return err - } - } - } - e.logf("wgengine: Reconfig: configuring userspace WireGuard config (with %d/%d peers)", len(min.Peers), len(full.Peers)) if err := wgcfg.ReconfigDevice(e.wgdev, &min, e.logf); err != nil { e.logf("wgdev.Reconfig: %v", err) @@ -896,7 +869,7 @@ func (e *userspaceEngine) updateActivityMapsLocked(trackNodes []key.NodePublic, if elapsed >= packetSendRecheckWireguardThreshold { e.wgLock.Lock() defer e.wgLock.Unlock() - e.maybeReconfigWireguardLocked(nil) + e.maybeReconfigWireguardLocked(false) } } } @@ -1029,10 +1002,8 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, } // See if any peers have changed disco keys, which means they've restarted. - // If so, we need to update the wireguard-go/device.Device in two phases: - // once without the node which has restarted, to clear its wireguard session key, - // and a second time with it. - discoChanged := make(map[key.NodePublic]bool) + // If we see that, we clear our wireguard-go session state for that peer. + forceReconfig := false { prevEP := make(map[key.NodePublic]key.DiscoPublic) for i := range e.lastCfgFull.Peers { @@ -1047,7 +1018,8 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, } pub := p.PublicKey if old, ok := prevEP[pub]; ok && old != p.DiscoKey { - discoChanged[pub] = true + e.wgdev.RemovePeer(pub.Raw32()) + forceReconfig = true // to make sure we add it back e.logf("wgengine: Reconfig: %s changed from %q to %q", pub.ShortString(), old, p.DiscoKey) } } @@ -1066,7 +1038,7 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, e.magicConn.SetPreferredPort(listenPort) e.magicConn.UpdatePMTUD() - if err := e.maybeReconfigWireguardLocked(discoChanged); err != nil { + if err := e.maybeReconfigWireguardLocked(forceReconfig); err != nil { return err }