allowedPeers -> auth obj

This commit is contained in:
Fran Bull 2025-02-20 10:50:25 -08:00
parent 95131102df
commit 6d78f27d73
3 changed files with 138 additions and 98 deletions

View File

@ -0,0 +1,80 @@
package tsconsensus
import (
"context"
"net/netip"
"slices"
"tailscale.com/ipn/ipnstate"
"tailscale.com/tsnet"
)
type authorization struct {
ts *tsnet.Server
tag string
peers *peers
}
func (a *authorization) refresh(ctx context.Context) error {
lc, err := a.ts.LocalClient()
if err != nil {
return err
}
tStatus, err := lc.Status(ctx)
if err != nil {
return err
}
a.peers = newPeers(tStatus)
return nil
}
func (a *authorization) allowsHost(addr netip.Addr) bool {
return a.peers.peerExists(addr, a.tag)
}
func (a *authorization) selfAllowed() bool {
return a.peers.status.Self.Tags != nil && slices.Contains(a.peers.status.Self.Tags.AsSlice(), a.tag)
}
func (a *authorization) allowedPeers() []*ipnstate.PeerStatus {
if a.peers.allowedPeers == nil {
return []*ipnstate.PeerStatus{}
}
return a.peers.allowedPeers
}
type peers struct {
status *ipnstate.Status
peerByIPAddressAndTag map[netip.Addr]map[string]*ipnstate.PeerStatus
allowedPeers []*ipnstate.PeerStatus
}
func (ps *peers) peerExists(a netip.Addr, tag string) bool {
byTag, ok := ps.peerByIPAddressAndTag[a]
if !ok {
return false
}
_, ok = byTag[tag]
return ok
}
func newPeers(status *ipnstate.Status) *peers {
ps := &peers{
peerByIPAddressAndTag: map[netip.Addr]map[string]*ipnstate.PeerStatus{},
status: status,
}
for _, p := range status.Peer {
for _, addr := range p.TailscaleIPs {
if ps.peerByIPAddressAndTag[addr] == nil {
ps.peerByIPAddressAndTag[addr] = map[string]*ipnstate.PeerStatus{}
}
if p.Tags != nil {
for _, tag := range p.Tags.AsSlice() {
ps.peerByIPAddressAndTag[addr][tag] = p
ps.allowedPeers = append(ps.allowedPeers, p)
}
}
}
}
return ps
}

View File

