191 lines
4.3 KiB
Go
Raw Permalink Normal View History

package notifier
import (
"context"
"fmt"
"slices"
"strings"
"sync"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log"
)
type Notifier struct {
l sync.RWMutex
nodes map[types.NodeID]chan<- types.StateUpdate
connected types.NodeConnectedMap
}
func NewNotifier() *Notifier {
return &Notifier{
nodes: make(map[types.NodeID]chan<- types.StateUpdate),
connected: make(types.NodeConnectedMap),
}
}
func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) {
log.Trace().Caller().Uint64("node.id", nodeID.Uint64()).Msg("acquiring lock to add node")
defer log.Trace().
Caller().
Uint64("node.id", nodeID.Uint64()).
Msg("releasing lock to add node")
n.l.Lock()
defer n.l.Unlock()
n.nodes[nodeID] = c
n.connected[nodeID] = true
log.Trace().
Uint64("node.id", nodeID.Uint64()).
Int("open_chans", len(n.nodes)).
Msg("Added new channel")
}
func (n *Notifier) RemoveNode(nodeID types.NodeID) {
log.Trace().Caller().Uint64("node.id", nodeID.Uint64()).Msg("acquiring lock to remove node")
defer log.Trace().
Caller().
Uint64("node.id", nodeID.Uint64()).
Msg("releasing lock to remove node")
n.l.Lock()
defer n.l.Unlock()
if len(n.nodes) == 0 {
return
}
delete(n.nodes, nodeID)
n.connected[nodeID] = false
log.Trace().
Uint64("node.id", nodeID.Uint64()).
Int("open_chans", len(n.nodes)).
Msg("Removed channel")
}
// IsConnected reports if a node is connected to headscale and has a
// poll session open.
func (n *Notifier) IsConnected(nodeID types.NodeID) bool {
n.l.RLock()
defer n.l.RUnlock()
return n.connected[nodeID]
}
// IsLikelyConnected reports if a node is connected to headscale and has a
// poll session open, but doesnt lock, so might be wrong.
func (n *Notifier) IsLikelyConnected(nodeID types.NodeID) bool {
return n.connected[nodeID]
}
// TODO(kradalby): This returns a pointer and can be dangerous.
func (n *Notifier) ConnectedMap() types.NodeConnectedMap {
return n.connected
}
func (n *Notifier) NotifyAll(ctx context.Context, update types.StateUpdate) {
n.NotifyWithIgnore(ctx, update)
}
func (n *Notifier) NotifyWithIgnore(
ctx context.Context,
update types.StateUpdate,
ignoreNodeIDs ...types.NodeID,
) {
log.Trace().Caller().Str("type", update.Type.String()).Msg("acquiring lock to notify")
defer log.Trace().
Caller().
Str("type", update.Type.String()).
Msg("releasing lock, finished notifying")
n.l.RLock()
defer n.l.RUnlock()
if update.Type == types.StatePeerChangedPatch {
log.Trace().Interface("update", update).Interface("online", n.connected).Msg("PATCH UPDATE SENT")
}
for nodeID, c := range n.nodes {
if slices.Contains(ignoreNodeIDs, nodeID) {
continue
}
select {
case <-ctx.Done():
log.Error().
Err(ctx.Err()).
Uint64("node.id", nodeID.Uint64()).
Any("origin", ctx.Value("origin")).
Any("origin-hostname", ctx.Value("hostname")).
Msgf("update not sent, context cancelled")
return
case c <- update:
log.Trace().
Uint64("node.id", nodeID.Uint64()).
Any("origin", ctx.Value("origin")).
Any("origin-hostname", ctx.Value("hostname")).
Msgf("update successfully sent on chan")
}
}
}
func (n *Notifier) NotifyByMachineKey(
ctx context.Context,
update types.StateUpdate,
nodeID types.NodeID,
) {
log.Trace().Caller().Str("type", update.Type.String()).Msg("acquiring lock to notify")
defer log.Trace().
Caller().
Str("type", update.Type.String()).
Msg("releasing lock, finished notifying")
n.l.RLock()
defer n.l.RUnlock()
if c, ok := n.nodes[nodeID]; ok {
select {
case <-ctx.Done():
log.Error().
Err(ctx.Err()).
Uint64("node.id", nodeID.Uint64()).
Any("origin", ctx.Value("origin")).
Any("origin-hostname", ctx.Value("hostname")).
Msgf("update not sent, context cancelled")
return
case c <- update:
log.Trace().
Uint64("node.id", nodeID.Uint64()).
Any("origin", ctx.Value("origin")).
Any("origin-hostname", ctx.Value("hostname")).
Msgf("update successfully sent on chan")
}
}
}
func (n *Notifier) String() string {
n.l.RLock()
defer n.l.RUnlock()
var b strings.Builder
b.WriteString("chans:\n")
for k, v := range n.nodes {
fmt.Fprintf(&b, "\t%d: %p\n", k, v)
}
b.WriteString("\n")
b.WriteString("connected:\n")
for k, v := range n.connected {
fmt.Fprintf(&b, "\t%d: %t\n", k, v)
}
return b.String()
}