From 6d78f27d730a2ed953149364a93a836aee19cd5c Mon Sep 17 00:00:00 2001 From: Fran Bull Date: Thu, 20 Feb 2025 10:50:25 -0800 Subject: [PATCH] allowedPeers -> auth obj --- tsconsensus/authorization.go | 80 ++++++++++++++++++++ tsconsensus/http.go | 18 +++-- tsconsensus/tsconsensus.go | 138 ++++++++++++----------------------- 3 files changed, 138 insertions(+), 98 deletions(-) create mode 100644 tsconsensus/authorization.go diff --git a/tsconsensus/authorization.go b/tsconsensus/authorization.go new file mode 100644 index 000000000..67e685def --- /dev/null +++ b/tsconsensus/authorization.go @@ -0,0 +1,80 @@ +package tsconsensus + +import ( + "context" + "net/netip" + "slices" + + "tailscale.com/ipn/ipnstate" + "tailscale.com/tsnet" +) + +type authorization struct { + ts *tsnet.Server + tag string + peers *peers +} + +func (a *authorization) refresh(ctx context.Context) error { + lc, err := a.ts.LocalClient() + if err != nil { + return err + } + tStatus, err := lc.Status(ctx) + if err != nil { + return err + } + a.peers = newPeers(tStatus) + return nil +} + +func (a *authorization) allowsHost(addr netip.Addr) bool { + return a.peers.peerExists(addr, a.tag) +} + +func (a *authorization) selfAllowed() bool { + return a.peers.status.Self.Tags != nil && slices.Contains(a.peers.status.Self.Tags.AsSlice(), a.tag) +} + +func (a *authorization) allowedPeers() []*ipnstate.PeerStatus { + if a.peers.allowedPeers == nil { + return []*ipnstate.PeerStatus{} + } + return a.peers.allowedPeers +} + +type peers struct { + status *ipnstate.Status + peerByIPAddressAndTag map[netip.Addr]map[string]*ipnstate.PeerStatus + allowedPeers []*ipnstate.PeerStatus +} + +func (ps *peers) peerExists(a netip.Addr, tag string) bool { + byTag, ok := ps.peerByIPAddressAndTag[a] + if !ok { + return false + } + _, ok = byTag[tag] + return ok +} + +func newPeers(status *ipnstate.Status) *peers { + ps := &peers{ + peerByIPAddressAndTag: map[netip.Addr]map[string]*ipnstate.PeerStatus{}, + status: status, + } + for _, p := range status.Peer { + for _, addr := range p.TailscaleIPs { + if ps.peerByIPAddressAndTag[addr] == nil { + ps.peerByIPAddressAndTag[addr] = map[string]*ipnstate.PeerStatus{} + } + if p.Tags != nil { + for _, tag := range p.Tags.AsSlice() { + ps.peerByIPAddressAndTag[addr][tag] = p + ps.allowedPeers = append(ps.allowedPeers, p) + } + } + } + } + return ps +} diff --git a/tsconsensus/http.go b/tsconsensus/http.go index 301127687..7570e2936 100644 --- a/tsconsensus/http.go +++ b/tsconsensus/http.go @@ -9,8 +9,6 @@ import ( "io" "net/http" "time" - - "tailscale.com/tsnet" ) type joinRequest struct { @@ -79,13 +77,19 @@ func (rac *commandClient) ExecuteCommand(host string, bs []byte) (CommandResult, return cr, nil } -func taggedOnly(ts *tsnet.Server, tag string, fx func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { +func authorized(auth *authorization, fx func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { - allowed, err := allowedPeer(r.RemoteAddr, tag, ts) + err := auth.refresh(r.Context()) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } + a, err := addrFromServerAddress(r.RemoteAddr) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + allowed := auth.allowsHost(a) if !allowed { http.Error(w, "peer not allowed", http.StatusBadRequest) return @@ -94,9 +98,9 @@ func taggedOnly(ts *tsnet.Server, tag string, fx func(http.ResponseWriter, *http } } -func (c *Consensus) makeCommandMux(ts *tsnet.Server, tag string) *http.ServeMux { +func (c *Consensus) makeCommandMux(auth *authorization) *http.ServeMux { mux := http.NewServeMux() - mux.HandleFunc("/join", taggedOnly(ts, tag, func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/join", authorized(auth, func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Bad Request", http.StatusBadRequest) return @@ -122,7 +126,7 @@ func (c *Consensus) makeCommandMux(ts *tsnet.Server, tag string) *http.ServeMux return } })) - mux.HandleFunc("/executeCommand", taggedOnly(ts, tag, func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/executeCommand", authorized(auth, func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "Bad Request", http.StatusBadRequest) return diff --git a/tsconsensus/tsconsensus.go b/tsconsensus/tsconsensus.go index c6fb4f30b..a5aa75635 100644 --- a/tsconsensus/tsconsensus.go +++ b/tsconsensus/tsconsensus.go @@ -9,7 +9,6 @@ import ( "net" "net/http" "net/netip" - "slices" "time" "github.com/hashicorp/raft" @@ -79,118 +78,67 @@ func DefaultConfig() Config { } } +func addrFromServerAddress(sa string) (netip.Addr, error) { + sAddr, _, err := net.SplitHostPort(sa) + if err != nil { + return netip.Addr{}, err + } + return netip.ParseAddr(sAddr) +} + // StreamLayer implements an interface asked for by raft.NetworkTransport. // It does the raft interprocess communication via tailscale. type StreamLayer struct { net.Listener - s *tsnet.Server - tag string + auth *authorization + s *tsnet.Server } // Dial implements the raft.StreamLayer interface with the tsnet.Server's Dial. func (sl StreamLayer) Dial(address raft.ServerAddress, timeout time.Duration) (net.Conn, error) { - allowed, err := allowedPeer(string(address), sl.tag, sl.s) + ctx, _ := context.WithTimeout(context.Background(), timeout) + err := sl.auth.refresh(ctx) if err != nil { return nil, err } - if !allowed { + + addr, err := addrFromServerAddress(string(address)) + if err != nil { + return nil, err + } + + if !sl.auth.allowsHost(addr) { return nil, errors.New("peer is not allowed") } - ctx, _ := context.WithTimeout(context.Background(), timeout) return sl.s.Dial(ctx, "tcp", string(address)) } -func allowedPeer(remoteAddr string, tag string, s *tsnet.Server) (bool, error) { - sAddr, _, err := net.SplitHostPort(remoteAddr) - if err != nil { - return false, err - } - a, err := netip.ParseAddr(sAddr) - if err != nil { - return false, err - } - ctx := context.Background() // TODO very much a sign I shouldn't be doing this here - peers, err := taggedNodesFromStatus(ctx, tag, s) - if err != nil { - return false, err - } - return peers.has(a), nil -} - func (sl StreamLayer) Accept() (net.Conn, error) { for { conn, err := sl.Listener.Accept() if err != nil || conn == nil { return conn, err } - allowed, err := allowedPeer(conn.RemoteAddr().String(), sl.tag, sl.s) + ctx := context.Background() // TODO + err = sl.auth.refresh(ctx) if err != nil { // TODO should we stay alive here? return nil, err } - if !allowed { + + addr, err := addrFromServerAddress(conn.RemoteAddr().String()) + if err != nil { + // TODO should we stay alive here? + return nil, err + } + + if !sl.auth.allowsHost(addr) { continue } return conn, err } } -type allowedPeers struct { - self *ipnstate.PeerStatus - peers []*ipnstate.PeerStatus - peerByIPAddress map[netip.Addr]*ipnstate.PeerStatus - clusterTag string -} - -func (ap *allowedPeers) allowed(n *ipnstate.PeerStatus) bool { - return n.Tags != nil && slices.Contains(n.Tags.AsSlice(), ap.clusterTag) -} - -func (ap *allowedPeers) addPeerIfAllowed(p *ipnstate.PeerStatus) { - if !ap.allowed(p) { - return - } - ap.peers = append(ap.peers, p) - for _, addr := range p.TailscaleIPs { - ap.peerByIPAddress[addr] = p - } -} - -func (ap *allowedPeers) addSelfIfAllowed(n *ipnstate.PeerStatus) { - if ap.allowed(n) { - ap.self = n - } -} - -func (ap *allowedPeers) has(a netip.Addr) bool { - _, ok := ap.peerByIPAddress[a] - return ok -} - -func taggedNodesFromStatus(ctx context.Context, clusterTag string, ts *tsnet.Server) (*allowedPeers, error) { - lc, err := ts.LocalClient() - if err != nil { - return nil, err - } - tStatus, err := lc.Status(ctx) - if err != nil { - return nil, err - } - ap := newAllowedPeers(clusterTag) - for _, v := range tStatus.Peer { - ap.addPeerIfAllowed(v) - } - ap.addSelfIfAllowed(tStatus.Self) - return ap, nil -} - -func newAllowedPeers(tag string) *allowedPeers { - return &allowedPeers{ - peerByIPAddress: map[netip.Addr]*ipnstate.PeerStatus{}, - clusterTag: tag, - } -} - // Start returns a pointer to a running Consensus instance. func Start(ctx context.Context, ts *tsnet.Server, fsm raft.FSM, clusterTag string, cfg Config) (*Consensus, error) { if clusterTag == "" { @@ -211,22 +159,30 @@ func Start(ctx context.Context, ts *tsnet.Server, fsm raft.FSM, clusterTag strin Config: cfg, } - tnfs, err := taggedNodesFromStatus(ctx, clusterTag, ts) - if tnfs.self == nil { + auth := &authorization{ + tag: clusterTag, + ts: ts, + } + err := auth.refresh(ctx) + if err != nil { + return nil, err + } + + if !auth.selfAllowed() { return nil, errors.New("this node is not tagged with the cluster tag") } - r, err := startRaft(ts, &fsm, c.Self, clusterTag, cfg) + r, err := startRaft(ts, &fsm, c.Self, auth, cfg) if err != nil { return nil, err } c.Raft = r - srv, err := c.serveCmdHttp(ts, clusterTag) + srv, err := c.serveCmdHttp(ts, auth) if err != nil { return nil, err } c.cmdHttpServer = srv - c.bootstrap(tnfs.peers) + c.bootstrap(auth.allowedPeers()) srv, err = serveMonitor(&c, ts, addr(c.Self.Host, cfg.MonitorPort)) if err != nil { return nil, err @@ -235,7 +191,7 @@ func Start(ctx context.Context, ts *tsnet.Server, fsm raft.FSM, clusterTag strin return &c, nil } -func startRaft(ts *tsnet.Server, fsm *raft.FSM, self SelfRaftNode, clusterTag string, cfg Config) (*raft.Raft, error) { +func startRaft(ts *tsnet.Server, fsm *raft.FSM, self SelfRaftNode, auth *authorization, cfg Config) (*raft.Raft, error) { config := cfg.Raft config.LocalID = raft.ServerID(self.ID) @@ -251,9 +207,9 @@ func startRaft(ts *tsnet.Server, fsm *raft.FSM, self SelfRaftNode, clusterTag st } transport := raft.NewNetworkTransport(StreamLayer{ - s: ts, Listener: ln, - tag: clusterTag, + auth: auth, + s: ts, }, cfg.MaxConnPool, cfg.ConnTimeout, @@ -386,12 +342,12 @@ func (e lookElsewhereError) Error() string { var ErrLeaderUnknown = errors.New("Leader Unknown") -func (c *Consensus) serveCmdHttp(ts *tsnet.Server, tag string) (*http.Server, error) { +func (c *Consensus) serveCmdHttp(ts *tsnet.Server, auth *authorization) (*http.Server, error) { ln, err := ts.Listen("tcp", c.commandAddr(c.Self.Host)) if err != nil { return nil, err } - mux := c.makeCommandMux(ts, tag) + mux := c.makeCommandMux(auth) srv := &http.Server{Handler: mux} go func() { err := srv.Serve(ln)