2023-06-21 11:29:52 +02:00
|
|
|
package notifier
|
|
|
|
|
|
|
|
import (
|
2024-02-08 17:28:19 +01:00
|
|
|
"context"
|
2023-12-09 18:09:24 +01:00
|
|
|
"fmt"
|
2024-02-23 10:59:24 +01:00
|
|
|
"slices"
|
2023-12-09 18:09:24 +01:00
|
|
|
"strings"
|
2023-06-21 11:29:52 +02:00
|
|
|
"sync"
|
|
|
|
|
2023-06-29 11:20:22 +01:00
|
|
|
"github.com/juanfont/headscale/hscontrol/types"
|
2023-07-24 08:58:51 +02:00
|
|
|
"github.com/rs/zerolog/log"
|
2023-06-21 11:29:52 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
type Notifier struct {
|
2024-02-08 17:28:19 +01:00
|
|
|
l sync.RWMutex
|
2024-02-23 10:59:24 +01:00
|
|
|
nodes map[types.NodeID]chan<- types.StateUpdate
|
|
|
|
connected types.NodeConnectedMap
|
2023-06-21 11:29:52 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
func NewNotifier() *Notifier {
|
2024-02-08 17:28:19 +01:00
|
|
|
return &Notifier{
|
2024-02-23 10:59:24 +01:00
|
|
|
nodes: make(map[types.NodeID]chan<- types.StateUpdate),
|
|
|
|
connected: make(types.NodeConnectedMap),
|
2024-02-08 17:28:19 +01:00
|
|
|
}
|
2023-06-21 11:29:52 +02:00
|
|
|
}
|
|
|
|
|
2024-02-23 10:59:24 +01:00
|
|
|
func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) {
|
|
|
|
log.Trace().Caller().Uint64("node.id", nodeID.Uint64()).Msg("acquiring lock to add node")
|
2024-02-08 17:28:19 +01:00
|
|
|
defer log.Trace().
|
|
|
|
Caller().
|
2024-02-23 10:59:24 +01:00
|
|
|
Uint64("node.id", nodeID.Uint64()).
|
2024-02-08 17:28:19 +01:00
|
|
|
Msg("releasing lock to add node")
|
2023-09-11 06:08:44 -05:00
|
|
|
|
2023-06-21 11:29:52 +02:00
|
|
|
n.l.Lock()
|
|
|
|
defer n.l.Unlock()
|
|
|
|
|
2024-02-23 10:59:24 +01:00
|
|
|
n.nodes[nodeID] = c
|
|
|
|
n.connected[nodeID] = true
|
2023-07-24 08:58:51 +02:00
|
|
|
|
|
|
|
log.Trace().
|
2024-02-23 10:59:24 +01:00
|
|
|
Uint64("node.id", nodeID.Uint64()).
|
2023-07-24 08:58:51 +02:00
|
|
|
Int("open_chans", len(n.nodes)).
|
|
|
|
Msg("Added new channel")
|
2023-06-21 11:29:52 +02:00
|
|
|
}
|
|
|
|
|
2024-02-23 10:59:24 +01:00
|
|
|
func (n *Notifier) RemoveNode(nodeID types.NodeID) {
|
|
|
|
log.Trace().Caller().Uint64("node.id", nodeID.Uint64()).Msg("acquiring lock to remove node")
|
2024-02-08 17:28:19 +01:00
|
|
|
defer log.Trace().
|
|
|
|
Caller().
|
2024-02-23 10:59:24 +01:00
|
|
|
Uint64("node.id", nodeID.Uint64()).
|
2024-02-08 17:28:19 +01:00
|
|
|
Msg("releasing lock to remove node")
|
2023-09-11 06:08:44 -05:00
|
|
|
|
2023-06-21 11:29:52 +02:00
|
|
|
n.l.Lock()
|
|
|
|
defer n.l.Unlock()
|
|
|
|
|
2024-02-08 17:28:19 +01:00
|
|
|
if len(n.nodes) == 0 {
|
2023-06-21 11:29:52 +02:00
|
|
|
return
|
|
|
|
}
|
|
|
|
|
2024-02-23 10:59:24 +01:00
|
|
|
delete(n.nodes, nodeID)
|
|
|
|
n.connected[nodeID] = false
|
2023-07-24 08:58:51 +02:00
|
|
|
|
|
|
|
log.Trace().
|
2024-02-23 10:59:24 +01:00
|
|
|
Uint64("node.id", nodeID.Uint64()).
|
2023-07-24 08:58:51 +02:00
|
|
|
Int("open_chans", len(n.nodes)).
|
|
|
|
Msg("Removed channel")
|
2023-06-21 11:29:52 +02:00
|
|
|
}
|
|
|
|
|
2023-12-09 18:09:24 +01:00
|
|
|
// IsConnected reports if a node is connected to headscale and has a
|
|
|
|
// poll session open.
|
2024-02-23 10:59:24 +01:00
|
|
|
func (n *Notifier) IsConnected(nodeID types.NodeID) bool {
|
2023-12-09 18:09:24 +01:00
|
|
|
n.l.RLock()
|
|
|
|
defer n.l.RUnlock()
|
|
|
|
|
2024-02-23 10:59:24 +01:00
|
|
|
return n.connected[nodeID]
|
|
|
|
}
|
|
|
|
|
|
|
|
// IsLikelyConnected reports if a node is connected to headscale and has a
|
|
|
|
// poll session open, but doesnt lock, so might be wrong.
|
|
|
|
func (n *Notifier) IsLikelyConnected(nodeID types.NodeID) bool {
|
|
|
|
return n.connected[nodeID]
|
2024-02-08 17:28:19 +01:00
|
|
|
}
|
2023-12-09 18:09:24 +01:00
|
|
|
|
2024-02-08 17:28:19 +01:00
|
|
|
// TODO(kradalby): This returns a pointer and can be dangerous.
|
2024-02-23 10:59:24 +01:00
|
|
|
func (n *Notifier) ConnectedMap() types.NodeConnectedMap {
|
2024-02-08 17:28:19 +01:00
|
|
|
return n.connected
|
2023-12-09 18:09:24 +01:00
|
|
|
}
|
|
|
|
|
2024-02-08 17:28:19 +01:00
|
|
|
func (n *Notifier) NotifyAll(ctx context.Context, update types.StateUpdate) {
|
|
|
|
n.NotifyWithIgnore(ctx, update)
|
2023-06-21 11:29:52 +02:00
|
|
|
}
|
|
|
|
|
2024-02-08 17:28:19 +01:00
|
|
|
func (n *Notifier) NotifyWithIgnore(
|
|
|
|
ctx context.Context,
|
|
|
|
update types.StateUpdate,
|
2024-02-23 10:59:24 +01:00
|
|
|
ignoreNodeIDs ...types.NodeID,
|
2024-02-08 17:28:19 +01:00
|
|
|
) {
|
2024-02-23 10:59:24 +01:00
|
|
|
log.Trace().Caller().Str("type", update.Type.String()).Msg("acquiring lock to notify")
|
2023-09-11 06:08:44 -05:00
|
|
|
defer log.Trace().
|
|
|
|
Caller().
|
2024-02-23 10:59:24 +01:00
|
|
|
Str("type", update.Type.String()).
|
2024-02-08 17:28:19 +01:00
|
|
|
Msg("releasing lock, finished notifying")
|
2023-09-11 06:08:44 -05:00
|
|
|
|
|
|
|
n.l.RLock()
|
|
|
|
defer n.l.RUnlock()
|
2023-06-21 11:29:52 +02:00
|
|
|
|
2024-02-23 10:59:24 +01:00
|
|
|
if update.Type == types.StatePeerChangedPatch {
|
|
|
|
log.Trace().Interface("update", update).Interface("online", n.connected).Msg("PATCH UPDATE SENT")
|
|
|
|
}
|
|
|
|
|
|
|
|
for nodeID, c := range n.nodes {
|
|
|
|
if slices.Contains(ignoreNodeIDs, nodeID) {
|
2023-06-21 11:29:52 +02:00
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
2024-02-08 17:28:19 +01:00
|
|
|
select {
|
|
|
|
case <-ctx.Done():
|
|
|
|
log.Error().
|
|
|
|
Err(ctx.Err()).
|
2024-02-23 10:59:24 +01:00
|
|
|
Uint64("node.id", nodeID.Uint64()).
|
2024-02-08 17:28:19 +01:00
|
|
|
Any("origin", ctx.Value("origin")).
|
2024-02-23 10:59:24 +01:00
|
|
|
Any("origin-hostname", ctx.Value("hostname")).
|
2024-02-08 17:28:19 +01:00
|
|
|
Msgf("update not sent, context cancelled")
|
|
|
|
|
|
|
|
return
|
|
|
|
case c <- update:
|
|
|
|
log.Trace().
|
2024-02-23 10:59:24 +01:00
|
|
|
Uint64("node.id", nodeID.Uint64()).
|
2024-02-08 17:28:19 +01:00
|
|
|
Any("origin", ctx.Value("origin")).
|
2024-02-23 10:59:24 +01:00
|
|
|
Any("origin-hostname", ctx.Value("hostname")).
|
2024-02-08 17:28:19 +01:00
|
|
|
Msgf("update successfully sent on chan")
|
|
|
|
}
|
2023-06-21 11:29:52 +02:00
|
|
|
}
|
|
|
|
}
|
2023-12-09 18:09:24 +01:00
|
|
|
|
2024-02-08 17:28:19 +01:00
|
|
|
func (n *Notifier) NotifyByMachineKey(
|
|
|
|
ctx context.Context,
|
|
|
|
update types.StateUpdate,
|
2024-02-23 10:59:24 +01:00
|
|
|
nodeID types.NodeID,
|
2024-02-08 17:28:19 +01:00
|
|
|
) {
|
2024-02-23 10:59:24 +01:00
|
|
|
log.Trace().Caller().Str("type", update.Type.String()).Msg("acquiring lock to notify")
|
2024-01-05 10:41:56 +01:00
|
|
|
defer log.Trace().
|
|
|
|
Caller().
|
2024-02-23 10:59:24 +01:00
|
|
|
Str("type", update.Type.String()).
|
2024-02-08 17:28:19 +01:00
|
|
|
Msg("releasing lock, finished notifying")
|
2024-01-05 10:41:56 +01:00
|
|
|
|
|
|
|
n.l.RLock()
|
|
|
|
defer n.l.RUnlock()
|
|
|
|
|
2024-02-23 10:59:24 +01:00
|
|
|
if c, ok := n.nodes[nodeID]; ok {
|
2024-02-08 17:28:19 +01:00
|
|
|
select {
|
|
|
|
case <-ctx.Done():
|
|
|
|
log.Error().
|
|
|
|
Err(ctx.Err()).
|
2024-02-23 10:59:24 +01:00
|
|
|
Uint64("node.id", nodeID.Uint64()).
|
2024-02-08 17:28:19 +01:00
|
|
|
Any("origin", ctx.Value("origin")).
|
2024-02-23 10:59:24 +01:00
|
|
|
Any("origin-hostname", ctx.Value("hostname")).
|
2024-02-08 17:28:19 +01:00
|
|
|
Msgf("update not sent, context cancelled")
|
|
|
|
|
|
|
|
return
|
|
|
|
case c <- update:
|
|
|
|
log.Trace().
|
2024-02-23 10:59:24 +01:00
|
|
|
Uint64("node.id", nodeID.Uint64()).
|
2024-02-08 17:28:19 +01:00
|
|
|
Any("origin", ctx.Value("origin")).
|
2024-02-23 10:59:24 +01:00
|
|
|
Any("origin-hostname", ctx.Value("hostname")).
|
2024-02-08 17:28:19 +01:00
|
|
|
Msgf("update successfully sent on chan")
|
|
|
|
}
|
2024-01-05 10:41:56 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-12-09 18:09:24 +01:00
|
|
|
func (n *Notifier) String() string {
|
|
|
|
n.l.RLock()
|
|
|
|
defer n.l.RUnlock()
|
|
|
|
|
2024-04-10 15:35:09 +02:00
|
|
|
var b strings.Builder
|
|
|
|
b.WriteString("chans:\n")
|
2023-12-09 18:09:24 +01:00
|
|
|
|
|
|
|
for k, v := range n.nodes {
|
2024-04-10 15:35:09 +02:00
|
|
|
fmt.Fprintf(&b, "\t%d: %p\n", k, v)
|
2023-12-09 18:09:24 +01:00
|
|
|
}
|
|
|
|
|
2024-04-10 15:35:09 +02:00
|
|
|
b.WriteString("\n")
|
|
|
|
b.WriteString("connected:\n")
|
|
|
|
|
|
|
|
for k, v := range n.connected {
|
|
|
|
fmt.Fprintf(&b, "\t%d: %t\n", k, v)
|
|
|
|
}
|
|
|
|
|
|
|
|
return b.String()
|
2023-12-09 18:09:24 +01:00
|
|
|
}
|