cmd/lopower: move lp init to newLP

Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
Maisem Ali 2024-11-02 14:16:40 -07:00 committed by Anton Tolchanov
parent 0f881a9d09
commit e3ee9c4980

View File

@ -16,6 +16,7 @@ import (
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"slices" "slices"
"sync"
"github.com/tailscale/wireguard-go/conn" "github.com/tailscale/wireguard-go/conn"
"github.com/tailscale/wireguard-go/device" "github.com/tailscale/wireguard-go/device"
@ -41,6 +42,7 @@ import (
var ( var (
wgListenPort = flag.Int("wg-port", 51820, "port number to listen on for WireGuard from the client") wgListenPort = flag.Int("wg-port", 51820, "port number to listen on for WireGuard from the client")
qrListenAddr = flag.String("qr-listen", "127.0.0.1:8014", "HTTP address to serve a QR code for client's WireGuard configuration") qrListenAddr = flag.String("qr-listen", "127.0.0.1:8014", "HTTP address to serve a QR code for client's WireGuard configuration")
confDir = flag.String("dir", filepath.Join(os.Getenv("HOME"), ".config/lopower"), "directory to store configuration in")
) )
type config struct { type config struct {
@ -62,8 +64,8 @@ type Peer struct {
V6 netip.Addr V6 netip.Addr
} }
func storeConfig(cfg *config) { func (lp *lpServer) storeConfigLocked() {
path := filepath.Join(os.Getenv("HOME"), ".config/lopower/config.json") path := filepath.Join(lp.dir, "config.json")
if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil { if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil {
log.Fatalf("os.MkdirAll(%q): %v", filepath.Dir(path), err) log.Fatalf("os.MkdirAll(%q): %v", filepath.Dir(path), err)
} }
@ -72,19 +74,23 @@ func storeConfig(cfg *config) {
log.Fatalf("os.OpenFile(%q): %v", path, err) log.Fatalf("os.OpenFile(%q): %v", path, err)
} }
defer f.Close() defer f.Close()
must.Do(json.NewEncoder(f).Encode(cfg)) must.Do(json.NewEncoder(f).Encode(lp.c))
if err := f.Close(); err != nil { if err := f.Close(); err != nil {
log.Fatalf("f.Close: %v", err) log.Fatalf("f.Close: %v", err)
} }
} }
func loadConfig() *config { func (lp *lpServer) loadConfig() {
path := filepath.Join(os.Getenv("HOME"), ".config/lopower/config.json") path := filepath.Join(lp.dir, "config.json")
f, err := os.OpenFile(path, os.O_RDONLY, 0) f, err := os.OpenFile(path, os.O_RDONLY, 0)
if err == nil { if err == nil {
defer f.Close()
var cfg *config var cfg *config
must.Do(json.NewDecoder(f).Decode(&cfg)) must.Do(json.NewDecoder(f).Decode(&cfg))
return cfg lp.mu.Lock()
defer lp.mu.Unlock()
lp.c = cfg
return
} }
if !os.IsNotExist(err) { if !os.IsNotExist(err) {
log.Fatalf("os.OpenFile(%q): %v", path, err) log.Fatalf("os.OpenFile(%q): %v", path, err)
@ -98,11 +104,15 @@ func loadConfig() *config {
} }
c.V4 = c.V4CIDR.Addr().Next() c.V4 = c.V4CIDR.Addr().Next()
c.V6 = c.V6CIDR.Addr().Next() c.V6 = c.V6CIDR.Addr().Next()
storeConfig(c) lp.mu.Lock()
return c defer lp.mu.Unlock()
lp.c = c
lp.storeConfigLocked()
return
} }
func (lp *lpServer) reconfig() { func (lp *lpServer) reconfig() {
lp.mu.Lock()
wc := &wgcfg.Config{ wc := &wgcfg.Config{
Name: "lopower0", Name: "lopower0",
PrivateKey: lp.c.PrivKey, PrivateKey: lp.c.PrivKey,
@ -121,14 +131,44 @@ func (lp *lpServer) reconfig() {
}, },
}) })
} }
lp.mu.Unlock()
must.Do(wgcfg.ReconfigDevice(lp.d, wc, log.Printf)) must.Do(wgcfg.ReconfigDevice(lp.d, wc, log.Printf))
} }
func newLP(ctx context.Context) *lpServer {
logf := log.Printf
deviceLogger := &device.Logger{
Verbosef: logger.Discard,
Errorf: logf,
}
lp := &lpServer{
dir: *confDir,
}
lp.loadConfig()
lp.initNetstack(ctx)
nst := &nsTUN{
lp: lp,
closeCh: make(chan struct{}),
evChan: make(chan tun.Event),
}
wgdev := wgcfg.NewDevice(nst, conn.NewDefaultBind(), deviceLogger)
defer wgdev.Close()
lp.d = wgdev
must.Do(wgdev.Up())
lp.reconfig()
lp.startTSNet(ctx)
return lp
}
type lpServer struct { type lpServer struct {
c *config dir string
tsnet *tsnet.Server
d *device.Device d *device.Device
ns *stack.Stack ns *stack.Stack
linkEP *channel.Endpoint linkEP *channel.Endpoint
mu sync.Mutex // protects following
c *config
} }
// MaxPacketSize is the maximum size (in bytes) // MaxPacketSize is the maximum size (in bytes)
@ -161,9 +201,12 @@ func (lp *lpServer) initNetstack(ctx context.Context) error {
ns.SetSpoofing(nicID, true) ns.SetSpoofing(nicID, true)
var routes []tcpip.Route var routes []tcpip.Route
lp.mu.Lock()
v4, v6 := lp.c.V4, lp.c.V6
lp.mu.Unlock()
{ {
prefix := tcpip.AddrFrom4Slice(lp.c.V4.AsSlice()).WithPrefix() prefix := tcpip.AddrFrom4Slice(v4.AsSlice()).WithPrefix()
if tcpProb := ns.AddProtocolAddress(nicID, tcpip.ProtocolAddress{ if tcpProb := ns.AddProtocolAddress(nicID, tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber, Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: prefix, AddressWithPrefix: prefix,
@ -181,7 +224,7 @@ func (lp *lpServer) initNetstack(ctx context.Context) error {
}) })
} }
{ {
prefix := tcpip.AddrFrom16(lp.c.V6.As16()).WithPrefix() prefix := tcpip.AddrFrom16(v6.As16()).WithPrefix()
if tcpProb := ns.AddProtocolAddress(nicID, tcpip.ProtocolAddress{ if tcpProb := ns.AddProtocolAddress(nicID, tcpip.ProtocolAddress{
Protocol: ipv6.ProtocolNumber, Protocol: ipv6.ProtocolNumber,
AddressWithPrefix: prefix, AddressWithPrefix: prefix,
@ -290,19 +333,20 @@ func (t *nsTUN) Name() (string, error) { return "nstun", nil }
func (t *nsTUN) Events() <-chan tun.Event { return t.evChan } func (t *nsTUN) Events() <-chan tun.Event { return t.evChan }
func (t *nsTUN) BatchSize() int { return 1 } func (t *nsTUN) BatchSize() int { return 1 }
func startTSNet(ctx context.Context) { func (lp *lpServer) startTSNet(ctx context.Context) {
hostname, err := os.Hostname() hostname, err := os.Hostname()
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
ts := &tsnet.Server{ lp.tsnet = &tsnet.Server{
Dir: filepath.Join(lp.dir, "tsnet"),
Hostname: hostname, Hostname: hostname,
UserLogf: log.Printf, UserLogf: log.Printf,
Ephemeral: false, Ephemeral: false,
} }
if _, err := ts.Up(ctx); err != nil { if _, err := lp.tsnet.Up(ctx); err != nil {
log.Fatal(err) log.Fatal(err)
} }
} }
@ -310,29 +354,10 @@ func startTSNet(ctx context.Context) {
func main() { func main() {
flag.Parse() flag.Parse()
logf := log.Printf
deviceLogger := &device.Logger{
Verbosef: logger.Discard,
Errorf: logf,
}
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
lp := &lpServer{ lp := newLP(ctx)
c: loadConfig(), _ = lp
}
lp.initNetstack(ctx)
nst := &nsTUN{
lp: lp,
closeCh: make(chan struct{}),
evChan: make(chan tun.Event),
}
wgdev := wgcfg.NewDevice(nst, conn.NewDefaultBind(), deviceLogger)
defer wgdev.Close()
lp.d = wgdev
must.Do(wgdev.Up())
lp.reconfig()
// startTSNet(ctx)
sigCh := make(chan os.Signal, 1) sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, unix.SIGTERM, os.Interrupt) signal.Notify(sigCh, unix.SIGTERM, os.Interrupt)