diff --git a/cmd/lopower/lopower.go b/cmd/lopower/lopower.go index 6263e001c..17515f3fa 100644 --- a/cmd/lopower/lopower.go +++ b/cmd/lopower/lopower.go @@ -16,6 +16,7 @@ import ( "os/signal" "path/filepath" "slices" + "sync" "github.com/tailscale/wireguard-go/conn" "github.com/tailscale/wireguard-go/device" @@ -41,6 +42,7 @@ import ( var ( wgListenPort = flag.Int("wg-port", 51820, "port number to listen on for WireGuard from the client") qrListenAddr = flag.String("qr-listen", "127.0.0.1:8014", "HTTP address to serve a QR code for client's WireGuard configuration") + confDir = flag.String("dir", filepath.Join(os.Getenv("HOME"), ".config/lopower"), "directory to store configuration in") ) type config struct { @@ -62,8 +64,8 @@ type Peer struct { V6 netip.Addr } -func storeConfig(cfg *config) { - path := filepath.Join(os.Getenv("HOME"), ".config/lopower/config.json") +func (lp *lpServer) storeConfigLocked() { + path := filepath.Join(lp.dir, "config.json") if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil { log.Fatalf("os.MkdirAll(%q): %v", filepath.Dir(path), err) } @@ -72,19 +74,23 @@ func storeConfig(cfg *config) { log.Fatalf("os.OpenFile(%q): %v", path, err) } defer f.Close() - must.Do(json.NewEncoder(f).Encode(cfg)) + must.Do(json.NewEncoder(f).Encode(lp.c)) if err := f.Close(); err != nil { log.Fatalf("f.Close: %v", err) } } -func loadConfig() *config { - path := filepath.Join(os.Getenv("HOME"), ".config/lopower/config.json") +func (lp *lpServer) loadConfig() { + path := filepath.Join(lp.dir, "config.json") f, err := os.OpenFile(path, os.O_RDONLY, 0) if err == nil { + defer f.Close() var cfg *config must.Do(json.NewDecoder(f).Decode(&cfg)) - return cfg + lp.mu.Lock() + defer lp.mu.Unlock() + lp.c = cfg + return } if !os.IsNotExist(err) { log.Fatalf("os.OpenFile(%q): %v", path, err) @@ -98,11 +104,15 @@ func loadConfig() *config { } c.V4 = c.V4CIDR.Addr().Next() c.V6 = c.V6CIDR.Addr().Next() - storeConfig(c) - return c + lp.mu.Lock() + defer lp.mu.Unlock() + lp.c = c + lp.storeConfigLocked() + return } func (lp *lpServer) reconfig() { + lp.mu.Lock() wc := &wgcfg.Config{ Name: "lopower0", PrivateKey: lp.c.PrivKey, @@ -121,14 +131,44 @@ func (lp *lpServer) reconfig() { }, }) } + lp.mu.Unlock() must.Do(wgcfg.ReconfigDevice(lp.d, wc, log.Printf)) } +func newLP(ctx context.Context) *lpServer { + logf := log.Printf + deviceLogger := &device.Logger{ + Verbosef: logger.Discard, + Errorf: logf, + } + lp := &lpServer{ + dir: *confDir, + } + lp.loadConfig() + lp.initNetstack(ctx) + nst := &nsTUN{ + lp: lp, + closeCh: make(chan struct{}), + evChan: make(chan tun.Event), + } + wgdev := wgcfg.NewDevice(nst, conn.NewDefaultBind(), deviceLogger) + defer wgdev.Close() + lp.d = wgdev + must.Do(wgdev.Up()) + lp.reconfig() + lp.startTSNet(ctx) + return lp +} + type lpServer struct { - c *config + dir string + tsnet *tsnet.Server d *device.Device ns *stack.Stack linkEP *channel.Endpoint + + mu sync.Mutex // protects following + c *config } // MaxPacketSize is the maximum size (in bytes) @@ -161,9 +201,12 @@ func (lp *lpServer) initNetstack(ctx context.Context) error { ns.SetSpoofing(nicID, true) var routes []tcpip.Route + lp.mu.Lock() + v4, v6 := lp.c.V4, lp.c.V6 + lp.mu.Unlock() { - prefix := tcpip.AddrFrom4Slice(lp.c.V4.AsSlice()).WithPrefix() + prefix := tcpip.AddrFrom4Slice(v4.AsSlice()).WithPrefix() if tcpProb := ns.AddProtocolAddress(nicID, tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: prefix, @@ -181,7 +224,7 @@ func (lp *lpServer) initNetstack(ctx context.Context) error { }) } { - prefix := tcpip.AddrFrom16(lp.c.V6.As16()).WithPrefix() + prefix := tcpip.AddrFrom16(v6.As16()).WithPrefix() if tcpProb := ns.AddProtocolAddress(nicID, tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: prefix, @@ -290,19 +333,20 @@ func (t *nsTUN) Name() (string, error) { return "nstun", nil } func (t *nsTUN) Events() <-chan tun.Event { return t.evChan } func (t *nsTUN) BatchSize() int { return 1 } -func startTSNet(ctx context.Context) { +func (lp *lpServer) startTSNet(ctx context.Context) { hostname, err := os.Hostname() if err != nil { log.Fatal(err) } - ts := &tsnet.Server{ + lp.tsnet = &tsnet.Server{ + Dir: filepath.Join(lp.dir, "tsnet"), Hostname: hostname, UserLogf: log.Printf, Ephemeral: false, } - if _, err := ts.Up(ctx); err != nil { + if _, err := lp.tsnet.Up(ctx); err != nil { log.Fatal(err) } } @@ -310,29 +354,10 @@ func startTSNet(ctx context.Context) { func main() { flag.Parse() - logf := log.Printf - deviceLogger := &device.Logger{ - Verbosef: logger.Discard, - Errorf: logf, - } ctx, cancel := context.WithCancel(context.Background()) defer cancel() - lp := &lpServer{ - c: loadConfig(), - } - lp.initNetstack(ctx) - nst := &nsTUN{ - lp: lp, - closeCh: make(chan struct{}), - evChan: make(chan tun.Event), - } - wgdev := wgcfg.NewDevice(nst, conn.NewDefaultBind(), deviceLogger) - defer wgdev.Close() - lp.d = wgdev - must.Do(wgdev.Up()) - lp.reconfig() - - // startTSNet(ctx) + lp := newLP(ctx) + _ = lp sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, unix.SIGTERM, os.Interrupt)