use a set for authorization, test it

This commit is contained in:
Fran Bull 2025-02-26 14:09:24 -08:00
parent f63ce0066d
commit 66ecab9540
4 changed files with 229 additions and 35 deletions

View File

@ -13,47 +13,81 @@ import (
"tailscale.com/ipn"
"tailscale.com/ipn/ipnstate"
"tailscale.com/tsnet"
"tailscale.com/util/set"
)
type statusGetter interface {
getStatus(context.Context) (*ipnstate.Status, error)
}
type tailscaleStatusGetter struct {
ts *tsnet.Server
}
func (sg tailscaleStatusGetter) getStatus(ctx context.Context) (*ipnstate.Status, error) {
lc, err := sg.ts.LocalClient()
if err != nil {
return nil, err
}
return lc.Status(ctx)
}
type authorization struct {
ts *tsnet.Server
sg statusGetter
tag string
mu sync.Mutex
peers *peers // protected by mu
}
func newAuthorization(ts *tsnet.Server, tag string) *authorization {
return &authorization{
sg: tailscaleStatusGetter{
ts: ts,
},
tag: tag,
}
}
func (a *authorization) refresh(ctx context.Context) error {
lc, err := a.ts.LocalClient()
tStatus, err := a.sg.getStatus(ctx)
if err != nil {
return err
}
tStatus, err := lc.Status(ctx)
if err != nil {
return err
if tStatus == nil {
return errors.New("no status")
}
if tStatus.BackendState != ipn.Running.String() {
return errors.New("ts Server is not running")
}
a.mu.Lock()
defer a.mu.Unlock()
a.peers = newPeers(tStatus)
a.peers = newPeers(tStatus, a.tag)
return nil
}
func (a *authorization) allowsHost(addr netip.Addr) bool {
if a.peers == nil {
return false
}
a.mu.Lock()
defer a.mu.Unlock()
return a.peers.peerExists(addr, a.tag)
}
func (a *authorization) selfAllowed() bool {
if a.peers == nil {
return false
}
a.mu.Lock()
defer a.mu.Unlock()
return a.peers.status.Self.Tags != nil && slices.Contains(a.peers.status.Self.Tags.AsSlice(), a.tag)
}
func (a *authorization) allowedPeers() []*ipnstate.PeerStatus {
if a.peers == nil {
return nil
}
a.mu.Lock()
defer a.mu.Unlock()
if a.peers.allowedPeers == nil {
@ -63,35 +97,27 @@ func (a *authorization) allowedPeers() []*ipnstate.PeerStatus {
}
type peers struct {
status *ipnstate.Status
peerByIPAddressAndTag map[netip.Addr]map[string]*ipnstate.PeerStatus
allowedPeers []*ipnstate.PeerStatus
status *ipnstate.Status
allowedRemoteAddrs set.Set[netip.Addr]
allowedPeers []*ipnstate.PeerStatus
}
func (ps *peers) peerExists(a netip.Addr, tag string) bool {
byTag, ok := ps.peerByIPAddressAndTag[a]
if !ok {
return false
}
_, ok = byTag[tag]
return ok
return ps.allowedRemoteAddrs.Contains(a)
}
func newPeers(status *ipnstate.Status) *peers {
func newPeers(status *ipnstate.Status, tag string) *peers {
ps := &peers{
peerByIPAddressAndTag: map[netip.Addr]map[string]*ipnstate.PeerStatus{},
status: status,
status: status,
allowedRemoteAddrs: set.Set[netip.Addr]{},
}
for _, p := range status.Peer {
for _, addr := range p.TailscaleIPs {
if ps.peerByIPAddressAndTag[addr] == nil {
ps.peerByIPAddressAndTag[addr] = map[string]*ipnstate.PeerStatus{}
}
if p.Tags != nil {
for _, tag := range p.Tags.AsSlice() {
ps.peerByIPAddressAndTag[addr][tag] = p
ps.allowedPeers = append(ps.allowedPeers, p)
}
if p.Tags != nil && p.Tags.ContainsFunc(func(s string) bool {
return s == tag
}) {
ps.allowedPeers = append(ps.allowedPeers, p)
for _, addr := range p.TailscaleIPs {
ps.allowedRemoteAddrs.Add(addr)
}
}
}

View File

@ -0,0 +1,174 @@
package tsconsensus
import (
"context"
"net/netip"
"testing"
"tailscale.com/ipn"
"tailscale.com/ipn/ipnstate"
"tailscale.com/types/key"
"tailscale.com/types/views"
)
type testStatusGetter struct {
status *ipnstate.Status
}
func (sg testStatusGetter) getStatus(ctx context.Context) (*ipnstate.Status, error) {
return sg.status, nil
}
const testTag string = "tag:clusterTag"
func authForStatus(s *ipnstate.Status) *authorization {
return &authorization{
sg: testStatusGetter{
status: s,
},
tag: testTag,
}
}
func addrsForIndex(i int) []netip.Addr {
return []netip.Addr{
netip.AddrFrom4([4]byte{100, 0, 0, byte(i)}),
netip.AddrFrom4([4]byte{100, 0, 1, byte(i)}),
}
}
func statusForTags(self []string, peers [][]string) *ipnstate.Status {
selfTags := views.SliceOf(self)
s := &ipnstate.Status{
BackendState: ipn.Running.String(),
Self: &ipnstate.PeerStatus{
Tags: &selfTags,
},
Peer: map[key.NodePublic]*ipnstate.PeerStatus{},
}
for i, tagStrings := range peers {
tags := views.SliceOf(tagStrings)
s.Peer[key.NewNode().Public()] = &ipnstate.PeerStatus{
Tags: &tags,
TailscaleIPs: addrsForIndex(i),
}
}
return s
}
func authForTags(self []string, peers [][]string) *authorization {
return authForStatus(statusForTags(self, peers))
}
func TestAuthRefreshErrorsNotRunning(t *testing.T) {
ctx := context.Background()
a := authForStatus(nil)
err := a.refresh(ctx)
if err == nil {
t.Fatalf("expected err to be non-nil")
}
expected := "no status"
if err.Error() != expected {
t.Fatalf("expected: %s, got: %s", expected, err.Error())
}
a = authForStatus(&ipnstate.Status{
BackendState: "NeedsMachineAuth",
})
err = a.refresh(ctx)
if err == nil {
t.Fatalf("expected err to be non-nil")
}
expected = "ts Server is not running"
if err.Error() != expected {
t.Fatalf("expected: %s, got: %s", expected, err.Error())
}
}
func TestAuthUnrefreshed(t *testing.T) {
a := authForStatus(nil)
if a.allowsHost(netip.MustParseAddr("100.0.0.1")) {
t.Fatalf("never refreshed authorization, allowsHost: expected false, got true")
}
gotAllowedPeers := a.allowedPeers()
if gotAllowedPeers != nil {
t.Fatalf("never refreshed authorization, allowedPeers: expected [], got %v", gotAllowedPeers)
}
if a.selfAllowed() != false {
t.Fatalf("never refreshed authorization, selfAllowed: expected false got true")
}
}
func TestAuthAllowsHost(t *testing.T) {
ctx := context.Background()
peerTags := [][]string{
[]string{"woo"},
nil,
[]string{"woo", testTag},
[]string{testTag},
}
expected := []bool{
false,
false,
true,
true,
}
a := authForTags(nil, peerTags)
err := a.refresh(ctx)
if err != nil {
t.Fatal(err)
}
for i, tags := range peerTags {
for _, addr := range addrsForIndex(i) {
got := a.allowsHost(addr)
if got != expected[i] {
t.Fatalf("allowed %v, expected: %t, got %t", tags, expected[i], got)
}
}
}
}
func TestAuthAllowedPeers(t *testing.T) {
ctx := context.Background()
a := authForTags(nil, [][]string{
[]string{"woo"},
nil,
[]string{"woo", testTag},
[]string{testTag},
})
err := a.refresh(ctx)
if err != nil {
t.Fatal(err)
}
ps := a.allowedPeers()
if len(ps) != 2 {
t.Fatalf("expected: 2, got: %d", len(ps))
}
}
func TestAuthSelfAllowed(t *testing.T) {
ctx := context.Background()
a := authForTags([]string{"woo"}, nil)
err := a.refresh(ctx)
if err != nil {
t.Fatal(err)
}
got := a.selfAllowed()
if got {
t.Fatalf("expected: false, got: %t", got)
}
a = authForTags([]string{"woo", testTag}, nil)
err = a.refresh(ctx)
if err != nil {
t.Fatal(err)
}
got = a.selfAllowed()
if !got {
t.Fatalf("expected: true, got: %t", got)
}
}

View File

@ -171,10 +171,7 @@ func Start(ctx context.Context, ts *tsnet.Server, fsm raft.FSM, clusterTag strin
config: cfg,
}
auth := &authorization{
tag: clusterTag,
ts: ts,
}
auth := newAuthorization(ts, clusterTag)
err := auth.refresh(ctx)
if err != nil {
return nil, fmt.Errorf("auth refresh: %w", err)

View File

@ -594,10 +594,7 @@ func TestOnlyTaggedPeersCanBeDialed(t *testing.T) {
// make a StreamLayer for ps[0]
ts := ps[0].ts
auth := &authorization{
tag: clusterTag,
ts: ts,
}
auth := newAuthorization(ts, clusterTag)
port := 19841
lns := make([]net.Listener, 3)