@ -9,8 +9,6 @@ import (
"io" "io"
"net/http" "net/http"
"time" "time"
"tailscale.com/tsnet"
) )
type joinRequest struct { type joinRequest struct {
@ -79,13 +77,19 @@ func (rac *commandClient) ExecuteCommand(host string, bs []byte) (CommandResult,
return cr, nil return cr, nil
} }
func taggedOnly(ts *tsnet.Server, tag string, fx func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) { func authorized(auth *authorization, fx func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
allowed, err := allowedPeer(r.RemoteAddr, tag, ts) err := auth.refresh(r.Context())
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
} }
a, err := addrFromServerAddress(r.RemoteAddr)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
allowed := auth.allowsHost(a)
if !allowed { if !allowed {
http.Error(w, "peer not allowed", http.StatusBadRequest) http.Error(w, "peer not allowed", http.StatusBadRequest)
return return
@ -94,9 +98,9 @@ func taggedOnly(ts *tsnet.Server, tag string, fx func(http.ResponseWriter, *http
} }
} }
func (c *Consensus) makeCommandMux(ts *tsnet.Server, tag string) *http.ServeMux { func (c *Consensus) makeCommandMux(auth *authorization) *http.ServeMux {
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("/join", taggedOnly(ts, tag, func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/join", authorized(auth, func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
http.Error(w, "Bad Request", http.StatusBadRequest) http.Error(w, "Bad Request", http.StatusBadRequest)
return return
@ -122,7 +126,7 @@ func (c *Consensus) makeCommandMux(ts *tsnet.Server, tag string) *http.ServeMux
return return
} }
})) }))
mux.HandleFunc("/executeCommand", taggedOnly(ts, tag, func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/executeCommand", authorized(auth, func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
http.Error(w, "Bad Request", http.StatusBadRequest) http.Error(w, "Bad Request", http.StatusBadRequest)
return return

View File

@ -9,7 +9,6 @@ import (
"net" "net"
"net/http" "net/http"
"net/netip" "net/netip"
"slices"
"time" "time"
"github.com/hashicorp/raft" "github.com/hashicorp/raft"
@ -79,118 +78,67 @@ func DefaultConfig() Config {
} }
} }
func addrFromServerAddress(sa string) (netip.Addr, error) {
sAddr, _, err := net.SplitHostPort(sa)
if err != nil {
return netip.Addr{}, err
}
return netip.ParseAddr(sAddr)
}
// StreamLayer implements an interface asked for by raft.NetworkTransport. // StreamLayer implements an interface asked for by raft.NetworkTransport.
// It does the raft interprocess communication via tailscale. // It does the raft interprocess communication via tailscale.
type StreamLayer struct { type StreamLayer struct {
net.Listener net.Listener
s *tsnet.Server auth *authorization
tag string s *tsnet.Server
} }
// Dial implements the raft.StreamLayer interface with the tsnet.Server's Dial. // Dial implements the raft.StreamLayer interface with the tsnet.Server's Dial.
func (sl StreamLayer) Dial(address raft.ServerAddress, timeout time.Duration) (net.Conn, error) { func (sl StreamLayer) Dial(address raft.ServerAddress, timeout time.Duration) (net.Conn, error) {
allowed, err := allowedPeer(string(address), sl.tag, sl.s) ctx, _ := context.WithTimeout(context.Background(), timeout)
err := sl.auth.refresh(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !allowed {
addr, err := addrFromServerAddress(string(address))
if err != nil {
return nil, err
}
if !sl.auth.allowsHost(addr) {
return nil, errors.New("peer is not allowed") return nil, errors.New("peer is not allowed")
} }
ctx, _ := context.WithTimeout(context.Background(), timeout)
return sl.s.Dial(ctx, "tcp", string(address)) return sl.s.Dial(ctx, "tcp", string(address))
} }
func allowedPeer(remoteAddr string, tag string, s *tsnet.Server) (bool, error) {
sAddr, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return false, err
}
a, err := netip.ParseAddr(sAddr)
if err != nil {
return false, err
}
ctx := context.Background() // TODO very much a sign I shouldn't be doing this here
peers, err := taggedNodesFromStatus(ctx, tag, s)
if err != nil {
return false, err
}
return peers.has(a), nil
}
func (sl StreamLayer) Accept() (net.Conn, error) { func (sl StreamLayer) Accept() (net.Conn, error) {
for { for {
conn, err := sl.Listener.Accept() conn, err := sl.Listener.Accept()
if err != nil || conn == nil { if err != nil || conn == nil {
return conn, err return conn, err
} }
allowed, err := allowedPeer(conn.RemoteAddr().String(), sl.tag, sl.s) ctx := context.Background() // TODO
err = sl.auth.refresh(ctx)
if err != nil { if err != nil {
// TODO should we stay alive here? // TODO should we stay alive here?
return nil, err return nil, err
} }
if !allowed {
addr, err := addrFromServerAddress(conn.RemoteAddr().String())
if err != nil {
// TODO should we stay alive here?
return nil, err
}
if !sl.auth.allowsHost(addr) {
continue continue
} }
return conn, err return conn, err
} }
} }
type allowedPeers struct {
self *ipnstate.PeerStatus
peers []*ipnstate.PeerStatus
peerByIPAddress map[netip.Addr]*ipnstate.PeerStatus
clusterTag string
}
func (ap *allowedPeers) allowed(n *ipnstate.PeerStatus) bool {
return n.Tags != nil && slices.Contains(n.Tags.AsSlice(), ap.clusterTag)
}
func (ap *allowedPeers) addPeerIfAllowed(p *ipnstate.PeerStatus) {
if !ap.allowed(p) {
return
}
ap.peers = append(ap.peers, p)
for _, addr := range p.TailscaleIPs {
ap.peerByIPAddress[addr] = p
}
}
func (ap *allowedPeers) addSelfIfAllowed(n *ipnstate.PeerStatus) {
if ap.allowed(n) {
ap.self = n
}
}
func (ap *allowedPeers) has(a netip.Addr) bool {
_, ok := ap.peerByIPAddress[a]
return ok
}
func taggedNodesFromStatus(ctx context.Context, clusterTag string, ts *tsnet.Server) (*allowedPeers, error) {
lc, err := ts.LocalClient()
if err != nil {
return nil, err
}
tStatus, err := lc.Status(ctx)
if err != nil {
return nil, err
}
ap := newAllowedPeers(clusterTag)
for _, v := range tStatus.Peer {
ap.addPeerIfAllowed(v)
}
ap.addSelfIfAllowed(tStatus.Self)
return ap, nil
}
func newAllowedPeers(tag string) *allowedPeers {
return &allowedPeers{
peerByIPAddress: map[netip.Addr]*ipnstate.PeerStatus{},
clusterTag: tag,
}
}
// Start returns a pointer to a running Consensus instance. // Start returns a pointer to a running Consensus instance.
func Start(ctx context.Context, ts *tsnet.Server, fsm raft.FSM, clusterTag string, cfg Config) (*Consensus, error) { func Start(ctx context.Context, ts *tsnet.Server, fsm raft.FSM, clusterTag string, cfg Config) (*Consensus, error) {
if clusterTag == "" { if clusterTag == "" {
@ -211,22 +159,30 @@ func Start(ctx context.Context, ts *tsnet.Server, fsm raft.FSM, clusterTag strin
Config: cfg, Config: cfg,
} }
tnfs, err := taggedNodesFromStatus(ctx, clusterTag, ts) auth := &authorization{
if tnfs.self == nil { tag: clusterTag,
ts: ts,
}
err := auth.refresh(ctx)
if err != nil {
return nil, err
}
if !auth.selfAllowed() {
return nil, errors.New("this node is not tagged with the cluster tag") return nil, errors.New("this node is not tagged with the cluster tag")
} }
r, err := startRaft(ts, &fsm, c.Self, clusterTag, cfg) r, err := startRaft(ts, &fsm, c.Self, auth, cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
c.Raft = r c.Raft = r
srv, err := c.serveCmdHttp(ts, clusterTag) srv, err := c.serveCmdHttp(ts, auth)
if err != nil { if err != nil {
return nil, err return nil, err
} }
c.cmdHttpServer = srv c.cmdHttpServer = srv
c.bootstrap(tnfs.peers) c.bootstrap(auth.allowedPeers())
srv, err = serveMonitor(&c, ts, addr(c.Self.Host, cfg.MonitorPort)) srv, err = serveMonitor(&c, ts, addr(c.Self.Host, cfg.MonitorPort))
if err != nil { if err != nil {
return nil, err return nil, err
@ -235,7 +191,7 @@ func Start(ctx context.Context, ts *tsnet.Server, fsm raft.FSM, clusterTag strin
return &c, nil return &c, nil
} }
func startRaft(ts *tsnet.Server, fsm *raft.FSM, self SelfRaftNode, clusterTag string, cfg Config) (*raft.Raft, error) { func startRaft(ts *tsnet.Server, fsm *raft.FSM, self SelfRaftNode, auth *authorization, cfg Config) (*raft.Raft, error) {
config := cfg.Raft config := cfg.Raft
config.LocalID = raft.ServerID(self.ID) config.LocalID = raft.ServerID(self.ID)
@ -251,9 +207,9 @@ func startRaft(ts *tsnet.Server, fsm *raft.FSM, self SelfRaftNode, clusterTag st
} }
transport := raft.NewNetworkTransport(StreamLayer{ transport := raft.NewNetworkTransport(StreamLayer{
s: ts,
Listener: ln, Listener: ln,
tag: clusterTag, auth: auth,
s: ts,
}, },
cfg.MaxConnPool, cfg.MaxConnPool,
cfg.ConnTimeout, cfg.ConnTimeout,
@ -386,12 +342,12 @@ func (e lookElsewhereError) Error() string {
var ErrLeaderUnknown = errors.New("Leader Unknown") var ErrLeaderUnknown = errors.New("Leader Unknown")
func (c *Consensus) serveCmdHttp(ts *tsnet.Server, tag string) (*http.Server, error) { func (c *Consensus) serveCmdHttp(ts *tsnet.Server, auth *authorization) (*http.Server, error) {
ln, err := ts.Listen("tcp", c.commandAddr(c.Self.Host)) ln, err := ts.Listen("tcp", c.commandAddr(c.Self.Host))
if err != nil { if err != nil {
return nil, err return nil, err
} }
mux := c.makeCommandMux(ts, tag) mux := c.makeCommandMux(auth)
srv := &http.Server{Handler: mux} srv := &http.Server{Handler: mux}
go func() { go func() {
err := srv.Serve(ln) err := srv.Serve(ln)