wgengine/router: run netsh advfirewall less, rename, document setFirewall

This commit is contained in:
Brad Fitzpatrick 2020-09-16 14:41:28 -07:00
parent 48fbe93e72
commit acafe9811f
2 changed files with 54 additions and 24 deletions

View File

@ -156,10 +156,17 @@ func getDefaultRouteMTU() (uint32, error) {
return mtu, nil return mtu, nil
} }
func setFirewall(ifcGUID *windows.GUID) (bool, error) { // setPrivateNetwork marks the provided network adapter's category to private.
c := ole.Connection{} // It returns (false, nil) if the adapter was not found.
err := c.Initialize() func setPrivateNetwork(ifcGUID *windows.GUID) (bool, error) {
if err != nil { // NLM_NETWORK_CATEGORY values.
const (
categoryPublic = 0
categoryPrivate = 1
categoryDomain = 2
)
var c ole.Connection
if err := c.Initialize(); err != nil {
return false, fmt.Errorf("c.Initialize: %v", err) return false, fmt.Errorf("c.Initialize: %v", err)
} }
defer c.Uninitialize() defer c.Uninitialize()
@ -182,10 +189,8 @@ func setFirewall(ifcGUID *windows.GUID) (bool, error) {
return false, fmt.Errorf("nco.GetAdapterId: %v", err) return false, fmt.Errorf("nco.GetAdapterId: %v", err)
} }
if aid != ifcGUID.String() { if aid != ifcGUID.String() {
log.Printf("skipping adapter id: %v", aid)
continue continue
} }
log.Printf("found! adapter id: %v", aid)
n, err := nco.GetNetwork() n, err := nco.GetNetwork()
if err != nil { if err != nil {
@ -198,15 +203,11 @@ func setFirewall(ifcGUID *windows.GUID) (bool, error) {
return false, fmt.Errorf("GetCategory: %v", err) return false, fmt.Errorf("GetCategory: %v", err)
} }
if cat == 0 { if cat != categoryPrivate {
err = n.SetCategory(1) if err := n.SetCategory(categoryPrivate); err != nil {
if err != nil {
return false, fmt.Errorf("SetCategory: %v", err) return false, fmt.Errorf("SetCategory: %v", err)
} }
} else {
log.Printf("setFirewall: already category %v", cat)
} }
return true, nil return true, nil
} }
@ -225,17 +226,20 @@ func configureInterface(cfg *Config, tun *tun.NativeTun) error {
// It takes a weirdly long time for Windows to notice the // It takes a weirdly long time for Windows to notice the
// new interface has come up. Poll periodically until it // new interface has come up. Poll periodically until it
// does. // does.
for i := 0; i < 20; i++ { const tries = 20
found, err := setFirewall(&guid) for i := 0; i < tries; i++ {
found, err := setPrivateNetwork(&guid)
if err != nil { if err != nil {
log.Printf("setFirewall: %v", err) log.Printf("setPrivateNetwork(try=%d): %v", i, err)
// fall through anyway, this isn't fatal. } else {
}
if found { if found {
break return
}
log.Printf("setPrivateNetwork(try=%d): not found", i)
} }
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
} }
log.Printf("setPrivateNetwork: adapter %v not found after %d tries, giving up", guid, tries)
}() }()
routes := []winipcfg.RouteData{} routes := []winipcfg.RouteData{}

View File

@ -8,6 +8,7 @@ import (
"fmt" "fmt"
"log" "log"
"os/exec" "os/exec"
"sync"
"syscall" "syscall"
winipcfg "github.com/tailscale/winipcfg-go" winipcfg "github.com/tailscale/winipcfg-go"
@ -24,6 +25,9 @@ type winRouter struct {
wgdev *device.Device wgdev *device.Device
routeChangeCallback *winipcfg.RouteChangeCallback routeChangeCallback *winipcfg.RouteChangeCallback
dns *dns.Manager dns *dns.Manager
mu sync.Mutex
firewallRuleIP string // the IP rule exists for, or "" if rule doesn't exist
} }
func newUserspaceRouter(logf logger.Logf, wgdev *device.Device, tundev tun.Device) (Router, error) { func newUserspaceRouter(logf logger.Logf, wgdev *device.Device, tundev tun.Device) (Router, error) {
@ -65,15 +69,36 @@ func (r *winRouter) Up() error {
// //
// So callers should ignore its error for now. // So callers should ignore its error for now.
func (r *winRouter) removeFirewallAcceptRule() error { func (r *winRouter) removeFirewallAcceptRule() error {
r.mu.Lock()
defer r.mu.Unlock()
r.firewallRuleIP = ""
cmd := exec.Command("netsh", "advfirewall", "firewall", "delete", "rule", "name=Tailscale-In", "dir=in") cmd := exec.Command("netsh", "advfirewall", "firewall", "delete", "rule", "name=Tailscale-In", "dir=in")
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
return cmd.Run() return cmd.Run()
} }
func (r *winRouter) addFirewallAcceptRule(ipStr string) error { // addFirewallAcceptRule adds a firewall rule to allow all incoming
// traffic to the given IP (the Tailscale adapter's IP) for network
// adapters in category private. (as previously set by
// setPrivateNetwork)
//
// It returns (false, nil) if the firewall rule was already previously existed with this IP.
func (r *winRouter) addFirewallAcceptRule(ipStr string) (added bool, err error) {
r.mu.Lock()
defer r.mu.Unlock()
if ipStr == r.firewallRuleIP {
return false, nil
}
cmd := exec.Command("netsh", "advfirewall", "firewall", "add", "rule", "name=Tailscale-In", "dir=in", "action=allow", "localip="+ipStr, "profile=private", "enable=yes") cmd := exec.Command("netsh", "advfirewall", "firewall", "add", "rule", "name=Tailscale-In", "dir=in", "action=allow", "localip="+ipStr, "profile=private", "enable=yes")
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
return cmd.Run() err = cmd.Run()
if err != nil {
return false, err
}
r.firewallRuleIP = ipStr
return true, nil
} }
func (r *winRouter) Set(cfg *Config) error { func (r *winRouter) Set(cfg *Config) error {
@ -81,14 +106,15 @@ func (r *winRouter) Set(cfg *Config) error {
cfg = &shutdownConfig cfg = &shutdownConfig
} }
r.removeFirewallAcceptRule()
if len(cfg.LocalAddrs) == 1 && cfg.LocalAddrs[0].Bits == 32 { if len(cfg.LocalAddrs) == 1 && cfg.LocalAddrs[0].Bits == 32 {
ipStr := cfg.LocalAddrs[0].IP.String() ipStr := cfg.LocalAddrs[0].IP.String()
if err := r.addFirewallAcceptRule(ipStr); err != nil { if ok, err := r.addFirewallAcceptRule(ipStr); err != nil {
r.logf("addFirewallRule(%q): %v", ipStr, err) r.logf("addFirewallRule(%q): %v", ipStr, err)
} else { } else if ok {
r.logf("added firewall rule Tailscale-In for %v", ipStr) r.logf("added firewall rule Tailscale-In for %v", ipStr)
} }
} else {
r.removeFirewallAcceptRule()
} }
err := configureInterface(cfg, r.nativeTun) err := configureInterface(cfg, r.nativeTun)