Refactored app code with Node

This commit is contained in:
Juan Font 2023-05-01 14:52:03 +00:00
parent 89fffeab31
commit 83b4389090
17 changed files with 416 additions and 416 deletions

90
acls.go
View File

@ -119,7 +119,7 @@ func (h *Headscale) LoadACLPolicy(path string) error {
} }
func (h *Headscale) UpdateACLRules() error { func (h *Headscale) UpdateACLRules() error {
machines, err := h.ListMachines() nodes, err := h.ListNodes()
if err != nil { if err != nil {
return err return err
} }
@ -128,7 +128,7 @@ func (h *Headscale) UpdateACLRules() error {
return errEmptyPolicy return errEmptyPolicy
} }
rules, err := generateACLRules(machines, *h.aclPolicy, h.cfg.OIDC.StripEmaildomain) rules, err := generateACLRules(nodes, *h.aclPolicy, h.cfg.OIDC.StripEmaildomain)
if err != nil { if err != nil {
return err return err
} }
@ -225,7 +225,7 @@ func expandACLPeerAddr(srcIP string) []string {
} }
func generateACLRules( func generateACLRules(
machines []Machine, nodes []Node,
aclPolicy ACLPolicy, aclPolicy ACLPolicy,
stripEmaildomain bool, stripEmaildomain bool,
) ([]tailcfg.FilterRule, error) { ) ([]tailcfg.FilterRule, error) {
@ -238,7 +238,7 @@ func generateACLRules(
srcIPs := []string{} srcIPs := []string{}
for innerIndex, src := range acl.Sources { for innerIndex, src := range acl.Sources {
srcs, err := generateACLPolicySrc(machines, aclPolicy, src, stripEmaildomain) srcs, err := generateACLPolicySrc(nodes, aclPolicy, src, stripEmaildomain)
if err != nil { if err != nil {
log.Error(). log.Error().
Msgf("Error parsing ACL %d, Source %d", index, innerIndex) Msgf("Error parsing ACL %d, Source %d", index, innerIndex)
@ -259,7 +259,7 @@ func generateACLRules(
destPorts := []tailcfg.NetPortRange{} destPorts := []tailcfg.NetPortRange{}
for innerIndex, dest := range acl.Destinations { for innerIndex, dest := range acl.Destinations {
dests, err := generateACLPolicyDest( dests, err := generateACLPolicyDest(
machines, nodes,
aclPolicy, aclPolicy,
dest, dest,
needsWildcard, needsWildcard,
@ -291,7 +291,7 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) {
return nil, errEmptyPolicy return nil, errEmptyPolicy
} }
machines, err := h.ListMachines() nodes, err := h.ListNodes()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -339,7 +339,7 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) {
principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources)) principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources))
for innerIndex, rawSrc := range sshACL.Sources { for innerIndex, rawSrc := range sshACL.Sources {
expandedSrcs, err := expandAlias( expandedSrcs, err := expandAlias(
machines, nodes,
*h.aclPolicy, *h.aclPolicy,
rawSrc, rawSrc,
h.cfg.OIDC.StripEmaildomain, h.cfg.OIDC.StripEmaildomain,
@ -390,16 +390,16 @@ func sshCheckAction(duration string) (*tailcfg.SSHAction, error) {
} }
func generateACLPolicySrc( func generateACLPolicySrc(
machines []Machine, nodes []Node,
aclPolicy ACLPolicy, aclPolicy ACLPolicy,
src string, src string,
stripEmaildomain bool, stripEmaildomain bool,
) ([]string, error) { ) ([]string, error) {
return expandAlias(machines, aclPolicy, src, stripEmaildomain) return expandAlias(nodes, aclPolicy, src, stripEmaildomain)
} }
func generateACLPolicyDest( func generateACLPolicyDest(
machines []Machine, nodes []Node,
aclPolicy ACLPolicy, aclPolicy ACLPolicy,
dest string, dest string,
needsWildcard bool, needsWildcard bool,
@ -449,7 +449,7 @@ func generateACLPolicyDest(
} }
expanded, err := expandAlias( expanded, err := expandAlias(
machines, nodes,
aclPolicy, aclPolicy,
alias, alias,
stripEmaildomain, stripEmaildomain,
@ -535,7 +535,7 @@ func parseProtocol(protocol string) ([]int, bool, error) {
// - a cidr // - a cidr
// and transform these in IPAddresses. // and transform these in IPAddresses.
func expandAlias( func expandAlias(
machines Machines, nodes Nodes,
aclPolicy ACLPolicy, aclPolicy ACLPolicy,
alias string, alias string,
stripEmailDomain bool, stripEmailDomain bool,
@ -555,7 +555,7 @@ func expandAlias(
return ips, err return ips, err
} }
for _, n := range users { for _, n := range users {
nodes := filterMachinesByUser(machines, n) nodes := filterNodesByUser(nodes, n)
for _, node := range nodes { for _, node := range nodes {
ips = append(ips, node.IPAddresses.ToStringSlice()...) ips = append(ips, node.IPAddresses.ToStringSlice()...)
} }
@ -566,9 +566,9 @@ func expandAlias(
if strings.HasPrefix(alias, "tag:") { if strings.HasPrefix(alias, "tag:") {
// check for forced tags // check for forced tags
for _, machine := range machines { for _, node := range nodes {
if contains(machine.ForcedTags, alias) { if contains(node.ForcedTags, alias) {
ips = append(ips, machine.IPAddresses.ToStringSlice()...) ips = append(ips, node.IPAddresses.ToStringSlice()...)
} }
} }
@ -590,13 +590,13 @@ func expandAlias(
} }
} }
// filter out machines per tag owner // filter out nodes per tag owner
for _, user := range owners { for _, user := range owners {
machines := filterMachinesByUser(machines, user) nodes := filterNodesByUser(nodes, user)
for _, machine := range machines { for _, node := range nodes {
hi := machine.GetHostInfo() hi := node.GetHostInfo()
if contains(hi.RequestTags, alias) { if contains(hi.RequestTags, alias) {
ips = append(ips, machine.IPAddresses.ToStringSlice()...) ips = append(ips, node.IPAddresses.ToStringSlice()...)
} }
} }
} }
@ -605,10 +605,10 @@ func expandAlias(
} }
// if alias is a user // if alias is a user
nodes := filterMachinesByUser(machines, alias) filteredNodes := filterNodesByUser(nodes, alias)
nodes = excludeCorrectlyTaggedNodes(aclPolicy, nodes, alias, stripEmailDomain) filteredNodes = excludeCorrectlyTaggedNodes(aclPolicy, filteredNodes, alias, stripEmailDomain)
for _, n := range nodes { for _, n := range filteredNodes {
ips = append(ips, n.IPAddresses.ToStringSlice()...) ips = append(ips, n.IPAddresses.ToStringSlice()...)
} }
if len(ips) > 0 { if len(ips) > 0 {
@ -619,17 +619,17 @@ func expandAlias(
if h, ok := aclPolicy.Hosts[alias]; ok { if h, ok := aclPolicy.Hosts[alias]; ok {
log.Trace().Str("host", h.String()).Msg("expandAlias got hosts entry") log.Trace().Str("host", h.String()).Msg("expandAlias got hosts entry")
return expandAlias(machines, aclPolicy, h.String(), stripEmailDomain) return expandAlias(filteredNodes, aclPolicy, h.String(), stripEmailDomain)
} }
// if alias is an IP // if alias is an IP
if ip, err := netip.ParseAddr(alias); err == nil { if ip, err := netip.ParseAddr(alias); err == nil {
log.Trace().Str("ip", ip.String()).Msg("expandAlias got ip") log.Trace().Str("ip", ip.String()).Msg("expandAlias got ip")
ips := []string{ip.String()} ips := []string{ip.String()}
matches := machines.FilterByIP(ip) matches := nodes.FilterByIP(ip)
for _, machine := range matches { for _, node := range matches {
ips = append(ips, machine.IPAddresses.ToStringSlice()...) ips = append(ips, node.IPAddresses.ToStringSlice()...)
} }
return lo.Uniq(ips), nil return lo.Uniq(ips), nil
@ -640,12 +640,12 @@ func expandAlias(
val := []string{cidr.String()} val := []string{cidr.String()}
// This is suboptimal and quite expensive, but if we only add the cidr, we will miss all the relevant IPv6 // This is suboptimal and quite expensive, but if we only add the cidr, we will miss all the relevant IPv6
// addresses for the hosts that belong to tailscale. This doesnt really affect stuff like subnet routers. // addresses for the hosts that belong to tailscale. This doesnt really affect stuff like subnet routers.
for _, machine := range machines { for _, node := range nodes {
for _, ip := range machine.IPAddresses { for _, ip := range node.IPAddresses {
// log.Trace(). // log.Trace().
// Msgf("checking if machine ip (%s) is part of cidr (%s): %v, is single ip cidr (%v), addr: %s", ip.String(), cidr.String(), cidr.Contains(ip), cidr.IsSingleIP(), cidr.Addr().String()) // Msgf("checking if node ip (%s) is part of cidr (%s): %v, is single ip cidr (%v), addr: %s", ip.String(), cidr.String(), cidr.Contains(ip), cidr.IsSingleIP(), cidr.Addr().String())
if cidr.Contains(ip) { if cidr.Contains(ip) {
val = append(val, machine.IPAddresses.ToStringSlice()...) val = append(val, node.IPAddresses.ToStringSlice()...)
} }
} }
} }
@ -663,11 +663,11 @@ func expandAlias(
// we assume in this function that we only have nodes from 1 user. // we assume in this function that we only have nodes from 1 user.
func excludeCorrectlyTaggedNodes( func excludeCorrectlyTaggedNodes(
aclPolicy ACLPolicy, aclPolicy ACLPolicy,
nodes []Machine, nodes []Node,
user string, user string,
stripEmailDomain bool, stripEmailDomain bool,
) []Machine { ) []Node {
out := []Machine{} out := []Node{}
tags := []string{} tags := []string{}
for tag := range aclPolicy.TagOwners { for tag := range aclPolicy.TagOwners {
owners, _ := expandTagOwners(aclPolicy, user, stripEmailDomain) owners, _ := expandTagOwners(aclPolicy, user, stripEmailDomain)
@ -676,9 +676,9 @@ func excludeCorrectlyTaggedNodes(
tags = append(tags, tag) tags = append(tags, tag)
} }
} }
// for each machine if tag is in tags list, don't append it. // for each node if tag is in tags list, don't append it.
for _, machine := range nodes { for _, node := range nodes {
hi := machine.GetHostInfo() hi := node.GetHostInfo()
found := false found := false
for _, t := range hi.RequestTags { for _, t := range hi.RequestTags {
@ -688,11 +688,11 @@ func excludeCorrectlyTaggedNodes(
break break
} }
} }
if len(machine.ForcedTags) > 0 { if len(node.ForcedTags) > 0 {
found = true found = true
} }
if !found { if !found {
out = append(out, machine) out = append(out, node)
} }
} }
@ -747,11 +747,11 @@ func expandPorts(portsStr string, needsWildcard bool) (*[]tailcfg.PortRange, err
return &ports, nil return &ports, nil
} }
func filterMachinesByUser(machines []Machine, user string) []Machine { func filterNodesByUser(nodes []Node, user string) []Node {
out := []Machine{} out := []Node{}
for _, machine := range machines { for _, node := range nodes {
if machine.User.Name == user { if node.User.Name == user {
out = append(out, machine) out = append(out, node)
} }
} }

6
api.go
View File

@ -20,7 +20,7 @@ const (
RegisterMethodOIDC = "oidc" RegisterMethodOIDC = "oidc"
RegisterMethodCLI = "cli" RegisterMethodCLI = "cli"
ErrRegisterMethodCLIDoesNotSupportExpire = Error( ErrRegisterMethodCLIDoesNotSupportExpire = Error(
"machines registered with CLI does not support expire", "node registered with CLI does not support expire",
) )
) )
@ -74,9 +74,9 @@ var registerWebAPITemplate = template.Must(
</head> </head>
<body> <body>
<h1>headscale</h1> <h1>headscale</h1>
<h2>Machine registration</h2> <h2>Node registration</h2>
<p> <p>
Run the command below in the headscale server to add this machine to your network: Run the command below in the headscale server to add this node to your network:
</p> </p>
<pre><code>headscale nodes register --user USERNAME --key {{.Key}}</code></pre> <pre><code>headscale nodes register --user USERNAME --key {{.Key}}</code></pre>
</body> </body>

View File

@ -9,13 +9,13 @@ import (
func (h *Headscale) generateMapResponse( func (h *Headscale) generateMapResponse(
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
machine *Machine, node *Node,
) (*tailcfg.MapResponse, error) { ) (*tailcfg.MapResponse, error) {
log.Trace(). log.Trace().
Str("func", "generateMapResponse"). Str("func", "generateMapResponse").
Str("machine", mapRequest.Hostinfo.Hostname). Str("node", mapRequest.Hostinfo.Hostname).
Msg("Creating Map response") Msg("Creating Map response")
node, err := h.toNode(*machine, h.cfg.BaseDomain, h.cfg.DNSConfig) tailNode, err := h.toNode(*node, h.cfg.BaseDomain, h.cfg.DNSConfig)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
@ -26,7 +26,7 @@ func (h *Headscale) generateMapResponse(
return nil, err return nil, err
} }
peers, err := h.getValidPeers(machine) peers, err := h.getValidPeers(node)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
@ -37,7 +37,7 @@ func (h *Headscale) generateMapResponse(
return nil, err return nil, err
} }
profiles := h.getMapResponseUserProfiles(*machine, peers) profiles := h.getMapResponseUserProfiles(*node, peers)
nodePeers, err := h.toNodes(peers, h.cfg.BaseDomain, h.cfg.DNSConfig) nodePeers, err := h.toNodes(peers, h.cfg.BaseDomain, h.cfg.DNSConfig)
if err != nil { if err != nil {
@ -53,7 +53,7 @@ func (h *Headscale) generateMapResponse(
dnsConfig := getMapResponseDNSConfig( dnsConfig := getMapResponseDNSConfig(
h.cfg.DNSConfig, h.cfg.DNSConfig,
h.cfg.BaseDomain, h.cfg.BaseDomain,
*machine, *node,
peers, peers,
) )
@ -61,7 +61,7 @@ func (h *Headscale) generateMapResponse(
resp := tailcfg.MapResponse{ resp := tailcfg.MapResponse{
KeepAlive: false, KeepAlive: false,
Node: node, Node: tailNode,
// TODO: Only send if updated // TODO: Only send if updated
DERPMap: h.DERPMap, DERPMap: h.DERPMap,
@ -105,7 +105,7 @@ func (h *Headscale) generateMapResponse(
log.Trace(). log.Trace().
Str("func", "generateMapResponse"). Str("func", "generateMapResponse").
Str("machine", mapRequest.Hostinfo.Hostname). Str("node", mapRequest.Hostinfo.Hostname).
// Interface("payload", resp). // Interface("payload", resp).
Msgf("Generated map response: %s", tailMapResponseToString(resp)) Msgf("Generated map response: %s", tailMapResponseToString(resp))

54
app.go
View File

@ -211,7 +211,7 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) {
http.Redirect(w, req, target, http.StatusFound) http.Redirect(w, req, target, http.StatusFound)
} }
// expireEphemeralNodes deletes ephemeral machine records that have not been // expireEphemeralNodes deletes ephemeral node records that have not been
// seen for longer than h.cfg.EphemeralNodeInactivityTimeout. // seen for longer than h.cfg.EphemeralNodeInactivityTimeout.
func (h *Headscale) expireEphemeralNodes(milliSeconds int64) { func (h *Headscale) expireEphemeralNodes(milliSeconds int64) {
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond) ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
@ -220,12 +220,12 @@ func (h *Headscale) expireEphemeralNodes(milliSeconds int64) {
} }
} }
// expireExpiredMachines expires machines that have an explicit expiry set // expireExpiredNodes expires node that have an explicit expiry set
// after that expiry time has passed. // after that expiry time has passed.
func (h *Headscale) expireExpiredMachines(milliSeconds int64) { func (h *Headscale) expireExpiredNodes(milliSeconds int64) {
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond) ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
for range ticker.C { for range ticker.C {
h.expireExpiredMachinesWorker() h.expireExpiredNodesWorker()
} }
} }
@ -248,32 +248,32 @@ func (h *Headscale) expireEphemeralNodesWorker() {
} }
for _, user := range users { for _, user := range users {
machines, err := h.ListMachinesByUser(user.Name) nodes, err := h.ListNodesByUser(user.Name)
if err != nil { if err != nil {
log.Error(). log.Error().
Err(err). Err(err).
Str("user", user.Name). Str("user", user.Name).
Msg("Error listing machines in user") Msg("Error listing nodes in user")
return return
} }
expiredFound := false expiredFound := false
for _, machine := range machines { for _, node := range nodes {
if machine.isEphemeral() && machine.LastSeen != nil && if node.isEphemeral() && node.LastSeen != nil &&
time.Now(). time.Now().
After(machine.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) { After(node.LastSeen.Add(h.cfg.EphemeralNodeInactivityTimeout)) {
expiredFound = true expiredFound = true
log.Info(). log.Info().
Str("machine", machine.Hostname). Str("node", node.Hostname).
Msg("Ephemeral client removed from database") Msg("Ephemeral client removed from database")
err = h.db.Unscoped().Delete(machine).Error err = h.db.Unscoped().Delete(node).Error
if err != nil { if err != nil {
log.Error(). log.Error().
Err(err). Err(err).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Msg("🤮 Cannot delete ephemeral machine from the database") Msg("Cannot delete ephemeral node from the database")
} }
} }
} }
@ -284,7 +284,7 @@ func (h *Headscale) expireEphemeralNodesWorker() {
} }
} }
func (h *Headscale) expireExpiredMachinesWorker() { func (h *Headscale) expireExpiredNodesWorker() {
users, err := h.ListUsers() users, err := h.ListUsers()
if err != nil { if err != nil {
log.Error().Err(err).Msg("Error listing users") log.Error().Err(err).Msg("Error listing users")
@ -293,34 +293,34 @@ func (h *Headscale) expireExpiredMachinesWorker() {
} }
for _, user := range users { for _, user := range users {
machines, err := h.ListMachinesByUser(user.Name) nodes, err := h.ListNodesByUser(user.Name)
if err != nil { if err != nil {
log.Error(). log.Error().
Err(err). Err(err).
Str("user", user.Name). Str("user", user.Name).
Msg("Error listing machines in user") Msg("Error listing nodes in user")
return return
} }
expiredFound := false expiredFound := false
for index, machine := range machines { for index, node := range nodes {
if machine.isExpired() && if node.isExpired() &&
machine.Expiry.After(h.getLastStateChange(user)) { node.Expiry.After(h.getLastStateChange(user)) {
expiredFound = true expiredFound = true
err := h.ExpireMachine(&machines[index]) err := h.ExpireNode(&nodes[index])
if err != nil { if err != nil {
log.Error(). log.Error().
Err(err). Err(err).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Str("name", machine.GivenName). Str("name", node.GivenName).
Msg("🤮 Cannot expire machine") Msg("Cannot expire node")
} else { } else {
log.Info(). log.Info().
Str("machine", machine.Hostname). Str("node", node.Hostname).
Str("name", machine.GivenName). Str("name", node.GivenName).
Msg("Machine successfully expired") Msg("Node successfully expired")
} }
} }
} }
@ -552,7 +552,7 @@ func (h *Headscale) Serve() error {
} }
go h.expireEphemeralNodes(updateInterval) go h.expireEphemeralNodes(updateInterval)
go h.expireExpiredMachines(updateInterval) go h.expireExpiredNodes(updateInterval)
go h.failoverSubnetRoutes(updateInterval) go h.failoverSubnetRoutes(updateInterval)

22
dns.go
View File

@ -159,22 +159,22 @@ func generateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
} }
// If any nextdns DoH resolvers are present in the list of resolvers it will // If any nextdns DoH resolvers are present in the list of resolvers it will
// take metadata from the machine metadata and instruct tailscale to add it // take metadata from the node metadata and instruct tailscale to add it
// to the requests. This makes it possible to identify from which device the // to the requests. This makes it possible to identify from which device the
// requests come in the NextDNS dashboard. // requests come in the NextDNS dashboard.
// //
// This will produce a resolver like: // This will produce a resolver like:
// `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1` // `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1`
func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine Machine) { func addNextDNSMetadata(resolvers []*dnstype.Resolver, node Node) {
for _, resolver := range resolvers { for _, resolver := range resolvers {
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) { if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {
attrs := url.Values{ attrs := url.Values{
"device_name": []string{machine.Hostname}, "device_name": []string{node.Hostname},
"device_model": []string{machine.HostInfo.OS}, "device_model": []string{node.HostInfo.OS},
} }
if len(machine.IPAddresses) > 0 { if len(node.IPAddresses) > 0 {
attrs.Add("device_ip", machine.IPAddresses[0].String()) attrs.Add("device_ip", node.IPAddresses[0].String())
} }
resolver.Addr = fmt.Sprintf("%s?%s", resolver.Addr, attrs.Encode()) resolver.Addr = fmt.Sprintf("%s?%s", resolver.Addr, attrs.Encode())
@ -185,8 +185,8 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine Machine) {
func getMapResponseDNSConfig( func getMapResponseDNSConfig(
dnsConfigOrig *tailcfg.DNSConfig, dnsConfigOrig *tailcfg.DNSConfig,
baseDomain string, baseDomain string,
machine Machine, node Node,
peers Machines, peers Nodes,
) *tailcfg.DNSConfig { ) *tailcfg.DNSConfig {
var dnsConfig *tailcfg.DNSConfig = dnsConfigOrig.Clone() var dnsConfig *tailcfg.DNSConfig = dnsConfigOrig.Clone()
if dnsConfigOrig != nil && dnsConfigOrig.Proxied { // if MagicDNS is enabled if dnsConfigOrig != nil && dnsConfigOrig.Proxied { // if MagicDNS is enabled
@ -195,13 +195,13 @@ func getMapResponseDNSConfig(
dnsConfig.Domains, dnsConfig.Domains,
fmt.Sprintf( fmt.Sprintf(
"%s.%s", "%s.%s",
machine.User.Name, node.User.Name,
baseDomain, baseDomain,
), ),
) )
userSet := mapset.NewSet[User]() userSet := mapset.NewSet[User]()
userSet.Add(machine.User) userSet.Add(node.User)
for _, p := range peers { for _, p := range peers {
userSet.Add(p.User) userSet.Add(p.User)
} }
@ -213,7 +213,7 @@ func getMapResponseDNSConfig(
dnsConfig = dnsConfigOrig dnsConfig = dnsConfigOrig
} }
addNextDNSMetadata(dnsConfig.Resolvers, machine) addNextDNSMetadata(dnsConfig.Resolvers, node)
return dnsConfig return dnsConfig
} }

View File

@ -8,34 +8,34 @@ import (
const prometheusNamespace = "headscale" const prometheusNamespace = "headscale"
var ( var (
// This is a high cardinality metric (user x machines), we might want to make this // This is a high cardinality metric (user x nodes), we might want to make this
// configurable/opt-in in the future. // configurable/opt-in in the future.
lastStateUpdate = promauto.NewGaugeVec(prometheus.GaugeOpts{ lastStateUpdate = promauto.NewGaugeVec(prometheus.GaugeOpts{
Namespace: prometheusNamespace, Namespace: prometheusNamespace,
Name: "last_update_seconds", Name: "last_update_seconds",
Help: "Time stamp in unix time when a machine or headscale was updated", Help: "Time stamp in unix time when a node or headscale was updated",
}, []string{"user", "machine"}) }, []string{"user", "nodes"})
machineRegistrations = promauto.NewCounterVec(prometheus.CounterOpts{ nodeRegistrations = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace, Namespace: prometheusNamespace,
Name: "machine_registrations_total", Name: "node_registrations_total",
Help: "The total amount of registered machine attempts", Help: "The total amount of registered node attempts",
}, []string{"action", "auth", "status", "user"}) }, []string{"action", "auth", "status", "user"})
updateRequestsFromNode = promauto.NewCounterVec(prometheus.CounterOpts{ updateRequestsFromNode = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace, Namespace: prometheusNamespace,
Name: "update_request_from_node_total", Name: "update_request_from_node_total",
Help: "The number of updates requested by a node/update function", Help: "The number of updates requested by a node/update function",
}, []string{"user", "machine", "state"}) }, []string{"user", "node", "state"})
updateRequestsSentToNode = promauto.NewCounterVec(prometheus.CounterOpts{ updateRequestsSentToNode = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace, Namespace: prometheusNamespace,
Name: "update_request_sent_to_node_total", Name: "update_request_sent_to_node_total",
Help: "The number of calls/messages issued on a specific nodes update channel", Help: "The number of calls/messages issued on a specific nodes update channel",
}, []string{"user", "machine", "status"}) }, []string{"user", "node", "status"})
// TODO(kradalby): This is very debugging, we might want to remove it. // TODO(kradalby): This is very debugging, we might want to remove it.
updateRequestsReceivedOnChannel = promauto.NewCounterVec(prometheus.CounterOpts{ updateRequestsReceivedOnChannel = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace, Namespace: prometheusNamespace,
Name: "update_request_received_on_channel_total", Name: "update_request_received_on_channel_total",
Help: "The number of update requests received on an update channel", Help: "The number of update requests received on an update channel",
}, []string{"user", "machine"}) }, []string{"user", "node"})
) )

64
oidc.go
View File

@ -27,8 +27,8 @@ const (
errOIDCAllowedDomains = Error("authenticated principal does not match any allowed domain") errOIDCAllowedDomains = Error("authenticated principal does not match any allowed domain")
errOIDCAllowedGroups = Error("authenticated principal is not in any allowed group") errOIDCAllowedGroups = Error("authenticated principal is not in any allowed group")
errOIDCAllowedUsers = Error("authenticated principal does not match any allowed user") errOIDCAllowedUsers = Error("authenticated principal does not match any allowed user")
errOIDCInvalidMachineState = Error( errOIDCInvalidNodeState = Error(
"requested machine state key expired before authorisation completed", "requested node state key expired before authorisation completed",
) )
errOIDCNodeKeyMissing = Error("could not get node key from cache") errOIDCNodeKeyMissing = Error("could not get node key from cache")
) )
@ -181,9 +181,9 @@ var oidcCallbackTemplate = template.Must(
) )
// OIDCCallback handles the callback from the OIDC endpoint // OIDCCallback handles the callback from the OIDC endpoint
// Retrieves the nkey from the state cache and adds the machine to the users email user // Retrieves the nkey from the state cache and adds the node to the users email user
// TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities // TODO: A confirmation page for new nodes should be added to avoid phishing vulnerabilities
// TODO: Add groups information from OIDC tokens into machine HostInfo // TODO: Add groups information from OIDC tokens into node HostInfo
// Listens in /oidc/callback. // Listens in /oidc/callback.
func (h *Headscale) OIDCCallback( func (h *Headscale) OIDCCallback(
writer http.ResponseWriter, writer http.ResponseWriter,
@ -229,13 +229,13 @@ func (h *Headscale) OIDCCallback(
return return
} }
nodeKey, machineExists, err := h.validateMachineForOIDCCallback( nodeKey, nodeExists, err := h.validateNodeForOIDCCallback(
writer, writer,
state, state,
claims, claims,
idTokenExpiry, idTokenExpiry,
) )
if err != nil || machineExists { if err != nil || nodeExists {
return return
} }
@ -244,15 +244,15 @@ func (h *Headscale) OIDCCallback(
return return
} }
// register the machine if it's new // register the node if it's new
log.Debug().Msg("Registering new machine after successful callback") log.Debug().Msg("Registering new node after successful callback")
user, err := h.findOrCreateNewUserForOIDCCallback(writer, userName) user, err := h.findOrCreateNewUserForOIDCCallback(writer, userName)
if err != nil { if err != nil {
return return
} }
if err := h.registerMachineForOIDCCallback(writer, user, nodeKey, idTokenExpiry); err != nil { if err := h.registerNodeForOIDCCallback(writer, user, nodeKey, idTokenExpiry); err != nil {
return return
} }
@ -484,21 +484,21 @@ func validateOIDCAllowedUsers(
return nil return nil
} }
// validateMachine retrieves machine information if it exist // validateNode retrieves node information if it exist
// The error is not important, because if it does not // The error is not important, because if it does not
// exist, then this is a new machine and we will move // exist, then this is a new node and we will move
// on to registration. // on to registration.
func (h *Headscale) validateMachineForOIDCCallback( func (h *Headscale) validateNodeForOIDCCallback(
writer http.ResponseWriter, writer http.ResponseWriter,
state string, state string,
claims *IDTokenClaims, claims *IDTokenClaims,
expiry time.Time, expiry time.Time,
) (*key.NodePublic, bool, error) { ) (*key.NodePublic, bool, error) {
// retrieve machinekey from state cache // retrieve nodekey from state cache
nodeKeyIf, nodeKeyFound := h.registrationCache.Get(state) nodeKeyIf, nodeKeyFound := h.registrationCache.Get(state)
if !nodeKeyFound { if !nodeKeyFound {
log.Error(). log.Error().
Msg("requested machine state key expired before authorisation completed") Msg("requested node state key expired before authorisation completed")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest) writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("state has expired")) _, err := writer.Write([]byte("state has expired"))
@ -516,7 +516,7 @@ func (h *Headscale) validateMachineForOIDCCallback(
nodeKeyFromCache, nodeKeyOK := nodeKeyIf.(string) nodeKeyFromCache, nodeKeyOK := nodeKeyIf.(string)
if !nodeKeyOK { if !nodeKeyOK {
log.Error(). log.Error().
Msg("requested machine state key is not a string") Msg("requested node state key is not a string")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest) writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("state is invalid")) _, err := writer.Write([]byte("state is invalid"))
@ -527,7 +527,7 @@ func (h *Headscale) validateMachineForOIDCCallback(
Msg("Failed to write response") Msg("Failed to write response")
} }
return nil, false, errOIDCInvalidMachineState return nil, false, errOIDCInvalidNodeState
} }
err := nodeKey.UnmarshalText( err := nodeKey.UnmarshalText(
@ -551,36 +551,36 @@ func (h *Headscale) validateMachineForOIDCCallback(
return nil, false, err return nil, false, err
} }
// retrieve machine information if it exist // retrieve node information if it exist
// The error is not important, because if it does not // The error is not important, because if it does not
// exist, then this is a new machine and we will move // exist, then this is a new node and we will move
// on to registration. // on to registration.
machine, _ := h.GetMachineByNodeKey(nodeKey) node, _ := h.GetNodeByNodeKey(nodeKey)
if machine != nil { if node != nil {
log.Trace(). log.Trace().
Caller(). Caller().
Str("machine", machine.Hostname). Str("node", node.Hostname).
Msg("machine already registered, reauthenticating") Msg("node already registered, reauthenticating")
err := h.RefreshMachine(machine, expiry) err := h.RefreshNode(node, expiry)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Err(err). Err(err).
Msg("Failed to refresh machine") Msg("Failed to refresh node")
http.Error( http.Error(
writer, writer,
"Failed to refresh machine", "Failed to refresh node",
http.StatusInternalServerError, http.StatusInternalServerError,
) )
return nil, true, err return nil, true, err
} }
log.Debug(). log.Debug().
Str("machine", machine.Hostname). Str("node", node.Hostname).
Str("expiresAt", fmt.Sprintf("%v", expiry)). Str("expiresAt", fmt.Sprintf("%v", expiry)).
Msg("successfully refreshed machine") Msg("successfully refreshed node")
var content bytes.Buffer var content bytes.Buffer
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{ if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
@ -696,13 +696,13 @@ func (h *Headscale) findOrCreateNewUserForOIDCCallback(
return user, nil return user, nil
} }
func (h *Headscale) registerMachineForOIDCCallback( func (h *Headscale) registerNodeForOIDCCallback(
writer http.ResponseWriter, writer http.ResponseWriter,
user *User, user *User,
nodeKey *key.NodePublic, nodeKey *key.NodePublic,
expiry time.Time, expiry time.Time,
) error { ) error {
if _, err := h.RegisterMachineFromAuthCallback( if _, err := h.RegisterNodeFromAuthCallback(
nodeKey.String(), nodeKey.String(),
user.Name, user.Name,
&expiry, &expiry,
@ -711,10 +711,10 @@ func (h *Headscale) registerMachineForOIDCCallback(
log.Error(). log.Error().
Caller(). Caller().
Err(err). Err(err).
Msg("could not register machine") Msg("could not register node")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError) writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("could not register machine")) _, werr := writer.Write([]byte("could not register node"))
if werr != nil { if werr != nil {
log.Error(). log.Error().
Caller(). Caller().

View File

@ -193,12 +193,12 @@ func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) {
return &pak, nil return &pak, nil
} }
machines := []Machine{} nodes := []Node{}
if err := h.db.Preload("AuthKey").Where(&Machine{AuthKeyID: uint(pak.ID)}).Find(&machines).Error; err != nil { if err := h.db.Preload("AuthKey").Where(&Node{AuthKeyID: uint(pak.ID)}).Find(&nodes).Error; err != nil {
return nil, err return nil, err
} }
if len(machines) != 0 || pak.Used { if len(nodes) != 0 || pak.Used {
return nil, ErrSingleUseAuthKeyHasBeenUsed return nil, ErrSingleUseAuthKeyHasBeenUsed
} }

View File

@ -102,9 +102,9 @@ func (h *Headscale) handleRegisterCommon(
isNoise bool, isNoise bool,
) { ) {
now := time.Now().UTC() now := time.Now().UTC()
machine, err := h.GetMachineByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey) node, err := h.GetNodeByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey)
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
// If the machine has AuthKey set, handle registration via PreAuthKeys // If the node has AuthKey set, handle registration via PreAuthKeys
if registerRequest.Auth.AuthKey != "" { if registerRequest.Auth.AuthKey != "" {
h.handleAuthKeyCommon(writer, registerRequest, machineKey, isNoise) h.handleAuthKeyCommon(writer, registerRequest, machineKey, isNoise)
@ -115,7 +115,7 @@ func (h *Headscale) handleRegisterCommon(
// //
// TODO(juan): We could use this field to improve our protocol implementation, // TODO(juan): We could use this field to improve our protocol implementation,
// and hold the request until the client closes it, or the interactive // and hold the request until the client closes it, or the interactive
// login is completed (i.e., the user registers the machine). // login is completed (i.e., the user registers the node).
// This is not implemented yet, as it is no strictly required. The only side-effect // This is not implemented yet, as it is no strictly required. The only side-effect
// is that the client will hammer headscale with requests until it gets a // is that the client will hammer headscale with requests until it gets a
// successful RegisterResponse. // successful RegisterResponse.
@ -123,19 +123,19 @@ func (h *Headscale) handleRegisterCommon(
if _, ok := h.registrationCache.Get(NodePublicKeyStripPrefix(registerRequest.NodeKey)); ok { if _, ok := h.registrationCache.Get(NodePublicKeyStripPrefix(registerRequest.NodeKey)); ok {
log.Debug(). log.Debug().
Caller(). Caller().
Str("machine", registerRequest.Hostinfo.Hostname). Str("node", registerRequest.Hostinfo.Hostname).
Str("machine_key", machineKey.ShortString()). Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()). Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()). Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Str("follow_up", registerRequest.Followup). Str("follow_up", registerRequest.Followup).
Bool("noise", isNoise). Bool("noise", isNoise).
Msg("Machine is waiting for interactive login") Msg("Node is waiting for interactive login")
select { select {
case <-req.Context().Done(): case <-req.Context().Done():
return return
case <-time.After(registrationHoldoff): case <-time.After(registrationHoldoff):
h.handleNewMachineCommon(writer, registerRequest, machineKey, isNoise) h.handleNewNodeCommon(writer, registerRequest, machineKey, isNoise)
return return
} }
@ -144,13 +144,13 @@ func (h *Headscale) handleRegisterCommon(
log.Info(). log.Info().
Caller(). Caller().
Str("machine", registerRequest.Hostinfo.Hostname). Str("node", registerRequest.Hostinfo.Hostname).
Str("machine_key", machineKey.ShortString()). Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()). Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()). Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Str("follow_up", registerRequest.Followup). Str("follow_up", registerRequest.Followup).
Bool("noise", isNoise). Bool("noise", isNoise).
Msg("New machine not yet in the database") Msg("New node not yet in the database")
givenName, err := h.GenerateGivenName( givenName, err := h.GenerateGivenName(
machineKey.String(), machineKey.String(),
@ -166,11 +166,11 @@ func (h *Headscale) handleRegisterCommon(
return return
} }
// The machine did not have a key to authenticate, which means // The node did not have a key to authenticate, which means
// that we rely on a method that calls back some how (OpenID or CLI) // that we rely on a method that calls back some how (OpenID or CLI)
// We create the machine and then keep it around until a callback // We create the node and then keep it around until a callback
// happens // happens
newMachine := Machine{ newNode := Node{
MachineKey: MachinePublicKeyStripPrefix(machineKey), MachineKey: MachinePublicKeyStripPrefix(machineKey),
Hostname: registerRequest.Hostinfo.Hostname, Hostname: registerRequest.Hostinfo.Hostname,
GivenName: givenName, GivenName: givenName,
@ -183,42 +183,42 @@ func (h *Headscale) handleRegisterCommon(
log.Trace(). log.Trace().
Caller(). Caller().
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", registerRequest.Hostinfo.Hostname). Str("node", registerRequest.Hostinfo.Hostname).
Time("expiry", registerRequest.Expiry). Time("expiry", registerRequest.Expiry).
Msg("Non-zero expiry time requested") Msg("Non-zero expiry time requested")
newMachine.Expiry = &registerRequest.Expiry newNode.Expiry = &registerRequest.Expiry
} }
h.registrationCache.Set( h.registrationCache.Set(
newMachine.NodeKey, newNode.NodeKey,
newMachine, newNode,
registerCacheExpiration, registerCacheExpiration,
) )
h.handleNewMachineCommon(writer, registerRequest, machineKey, isNoise) h.handleNewNodeCommon(writer, registerRequest, machineKey, isNoise)
return return
} }
// The machine is already in the DB. This could mean one of the following: // The node is already in the DB. This could mean one of the following:
// - The machine is authenticated and ready to /map // - The node is authenticated and ready to /map
// - We are doing a key refresh // - We are doing a key refresh
// - The machine is logged out (or expired) and pending to be authorized. TODO(juan): We need to keep alive the connection here // - The node is logged out (or expired) and pending to be authorized. TODO(juan): We need to keep alive the connection here
if machine != nil { if node != nil {
// (juan): For a while we had a bug where we were not storing the MachineKey for the nodes using the TS2021, // (juan): For a while we had a bug where we were not storing the MachineKey for the nodes using the TS2021,
// due to a misunderstanding of the protocol https://github.com/juanfont/headscale/issues/1054 // due to a misunderstanding of the protocol https://github.com/juanfont/headscale/issues/1054
// So if we have a not valid MachineKey (but we were able to fetch the machine with the NodeKeys), we update it. // So if we have a not valid MachineKey (but we were able to fetch the node with the NodeKeys), we update it.
var storedMachineKey key.MachinePublic var storedMachineKey key.MachinePublic
err = storedMachineKey.UnmarshalText( err = storedMachineKey.UnmarshalText(
[]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)), []byte(MachinePublicKeyEnsurePrefix(node.MachineKey)),
) )
if err != nil || storedMachineKey.IsZero() { if err != nil || storedMachineKey.IsZero() {
machine.MachineKey = MachinePublicKeyStripPrefix(machineKey) node.MachineKey = MachinePublicKeyStripPrefix(machineKey)
if err := h.db.Save(&machine).Error; err != nil { if err := h.db.Save(&node).Error; err != nil {
log.Error(). log.Error().
Caller(). Caller().
Str("func", "RegistrationHandler"). Str("func", "RegistrationHandler").
Str("machine", machine.Hostname). Str("node", node.Hostname).
Err(err). Err(err).
Msg("Error saving machine key to database") Msg("Error saving machine key to database")
@ -229,34 +229,34 @@ func (h *Headscale) handleRegisterCommon(
// If the NodeKey stored in headscale is the same as the key presented in a registration // If the NodeKey stored in headscale is the same as the key presented in a registration
// request, then we have a node that is either: // request, then we have a node that is either:
// - Trying to log out (sending a expiry in the past) // - Trying to log out (sending a expiry in the past)
// - A valid, registered machine, looking for /map // - A valid, registered node, looking for /map
// - Expired machine wanting to reauthenticate // - Expired node wanting to reauthenticate
if machine.NodeKey == NodePublicKeyStripPrefix(registerRequest.NodeKey) { if node.NodeKey == NodePublicKeyStripPrefix(registerRequest.NodeKey) {
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout) // The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648 // https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
if !registerRequest.Expiry.IsZero() && if !registerRequest.Expiry.IsZero() &&
registerRequest.Expiry.UTC().Before(now) { registerRequest.Expiry.UTC().Before(now) {
h.handleMachineLogOutCommon(writer, *machine, machineKey, isNoise) h.handleNodeLogOutCommon(writer, *node, machineKey, isNoise)
return return
} }
// If machine is not expired, and it is register, we have a already accepted this machine, // If node is not expired, and it is register, we have a already accepted this node,
// let it proceed with a valid registration // let it proceed with a valid registration
if !machine.isExpired() { if !node.isExpired() {
h.handleMachineValidRegistrationCommon(writer, *machine, machineKey, isNoise) h.handleNodeValidRegistrationCommon(writer, *node, machineKey, isNoise)
return return
} }
} }
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration // The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
if machine.NodeKey == NodePublicKeyStripPrefix(registerRequest.OldNodeKey) && if node.NodeKey == NodePublicKeyStripPrefix(registerRequest.OldNodeKey) &&
!machine.isExpired() { !node.isExpired() {
h.handleMachineRefreshKeyCommon( h.handleNodeRefreshKeyCommon(
writer, writer,
registerRequest, registerRequest,
*machine, *node,
machineKey, machineKey,
isNoise, isNoise,
) )
@ -272,20 +272,20 @@ func (h *Headscale) handleRegisterCommon(
} }
} }
// The machine has expired or it is logged out // The node has expired or it is logged out
h.handleMachineExpiredOrLoggedOutCommon(writer, registerRequest, *machine, machineKey, isNoise) h.handleNodeExpiredOrLoggedOutCommon(writer, registerRequest, *node, machineKey, isNoise)
// TODO(juan): RegisterRequest includes an Expiry time, that we could optionally use // TODO(juan): RegisterRequest includes an Expiry time, that we could optionally use
machine.Expiry = &time.Time{} node.Expiry = &time.Time{}
// If we are here it means the client needs to be reauthorized, // If we are here it means the client needs to be reauthorized,
// we need to make sure the NodeKey matches the one in the request // we need to make sure the NodeKey matches the one in the request
// TODO(juan): What happens when using fast user switching between two // TODO(juan): What happens when using fast user switching between two
// headscale-managed tailnets? // headscale-managed tailnets?
machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey) node.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey)
h.registrationCache.Set( h.registrationCache.Set(
NodePublicKeyStripPrefix(registerRequest.NodeKey), NodePublicKeyStripPrefix(registerRequest.NodeKey),
*machine, *node,
registerCacheExpiration, registerCacheExpiration,
) )
@ -306,7 +306,7 @@ func (h *Headscale) handleAuthKeyCommon(
) { ) {
log.Debug(). log.Debug().
Str("func", "handleAuthKeyCommon"). Str("func", "handleAuthKeyCommon").
Str("machine", registerRequest.Hostinfo.Hostname). Str("node", registerRequest.Hostinfo.Hostname).
Bool("noise", isNoise). Bool("noise", isNoise).
Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname) Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname)
resp := tailcfg.RegisterResponse{} resp := tailcfg.RegisterResponse{}
@ -317,7 +317,7 @@ func (h *Headscale) handleAuthKeyCommon(
Caller(). Caller().
Str("func", "handleAuthKeyCommon"). Str("func", "handleAuthKeyCommon").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", registerRequest.Hostinfo.Hostname). Str("node", registerRequest.Hostinfo.Hostname).
Err(err). Err(err).
Msg("Failed authentication via AuthKey") Msg("Failed authentication via AuthKey")
resp.MachineAuthorized = false resp.MachineAuthorized = false
@ -328,11 +328,11 @@ func (h *Headscale) handleAuthKeyCommon(
Caller(). Caller().
Str("func", "handleAuthKeyCommon"). Str("func", "handleAuthKeyCommon").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", registerRequest.Hostinfo.Hostname). Str("node", registerRequest.Hostinfo.Hostname).
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
http.Error(writer, "Internal server error", http.StatusInternalServerError) http.Error(writer, "Internal server error", http.StatusInternalServerError)
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name). nodeRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name).
Inc() Inc()
return return
@ -353,14 +353,14 @@ func (h *Headscale) handleAuthKeyCommon(
Caller(). Caller().
Str("func", "handleAuthKeyCommon"). Str("func", "handleAuthKeyCommon").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", registerRequest.Hostinfo.Hostname). Str("node", registerRequest.Hostinfo.Hostname).
Msg("Failed authentication via AuthKey") Msg("Failed authentication via AuthKey")
if pak != nil { if pak != nil {
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name). nodeRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name).
Inc() Inc()
} else { } else {
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", "unknown").Inc() nodeRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", "unknown").Inc()
} }
return return
@ -369,33 +369,33 @@ func (h *Headscale) handleAuthKeyCommon(
log.Debug(). log.Debug().
Str("func", "handleAuthKeyCommon"). Str("func", "handleAuthKeyCommon").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", registerRequest.Hostinfo.Hostname). Str("node", registerRequest.Hostinfo.Hostname).
Msg("Authentication key was valid, proceeding to acquire IP addresses") Msg("Authentication key was valid, proceeding to acquire IP addresses")
nodeKey := NodePublicKeyStripPrefix(registerRequest.NodeKey) nodeKey := NodePublicKeyStripPrefix(registerRequest.NodeKey)
// retrieve machine information if it exist // retrieve node information if it exist
// The error is not important, because if it does not // The error is not important, because if it does not
// exist, then this is a new machine and we will move // exist, then this is a new node and we will move
// on to registration. // on to registration.
machine, _ := h.GetMachineByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey) node, _ := h.GetNodeByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey)
if machine != nil { if node != nil {
log.Trace(). log.Trace().
Caller(). Caller().
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Msg("machine was already registered before, refreshing with new auth key") Msg("node was already registered before, refreshing with new auth key")
machine.NodeKey = nodeKey node.NodeKey = nodeKey
machine.AuthKeyID = uint(pak.ID) node.AuthKeyID = uint(pak.ID)
err := h.RefreshMachine(machine, registerRequest.Expiry) err := h.RefreshNode(node, registerRequest.Expiry)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Err(err). Err(err).
Msg("Failed to refresh machine") Msg("Failed to refresh node")
return return
} }
@ -403,16 +403,16 @@ func (h *Headscale) handleAuthKeyCommon(
aclTags := pak.toProto().AclTags aclTags := pak.toProto().AclTags
if len(aclTags) > 0 { if len(aclTags) > 0 {
// This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login // This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login
err = h.SetTags(machine, aclTags) err = h.SetTags(node, aclTags)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Strs("aclTags", aclTags). Strs("aclTags", aclTags).
Err(err). Err(err).
Msg("Failed to set tags after refreshing machine") Msg("Failed to set tags after refreshing node")
return return
} }
@ -432,7 +432,7 @@ func (h *Headscale) handleAuthKeyCommon(
return return
} }
machineToRegister := Machine{ nodeToRegister := Node{
Hostname: registerRequest.Hostinfo.Hostname, Hostname: registerRequest.Hostinfo.Hostname,
GivenName: givenName, GivenName: givenName,
UserID: pak.User.ID, UserID: pak.User.ID,
@ -445,16 +445,16 @@ func (h *Headscale) handleAuthKeyCommon(
ForcedTags: pak.toProto().AclTags, ForcedTags: pak.toProto().AclTags,
} }
machine, err = h.RegisterMachine( node, err = h.RegisterNode(
machineToRegister, nodeToRegister,
) )
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Bool("noise", isNoise). Bool("noise", isNoise).
Err(err). Err(err).
Msg("could not register machine") Msg("could not register node")
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name). nodeRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name).
Inc() Inc()
http.Error(writer, "Internal server error", http.StatusInternalServerError) http.Error(writer, "Internal server error", http.StatusInternalServerError)
@ -469,7 +469,7 @@ func (h *Headscale) handleAuthKeyCommon(
Bool("noise", isNoise). Bool("noise", isNoise).
Err(err). Err(err).
Msg("Failed to use pre-auth key") Msg("Failed to use pre-auth key")
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name). nodeRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name).
Inc() Inc()
http.Error(writer, "Internal server error", http.StatusInternalServerError) http.Error(writer, "Internal server error", http.StatusInternalServerError)
@ -488,16 +488,16 @@ func (h *Headscale) handleAuthKeyCommon(
Caller(). Caller().
Bool("noise", isNoise). Bool("noise", isNoise).
Str("func", "handleAuthKeyCommon"). Str("func", "handleAuthKeyCommon").
Str("machine", registerRequest.Hostinfo.Hostname). Str("node", registerRequest.Hostinfo.Hostname).
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name). nodeRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name).
Inc() Inc()
http.Error(writer, "Internal server error", http.StatusInternalServerError) http.Error(writer, "Internal server error", http.StatusInternalServerError)
return return
} }
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "success", pak.User.Name). nodeRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "success", pak.User.Name).
Inc() Inc()
writer.Header().Set("Content-Type", "application/json; charset=utf-8") writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK) writer.WriteHeader(http.StatusOK)
@ -513,14 +513,14 @@ func (h *Headscale) handleAuthKeyCommon(
log.Info(). log.Info().
Str("func", "handleAuthKeyCommon"). Str("func", "handleAuthKeyCommon").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", registerRequest.Hostinfo.Hostname). Str("node", registerRequest.Hostinfo.Hostname).
Str("ips", strings.Join(machine.IPAddresses.ToStringSlice(), ", ")). Str("ips", strings.Join(node.IPAddresses.ToStringSlice(), ", ")).
Msg("Successfully authenticated via AuthKey") Msg("Successfully authenticated via AuthKey")
} }
// handleNewMachineCommon exposes for both legacy and Noise the functionality to get a URL // handleNewNodeCommon exposes for both legacy and Noise the functionality to get a URL
// for authorizing the machine. This url is then showed to the user by the local Tailscale client. // for authorizing the node. This url is then showed to the user by the local Tailscale client.
func (h *Headscale) handleNewMachineCommon( func (h *Headscale) handleNewNodeCommon(
writer http.ResponseWriter, writer http.ResponseWriter,
registerRequest tailcfg.RegisterRequest, registerRequest tailcfg.RegisterRequest,
machineKey key.MachinePublic, machineKey key.MachinePublic,
@ -528,11 +528,11 @@ func (h *Headscale) handleNewMachineCommon(
) { ) {
resp := tailcfg.RegisterResponse{} resp := tailcfg.RegisterResponse{}
// The machine registration is new, redirect the client to the registration URL // The node registration is new, redirect the client to the registration URL
log.Debug(). log.Debug().
Caller(). Caller().
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", registerRequest.Hostinfo.Hostname). Str("node", registerRequest.Hostinfo.Hostname).
Msg("The node seems to be new, sending auth url") Msg("The node seems to be new, sending auth url")
if h.oauth2Config != nil { if h.oauth2Config != nil {
@ -574,13 +574,13 @@ func (h *Headscale) handleNewMachineCommon(
Caller(). Caller().
Bool("noise", isNoise). Bool("noise", isNoise).
Str("AuthURL", resp.AuthURL). Str("AuthURL", resp.AuthURL).
Str("machine", registerRequest.Hostinfo.Hostname). Str("node", registerRequest.Hostinfo.Hostname).
Msg("Successfully sent auth url") Msg("Successfully sent auth url")
} }
func (h *Headscale) handleMachineLogOutCommon( func (h *Headscale) handleNodeLogOutCommon(
writer http.ResponseWriter, writer http.ResponseWriter,
machine Machine, node Node,
machineKey key.MachinePublic, machineKey key.MachinePublic,
isNoise bool, isNoise bool,
) { ) {
@ -588,17 +588,17 @@ func (h *Headscale) handleMachineLogOutCommon(
log.Info(). log.Info().
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Msg("Client requested logout") Msg("Client requested logout")
err := h.ExpireMachine(&machine) err := h.ExpireNode(&node)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Bool("noise", isNoise). Bool("noise", isNoise).
Str("func", "handleMachineLogOutCommon"). Str("func", "handleNodeLogOutCommon").
Err(err). Err(err).
Msg("Failed to expire machine") Msg("Failed to expire node")
http.Error(writer, "Internal server error", http.StatusInternalServerError) http.Error(writer, "Internal server error", http.StatusInternalServerError)
return return
@ -607,7 +607,7 @@ func (h *Headscale) handleMachineLogOutCommon(
resp.AuthURL = "" resp.AuthURL = ""
resp.MachineAuthorized = false resp.MachineAuthorized = false
resp.NodeKeyExpired = true resp.NodeKeyExpired = true
resp.User = *machine.User.toTailscaleUser() resp.User = *node.User.toTailscaleUser()
respBody, err := h.marshalResponse(resp, machineKey, isNoise) respBody, err := h.marshalResponse(resp, machineKey, isNoise)
if err != nil { if err != nil {
log.Error(). log.Error().
@ -633,13 +633,13 @@ func (h *Headscale) handleMachineLogOutCommon(
return return
} }
if machine.isEphemeral() { if node.isEphemeral() {
err = h.HardDeleteMachine(&machine) err = h.HardDeleteNode(&node)
if err != nil { if err != nil {
log.Error(). log.Error().
Err(err). Err(err).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Msg("Cannot delete ephemeral machine from the database") Msg("Cannot delete ephemeral node from the database")
} }
return return
@ -648,29 +648,29 @@ func (h *Headscale) handleMachineLogOutCommon(
log.Info(). log.Info().
Caller(). Caller().
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Msg("Successfully logged out") Msg("Successfully logged out")
} }
func (h *Headscale) handleMachineValidRegistrationCommon( func (h *Headscale) handleNodeValidRegistrationCommon(
writer http.ResponseWriter, writer http.ResponseWriter,
machine Machine, node Node,
machineKey key.MachinePublic, machineKey key.MachinePublic,
isNoise bool, isNoise bool,
) { ) {
resp := tailcfg.RegisterResponse{} resp := tailcfg.RegisterResponse{}
// The machine registration is valid, respond with redirect to /map // The node registration is valid, respond with redirect to /map
log.Debug(). log.Debug().
Caller(). Caller().
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Msg("Client is registered and we have the current NodeKey. All clear to /map") Msg("Client is registered and we have the current NodeKey. All clear to /map")
resp.AuthURL = "" resp.AuthURL = ""
resp.MachineAuthorized = true resp.MachineAuthorized = true
resp.User = *machine.User.toTailscaleUser() resp.User = *node.User.toTailscaleUser()
resp.Login = *machine.User.toTailscaleLogin() resp.Login = *node.User.toTailscaleLogin()
respBody, err := h.marshalResponse(resp, machineKey, isNoise) respBody, err := h.marshalResponse(resp, machineKey, isNoise)
if err != nil { if err != nil {
@ -679,13 +679,13 @@ func (h *Headscale) handleMachineValidRegistrationCommon(
Bool("noise", isNoise). Bool("noise", isNoise).
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
machineRegistrations.WithLabelValues("update", "web", "error", machine.User.Name). nodeRegistrations.WithLabelValues("update", "web", "error", node.User.Name).
Inc() Inc()
http.Error(writer, "Internal server error", http.StatusInternalServerError) http.Error(writer, "Internal server error", http.StatusInternalServerError)
return return
} }
machineRegistrations.WithLabelValues("update", "web", "success", machine.User.Name). nodeRegistrations.WithLabelValues("update", "web", "success", node.User.Name).
Inc() Inc()
writer.Header().Set("Content-Type", "application/json; charset=utf-8") writer.Header().Set("Content-Type", "application/json; charset=utf-8")
@ -702,14 +702,14 @@ func (h *Headscale) handleMachineValidRegistrationCommon(
log.Info(). log.Info().
Caller(). Caller().
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Msg("Machine successfully authorized") Msg("Node successfully authorized")
} }
func (h *Headscale) handleMachineRefreshKeyCommon( func (h *Headscale) handleNodeRefreshKeyCommon(
writer http.ResponseWriter, writer http.ResponseWriter,
registerRequest tailcfg.RegisterRequest, registerRequest tailcfg.RegisterRequest,
machine Machine, node Node,
machineKey key.MachinePublic, machineKey key.MachinePublic,
isNoise bool, isNoise bool,
) { ) {
@ -718,22 +718,22 @@ func (h *Headscale) handleMachineRefreshKeyCommon(
log.Info(). log.Info().
Caller(). Caller().
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Msg("We have the OldNodeKey in the database. This is a key refresh") Msg("We have the OldNodeKey in the database. This is a key refresh")
machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey) node.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey)
if err := h.db.Save(&machine).Error; err != nil { if err := h.db.Save(&node).Error; err != nil {
log.Error(). log.Error().
Caller(). Caller().
Err(err). Err(err).
Msg("Failed to update machine key in the database") Msg("Failed to update node key in the database")
http.Error(writer, "Internal server error", http.StatusInternalServerError) http.Error(writer, "Internal server error", http.StatusInternalServerError)
return return
} }
resp.AuthURL = "" resp.AuthURL = ""
resp.User = *machine.User.toTailscaleUser() resp.User = *node.User.toTailscaleUser()
respBody, err := h.marshalResponse(resp, machineKey, isNoise) respBody, err := h.marshalResponse(resp, machineKey, isNoise)
if err != nil { if err != nil {
log.Error(). log.Error().
@ -762,14 +762,14 @@ func (h *Headscale) handleMachineRefreshKeyCommon(
Bool("noise", isNoise). Bool("noise", isNoise).
Str("node_key", registerRequest.NodeKey.ShortString()). Str("node_key", registerRequest.NodeKey.ShortString()).
Str("old_node_key", registerRequest.OldNodeKey.ShortString()). Str("old_node_key", registerRequest.OldNodeKey.ShortString()).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Msg("Node key successfully refreshed") Msg("Node key successfully refreshed")
} }
func (h *Headscale) handleMachineExpiredOrLoggedOutCommon( func (h *Headscale) handleNodeExpiredOrLoggedOutCommon(
writer http.ResponseWriter, writer http.ResponseWriter,
registerRequest tailcfg.RegisterRequest, registerRequest tailcfg.RegisterRequest,
machine Machine, node Node,
machineKey key.MachinePublic, machineKey key.MachinePublic,
isNoise bool, isNoise bool,
) { ) {
@ -785,11 +785,11 @@ func (h *Headscale) handleMachineExpiredOrLoggedOutCommon(
log.Trace(). log.Trace().
Caller(). Caller().
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Str("machine_key", machineKey.ShortString()). Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()). Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()). Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Msg("Machine registration has expired or logged out. Sending a auth url to register") Msg("Node registration has expired or logged out. Sending a auth url to register")
if h.oauth2Config != nil { if h.oauth2Config != nil {
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
@ -808,13 +808,13 @@ func (h *Headscale) handleMachineExpiredOrLoggedOutCommon(
Bool("noise", isNoise). Bool("noise", isNoise).
Err(err). Err(err).
Msg("Cannot encode message") Msg("Cannot encode message")
machineRegistrations.WithLabelValues("reauth", "web", "error", machine.User.Name). nodeRegistrations.WithLabelValues("reauth", "web", "error", node.User.Name).
Inc() Inc()
http.Error(writer, "Internal server error", http.StatusInternalServerError) http.Error(writer, "Internal server error", http.StatusInternalServerError)
return return
} }
machineRegistrations.WithLabelValues("reauth", "web", "success", machine.User.Name). nodeRegistrations.WithLabelValues("reauth", "web", "success", node.User.Name).
Inc() Inc()
writer.Header().Set("Content-Type", "application/json; charset=utf-8") writer.Header().Set("Content-Type", "application/json; charset=utf-8")
@ -834,6 +834,6 @@ func (h *Headscale) handleMachineExpiredOrLoggedOutCommon(
Str("machine_key", machineKey.ShortString()). Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()). Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()). Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Msg("Machine logged out. Sent AuthURL for reauthentication") Msg("Node logged out. Sent AuthURL for reauthentication")
} }

View File

@ -16,29 +16,29 @@ const (
type contextKey string type contextKey string
const machineNameContextKey = contextKey("machineName") const nodeNameContextKey = contextKey("machineName")
// handlePollCommon is the common code for the legacy and Noise protocols to // handlePollCommon is the common code for the legacy and Noise protocols to
// managed the poll loop. // managed the poll loop.
func (h *Headscale) handlePollCommon( func (h *Headscale) handlePollCommon(
writer http.ResponseWriter, writer http.ResponseWriter,
ctx context.Context, ctx context.Context,
machine *Machine, node *Node,
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
isNoise bool, isNoise bool,
) { ) {
machine.Hostname = mapRequest.Hostinfo.Hostname node.Hostname = mapRequest.Hostinfo.Hostname
machine.HostInfo = HostInfo(*mapRequest.Hostinfo) node.HostInfo = HostInfo(*mapRequest.Hostinfo)
machine.DiscoKey = DiscoPublicKeyStripPrefix(mapRequest.DiscoKey) node.DiscoKey = DiscoPublicKeyStripPrefix(mapRequest.DiscoKey)
now := time.Now().UTC() now := time.Now().UTC()
err := h.processMachineRoutes(machine) err := h.processNodeRoutes(node)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Err(err). Err(err).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Msg("Error processing machine routes") Msg("Error processing node routes")
} }
// update ACLRules with peer informations (to update server tags if necessary) // update ACLRules with peer informations (to update server tags if necessary)
@ -48,17 +48,17 @@ func (h *Headscale) handlePollCommon(
log.Error(). log.Error().
Caller(). Caller().
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Err(err) Err(err)
} }
// update routes with peer information // update routes with peer information
err = h.EnableAutoApprovedRoutes(machine) err = h.EnableAutoApprovedRoutes(node)
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Err(err). Err(err).
Msg("Error running auto approved routes") Msg("Error running auto approved routes")
} }
@ -73,32 +73,32 @@ func (h *Headscale) handlePollCommon(
// The intended use is for clients to discover the DERP map at start-up // The intended use is for clients to discover the DERP map at start-up
// before their first real endpoint update. // before their first real endpoint update.
if !mapRequest.ReadOnly { if !mapRequest.ReadOnly {
machine.Endpoints = mapRequest.Endpoints node.Endpoints = mapRequest.Endpoints
machine.LastSeen = &now node.LastSeen = &now
} }
if err := h.db.Updates(machine).Error; err != nil { if err := h.db.Updates(node).Error; err != nil {
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("node_key", machine.NodeKey). Str("node_key", node.NodeKey).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Err(err). Err(err).
Msg("Failed to persist/update machine in the database") Msg("Failed to persist/update node in the database")
http.Error(writer, "", http.StatusInternalServerError) http.Error(writer, "", http.StatusInternalServerError)
return return
} }
} }
mapResp, err := h.getMapResponseData(mapRequest, machine, isNoise) mapResp, err := h.getMapResponseData(mapRequest, node, isNoise)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("node_key", machine.NodeKey). Str("node_key", node.NodeKey).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Err(err). Err(err).
Msg("Failed to get Map response") Msg("Failed to get Map response")
http.Error(writer, "", http.StatusInternalServerError) http.Error(writer, "", http.StatusInternalServerError)
@ -114,7 +114,7 @@ func (h *Headscale) handlePollCommon(
log.Debug(). log.Debug().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Bool("readOnly", mapRequest.ReadOnly). Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers). Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream). Bool("stream", mapRequest.Stream).
@ -124,7 +124,7 @@ func (h *Headscale) handlePollCommon(
log.Info(). log.Info().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Msg("Client is starting up. Probably interested in a DERP map") Msg("Client is starting up. Probably interested in a DERP map")
writer.Header().Set("Content-Type", "application/json; charset=utf-8") writer.Header().Set("Content-Type", "application/json; charset=utf-8")
@ -155,14 +155,14 @@ func (h *Headscale) handlePollCommon(
log.Trace(). log.Trace().
Caller(). Caller().
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Msg("Loading or creating update channel") Msg("Loading or creating update channel")
const chanSize = 8 const chanSize = 8
updateChan := make(chan struct{}, chanSize) updateChan := make(chan struct{}, chanSize)
pollDataChan := make(chan []byte, chanSize) pollDataChan := make(chan []byte, chanSize)
defer closeChanWithLog(pollDataChan, machine.Hostname, "pollDataChan") defer closeChanWithLog(pollDataChan, node.Hostname, "pollDataChan")
keepAliveChan := make(chan []byte) keepAliveChan := make(chan []byte)
@ -170,7 +170,7 @@ func (h *Headscale) handlePollCommon(
log.Info(). log.Info().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Msg("Client sent endpoint update and is ok with a response without peer list") Msg("Client sent endpoint update and is ok with a response without peer list")
writer.Header().Set("Content-Type", "application/json; charset=utf-8") writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK) writer.WriteHeader(http.StatusOK)
@ -183,7 +183,7 @@ func (h *Headscale) handlePollCommon(
} }
// It sounds like we should update the nodes when we have received a endpoint update // It sounds like we should update the nodes when we have received a endpoint update
// even tho the comments in the tailscale code dont explicitly say so. // even tho the comments in the tailscale code dont explicitly say so.
updateRequestsFromNode.WithLabelValues(machine.User.Name, machine.Hostname, "endpoint-update"). updateRequestsFromNode.WithLabelValues(node.User.Name, node.Hostname, "endpoint-update").
Inc() Inc()
updateChan <- struct{}{} updateChan <- struct{}{}
@ -192,7 +192,7 @@ func (h *Headscale) handlePollCommon(
log.Warn(). log.Warn().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Msg("Ignoring request, don't know how to handle it") Msg("Ignoring request, don't know how to handle it")
http.Error(writer, "", http.StatusBadRequest) http.Error(writer, "", http.StatusBadRequest)
@ -202,28 +202,28 @@ func (h *Headscale) handlePollCommon(
log.Info(). log.Info().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Msg("Client is ready to access the tailnet") Msg("Client is ready to access the tailnet")
log.Info(). log.Info().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Msg("Sending initial map") Msg("Sending initial map")
pollDataChan <- mapResp pollDataChan <- mapResp
log.Info(). log.Info().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Msg("Notifying peers") Msg("Notifying peers")
updateRequestsFromNode.WithLabelValues(machine.User.Name, machine.Hostname, "full-update"). updateRequestsFromNode.WithLabelValues(node.User.Name, node.Hostname, "full-update").
Inc() Inc()
updateChan <- struct{}{} updateChan <- struct{}{}
h.pollNetMapStream( h.pollNetMapStream(
writer, writer,
ctx, ctx,
machine, node,
mapRequest, mapRequest,
pollDataChan, pollDataChan,
keepAliveChan, keepAliveChan,
@ -234,7 +234,7 @@ func (h *Headscale) handlePollCommon(
log.Trace(). log.Trace().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Msg("Finished stream, closing PollNetMap session") Msg("Finished stream, closing PollNetMap session")
} }
@ -243,7 +243,7 @@ func (h *Headscale) handlePollCommon(
func (h *Headscale) pollNetMapStream( func (h *Headscale) pollNetMapStream(
writer http.ResponseWriter, writer http.ResponseWriter,
ctxReq context.Context, ctxReq context.Context,
machine *Machine, node *Node,
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
pollDataChan chan []byte, pollDataChan chan []byte,
keepAliveChan chan []byte, keepAliveChan chan []byte,
@ -253,7 +253,7 @@ func (h *Headscale) pollNetMapStream(
h.pollNetMapStreamWG.Add(1) h.pollNetMapStreamWG.Add(1)
defer h.pollNetMapStreamWG.Done() defer h.pollNetMapStreamWG.Done()
ctx := context.WithValue(ctxReq, machineNameContextKey, machine.Hostname) ctx := context.WithValue(ctxReq, nodeNameContextKey, node.Hostname)
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
@ -263,20 +263,20 @@ func (h *Headscale) pollNetMapStream(
updateChan, updateChan,
keepAliveChan, keepAliveChan,
mapRequest, mapRequest,
machine, node,
isNoise, isNoise,
) )
log.Trace(). log.Trace().
Str("handler", "pollNetMapStream"). Str("handler", "pollNetMapStream").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Msg("Waiting for data to stream...") Msg("Waiting for data to stream...")
log.Trace(). log.Trace().
Str("handler", "pollNetMapStream"). Str("handler", "pollNetMapStream").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Msgf("pollData is %#v, keepAliveChan is %#v, updateChan is %#v", pollDataChan, keepAliveChan, updateChan) Msgf("pollData is %#v, keepAliveChan is %#v, updateChan is %#v", pollDataChan, keepAliveChan, updateChan)
for { for {
@ -285,7 +285,7 @@ func (h *Headscale) pollNetMapStream(
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Str("channel", "pollData"). Str("channel", "pollData").
Int("bytes", len(data)). Int("bytes", len(data)).
Msg("Sending data received via pollData channel") Msg("Sending data received via pollData channel")
@ -294,7 +294,7 @@ func (h *Headscale) pollNetMapStream(
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Str("channel", "pollData"). Str("channel", "pollData").
Err(err). Err(err).
Msg("Cannot write data") Msg("Cannot write data")
@ -308,7 +308,7 @@ func (h *Headscale) pollNetMapStream(
Caller(). Caller().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Str("channel", "pollData"). Str("channel", "pollData").
Msg("Cannot cast writer to http.Flusher") Msg("Cannot cast writer to http.Flusher")
} else { } else {
@ -318,43 +318,43 @@ func (h *Headscale) pollNetMapStream(
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Str("channel", "pollData"). Str("channel", "pollData").
Int("bytes", len(data)). Int("bytes", len(data)).
Msg("Data from pollData channel written successfully") Msg("Data from pollData channel written successfully")
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions // TODO(kradalby): Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from // when an outdated node object is kept alive, e.g. db is update from
// command line, but then overwritten. // command line, but then overwritten.
err = h.UpdateMachineFromDatabase(machine) err = h.UpdateNodeFromDatabase(node)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Str("channel", "pollData"). Str("channel", "pollData").
Err(err). Err(err).
Msg("Cannot update machine from database") Msg("Cannot update node from database")
// client has been removed from database // client has been removed from database
// since the stream opened, terminate connection. // since the stream opened, terminate connection.
return return
} }
now := time.Now().UTC() now := time.Now().UTC()
machine.LastSeen = &now node.LastSeen = &now
lastStateUpdate.WithLabelValues(machine.User.Name, machine.Hostname). lastStateUpdate.WithLabelValues(node.User.Name, node.Hostname).
Set(float64(now.Unix())) Set(float64(now.Unix()))
machine.LastSuccessfulUpdate = &now node.LastSuccessfulUpdate = &now
err = h.TouchMachine(machine) err = h.TouchNode(node)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Str("channel", "pollData"). Str("channel", "pollData").
Err(err). Err(err).
Msg("Cannot update machine LastSuccessfulUpdate") Msg("Cannot update node LastSuccessfulUpdate")
return return
} }
@ -362,15 +362,15 @@ func (h *Headscale) pollNetMapStream(
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Str("channel", "pollData"). Str("channel", "pollData").
Int("bytes", len(data)). Int("bytes", len(data)).
Msg("Machine entry in database updated successfully after sending data") Msg("Node entry in database updated successfully after sending data")
case data := <-keepAliveChan: case data := <-keepAliveChan:
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", machine.Hostname). Str("node", node.Hostname).
Str("channel", "keepAlive"). Str("channel", "keepAlive").
Int("bytes", len(data)). Int("bytes", len(data)).
Msg("Sending keep alive message") Msg("Sending keep alive message")
@ -379,7 +379,7 @@ func (h *Headscale) pollNetMapStream(
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Str("channel", "keepAlive"). Str("channel", "keepAlive").
Err(err). Err(err).
Msg("Cannot write keep alive message") Msg("Cannot write keep alive message")
@ -392,7 +392,7 @@ func (h *Headscale) pollNetMapStream(
Caller(). Caller().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Str("channel", "keepAlive"). Str("channel", "keepAlive").
Msg("Cannot cast writer to http.Flusher") Msg("Cannot cast writer to http.Flusher")
} else { } else {
@ -402,38 +402,38 @@ func (h *Headscale) pollNetMapStream(
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Str("channel", "keepAlive"). Str("channel", "keepAlive").
Int("bytes", len(data)). Int("bytes", len(data)).
Msg("Keep alive sent successfully") Msg("Keep alive sent successfully")
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions // TODO(kradalby): Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from // when an outdated node object is kept alive, e.g. db is update from
// command line, but then overwritten. // command line, but then overwritten.
err = h.UpdateMachineFromDatabase(machine) err = h.UpdateNodeFromDatabase(node)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Str("channel", "keepAlive"). Str("channel", "keepAlive").
Err(err). Err(err).
Msg("Cannot update machine from database") Msg("Cannot update node from database")
// client has been removed from database // client has been removed from database
// since the stream opened, terminate connection. // since the stream opened, terminate connection.
return return
} }
now := time.Now().UTC() now := time.Now().UTC()
machine.LastSeen = &now node.LastSeen = &now
err = h.TouchMachine(machine) err = h.TouchNode(node)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Str("channel", "keepAlive"). Str("channel", "keepAlive").
Err(err). Err(err).
Msg("Cannot update machine LastSeen") Msg("Cannot update node LastSeen")
return return
} }
@ -441,39 +441,39 @@ func (h *Headscale) pollNetMapStream(
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Str("channel", "keepAlive"). Str("channel", "keepAlive").
Int("bytes", len(data)). Int("bytes", len(data)).
Msg("Machine updated successfully after sending keep alive") Msg("Node updated successfully after sending keep alive")
case <-updateChan: case <-updateChan:
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Str("channel", "update"). Str("channel", "update").
Msg("Received a request for update") Msg("Received a request for update")
updateRequestsReceivedOnChannel.WithLabelValues(machine.User.Name, machine.Hostname). updateRequestsReceivedOnChannel.WithLabelValues(node.User.Name, node.Hostname).
Inc() Inc()
if h.isOutdated(machine) { if h.isOutdated(node) {
var lastUpdate time.Time var lastUpdate time.Time
if machine.LastSuccessfulUpdate != nil { if node.LastSuccessfulUpdate != nil {
lastUpdate = *machine.LastSuccessfulUpdate lastUpdate = *node.LastSuccessfulUpdate
} }
log.Debug(). log.Debug().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Time("last_successful_update", lastUpdate). Time("last_successful_update", lastUpdate).
Time("last_state_change", h.getLastStateChange(machine.User)). Time("last_state_change", h.getLastStateChange(node.User)).
Msgf("There has been updates since the last successful update to %s", machine.Hostname) Msgf("There has been updates since the last successful update to %s", node.Hostname)
data, err := h.getMapResponseData(mapRequest, machine, isNoise) data, err := h.getMapResponseData(mapRequest, node, isNoise)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Str("channel", "update"). Str("channel", "update").
Err(err). Err(err).
Msg("Could not get the map update") Msg("Could not get the map update")
@ -485,11 +485,11 @@ func (h *Headscale) pollNetMapStream(
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Str("channel", "update"). Str("channel", "update").
Err(err). Err(err).
Msg("Could not write the map response") Msg("Could not write the map response")
updateRequestsSentToNode.WithLabelValues(machine.User.Name, machine.Hostname, "failed"). updateRequestsSentToNode.WithLabelValues(node.User.Name, node.Hostname, "failed").
Inc() Inc()
return return
@ -501,7 +501,7 @@ func (h *Headscale) pollNetMapStream(
Caller(). Caller().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Str("channel", "update"). Str("channel", "update").
Msg("Cannot cast writer to http.Flusher") Msg("Cannot cast writer to http.Flusher")
} else { } else {
@ -511,10 +511,10 @@ func (h *Headscale) pollNetMapStream(
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Str("channel", "update"). Str("channel", "update").
Msg("Updated Map has been sent") Msg("Updated Map has been sent")
updateRequestsSentToNode.WithLabelValues(machine.User.Name, machine.Hostname, "success"). updateRequestsSentToNode.WithLabelValues(node.User.Name, node.Hostname, "success").
Inc() Inc()
// Keep track of the last successful update, // Keep track of the last successful update,
@ -522,17 +522,17 @@ func (h *Headscale) pollNetMapStream(
// is not picked up by a client and we use this // is not picked up by a client and we use this
// to determine if we should "force" an update. // to determine if we should "force" an update.
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions // TODO(kradalby): Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from // when an outdated node object is kept alive, e.g. db is update from
// command line, but then overwritten. // command line, but then overwritten.
err = h.UpdateMachineFromDatabase(machine) err = h.UpdateNodeFromDatabase(node)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Str("channel", "update"). Str("channel", "update").
Err(err). Err(err).
Msg("Cannot update machine from database") Msg("Cannot update node from database")
// client has been removed from database // client has been removed from database
// since the stream opened, terminate connection. // since the stream opened, terminate connection.
@ -540,69 +540,69 @@ func (h *Headscale) pollNetMapStream(
} }
now := time.Now().UTC() now := time.Now().UTC()
lastStateUpdate.WithLabelValues(machine.User.Name, machine.Hostname). lastStateUpdate.WithLabelValues(node.User.Name, node.Hostname).
Set(float64(now.Unix())) Set(float64(now.Unix()))
machine.LastSuccessfulUpdate = &now node.LastSuccessfulUpdate = &now
err = h.TouchMachine(machine) err = h.TouchNode(node)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Str("channel", "update"). Str("channel", "update").
Err(err). Err(err).
Msg("Cannot update machine LastSuccessfulUpdate") Msg("Cannot update node LastSuccessfulUpdate")
return return
} }
} else { } else {
var lastUpdate time.Time var lastUpdate time.Time
if machine.LastSuccessfulUpdate != nil { if node.LastSuccessfulUpdate != nil {
lastUpdate = *machine.LastSuccessfulUpdate lastUpdate = *node.LastSuccessfulUpdate
} }
log.Trace(). log.Trace().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Time("last_successful_update", lastUpdate). Time("last_successful_update", lastUpdate).
Time("last_state_change", h.getLastStateChange(machine.User)). Time("last_state_change", h.getLastStateChange(node.User)).
Msgf("%s is up to date", machine.Hostname) Msgf("%s is up to date", node.Hostname)
} }
case <-ctx.Done(): case <-ctx.Done():
log.Info(). log.Info().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Str("machine", machine.Hostname). Str("node", node.Hostname).
Msg("The client has closed the connection") Msg("The client has closed the connection")
// TODO: Abstract away all the database calls, this can cause race conditions // TODO: Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from // when an outdated node object is kept alive, e.g. db is update from
// command line, but then overwritten. // command line, but then overwritten.
err := h.UpdateMachineFromDatabase(machine) err := h.UpdateNodeFromDatabase(node)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Str("channel", "Done"). Str("channel", "Done").
Err(err). Err(err).
Msg("Cannot update machine from database") Msg("Cannot update node from database")
// client has been removed from database // client has been removed from database
// since the stream opened, terminate connection. // since the stream opened, terminate connection.
return return
} }
now := time.Now().UTC() now := time.Now().UTC()
machine.LastSeen = &now node.LastSeen = &now
err = h.TouchMachine(machine) err = h.TouchNode(node)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Str("channel", "Done"). Str("channel", "Done").
Err(err). Err(err).
Msg("Cannot update machine LastSeen") Msg("Cannot update node LastSeen")
} }
// The connection has been closed, so we can stop polling. // The connection has been closed, so we can stop polling.
@ -612,7 +612,7 @@ func (h *Headscale) pollNetMapStream(
log.Info(). log.Info().
Str("handler", "PollNetMapStream"). Str("handler", "PollNetMapStream").
Bool("noise", isNoise). Bool("noise", isNoise).
Str("machine", machine.Hostname). Str("node", node.Hostname).
Msg("The long-poll handler is shutting down") Msg("The long-poll handler is shutting down")
return return
@ -625,7 +625,7 @@ func (h *Headscale) scheduledPollWorker(
updateChan chan struct{}, updateChan chan struct{},
keepAliveChan chan []byte, keepAliveChan chan []byte,
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
machine *Machine, node *Node,
isNoise bool, isNoise bool,
) { ) {
keepAliveTicker := time.NewTicker(keepAliveInterval) keepAliveTicker := time.NewTicker(keepAliveInterval)
@ -633,12 +633,12 @@ func (h *Headscale) scheduledPollWorker(
defer closeChanWithLog( defer closeChanWithLog(
updateChan, updateChan,
fmt.Sprint(ctx.Value(machineNameContextKey)), fmt.Sprint(ctx.Value(nodeNameContextKey)),
"updateChan", "updateChan",
) )
defer closeChanWithLog( defer closeChanWithLog(
keepAliveChan, keepAliveChan,
fmt.Sprint(ctx.Value(machineNameContextKey)), fmt.Sprint(ctx.Value(nodeNameContextKey)),
"keepAliveChan", "keepAliveChan",
) )
@ -648,7 +648,7 @@ func (h *Headscale) scheduledPollWorker(
return return
case <-keepAliveTicker.C: case <-keepAliveTicker.C:
data, err := h.getMapKeepAliveResponseData(mapRequest, machine, isNoise) data, err := h.getMapKeepAliveResponseData(mapRequest, node, isNoise)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("func", "keepAlive"). Str("func", "keepAlive").
@ -661,7 +661,7 @@ func (h *Headscale) scheduledPollWorker(
log.Debug(). log.Debug().
Str("func", "keepAlive"). Str("func", "keepAlive").
Str("machine", machine.Hostname). Str("node", node.Hostname).
Bool("noise", isNoise). Bool("noise", isNoise).
Msg("Sending keepalive") Msg("Sending keepalive")
select { select {
@ -673,10 +673,10 @@ func (h *Headscale) scheduledPollWorker(
case <-updateCheckerTicker.C: case <-updateCheckerTicker.C:
log.Debug(). log.Debug().
Str("func", "scheduledPollWorker"). Str("func", "scheduledPollWorker").
Str("machine", machine.Hostname). Str("node", node.Hostname).
Bool("noise", isNoise). Bool("noise", isNoise).
Msg("Sending update request") Msg("Sending update request")
updateRequestsFromNode.WithLabelValues(machine.User.Name, machine.Hostname, "scheduled-update"). updateRequestsFromNode.WithLabelValues(node.User.Name, node.Hostname, "scheduled-update").
Inc() Inc()
select { select {
case updateChan <- struct{}{}: case updateChan <- struct{}{}:
@ -687,10 +687,10 @@ func (h *Headscale) scheduledPollWorker(
} }
} }
func closeChanWithLog[C chan []byte | chan struct{}](channel C, machine, name string) { func closeChanWithLog[C chan []byte | chan struct{}](channel C, node, name string) {
log.Trace(). log.Trace().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("machine", machine). Str("node", node).
Str("channel", "Done"). Str("channel", "Done").
Msg(fmt.Sprintf("Closing %s channel", name)) Msg(fmt.Sprintf("Closing %s channel", name))

View File

@ -14,10 +14,10 @@ import (
func (h *Headscale) getMapResponseData( func (h *Headscale) getMapResponseData(
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
machine *Machine, node *Node,
isNoise bool, isNoise bool,
) ([]byte, error) { ) ([]byte, error) {
mapResponse, err := h.generateMapResponse(mapRequest, machine) mapResponse, err := h.generateMapResponse(mapRequest, node)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -27,7 +27,7 @@ func (h *Headscale) getMapResponseData(
} }
var machineKey key.MachinePublic var machineKey key.MachinePublic
err = machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey))) err = machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(node.MachineKey)))
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().
@ -42,7 +42,7 @@ func (h *Headscale) getMapResponseData(
func (h *Headscale) getMapKeepAliveResponseData( func (h *Headscale) getMapKeepAliveResponseData(
mapRequest tailcfg.MapRequest, mapRequest tailcfg.MapRequest,
machine *Machine, node *Node,
isNoise bool, isNoise bool,
) ([]byte, error) { ) ([]byte, error) {
keepAliveResponse := tailcfg.MapResponse{ keepAliveResponse := tailcfg.MapResponse{
@ -54,7 +54,7 @@ func (h *Headscale) getMapKeepAliveResponseData(
} }
var machineKey key.MachinePublic var machineKey key.MachinePublic
err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey))) err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(node.MachineKey)))
if err != nil { if err != nil {
log.Error(). log.Error().
Caller(). Caller().

View File

@ -12,7 +12,7 @@ import (
"tailscale.com/types/key" "tailscale.com/types/key"
) )
// RegistrationHandler handles the actual registration process of a machine // RegistrationHandler handles the actual registration process of a node
// Endpoint /machine/:mkey. // Endpoint /machine/:mkey.
func (h *Headscale) RegistrationHandler( func (h *Headscale) RegistrationHandler(
writer http.ResponseWriter, writer http.ResponseWriter,
@ -38,7 +38,7 @@ func (h *Headscale) RegistrationHandler(
Caller(). Caller().
Err(err). Err(err).
Msg("Cannot parse machine key") Msg("Cannot parse machine key")
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc() nodeRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
http.Error(writer, "Cannot parse machine key", http.StatusBadRequest) http.Error(writer, "Cannot parse machine key", http.StatusBadRequest)
return return
@ -50,7 +50,7 @@ func (h *Headscale) RegistrationHandler(
Caller(). Caller().
Err(err). Err(err).
Msg("Cannot decode message") Msg("Cannot decode message")
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc() nodeRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
http.Error(writer, "Cannot decode message", http.StatusBadRequest) http.Error(writer, "Cannot decode message", http.StatusBadRequest)
return return

View File

@ -67,12 +67,12 @@ func (h *Headscale) PollNetMapHandler(
return return
} }
machine, err := h.GetMachineByMachineKey(machineKey) node, err := h.GetNodeByMachineKey(machineKey)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn(). log.Warn().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", machineKey.String()) Msgf("Ignoring request, cannot find node with mkey %s", machineKey.String())
http.Error(writer, "", http.StatusUnauthorized) http.Error(writer, "", http.StatusUnauthorized)
@ -80,7 +80,7 @@ func (h *Headscale) PollNetMapHandler(
} }
log.Error(). log.Error().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Msgf("Failed to fetch machine from the database with Machine key: %s", machineKey.String()) Msgf("Failed to fetch node from the database with Machine key: %s", machineKey.String())
http.Error(writer, "", http.StatusInternalServerError) http.Error(writer, "", http.StatusInternalServerError)
return return
@ -89,8 +89,8 @@ func (h *Headscale) PollNetMapHandler(
log.Trace(). log.Trace().
Str("handler", "PollNetMap"). Str("handler", "PollNetMap").
Str("id", machineKeyStr). Str("id", machineKeyStr).
Str("machine", machine.Hostname). Str("machine", node.Hostname).
Msg("A machine is entering polling via the legacy protocol") Msg("A node is entering polling via the legacy protocol")
h.handlePollCommon(writer, req.Context(), machine, mapRequest, false) h.handlePollCommon(writer, req.Context(), node, mapRequest, false)
} }

View File

@ -9,7 +9,7 @@ import (
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
) )
// // NoiseRegistrationHandler handles the actual registration process of a machine. // NoiseRegistrationHandler handles the actual registration process of a node.
func (t *ts2021App) NoiseRegistrationHandler( func (t *ts2021App) NoiseRegistrationHandler(
writer http.ResponseWriter, writer http.ResponseWriter,
req *http.Request, req *http.Request,
@ -27,7 +27,7 @@ func (t *ts2021App) NoiseRegistrationHandler(
Caller(). Caller().
Err(err). Err(err).
Msg("Cannot parse RegisterRequest") Msg("Cannot parse RegisterRequest")
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc() nodeRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
http.Error(writer, "Internal error", http.StatusInternalServerError) http.Error(writer, "Internal error", http.StatusInternalServerError)
return return

View File

@ -41,27 +41,27 @@ func (t *ts2021App) NoisePollNetMapHandler(
return return
} }
machine, err := t.headscale.GetMachineByAnyKey(t.conn.Peer(), mapRequest.NodeKey, key.NodePublic{}) node, err := t.headscale.GetNodeByAnyKey(t.conn.Peer(), mapRequest.NodeKey, key.NodePublic{})
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn(). log.Warn().
Str("handler", "NoisePollNetMap"). Str("handler", "NoisePollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", mapRequest.NodeKey.String()) Msgf("Ignoring request, cannot find node with key %s", mapRequest.NodeKey.String())
http.Error(writer, "Internal error", http.StatusNotFound) http.Error(writer, "Internal error", http.StatusNotFound)
return return
} }
log.Error(). log.Error().
Str("handler", "NoisePollNetMap"). Str("handler", "NoisePollNetMap").
Msgf("Failed to fetch machine from the database with node key: %s", mapRequest.NodeKey.String()) Msgf("Failed to fetch node from the database with node key: %s", mapRequest.NodeKey.String())
http.Error(writer, "Internal error", http.StatusInternalServerError) http.Error(writer, "Internal error", http.StatusInternalServerError)
return return
} }
log.Debug(). log.Debug().
Str("handler", "NoisePollNetMap"). Str("handler", "NoisePollNetMap").
Str("machine", machine.Hostname). Str("node", node.Hostname).
Msg("A machine is entering polling via the Noise protocol") Msg("A node is entering polling via the Noise protocol")
t.headscale.handlePollCommon(writer, req.Context(), machine, mapRequest, true) t.headscale.handlePollCommon(writer, req.Context(), node, mapRequest, true)
} }

View File

@ -32,7 +32,7 @@ var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+")
// User is the way Headscale implements the concept of users in Tailscale // User is the way Headscale implements the concept of users in Tailscale
// //
// At the end of the day, users in Tailscale are some kind of 'bubbles' or users // At the end of the day, users in Tailscale are some kind of 'bubbles' or users
// that contain our machines. // that contain our nodes.
type User struct { type User struct {
gorm.Model gorm.Model
Name string `gorm:"unique"` Name string `gorm:"unique"`
@ -63,18 +63,18 @@ func (h *Headscale) CreateUser(name string) (*User, error) {
} }
// DestroyUser destroys a User. Returns error if the User does // DestroyUser destroys a User. Returns error if the User does
// not exist or if there are machines associated with it. // not exist or if there are nodes associated with it.
func (h *Headscale) DestroyUser(name string) error { func (h *Headscale) DestroyUser(name string) error {
user, err := h.GetUser(name) user, err := h.GetUser(name)
if err != nil { if err != nil {
return ErrUserNotFound return ErrUserNotFound
} }
machines, err := h.ListMachinesByUser(name) nodes, err := h.ListNodesByUser(name)
if err != nil { if err != nil {
return err return err
} }
if len(machines) > 0 { if len(nodes) > 0 {
return ErrUserStillHasNodes return ErrUserStillHasNodes
} }
@ -148,8 +148,8 @@ func (h *Headscale) ListUsers() ([]User, error) {
return users, nil return users, nil
} }
// ListMachinesByUser gets all the nodes in a given user. // ListNodesByUser gets all the nodes in a given user.
func (h *Headscale) ListMachinesByUser(name string) ([]Machine, error) { func (h *Headscale) ListNodesByUser(name string) ([]Node, error) {
err := CheckForFQDNRules(name) err := CheckForFQDNRules(name)
if err != nil { if err != nil {
return nil, err return nil, err
@ -159,16 +159,16 @@ func (h *Headscale) ListMachinesByUser(name string) ([]Machine, error) {
return nil, err return nil, err
} }
machines := []Machine{} nodes := []Node{}
if err := h.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&Machine{UserID: user.ID}).Find(&machines).Error; err != nil { if err := h.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&Node{UserID: user.ID}).Find(&nodes).Error; err != nil {
return nil, err return nil, err
} }
return machines, nil return nodes, nil
} }
// SetMachineUser assigns a Machine to a user. // SetNodeUser assigns a Node to a user.
func (h *Headscale) SetMachineUser(machine *Machine, username string) error { func (h *Headscale) SetNodeUser(node *Node, username string) error {
err := CheckForFQDNRules(username) err := CheckForFQDNRules(username)
if err != nil { if err != nil {
return err return err
@ -177,8 +177,8 @@ func (h *Headscale) SetMachineUser(machine *Machine, username string) error {
if err != nil { if err != nil {
return err return err
} }
machine.User = *user node.User = *user
if result := h.db.Save(&machine); result.Error != nil { if result := h.db.Save(&node); result.Error != nil {
return result.Error return result.Error
} }
@ -212,11 +212,11 @@ func (n *User) toTailscaleLogin() *tailcfg.Login {
} }
func (h *Headscale) getMapResponseUserProfiles( func (h *Headscale) getMapResponseUserProfiles(
machine Machine, node Node,
peers Machines, peers Nodes,
) []tailcfg.UserProfile { ) []tailcfg.UserProfile {
userMap := make(map[string]User) userMap := make(map[string]User)
userMap[machine.User.Name] = machine.User userMap[node.User.Name] = node.User
for _, peer := range peers { for _, peer := range peers {
userMap[peer.User.Name] = peer.User // not worth checking if already is there userMap[peer.User.Name] = peer.User // not worth checking if already is there
} }

View File

@ -139,8 +139,8 @@ func decode(
return nil return nil
} }
func (h *Headscale) getAvailableIPs() (MachineAddresses, error) { func (h *Headscale) getAvailableIPs() (NodeAddresses, error) {
var ips MachineAddresses var ips NodeAddresses
var err error var err error
ipPrefixes := h.cfg.IPPrefixes ipPrefixes := h.cfg.IPPrefixes
for _, ipPrefix := range ipPrefixes { for _, ipPrefix := range ipPrefixes {
@ -201,12 +201,12 @@ func (h *Headscale) getUsedIPs() (*netipx.IPSet, error) {
// but this was quick to get running and it should be enough // but this was quick to get running and it should be enough
// to begin experimenting with a dual stack tailnet. // to begin experimenting with a dual stack tailnet.
var addressesSlices []string var addressesSlices []string
h.db.Model(&Machine{}).Pluck("ip_addresses", &addressesSlices) h.db.Model(&Node{}).Pluck("ip_addresses", &addressesSlices)
var ips netipx.IPSetBuilder var ips netipx.IPSetBuilder
for _, slice := range addressesSlices { for _, slice := range addressesSlices {
var machineAddresses MachineAddresses var nodeAddresses NodeAddresses
err := machineAddresses.Scan(slice) err := nodeAddresses.Scan(slice)
if err != nil { if err != nil {
return &netipx.IPSet{}, fmt.Errorf( return &netipx.IPSet{}, fmt.Errorf(
"failed to read ip from database: %w", "failed to read ip from database: %w",
@ -214,7 +214,7 @@ func (h *Headscale) getUsedIPs() (*netipx.IPSet, error) {
) )
} }
for _, ip := range machineAddresses { for _, ip := range nodeAddresses {
ips.Add(ip) ips.Add(ip)
} }
} }