cmd/lopower: move lp init to newLP

Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
Maisem Ali 2024-11-02 14:16:40 -07:00 committed by Anton Tolchanov
parent 0f881a9d09
commit e3ee9c4980

View File

@ -16,6 +16,7 @@ import (
"os/signal"
"path/filepath"
"slices"
"sync"
"github.com/tailscale/wireguard-go/conn"
"github.com/tailscale/wireguard-go/device"
@ -41,6 +42,7 @@ import (
var (
wgListenPort = flag.Int("wg-port", 51820, "port number to listen on for WireGuard from the client")
qrListenAddr = flag.String("qr-listen", "127.0.0.1:8014", "HTTP address to serve a QR code for client's WireGuard configuration")
confDir = flag.String("dir", filepath.Join(os.Getenv("HOME"), ".config/lopower"), "directory to store configuration in")
)
type config struct {
@ -62,8 +64,8 @@ type Peer struct {
V6 netip.Addr
}
func storeConfig(cfg *config) {
path := filepath.Join(os.Getenv("HOME"), ".config/lopower/config.json")
func (lp *lpServer) storeConfigLocked() {
path := filepath.Join(lp.dir, "config.json")
if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil {
log.Fatalf("os.MkdirAll(%q): %v", filepath.Dir(path), err)
}
@ -72,19 +74,23 @@ func storeConfig(cfg *config) {
log.Fatalf("os.OpenFile(%q): %v", path, err)
}
defer f.Close()
must.Do(json.NewEncoder(f).Encode(cfg))
must.Do(json.NewEncoder(f).Encode(lp.c))
if err := f.Close(); err != nil {
log.Fatalf("f.Close: %v", err)
}
}
func loadConfig() *config {
path := filepath.Join(os.Getenv("HOME"), ".config/lopower/config.json")
func (lp *lpServer) loadConfig() {
path := filepath.Join(lp.dir, "config.json")
f, err := os.OpenFile(path, os.O_RDONLY, 0)
if err == nil {
defer f.Close()
var cfg *config
must.Do(json.NewDecoder(f).Decode(&cfg))
return cfg
lp.mu.Lock()
defer lp.mu.Unlock()
lp.c = cfg
return
}
if !os.IsNotExist(err) {
log.Fatalf("os.OpenFile(%q): %v", path, err)
@ -98,11 +104,15 @@ func loadConfig() *config {
}
c.V4 = c.V4CIDR.Addr().Next()
c.V6 = c.V6CIDR.Addr().Next()
storeConfig(c)
return c
lp.mu.Lock()
defer lp.mu.Unlock()
lp.c = c
lp.storeConfigLocked()
return
}
func (lp *lpServer) reconfig() {
lp.mu.Lock()
wc := &wgcfg.Config{
Name: "lopower0",
PrivateKey: lp.c.PrivKey,
@ -121,14 +131,44 @@ func (lp *lpServer) reconfig() {
},
})
}
lp.mu.Unlock()
must.Do(wgcfg.ReconfigDevice(lp.d, wc, log.Printf))
}
func newLP(ctx context.Context) *lpServer {
logf := log.Printf
deviceLogger := &device.Logger{
Verbosef: logger.Discard,
Errorf: logf,
}
lp := &lpServer{
dir: *confDir,
}
lp.loadConfig()
lp.initNetstack(ctx)
nst := &nsTUN{
lp: lp,
closeCh: make(chan struct{}),
evChan: make(chan tun.Event),
}
wgdev := wgcfg.NewDevice(nst, conn.NewDefaultBind(), deviceLogger)
defer wgdev.Close()
lp.d = wgdev
must.Do(wgdev.Up())
lp.reconfig()
lp.startTSNet(ctx)
return lp
}
type lpServer struct {
c *config
dir string
tsnet *tsnet.Server
d *device.Device
ns *stack.Stack
linkEP *channel.Endpoint
mu sync.Mutex // protects following
c *config
}
// MaxPacketSize is the maximum size (in bytes)
@ -161,9 +201,12 @@ func (lp *lpServer) initNetstack(ctx context.Context) error {
ns.SetSpoofing(nicID, true)
var routes []tcpip.Route
lp.mu.Lock()
v4, v6 := lp.c.V4, lp.c.V6
lp.mu.Unlock()
{
prefix := tcpip.AddrFrom4Slice(lp.c.V4.AsSlice()).WithPrefix()
prefix := tcpip.AddrFrom4Slice(v4.AsSlice()).WithPrefix()
if tcpProb := ns.AddProtocolAddress(nicID, tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: prefix,
@ -181,7 +224,7 @@ func (lp *lpServer) initNetstack(ctx context.Context) error {
})
}
{
prefix := tcpip.AddrFrom16(lp.c.V6.As16()).WithPrefix()
prefix := tcpip.AddrFrom16(v6.As16()).WithPrefix()
if tcpProb := ns.AddProtocolAddress(nicID, tcpip.ProtocolAddress{
Protocol: ipv6.ProtocolNumber,
AddressWithPrefix: prefix,
@ -290,19 +333,20 @@ func (t *nsTUN) Name() (string, error) { return "nstun", nil }
func (t *nsTUN) Events() <-chan tun.Event { return t.evChan }
func (t *nsTUN) BatchSize() int { return 1 }
func startTSNet(ctx context.Context) {
func (lp *lpServer) startTSNet(ctx context.Context) {
hostname, err := os.Hostname()
if err != nil {
log.Fatal(err)
}
ts := &tsnet.Server{
lp.tsnet = &tsnet.Server{
Dir: filepath.Join(lp.dir, "tsnet"),
Hostname: hostname,
UserLogf: log.Printf,
Ephemeral: false,
}
if _, err := ts.Up(ctx); err != nil {
if _, err := lp.tsnet.Up(ctx); err != nil {
log.Fatal(err)
}
}
@ -310,29 +354,10 @@ func startTSNet(ctx context.Context) {
func main() {
flag.Parse()
logf := log.Printf
deviceLogger := &device.Logger{
Verbosef: logger.Discard,
Errorf: logf,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
lp := &lpServer{
c: loadConfig(),
}
lp.initNetstack(ctx)
nst := &nsTUN{
lp: lp,
closeCh: make(chan struct{}),
evChan: make(chan tun.Event),
}
wgdev := wgcfg.NewDevice(nst, conn.NewDefaultBind(), deviceLogger)
defer wgdev.Close()
lp.d = wgdev
must.Do(wgdev.Up())
lp.reconfig()
// startTSNet(ctx)
lp := newLP(ctx)
_ = lp
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, unix.SIGTERM, os.Interrupt)