diff --git a/wgengine/router/ifconfig_windows.go b/wgengine/router/ifconfig_windows.go index 525173970..4c8e2e0b6 100644 --- a/wgengine/router/ifconfig_windows.go +++ b/wgengine/router/ifconfig_windows.go @@ -156,10 +156,17 @@ func getDefaultRouteMTU() (uint32, error) { return mtu, nil } -func setFirewall(ifcGUID *windows.GUID) (bool, error) { - c := ole.Connection{} - err := c.Initialize() - if err != nil { +// setPrivateNetwork marks the provided network adapter's category to private. +// It returns (false, nil) if the adapter was not found. +func setPrivateNetwork(ifcGUID *windows.GUID) (bool, error) { + // NLM_NETWORK_CATEGORY values. + const ( + categoryPublic = 0 + categoryPrivate = 1 + categoryDomain = 2 + ) + var c ole.Connection + if err := c.Initialize(); err != nil { return false, fmt.Errorf("c.Initialize: %v", err) } defer c.Uninitialize() @@ -182,10 +189,8 @@ func setFirewall(ifcGUID *windows.GUID) (bool, error) { return false, fmt.Errorf("nco.GetAdapterId: %v", err) } if aid != ifcGUID.String() { - log.Printf("skipping adapter id: %v", aid) continue } - log.Printf("found! adapter id: %v", aid) n, err := nco.GetNetwork() if err != nil { @@ -198,15 +203,11 @@ func setFirewall(ifcGUID *windows.GUID) (bool, error) { return false, fmt.Errorf("GetCategory: %v", err) } - if cat == 0 { - err = n.SetCategory(1) - if err != nil { + if cat != categoryPrivate { + if err := n.SetCategory(categoryPrivate); err != nil { return false, fmt.Errorf("SetCategory: %v", err) } - } else { - log.Printf("setFirewall: already category %v", cat) } - return true, nil } @@ -225,17 +226,20 @@ func configureInterface(cfg *Config, tun *tun.NativeTun) error { // It takes a weirdly long time for Windows to notice the // new interface has come up. Poll periodically until it // does. - for i := 0; i < 20; i++ { - found, err := setFirewall(&guid) + const tries = 20 + for i := 0; i < tries; i++ { + found, err := setPrivateNetwork(&guid) if err != nil { - log.Printf("setFirewall: %v", err) - // fall through anyway, this isn't fatal. - } - if found { - break + log.Printf("setPrivateNetwork(try=%d): %v", i, err) + } else { + if found { + return + } + log.Printf("setPrivateNetwork(try=%d): not found", i) } time.Sleep(1 * time.Second) } + log.Printf("setPrivateNetwork: adapter %v not found after %d tries, giving up", guid, tries) }() routes := []winipcfg.RouteData{} diff --git a/wgengine/router/router_windows.go b/wgengine/router/router_windows.go index 1f36e5c7c..5e038fde4 100644 --- a/wgengine/router/router_windows.go +++ b/wgengine/router/router_windows.go @@ -8,6 +8,7 @@ "fmt" "log" "os/exec" + "sync" "syscall" winipcfg "github.com/tailscale/winipcfg-go" @@ -24,6 +25,9 @@ type winRouter struct { wgdev *device.Device routeChangeCallback *winipcfg.RouteChangeCallback dns *dns.Manager + + mu sync.Mutex + firewallRuleIP string // the IP rule exists for, or "" if rule doesn't exist } func newUserspaceRouter(logf logger.Logf, wgdev *device.Device, tundev tun.Device) (Router, error) { @@ -65,15 +69,36 @@ func (r *winRouter) Up() error { // // So callers should ignore its error for now. func (r *winRouter) removeFirewallAcceptRule() error { + r.mu.Lock() + defer r.mu.Unlock() + + r.firewallRuleIP = "" + cmd := exec.Command("netsh", "advfirewall", "firewall", "delete", "rule", "name=Tailscale-In", "dir=in") cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} return cmd.Run() } -func (r *winRouter) addFirewallAcceptRule(ipStr string) error { +// addFirewallAcceptRule adds a firewall rule to allow all incoming +// traffic to the given IP (the Tailscale adapter's IP) for network +// adapters in category private. (as previously set by +// setPrivateNetwork) +// +// It returns (false, nil) if the firewall rule was already previously existed with this IP. +func (r *winRouter) addFirewallAcceptRule(ipStr string) (added bool, err error) { + r.mu.Lock() + defer r.mu.Unlock() + if ipStr == r.firewallRuleIP { + return false, nil + } cmd := exec.Command("netsh", "advfirewall", "firewall", "add", "rule", "name=Tailscale-In", "dir=in", "action=allow", "localip="+ipStr, "profile=private", "enable=yes") cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} - return cmd.Run() + err = cmd.Run() + if err != nil { + return false, err + } + r.firewallRuleIP = ipStr + return true, nil } func (r *winRouter) Set(cfg *Config) error { @@ -81,14 +106,15 @@ func (r *winRouter) Set(cfg *Config) error { cfg = &shutdownConfig } - r.removeFirewallAcceptRule() if len(cfg.LocalAddrs) == 1 && cfg.LocalAddrs[0].Bits == 32 { ipStr := cfg.LocalAddrs[0].IP.String() - if err := r.addFirewallAcceptRule(ipStr); err != nil { + if ok, err := r.addFirewallAcceptRule(ipStr); err != nil { r.logf("addFirewallRule(%q): %v", ipStr, err) - } else { + } else if ok { r.logf("added firewall rule Tailscale-In for %v", ipStr) } + } else { + r.removeFirewallAcceptRule() } err := configureInterface(cfg, r.nativeTun)