From 5f24261a1eee73b619e40b328ce20c709eff57df Mon Sep 17 00:00:00 2001 From: Fran Bull Date: Mon, 30 Sep 2024 10:27:39 -0700 Subject: [PATCH] wip --- cmd/natc/consensus.go | 267 ++++++++++++++++++++++++++++++++++++++++ cmd/natc/http.go | 132 ++++++++++++++++++++ cmd/natc/ippool.go | 257 ++++++++++++++++++++++++++++++++++++++ cmd/natc/ippool_test.go | 129 +++++++++++++++++++ cmd/natc/natc.go | 154 ++++++++--------------- 5 files changed, 833 insertions(+), 106 deletions(-) create mode 100644 cmd/natc/consensus.go create mode 100644 cmd/natc/http.go create mode 100644 cmd/natc/ippool.go create mode 100644 cmd/natc/ippool_test.go diff --git a/cmd/natc/consensus.go b/cmd/natc/consensus.go new file mode 100644 index 000000000..c326bab94 --- /dev/null +++ b/cmd/natc/consensus.go @@ -0,0 +1,267 @@ +package main + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net" + "net/http" + "net/netip" + "time" + + "github.com/hashicorp/raft" + "tailscale.com/ipn/ipnstate" + "tailscale.com/tsnet" +) + +type consensus struct { + Raft *raft.Raft + CommandClient *commandClient + Self selfRaftNode +} + +type selfRaftNode struct { + ID string + Addr netip.Addr +} + +func (n *selfRaftNode) addrRaftPort() netip.AddrPort { + return netip.AddrPortFrom(n.Addr, 6311) +} + +// StreamLayer implements an interface asked for by raft.NetworkTransport. +// Do the raft interprocess comms via tailscale. +type StreamLayer struct { + net.Listener + s *tsnet.Server +} + +// Dial is used to create a new outgoing connection +func (sl StreamLayer) Dial(address raft.ServerAddress, timeout time.Duration) (net.Conn, error) { + ctx, _ := context.WithTimeout(context.Background(), timeout) + return sl.s.Dial(ctx, "tcp", string(address)) +} + +type listeners struct { + raft *StreamLayer // for the raft goroutine + command net.Listener // for the command http goroutine +} + +func NewConsensus(myAddr netip.Addr, httpClient *http.Client) *consensus { + cc := commandClient{ + port: 6312, + httpClient: httpClient, + } + self := selfRaftNode{ + ID: myAddr.String(), + Addr: myAddr, + } + return &consensus{ + CommandClient: &cc, + Self: self, + } +} + +func (c *consensus) Start(lns *listeners, sm *fsm) error { + config := raft.DefaultConfig() + config.LocalID = raft.ServerID(c.Self.ID) + config.HeartbeatTimeout = 1000 * time.Millisecond + config.ElectionTimeout = 1000 * time.Millisecond + logStore := raft.NewInmemStore() + stableStore := raft.NewInmemStore() + snapshots := raft.NewInmemSnapshotStore() + transport := raft.NewNetworkTransport(lns.raft, 5, 5*time.Second, nil) + + ra, err := raft.NewRaft(config, sm, logStore, stableStore, snapshots, transport) + if err != nil { + return fmt.Errorf("new raft: %s", err) + } + c.Raft = ra + + mux := c.makeCommandMux() + go func() { + defer lns.command.Close() + log.Fatal(http.Serve(lns.command, mux)) + }() + return nil +} + +func (c *consensus) handleJoin(jr joinRequest) error { + configFuture := c.Raft.GetConfiguration() + if err := configFuture.Error(); err != nil { + return err + } + + for _, srv := range configFuture.Configuration().Servers { + // If a node already exists with either the joining node's ID or address, + // that node may need to be removed from the config first. + if srv.ID == raft.ServerID(jr.RemoteID) || srv.Address == raft.ServerAddress(jr.RemoteAddr) { + // However if *both* the ID and the address are the same, then nothing -- not even + // a join operation -- is needed. + if srv.Address == raft.ServerAddress(jr.RemoteAddr) && srv.ID == raft.ServerID(jr.RemoteID) { + log.Printf("node %s at %s already member of cluster, ignoring join request", jr.RemoteID, jr.RemoteAddr) + return nil + } + + future := c.Raft.RemoveServer(srv.ID, 0, 0) + if err := future.Error(); err != nil { + return fmt.Errorf("error removing existing node %s at %s: %s", jr.RemoteID, jr.RemoteAddr, err) + } + } + } + + f := c.Raft.AddVoter(raft.ServerID(jr.RemoteID), raft.ServerAddress(jr.RemoteAddr), 0, 0) + if f.Error() != nil { + return f.Error() + } + return nil +} + +// try to join a raft cluster, or start one +func BootstrapConsensus(sm *fsm, myAddr netip.Addr, lns *listeners, targets []*ipnstate.PeerStatus, httpClient *http.Client) (*consensus, error) { + cns := NewConsensus(myAddr, httpClient) + err := cns.Start(lns, sm) + if err != nil { + return cns, err + } + joined := false + log.Printf("Trying to find cluster: num targets to try: %d", len(targets)) + for _, p := range targets { + if !p.Online { + log.Printf("Trying to find cluster: tailscale reports not online: %s", p.TailscaleIPs[0]) + } else { + log.Printf("Trying to find cluster: trying %s", p.TailscaleIPs[0]) + err = cns.JoinCluster(p.TailscaleIPs[0]) + if err != nil { + log.Printf("Trying to find cluster: could not join %s: %v", p.TailscaleIPs[0], err) + } else { + log.Printf("Trying to find cluster: joined %s", p.TailscaleIPs[0]) + joined = true + break + } + } + } + + if !joined { + log.Printf("Trying to find cluster: unsuccessful, starting as leader: %s", myAddr) + err = cns.LeadCluster() + if err != nil { + return cns, err + } + } + return cns, nil +} + +func (c *consensus) JoinCluster(a netip.Addr) error { + return c.CommandClient.Join(c.CommandClient.ServerAddressFromAddr(a), joinRequest{ + RemoteAddr: c.Self.addrRaftPort().String(), + RemoteID: c.Self.ID, + }) + +} + +func (c *consensus) LeadCluster() error { + configuration := raft.Configuration{ + Servers: []raft.Server{ + { + ID: raft.ServerID(c.Self.ID), + Address: raft.ServerAddress(fmt.Sprintf("%s:6311", c.Self.Addr)), + }, + }, + } + f := c.Raft.BootstrapCluster(configuration) + return f.Error() +} + +// plumbing for executing a command either locally or via http transport +// and telling peers we're not the leader and who we think the leader is +type command struct { + Name string + Args []byte +} + +type commandResult struct { + Err error + Result []byte +} + +type lookElsewhereError struct { + where string +} + +func (e lookElsewhereError) Error() string { + return fmt.Sprintf("not the leader, try: %s", e.where) +} + +func (c *consensus) executeCommandLocally(cmd command) (commandResult, error) { + b, err := json.Marshal(cmd) + if err != nil { + return commandResult{}, err + } + f := c.Raft.Apply(b, 10*time.Second) + err = f.Error() + result := f.Response() + if errors.Is(err, raft.ErrNotLeader) { + raftLeaderAddr, _ := c.Raft.LeaderWithID() + leaderAddr := (string)(raftLeaderAddr) + if leaderAddr != "" { + leaderAddr = leaderAddr[:len(raftLeaderAddr)-1] + "2" // TODO + } + return commandResult{}, lookElsewhereError{where: leaderAddr} + } + return result.(commandResult), err +} + +func (c *consensus) executeCommand(cmd command) (commandResult, error) { + b, err := json.Marshal(cmd) + if err != nil { + return commandResult{}, err + } + result, err := c.executeCommandLocally(cmd) + var leErr lookElsewhereError + for errors.As(err, &leErr) { + result, err = c.CommandClient.ExecuteCommand(leErr.where, b) + } + return result, err +} + +// fulfil the raft lib functional state machine interface +type fsm ipPool +type fsmSnapshot struct{} + +func (f *fsm) Apply(l *raft.Log) interface{} { + var c command + if err := json.Unmarshal(l.Data, &c); err != nil { + panic(fmt.Sprintf("failed to unmarshal command: %s", err.Error())) + } + switch c.Name { + case "checkoutAddr": + return f.executeCheckoutAddr(c.Args) + case "markLastUsed": + return f.executeMarkLastUsed(c.Args) + default: + panic(fmt.Sprintf("unrecognized command: %s", c.Name)) + } +} + +func (f *fsm) Snapshot() (raft.FSMSnapshot, error) { + panic("Snapshot unexpectedly used") + return nil, nil +} + +func (f *fsm) Restore(rc io.ReadCloser) error { + panic("Restore unexpectedly used") + return nil +} + +func (f *fsmSnapshot) Persist(sink raft.SnapshotSink) error { + panic("Persist unexpectedly used") + return nil +} + +func (f *fsmSnapshot) Release() { + panic("Release unexpectedly used") +} diff --git a/cmd/natc/http.go b/cmd/natc/http.go new file mode 100644 index 000000000..06ff334d7 --- /dev/null +++ b/cmd/natc/http.go @@ -0,0 +1,132 @@ +package main + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/netip" + "time" +) + +type joinRequest struct { + RemoteAddr string `json:'remoteAddr'` + RemoteID string `json:'remoteID'` +} + +type commandClient struct { + port int + httpClient *http.Client +} + +func (rac *commandClient) ServerAddressFromAddr(addr netip.Addr) string { + return fmt.Sprintf("%s:%d", addr, rac.port) +} + +func (rac *commandClient) Url(serverAddr string, path string) string { + return fmt.Sprintf("http://%s%s", serverAddr, path) +} + +func (rac *commandClient) Join(serverAddr string, jr joinRequest) error { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + rBs, err := json.Marshal(jr) + if err != nil { + return err + } + url := rac.Url(serverAddr, "/join") + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(rBs)) + if err != nil { + return err + } + resp, err := rac.httpClient.Do(req) + if err != nil { + return err + } + respBs, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + if resp.StatusCode != 200 { + return errors.New(fmt.Sprintf("remote responded %d: %s", resp.StatusCode, string(respBs))) + } + return nil +} + +func (rac *commandClient) ExecuteCommand(serverAddr string, bs []byte) (commandResult, error) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + url := rac.Url(serverAddr, "/executeCommand") + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bs)) + if err != nil { + return commandResult{}, err + } + resp, err := rac.httpClient.Do(req) + if err != nil { + return commandResult{}, err + } + respBs, err := io.ReadAll(resp.Body) + if err != nil { + return commandResult{}, err + } + if resp.StatusCode != 200 { + return commandResult{}, errors.New(fmt.Sprintf("remote responded %d: %s", resp.StatusCode, string(respBs))) + } + var cr commandResult + if err = json.Unmarshal(respBs, &cr); err != nil { + return commandResult{}, err + } + return cr, nil +} + +func (c *consensus) makeCommandMux() *http.ServeMux { + mux := http.NewServeMux() + mux.HandleFunc("/join", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Bad Request", http.StatusBadRequest) + return + } + decoder := json.NewDecoder(r.Body) + var jr joinRequest + err := decoder.Decode(&jr) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if jr.RemoteAddr == "" { + http.Error(w, "Required: remoteAddr", http.StatusBadRequest) + return + } + if jr.RemoteID == "" { + http.Error(w, "Required: remoteID", http.StatusBadRequest) + return + } + err = c.handleJoin(jr) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + }) + mux.HandleFunc("/executeCommand", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Bad Request", http.StatusBadRequest) + return + } + decoder := json.NewDecoder(r.Body) + var cmd command + err := decoder.Decode(&cmd) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + result, err := c.executeCommandLocally(cmd) + if err := json.NewEncoder(w).Encode(result); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + }) + return mux +} diff --git a/cmd/natc/ippool.go b/cmd/natc/ippool.go new file mode 100644 index 000000000..091a93c87 --- /dev/null +++ b/cmd/natc/ippool.go @@ -0,0 +1,257 @@ +package main + +import ( + "encoding/json" + "errors" + "fmt" + "log" + "net/netip" + "sync" + "time" + + "github.com/gaissmai/bart" + "tailscale.com/ipn/ipnstate" + "tailscale.com/syncs" + "tailscale.com/tailcfg" + "tailscale.com/tsnet" + "tailscale.com/util/mak" +) + +type ipPool struct { + perPeerMap syncs.Map[tailcfg.NodeID, *perPeerState] + v4Ranges []netip.Prefix + dnsAddr netip.Addr + consensus *consensus +} + +func (ipp *ipPool) DomainForIP(from tailcfg.NodeID, addr netip.Addr, updatedAt time.Time) string { + // TODO lock + pm, ok := ipp.perPeerMap.Load(from) + if !ok { + log.Printf("DomainForIP: peer state absent for: %d", from) + return "" + } + ww, ok := pm.AddrToDomain.Lookup(addr) + if !ok { + log.Printf("DomainForIP: peer state doesn't recognize domain") + return "" + } + go func() { + err := ipp.markLastUsed(from, addr, ww.Domain, updatedAt) + if err != nil { + panic(err) + } + }() + return ww.Domain +} + +type markLastUsedArgs struct { + NodeID tailcfg.NodeID + Addr netip.Addr + Domain string + UpdatedAt time.Time +} + +// called by raft +func (cd *fsm) executeMarkLastUsed(bs []byte) commandResult { + var args markLastUsedArgs + err := json.Unmarshal(bs, &args) + if err != nil { + return commandResult{Err: err} + } + err = cd.applyMarkLastUsed(args.NodeID, args.Addr, args.Domain, args.UpdatedAt) + if err != nil { + return commandResult{Err: err} + } + return commandResult{} +} + +func (ipp *fsm) applyMarkLastUsed(from tailcfg.NodeID, addr netip.Addr, domain string, updatedAt time.Time) error { + // TODO lock + ps, ok := ipp.perPeerMap.Load(from) + if !ok { + // unexpected in normal operation (but not an error?) + return nil + } + ww, ok := ps.AddrToDomain.Lookup(addr) + if !ok { + // unexpected in normal operation (but not an error?) + return nil + } + if ww.Domain != domain { + // then I guess we're too late to update lastUsed + return nil + } + if ww.LastUsed.After(updatedAt) { + // prefer the most recent + return nil + } + ww.LastUsed = updatedAt + ps.AddrToDomain.Insert(netip.PrefixFrom(addr, addr.BitLen()), ww) + return nil +} + +func (ipp *ipPool) StartConsensus(peers []*ipnstate.PeerStatus, ts *tsnet.Server) { + v4, _ := ts.TailscaleIPs() + adminLn, err := ts.Listen("tcp", fmt.Sprintf("%s:6312", v4)) + if err != nil { + log.Fatal(err) + } + raftLn, err := ts.Listen("tcp", fmt.Sprintf("%s:6311", v4)) + if err != nil { + log.Fatal(err) + } + sl := StreamLayer{s: ts, Listener: raftLn} + lns := listeners{command: adminLn, raft: &sl} + cns, err := BootstrapConsensus((*fsm)(ipp), v4, &lns, peers, ts.HTTPClient()) + if err != nil { + log.Fatalf("BootstrapConsensus failed: %v", err) + } + ipp.consensus = cns +} + +type whereWhen struct { + Domain string + LastUsed time.Time +} + +type perPeerState struct { + DomainToAddr map[string]netip.Addr + AddrToDomain *bart.Table[whereWhen] + mu sync.Mutex // not jsonified +} + +func (ps *perPeerState) unusedIPV4(ranges []netip.Prefix, exclude netip.Addr, reuseDeadline time.Time) (netip.Addr, bool, string, error) { + // TODO here we iterate through each ip within the ranges until we find one that's unused + // could be done more efficiently either by: + // 1) storing an index into ranges and an ip we had last used from that range in perPeerState + // (how would this work with checking ips back into the pool though?) + // 2) using a random approach like the natc does now, except the raft state machine needs to + // be deterministic so it can replay logs, so I think we would do something like generate a + // random ip each time, and then have a call into the state machine that says "give me whatever + // ip you have, and if you don't have one use this one". I think that would work. + for _, r := range ranges { + ip := r.Addr() + for r.Contains(ip) { + if ip != exclude { + ww, ok := ps.AddrToDomain.Lookup(ip) + if !ok { + return ip, false, "", nil + } + if ww.LastUsed.Before(reuseDeadline) { + return ip, true, ww.Domain, nil + } + } + ip = ip.Next() + } + } + return netip.Addr{}, false, "", errors.New("ip pool exhausted") +} + +func (cd *ipPool) IpForDomain(nid tailcfg.NodeID, domain string) (netip.Addr, error) { + now := time.Now() + args := checkoutAddrArgs{ + NodeID: nid, + Domain: domain, + ReuseDeadline: now.Add(-10 * time.Second), // TODO what time period? 48 hours? + UpdatedAt: now, + } + bs, err := json.Marshal(args) + if err != nil { + return netip.Addr{}, err + } + c := command{ + Name: "checkoutAddr", + Args: bs, + } + result, err := cd.consensus.executeCommand(c) + if err != nil { + log.Printf("IpForDomain: raft error executing command: %v", err) + return netip.Addr{}, err + } + if result.Err != nil { + log.Printf("IpForDomain: error returned from state machine: %v", err) + return netip.Addr{}, result.Err + } + var addr netip.Addr + err = json.Unmarshal(result.Result, &addr) + return addr, err +} + +func (cd *ipPool) markLastUsed(nid tailcfg.NodeID, addr netip.Addr, domain string, lastUsed time.Time) error { + args := markLastUsedArgs{ + NodeID: nid, + Addr: addr, + Domain: domain, + UpdatedAt: lastUsed, + } + bs, err := json.Marshal(args) + if err != nil { + return err + } + c := command{ + Name: "markLastUsed", + Args: bs, + } + result, err := cd.consensus.executeCommand(c) + if err != nil { + log.Printf("markLastUsed: raft error executing command: %v", err) + return err + } + if result.Err != nil { + log.Printf("markLastUsed: error returned from state machine: %v", err) + return result.Err + } + return nil +} + +type checkoutAddrArgs struct { + NodeID tailcfg.NodeID + Domain string + ReuseDeadline time.Time + UpdatedAt time.Time +} + +// called by raft +func (cd *fsm) executeCheckoutAddr(bs []byte) commandResult { + var args checkoutAddrArgs + err := json.Unmarshal(bs, &args) + if err != nil { + return commandResult{Err: err} + } + addr, err := cd.applyCheckoutAddr(args.NodeID, args.Domain, args.ReuseDeadline, args.UpdatedAt) + if err != nil { + return commandResult{Err: err} + } + resultBs, err := json.Marshal(addr) + if err != nil { + return commandResult{Err: err} + } + return commandResult{Result: resultBs} +} + +func (cd *fsm) applyCheckoutAddr(nid tailcfg.NodeID, domain string, reuseDeadline, updatedAt time.Time) (netip.Addr, error) { + // TODO lock and unlock + pm, _ := cd.perPeerMap.LoadOrStore(nid, &perPeerState{ + AddrToDomain: &bart.Table[whereWhen]{}, + }) + if existing, ok := pm.DomainToAddr[domain]; ok { + // TODO handle error case where this doesn't exist + ww, _ := pm.AddrToDomain.Lookup(existing) + ww.LastUsed = updatedAt + pm.AddrToDomain.Insert(netip.PrefixFrom(existing, existing.BitLen()), ww) + return existing, nil + } + addr, wasInUse, previousDomain, err := pm.unusedIPV4(cd.v4Ranges, cd.dnsAddr, reuseDeadline) + if err != nil { + return netip.Addr{}, err + } + mak.Set(&pm.DomainToAddr, domain, addr) + if wasInUse { + // remove it from domaintoaddr + delete(pm.DomainToAddr, previousDomain) + // don't need to remove it from addrtodomain, insert will do that + } + pm.AddrToDomain.Insert(netip.PrefixFrom(addr, addr.BitLen()), whereWhen{Domain: domain, LastUsed: updatedAt}) + return addr, nil +} diff --git a/cmd/natc/ippool_test.go b/cmd/natc/ippool_test.go new file mode 100644 index 000000000..fd5fc9c3e --- /dev/null +++ b/cmd/natc/ippool_test.go @@ -0,0 +1,129 @@ +package main + +import ( + "encoding/json" + "fmt" + "net/netip" + "testing" + + "tailscale.com/tailcfg" +) + +func TestV6V4(t *testing.T) { + c := connector{ + v6ULA: ula(uint16(1)), + } + + tests := [][]string{ + []string{"100.64.0.0", "fd7a:115c:a1e0:a99c:1:0:6440:0"}, + []string{"0.0.0.0", "fd7a:115c:a1e0:a99c:1::"}, + []string{"255.255.255.255", "fd7a:115c:a1e0:a99c:1:0:ffff:ffff"}, + } + + for i, test := range tests { + // to v6 + v6 := c.v6ForV4(netip.MustParseAddr(test[0])) + want := netip.MustParseAddr(test[1]) + if v6 != want { + t.Fatalf("test %d: want: %v, got: %v", i, want, v6) + } + + // to v4 + v4 := v4ForV6(netip.MustParseAddr(test[1])) + want = netip.MustParseAddr(test[0]) + if v4 != want { + t.Fatalf("test %d: want: %v, got: %v", i, want, v4) + } + } +} + +func TestIPForDomain(t *testing.T) { + pfx := netip.MustParsePrefix("100.64.0.0/16") + ipp := fsm{ + v4Ranges: []netip.Prefix{pfx}, + dnsAddr: netip.MustParseAddr("100.64.0.0"), + } + a, err := ipp.applyCheckoutAddr(tailcfg.NodeID(1), "example.com") + if err != nil { + t.Fatal(err) + } + if !pfx.Contains(a) { + t.Fatalf("expected %v to be in the prefix %v", a, pfx) + } + + b, err := ipp.applyCheckoutAddr(tailcfg.NodeID(1), "a.example.com") + if err != nil { + t.Fatal(err) + } + if !pfx.Contains(b) { + t.Fatalf("expected %v to be in the prefix %v", b, pfx) + } + if b == a { + t.Fatalf("same address issued twice %v, %v", a, b) + } + + c, err := ipp.applyCheckoutAddr(tailcfg.NodeID(1), "example.com") + if err != nil { + t.Fatal(err) + } + if c != a { + t.Fatalf("expected %v to be remembered as the addr for example.com, but got %v", a, c) + } +} + +func TestDomainForIP(t *testing.T) { + pfx := netip.MustParsePrefix("100.64.0.0/16") + sm := fsm{ + v4Ranges: []netip.Prefix{pfx}, + dnsAddr: netip.MustParseAddr("100.64.0.0"), + } + ipp := (*ipPool)(&sm) + nid := tailcfg.NodeID(1) + domain := "example.com" + d := ipp.DomainForIP(nid, netip.MustParseAddr("100.64.0.1")) + if d != "" { + t.Fatalf("expected an empty string if the addr is not found but got %s", d) + } + a, err := sm.applyCheckoutAddr(nid, domain) + if err != nil { + t.Fatal(err) + } + d2 := ipp.DomainForIP(nid, a) + if d2 != domain { + t.Fatalf("expected %s but got %s", domain, d2) + } +} + +func TestBlah(t *testing.T) { + type ecr interface { + getResult() interface{} + setResult(interface{}) + toJSON() ([]byte, error) + fromJSON([]byte) err + } + type fran struct { + Result netip.Addr + } + func(f *fran) toJSON() string { + return json.Marshal(f) + } + func(f *fran) fromJSON(bs []byte) err { + return json.UnMarshal(bs, f) + } + thrujson := func(in ecr) ecr { + bs, err := json.Marshal(in) + if err != nil { + t.Fatal(err) + } + var out ecr + err = json.Unmarshal(bs, &out) + if err != nil { + t.Fatal(err) + } + return out + } + a := netip.Addr{} + out := thrujson(ecr{Result: a}).Result + b := (out).(netip.Addr) + fmt.Println(b) +} diff --git a/cmd/natc/natc.go b/cmd/natc/natc.go index d94523c6e..433bdc6a6 100644 --- a/cmd/natc/natc.go +++ b/cmd/natc/natc.go @@ -8,18 +8,16 @@ package main import ( "context" - "encoding/binary" "errors" "flag" "fmt" "log" - "math/rand/v2" "net" "net/http" "net/netip" "os" + "slices" "strings" - "sync" "time" "github.com/gaissmai/bart" @@ -30,13 +28,11 @@ import ( "tailscale.com/envknob" "tailscale.com/hostinfo" "tailscale.com/ipn" + "tailscale.com/ipn/ipnstate" "tailscale.com/net/netutil" - "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/tsnet" "tailscale.com/tsweb" - "tailscale.com/util/dnsname" - "tailscale.com/util/mak" ) func main() { @@ -56,6 +52,7 @@ func main() { printULA = fs.Bool("print-ula", false, "print the ULA prefix and exit") ignoreDstPfxStr = fs.String("ignore-destinations", "", "comma-separated list of prefixes to ignore") wgPort = fs.Uint("wg-port", 0, "udp port for wireguard and peer to peer traffic") + clusterTag = fs.String("cluster-tag", "", "TODO") ) ff.Parse(fs, os.Args[1:], ff.WithEnvVarPrefix("TS_NATC")) @@ -105,6 +102,7 @@ func main() { ts := &tsnet.Server{ Hostname: *hostname, } + ts.ControlURL = "http://host.docker.internal:31544" if *wgPort != 0 { if *wgPort >= 1<<16 { log.Fatalf("wg-port must be in the range [0, 65535]") @@ -112,6 +110,7 @@ func main() { ts.Port = uint16(*wgPort) } defer ts.Close() + if *verboseTSNet { ts.Logf = log.Printf } @@ -136,6 +135,28 @@ func main() { if _, err := ts.Up(ctx); err != nil { log.Fatalf("ts.Up: %v", err) } + woo, err := lc.Status(ctx) + if err != nil { + panic(err) + } + var peers []*ipnstate.PeerStatus + if *clusterTag != "" && woo.Self.Tags != nil && slices.Contains(woo.Self.Tags.AsSlice(), *clusterTag) { + for _, v := range woo.Peer { + if v.Tags != nil && slices.Contains(v.Tags.AsSlice(), *clusterTag) { + peers = append(peers, v) + } + } + } else { + // we are not in clustering mode I guess? + panic("todo") + } + + ipp := ipPool{ + v4Ranges: v4Prefixes, + dnsAddr: dnsAddr, + } + + ipp.StartConsensus(peers, ts) c := &connector{ ts: ts, @@ -144,6 +165,7 @@ func main() { v4Ranges: v4Prefixes, v6ULA: ula(uint16(*siteID)), ignoreDsts: ignoreDstTable, + ipAddrs: &ipp, } c.run(ctx) } @@ -165,7 +187,7 @@ type connector struct { // v6ULA is the ULA prefix used by the app connector to assign IPv6 addresses. v6ULA netip.Prefix - perPeerMap syncs.Map[tailcfg.NodeID, *perPeerState] + ipAddrs *ipPool // ignoreDsts is initialized at start up with the contents of --ignore-destinations (if none it is nil) // It is never mutated, only used for lookups. @@ -332,16 +354,15 @@ var tsMBox = dnsmessage.MustNewName("support.tailscale.com.") // generateDNSResponse generates a DNS response for the given request. The from // argument is the NodeID of the node that sent the request. func (c *connector) generateDNSResponse(req *dnsmessage.Message, from tailcfg.NodeID) ([]byte, error) { - pm, _ := c.perPeerMap.LoadOrStore(from, &perPeerState{c: c}) var addrs []netip.Addr if len(req.Questions) > 0 { switch req.Questions[0].Type { case dnsmessage.TypeAAAA, dnsmessage.TypeA: - var err error - addrs, err = pm.ipForDomain(req.Questions[0].Name.String()) + v4, err := c.ipAddrs.IpForDomain(from, req.Questions[0].Name.String()) if err != nil { return nil, err } + addrs = []netip.Addr{v4, c.v6ForV4(v4)} } } return dnsResponse(req, addrs) @@ -429,14 +450,13 @@ func (c *connector) handleTCPFlow(src, dst netip.AddrPort) (handler func(net.Con } from := who.Node.ID - ps, ok := c.perPeerMap.Load(from) - if !ok { - log.Printf("handleTCPFlow: no perPeerState for %v", from) - return nil, false + dstAddr := dst.Addr() + if dstAddr.Is6() { + dstAddr = v4ForV6(dstAddr) } - domain, ok := ps.domainForIP(dst.Addr()) - if !ok { - log.Printf("handleTCPFlow: no domain for IP %v\n", dst.Addr()) + domain := c.ipAddrs.DomainForIP(from, dstAddr, time.Now()) + if domain == "" { + log.Print("handleTCPFlow: found no domain") return nil, false } return func(conn net.Conn) { @@ -480,96 +500,18 @@ func proxyTCPConn(c net.Conn, dest string) { p.Start() } -// perPeerState holds the state for a single peer. -type perPeerState struct { - c *connector - - mu sync.Mutex - domainToAddr map[string][]netip.Addr - addrToDomain *bart.Table[string] -} - -// domainForIP returns the domain name assigned to the given IP address and -// whether it was found. -func (ps *perPeerState) domainForIP(ip netip.Addr) (_ string, ok bool) { - ps.mu.Lock() - defer ps.mu.Unlock() - if ps.addrToDomain == nil { - return "", false - } - return ps.addrToDomain.Lookup(ip) -} - -// ipForDomain assigns a pair of unique IP addresses for the given domain and -// returns them. The first address is an IPv4 address and the second is an IPv6 -// address. If the domain already has assigned addresses, it returns them. -func (ps *perPeerState) ipForDomain(domain string) ([]netip.Addr, error) { - fqdn, err := dnsname.ToFQDN(domain) - if err != nil { - return nil, err - } - domain = fqdn.WithoutTrailingDot() - - ps.mu.Lock() - defer ps.mu.Unlock() - if addrs, ok := ps.domainToAddr[domain]; ok { - return addrs, nil - } - addrs := ps.assignAddrsLocked(domain) - return addrs, nil -} - -// isIPUsedLocked reports whether the given IP address is already assigned to a -// domain. -// ps.mu must be held. -func (ps *perPeerState) isIPUsedLocked(ip netip.Addr) bool { - _, ok := ps.addrToDomain.Lookup(ip) - return ok -} - -// unusedIPv4Locked returns an unused IPv4 address from the available ranges. -func (ps *perPeerState) unusedIPv4Locked() netip.Addr { - // TODO: skip ranges that have been exhausted - for _, r := range ps.c.v4Ranges { - ip := randV4(r) - for r.Contains(ip) { - if !ps.isIPUsedLocked(ip) && ip != ps.c.dnsAddr { - return ip - } - ip = ip.Next() - } - } - return netip.Addr{} -} - -// randV4 returns a random IPv4 address within the given prefix. -func randV4(maskedPfx netip.Prefix) netip.Addr { - bits := 32 - maskedPfx.Bits() - randBits := rand.Uint32N(1 << uint(bits)) - - ip4 := maskedPfx.Addr().As4() - pn := binary.BigEndian.Uint32(ip4[:]) - binary.BigEndian.PutUint32(ip4[:], randBits|pn) - return netip.AddrFrom4(ip4) -} - -// assignAddrsLocked assigns a pair of unique IP addresses for the given domain -// and returns them. The first address is an IPv4 address and the second is an -// IPv6 address. It does not check if the domain already has assigned addresses. -// ps.mu must be held. -func (ps *perPeerState) assignAddrsLocked(domain string) []netip.Addr { - if ps.addrToDomain == nil { - ps.addrToDomain = &bart.Table[string]{} - } - v4 := ps.unusedIPv4Locked() - as16 := ps.c.v6ULA.Addr().As16() +func (c *connector) v6ForV4(v4 netip.Addr) netip.Addr { + as16 := c.v6ULA.Addr().As16() as4 := v4.As4() copy(as16[12:], as4[:]) v6 := netip.AddrFrom16(as16) - addrs := []netip.Addr{v4, v6} - mak.Set(&ps.domainToAddr, domain, addrs) - for _, a := range addrs { - ps.addrToDomain.Insert(netip.PrefixFrom(a, a.BitLen()), domain) - } - return addrs + return v6 +} + +func v4ForV6(v6 netip.Addr) netip.Addr { + as16 := v6.As16() + var as4 [4]byte + copy(as4[:], as16[12:]) + v4 := netip.AddrFrom4(as4) + return v4 }