move all go files from root to hscontrol

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby
2023-05-10 09:23:26 +02:00
committed by Juan Font
parent 22e397e0b6
commit 4a7921ead5
45 changed files with 0 additions and 0 deletions

863
hscontrol/acls.go Normal file
View File

@@ -0,0 +1,863 @@
package headscale
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/netip"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/rs/zerolog/log"
"github.com/tailscale/hujson"
"go4.org/netipx"
"gopkg.in/yaml.v3"
"tailscale.com/envknob"
"tailscale.com/tailcfg"
)
const (
errEmptyPolicy = Error("empty policy")
errInvalidAction = Error("invalid action")
errInvalidGroup = Error("invalid group")
errInvalidTag = Error("invalid tag")
errInvalidPortFormat = Error("invalid port format")
errWildcardIsNeeded = Error("wildcard as port is required for the protocol")
)
const (
Base8 = 8
Base10 = 10
BitSize16 = 16
BitSize32 = 32
BitSize64 = 64
portRangeBegin = 0
portRangeEnd = 65535
expectedTokenItems = 2
)
// For some reason golang.org/x/net/internal/iana is an internal package.
const (
protocolICMP = 1 // Internet Control Message
protocolIGMP = 2 // Internet Group Management
protocolIPv4 = 4 // IPv4 encapsulation
protocolTCP = 6 // Transmission Control
protocolEGP = 8 // Exterior Gateway Protocol
protocolIGP = 9 // any private interior gateway (used by Cisco for their IGRP)
protocolUDP = 17 // User Datagram
protocolGRE = 47 // Generic Routing Encapsulation
protocolESP = 50 // Encap Security Payload
protocolAH = 51 // Authentication Header
protocolIPv6ICMP = 58 // ICMP for IPv6
protocolSCTP = 132 // Stream Control Transmission Protocol
ProtocolFC = 133 // Fibre Channel
)
var featureEnableSSH = envknob.RegisterBool("HEADSCALE_EXPERIMENTAL_FEATURE_SSH")
// LoadACLPolicy loads the ACL policy from the specify path, and generates the ACL rules.
func (h *Headscale) LoadACLPolicy(path string) error {
log.Debug().
Str("func", "LoadACLPolicy").
Str("path", path).
Msg("Loading ACL policy from path")
policyFile, err := os.Open(path)
if err != nil {
return err
}
defer policyFile.Close()
var policy ACLPolicy
policyBytes, err := io.ReadAll(policyFile)
if err != nil {
return err
}
switch filepath.Ext(path) {
case ".yml", ".yaml":
log.Debug().
Str("path", path).
Bytes("file", policyBytes).
Msg("Loading ACLs from YAML")
err := yaml.Unmarshal(policyBytes, &policy)
if err != nil {
return err
}
log.Trace().
Interface("policy", policy).
Msg("Loaded policy from YAML")
default:
ast, err := hujson.Parse(policyBytes)
if err != nil {
return err
}
ast.Standardize()
policyBytes = ast.Pack()
err = json.Unmarshal(policyBytes, &policy)
if err != nil {
return err
}
}
if policy.IsZero() {
return errEmptyPolicy
}
h.aclPolicy = &policy
return h.UpdateACLRules()
}
func (h *Headscale) UpdateACLRules() error {
machines, err := h.ListMachines()
if err != nil {
return err
}
if h.aclPolicy == nil {
return errEmptyPolicy
}
rules, err := h.aclPolicy.generateFilterRules(machines, h.cfg.OIDC.StripEmaildomain)
if err != nil {
return err
}
log.Trace().Interface("ACL", rules).Msg("ACL rules generated")
h.aclRules = rules
if featureEnableSSH() {
sshRules, err := h.generateSSHRules()
if err != nil {
return err
}
log.Trace().Interface("SSH", sshRules).Msg("SSH rules generated")
if h.sshPolicy == nil {
h.sshPolicy = &tailcfg.SSHPolicy{}
}
h.sshPolicy.Rules = sshRules
} else if h.aclPolicy != nil && len(h.aclPolicy.SSHs) > 0 {
log.Info().Msg("SSH ACLs has been defined, but HEADSCALE_EXPERIMENTAL_FEATURE_SSH is not enabled, this is a unstable feature, check docs before activating")
}
return nil
}
// generateFilterRules takes a set of machines and an ACLPolicy and generates a
// set of Tailscale compatible FilterRules used to allow traffic on clients.
func (pol *ACLPolicy) generateFilterRules(
machines []Machine,
stripEmailDomain bool,
) ([]tailcfg.FilterRule, error) {
rules := []tailcfg.FilterRule{}
for index, acl := range pol.ACLs {
if acl.Action != "accept" {
return nil, errInvalidAction
}
srcIPs := []string{}
for srcIndex, src := range acl.Sources {
srcs, err := pol.getIPsFromSource(src, machines, stripEmailDomain)
if err != nil {
log.Error().
Interface("src", src).
Int("ACL index", index).
Int("Src index", srcIndex).
Msgf("Error parsing ACL")
return nil, err
}
srcIPs = append(srcIPs, srcs...)
}
protocols, needsWildcard, err := parseProtocol(acl.Protocol)
if err != nil {
log.Error().
Msgf("Error parsing ACL %d. protocol unknown %s", index, acl.Protocol)
return nil, err
}
destPorts := []tailcfg.NetPortRange{}
for destIndex, dest := range acl.Destinations {
dests, err := pol.getNetPortRangeFromDestination(
dest,
machines,
needsWildcard,
stripEmailDomain,
)
if err != nil {
log.Error().
Interface("dest", dest).
Int("ACL index", index).
Int("dest index", destIndex).
Msgf("Error parsing ACL")
return nil, err
}
destPorts = append(destPorts, dests...)
}
rules = append(rules, tailcfg.FilterRule{
SrcIPs: srcIPs,
DstPorts: destPorts,
IPProto: protocols,
})
}
return rules, nil
}
func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) {
rules := []*tailcfg.SSHRule{}
if h.aclPolicy == nil {
return nil, errEmptyPolicy
}
machines, err := h.ListMachines()
if err != nil {
return nil, err
}
acceptAction := tailcfg.SSHAction{
Message: "",
Reject: false,
Accept: true,
SessionDuration: 0,
AllowAgentForwarding: false,
HoldAndDelegate: "",
AllowLocalPortForwarding: true,
}
rejectAction := tailcfg.SSHAction{
Message: "",
Reject: true,
Accept: false,
SessionDuration: 0,
AllowAgentForwarding: false,
HoldAndDelegate: "",
AllowLocalPortForwarding: false,
}
for index, sshACL := range h.aclPolicy.SSHs {
action := rejectAction
switch sshACL.Action {
case "accept":
action = acceptAction
case "check":
checkAction, err := sshCheckAction(sshACL.CheckPeriod)
if err != nil {
log.Error().
Msgf("Error parsing SSH %d, check action with unparsable duration '%s'", index, sshACL.CheckPeriod)
} else {
action = *checkAction
}
default:
log.Error().
Msgf("Error parsing SSH %d, unknown action '%s'", index, sshACL.Action)
return nil, err
}
principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources))
for innerIndex, rawSrc := range sshACL.Sources {
if isWildcard(rawSrc) {
principals = append(principals, &tailcfg.SSHPrincipal{
Any: true,
})
} else if isGroup(rawSrc) {
users, err := h.aclPolicy.getUsersInGroup(rawSrc, h.cfg.OIDC.StripEmaildomain)
if err != nil {
log.Error().
Msgf("Error parsing SSH %d, Source %d", index, innerIndex)
return nil, err
}
for _, user := range users {
principals = append(principals, &tailcfg.SSHPrincipal{
UserLogin: user,
})
}
} else {
expandedSrcs, err := h.aclPolicy.expandAlias(
machines,
rawSrc,
h.cfg.OIDC.StripEmaildomain,
)
if err != nil {
log.Error().
Msgf("Error parsing SSH %d, Source %d", index, innerIndex)
return nil, err
}
for _, expandedSrc := range expandedSrcs.Prefixes() {
principals = append(principals, &tailcfg.SSHPrincipal{
NodeIP: expandedSrc.Addr().String(),
})
}
}
}
userMap := make(map[string]string, len(sshACL.Users))
for _, user := range sshACL.Users {
userMap[user] = "="
}
rules = append(rules, &tailcfg.SSHRule{
Principals: principals,
SSHUsers: userMap,
Action: &action,
})
}
return rules, nil
}
func sshCheckAction(duration string) (*tailcfg.SSHAction, error) {
sessionLength, err := time.ParseDuration(duration)
if err != nil {
return nil, err
}
return &tailcfg.SSHAction{
Message: "",
Reject: false,
Accept: true,
SessionDuration: sessionLength,
AllowAgentForwarding: false,
HoldAndDelegate: "",
AllowLocalPortForwarding: true,
}, nil
}
// getIPsFromSource returns a set of Source IPs that would be associated
// with the given src alias.
func (pol *ACLPolicy) getIPsFromSource(
src string,
machines []Machine,
stripEmaildomain bool,
) ([]string, error) {
ipSet, err := pol.expandAlias(machines, src, stripEmaildomain)
if err != nil {
return []string{}, err
}
prefixes := []string{}
for _, prefix := range ipSet.Prefixes() {
prefixes = append(prefixes, prefix.String())
}
return prefixes, nil
}
// getNetPortRangeFromDestination returns a set of tailcfg.NetPortRange
// which are associated with the dest alias.
func (pol *ACLPolicy) getNetPortRangeFromDestination(
dest string,
machines []Machine,
needsWildcard bool,
stripEmaildomain bool,
) ([]tailcfg.NetPortRange, error) {
var tokens []string
log.Trace().Str("destination", dest).Msg("generating policy destination")
// Check if there is a IPv4/6:Port combination, IPv6 has more than
// three ":".
tokens = strings.Split(dest, ":")
if len(tokens) < expectedTokenItems || len(tokens) > 3 {
port := tokens[len(tokens)-1]
maybeIPv6Str := strings.TrimSuffix(dest, ":"+port)
log.Trace().Str("maybeIPv6Str", maybeIPv6Str).Msg("")
if maybeIPv6, err := netip.ParseAddr(maybeIPv6Str); err != nil && !maybeIPv6.Is6() {
log.Trace().Err(err).Msg("trying to parse as IPv6")
return nil, fmt.Errorf(
"failed to parse destination, tokens %v: %w",
tokens,
errInvalidPortFormat,
)
} else {
tokens = []string{maybeIPv6Str, port}
}
}
log.Trace().Strs("tokens", tokens).Msg("generating policy destination")
var alias string
// We can have here stuff like:
// git-server:*
// 192.168.1.0/24:22
// fd7a:115c:a1e0::2:22
// fd7a:115c:a1e0::2/128:22
// tag:montreal-webserver:80,443
// tag:api-server:443
// example-host-1:*
if len(tokens) == expectedTokenItems {
alias = tokens[0]
} else {
alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1])
}
expanded, err := pol.expandAlias(
machines,
alias,
stripEmaildomain,
)
if err != nil {
return nil, err
}
ports, err := expandPorts(tokens[len(tokens)-1], needsWildcard)
if err != nil {
return nil, err
}
dests := []tailcfg.NetPortRange{}
for _, dest := range expanded.Prefixes() {
for _, port := range *ports {
pr := tailcfg.NetPortRange{
IP: dest.String(),
Ports: port,
}
dests = append(dests, pr)
}
}
return dests, nil
}
// parseProtocol reads the proto field of the ACL and generates a list of
// protocols that will be allowed, following the IANA IP protocol number
// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml
//
// If the ACL proto field is empty, it allows ICMPv4, ICMPv6, TCP, and UDP,
// as per Tailscale behaviour (see tailcfg.FilterRule).
//
// Also returns a boolean indicating if the protocol
// requires all the destinations to use wildcard as port number (only TCP,
// UDP and SCTP support specifying ports).
func parseProtocol(protocol string) ([]int, bool, error) {
switch protocol {
case "":
return nil, false, nil
case "igmp":
return []int{protocolIGMP}, true, nil
case "ipv4", "ip-in-ip":
return []int{protocolIPv4}, true, nil
case "tcp":
return []int{protocolTCP}, false, nil
case "egp":
return []int{protocolEGP}, true, nil
case "igp":
return []int{protocolIGP}, true, nil
case "udp":
return []int{protocolUDP}, false, nil
case "gre":
return []int{protocolGRE}, true, nil
case "esp":
return []int{protocolESP}, true, nil
case "ah":
return []int{protocolAH}, true, nil
case "sctp":
return []int{protocolSCTP}, false, nil
case "icmp":
return []int{protocolICMP, protocolIPv6ICMP}, true, nil
default:
protocolNumber, err := strconv.Atoi(protocol)
if err != nil {
return nil, false, err
}
needsWildcard := protocolNumber != protocolTCP &&
protocolNumber != protocolUDP &&
protocolNumber != protocolSCTP
return []int{protocolNumber}, needsWildcard, nil
}
}
// expandalias has an input of either
// - a user
// - a group
// - a tag
// - a host
// - an ip
// - a cidr
// and transform these in IPAddresses.
func (pol *ACLPolicy) expandAlias(
machines Machines,
alias string,
stripEmailDomain bool,
) (*netipx.IPSet, error) {
if isWildcard(alias) {
return parseIPSet("*", nil)
}
build := netipx.IPSetBuilder{}
log.Debug().
Str("alias", alias).
Msg("Expanding")
// if alias is a group
if isGroup(alias) {
return pol.getIPsFromGroup(alias, machines, stripEmailDomain)
}
// if alias is a tag
if isTag(alias) {
return pol.getIPsFromTag(alias, machines, stripEmailDomain)
}
// if alias is a user
if ips, err := pol.getIPsForUser(alias, machines, stripEmailDomain); ips != nil {
return ips, err
}
// if alias is an host
// Note, this is recursive.
if h, ok := pol.Hosts[alias]; ok {
log.Trace().Str("host", h.String()).Msg("expandAlias got hosts entry")
return pol.expandAlias(machines, h.String(), stripEmailDomain)
}
// if alias is an IP
if ip, err := netip.ParseAddr(alias); err == nil {
return pol.getIPsFromSingleIP(ip, machines)
}
// if alias is an IP Prefix (CIDR)
if prefix, err := netip.ParsePrefix(alias); err == nil {
return pol.getIPsFromIPPrefix(prefix, machines)
}
log.Warn().Msgf("No IPs found with the alias %v", alias)
return build.IPSet()
}
// excludeCorrectlyTaggedNodes will remove from the list of input nodes the ones
// that are correctly tagged since they should not be listed as being in the user
// we assume in this function that we only have nodes from 1 user.
func excludeCorrectlyTaggedNodes(
aclPolicy *ACLPolicy,
nodes []Machine,
user string,
stripEmailDomain bool,
) []Machine {
out := []Machine{}
tags := []string{}
for tag := range aclPolicy.TagOwners {
owners, _ := getTagOwners(aclPolicy, user, stripEmailDomain)
ns := append(owners, user)
if contains(ns, user) {
tags = append(tags, tag)
}
}
// for each machine if tag is in tags list, don't append it.
for _, machine := range nodes {
hi := machine.GetHostInfo()
found := false
for _, t := range hi.RequestTags {
if contains(tags, t) {
found = true
break
}
}
if len(machine.ForcedTags) > 0 {
found = true
}
if !found {
out = append(out, machine)
}
}
return out
}
func expandPorts(portsStr string, needsWildcard bool) (*[]tailcfg.PortRange, error) {
if isWildcard(portsStr) {
return &[]tailcfg.PortRange{
{First: portRangeBegin, Last: portRangeEnd},
}, nil
}
if needsWildcard {
return nil, errWildcardIsNeeded
}
ports := []tailcfg.PortRange{}
for _, portStr := range strings.Split(portsStr, ",") {
log.Trace().Msgf("parsing portstring: %s", portStr)
rang := strings.Split(portStr, "-")
switch len(rang) {
case 1:
port, err := strconv.ParseUint(rang[0], Base10, BitSize16)
if err != nil {
return nil, err
}
ports = append(ports, tailcfg.PortRange{
First: uint16(port),
Last: uint16(port),
})
case expectedTokenItems:
start, err := strconv.ParseUint(rang[0], Base10, BitSize16)
if err != nil {
return nil, err
}
last, err := strconv.ParseUint(rang[1], Base10, BitSize16)
if err != nil {
return nil, err
}
ports = append(ports, tailcfg.PortRange{
First: uint16(start),
Last: uint16(last),
})
default:
return nil, errInvalidPortFormat
}
}
return &ports, nil
}
func filterMachinesByUser(machines []Machine, user string) []Machine {
out := []Machine{}
for _, machine := range machines {
if machine.User.Name == user {
out = append(out, machine)
}
}
return out
}
// getTagOwners will return a list of user. An owner can be either a user or a group
// a group cannot be composed of groups.
func getTagOwners(
pol *ACLPolicy,
tag string,
stripEmailDomain bool,
) ([]string, error) {
var owners []string
ows, ok := pol.TagOwners[tag]
if !ok {
return []string{}, fmt.Errorf(
"%w. %v isn't owned by a TagOwner. Please add one first. https://tailscale.com/kb/1018/acls/#tag-owners",
errInvalidTag,
tag,
)
}
for _, owner := range ows {
if isGroup(owner) {
gs, err := pol.getUsersInGroup(owner, stripEmailDomain)
if err != nil {
return []string{}, err
}
owners = append(owners, gs...)
} else {
owners = append(owners, owner)
}
}
return owners, nil
}
// getUsersInGroup will return the list of user inside the group
// after some validation.
func (pol *ACLPolicy) getUsersInGroup(
group string,
stripEmailDomain bool,
) ([]string, error) {
users := []string{}
log.Trace().Caller().Interface("pol", pol).Msg("test")
aclGroups, ok := pol.Groups[group]
if !ok {
return []string{}, fmt.Errorf(
"group %v isn't registered. %w",
group,
errInvalidGroup,
)
}
for _, group := range aclGroups {
if isGroup(group) {
return []string{}, fmt.Errorf(
"%w. A group cannot be composed of groups. https://tailscale.com/kb/1018/acls/#groups",
errInvalidGroup,
)
}
grp, err := NormalizeToFQDNRules(group, stripEmailDomain)
if err != nil {
return []string{}, fmt.Errorf(
"failed to normalize group %q, err: %w",
group,
errInvalidGroup,
)
}
users = append(users, grp)
}
return users, nil
}
func (pol *ACLPolicy) getIPsFromGroup(
group string,
machines Machines,
stripEmailDomain bool,
) (*netipx.IPSet, error) {
build := netipx.IPSetBuilder{}
users, err := pol.getUsersInGroup(group, stripEmailDomain)
if err != nil {
return &netipx.IPSet{}, err
}
for _, user := range users {
filteredMachines := filterMachinesByUser(machines, user)
for _, machine := range filteredMachines {
machine.IPAddresses.AppendToIPSet(&build)
}
}
return build.IPSet()
}
func (pol *ACLPolicy) getIPsFromTag(
alias string,
machines Machines,
stripEmailDomain bool,
) (*netipx.IPSet, error) {
build := netipx.IPSetBuilder{}
// check for forced tags
for _, machine := range machines {
if contains(machine.ForcedTags, alias) {
machine.IPAddresses.AppendToIPSet(&build)
}
}
// find tag owners
owners, err := getTagOwners(pol, alias, stripEmailDomain)
if err != nil {
if errors.Is(err, errInvalidTag) {
ipSet, _ := build.IPSet()
if len(ipSet.Prefixes()) == 0 {
return ipSet, fmt.Errorf(
"%w. %v isn't owned by a TagOwner and no forced tags are defined",
errInvalidTag,
alias,
)
}
return build.IPSet()
} else {
return nil, err
}
}
// filter out machines per tag owner
for _, user := range owners {
machines := filterMachinesByUser(machines, user)
for _, machine := range machines {
hi := machine.GetHostInfo()
if contains(hi.RequestTags, alias) {
machine.IPAddresses.AppendToIPSet(&build)
}
}
}
return build.IPSet()
}
func (pol *ACLPolicy) getIPsForUser(
user string,
machines Machines,
stripEmailDomain bool,
) (*netipx.IPSet, error) {
build := netipx.IPSetBuilder{}
filteredMachines := filterMachinesByUser(machines, user)
filteredMachines = excludeCorrectlyTaggedNodes(pol, filteredMachines, user, stripEmailDomain)
// shortcurcuit if we have no machines to get ips from.
if len(filteredMachines) == 0 {
return nil, nil //nolint
}
for _, machine := range filteredMachines {
machine.IPAddresses.AppendToIPSet(&build)
}
return build.IPSet()
}
func (pol *ACLPolicy) getIPsFromSingleIP(
ip netip.Addr,
machines Machines,
) (*netipx.IPSet, error) {
log.Trace().Str("ip", ip.String()).Msg("expandAlias got ip")
matches := machines.FilterByIP(ip)
build := netipx.IPSetBuilder{}
build.Add(ip)
for _, machine := range matches {
machine.IPAddresses.AppendToIPSet(&build)
}
return build.IPSet()
}
func (pol *ACLPolicy) getIPsFromIPPrefix(
prefix netip.Prefix,
machines Machines,
) (*netipx.IPSet, error) {
log.Trace().Str("prefix", prefix.String()).Msg("expandAlias got prefix")
build := netipx.IPSetBuilder{}
build.AddPrefix(prefix)
// This is suboptimal and quite expensive, but if we only add the prefix, we will miss all the relevant IPv6
// addresses for the hosts that belong to tailscale. This doesnt really affect stuff like subnet routers.
for _, machine := range machines {
for _, ip := range machine.IPAddresses {
// log.Trace().
// Msgf("checking if machine ip (%s) is part of prefix (%s): %v, is single ip prefix (%v), addr: %s", ip.String(), prefix.String(), prefix.Contains(ip), prefix.IsSingleIP(), prefix.Addr().String())
if prefix.Contains(ip) {
machine.IPAddresses.AppendToIPSet(&build)
}
}
}
return build.IPSet()
}
func isWildcard(str string) bool {
return str == "*"
}
func isGroup(str string) bool {
return strings.HasPrefix(str, "group:")
}
func isTag(str string) bool {
return strings.HasPrefix(str, "tag:")
}

1834
hscontrol/acls_test.go Normal file

File diff suppressed because it is too large Load Diff

145
hscontrol/acls_types.go Normal file
View File

@@ -0,0 +1,145 @@
package headscale
import (
"encoding/json"
"net/netip"
"strings"
"github.com/tailscale/hujson"
"gopkg.in/yaml.v3"
)
// ACLPolicy represents a Tailscale ACL Policy.
type ACLPolicy struct {
Groups Groups `json:"groups" yaml:"groups"`
Hosts Hosts `json:"hosts" yaml:"hosts"`
TagOwners TagOwners `json:"tagOwners" yaml:"tagOwners"`
ACLs []ACL `json:"acls" yaml:"acls"`
Tests []ACLTest `json:"tests" yaml:"tests"`
AutoApprovers AutoApprovers `json:"autoApprovers" yaml:"autoApprovers"`
SSHs []SSH `json:"ssh" yaml:"ssh"`
}
// ACL is a basic rule for the ACL Policy.
type ACL struct {
Action string `json:"action" yaml:"action"`
Protocol string `json:"proto" yaml:"proto"`
Sources []string `json:"src" yaml:"src"`
Destinations []string `json:"dst" yaml:"dst"`
}
// Groups references a series of alias in the ACL rules.
type Groups map[string][]string
// Hosts are alias for IP addresses or subnets.
type Hosts map[string]netip.Prefix
// TagOwners specify what users (users?) are allow to use certain tags.
type TagOwners map[string][]string
// ACLTest is not implemented, but should be use to check if a certain rule is allowed.
type ACLTest struct {
Source string `json:"src" yaml:"src"`
Accept []string `json:"accept" yaml:"accept"`
Deny []string `json:"deny,omitempty" yaml:"deny,omitempty"`
}
// AutoApprovers specify which users (users?), groups or tags have their advertised routes
// or exit node status automatically enabled.
type AutoApprovers struct {
Routes map[string][]string `json:"routes" yaml:"routes"`
ExitNode []string `json:"exitNode" yaml:"exitNode"`
}
// SSH controls who can ssh into which machines.
type SSH struct {
Action string `json:"action" yaml:"action"`
Sources []string `json:"src" yaml:"src"`
Destinations []string `json:"dst" yaml:"dst"`
Users []string `json:"users" yaml:"users"`
CheckPeriod string `json:"checkPeriod,omitempty" yaml:"checkPeriod,omitempty"`
}
// UnmarshalJSON allows to parse the Hosts directly into netip objects.
func (hosts *Hosts) UnmarshalJSON(data []byte) error {
newHosts := Hosts{}
hostIPPrefixMap := make(map[string]string)
ast, err := hujson.Parse(data)
if err != nil {
return err
}
ast.Standardize()
data = ast.Pack()
err = json.Unmarshal(data, &hostIPPrefixMap)
if err != nil {
return err
}
for host, prefixStr := range hostIPPrefixMap {
if !strings.Contains(prefixStr, "/") {
prefixStr += "/32"
}
prefix, err := netip.ParsePrefix(prefixStr)
if err != nil {
return err
}
newHosts[host] = prefix
}
*hosts = newHosts
return nil
}
// UnmarshalYAML allows to parse the Hosts directly into netip objects.
func (hosts *Hosts) UnmarshalYAML(data []byte) error {
newHosts := Hosts{}
hostIPPrefixMap := make(map[string]string)
err := yaml.Unmarshal(data, &hostIPPrefixMap)
if err != nil {
return err
}
for host, prefixStr := range hostIPPrefixMap {
prefix, err := netip.ParsePrefix(prefixStr)
if err != nil {
return err
}
newHosts[host] = prefix
}
*hosts = newHosts
return nil
}
// IsZero is perhaps a bit naive here.
func (pol ACLPolicy) IsZero() bool {
if len(pol.Groups) == 0 && len(pol.Hosts) == 0 && len(pol.ACLs) == 0 {
return true
}
return false
}
// Returns the list of autoApproving users, groups or tags for a given IPPrefix.
func (autoApprovers *AutoApprovers) GetRouteApprovers(
prefix netip.Prefix,
) ([]string, error) {
if prefix.Bits() == 0 {
return autoApprovers.ExitNode, nil // 0.0.0.0/0, ::/0 or equivalent
}
approverAliases := []string{}
for autoApprovedPrefix, autoApproverAliases := range autoApprovers.Routes {
autoApprovedPrefix, err := netip.ParsePrefix(autoApprovedPrefix)
if err != nil {
return nil, err
}
if prefix.Bits() >= autoApprovedPrefix.Bits() &&
autoApprovedPrefix.Contains(prefix.Masked().Addr()) {
approverAliases = append(approverAliases, autoApproverAliases...)
}
}
return approverAliases, nil
}

168
hscontrol/api.go Normal file
View File

@@ -0,0 +1,168 @@
package headscale
import (
"bytes"
"encoding/json"
"html/template"
"net/http"
"time"
"github.com/gorilla/mux"
"github.com/rs/zerolog/log"
"tailscale.com/types/key"
)
const (
// TODO(juan): remove this once https://github.com/juanfont/headscale/issues/727 is fixed.
registrationHoldoff = time.Second * 5
reservedResponseHeaderSize = 4
RegisterMethodAuthKey = "authkey"
RegisterMethodOIDC = "oidc"
RegisterMethodCLI = "cli"
ErrRegisterMethodCLIDoesNotSupportExpire = Error(
"machines registered with CLI does not support expire",
)
)
func (h *Headscale) HealthHandler(
writer http.ResponseWriter,
req *http.Request,
) {
respond := func(err error) {
writer.Header().Set("Content-Type", "application/health+json; charset=utf-8")
res := struct {
Status string `json:"status"`
}{
Status: "pass",
}
if err != nil {
writer.WriteHeader(http.StatusInternalServerError)
log.Error().Caller().Err(err).Msg("health check failed")
res.Status = "fail"
}
buf, err := json.Marshal(res)
if err != nil {
log.Error().Caller().Err(err).Msg("marshal failed")
}
_, err = writer.Write(buf)
if err != nil {
log.Error().Caller().Err(err).Msg("write failed")
}
}
if err := h.pingDB(req.Context()); err != nil {
respond(err)
return
}
respond(nil)
}
type registerWebAPITemplateConfig struct {
Key string
}
var registerWebAPITemplate = template.Must(
template.New("registerweb").Parse(`
<html>
<head>
<title>Registration - Headscale</title>
</head>
<body>
<h1>headscale</h1>
<h2>Machine registration</h2>
<p>
Run the command below in the headscale server to add this machine to your network:
</p>
<pre><code>headscale nodes register --user USERNAME --key {{.Key}}</code></pre>
</body>
</html>
`))
// RegisterWebAPI shows a simple message in the browser to point to the CLI
// Listens in /register/:nkey.
//
// This is not part of the Tailscale control API, as we could send whatever URL
// in the RegisterResponse.AuthURL field.
func (h *Headscale) RegisterWebAPI(
writer http.ResponseWriter,
req *http.Request,
) {
vars := mux.Vars(req)
nodeKeyStr, ok := vars["nkey"]
if !NodePublicKeyRegex.Match([]byte(nodeKeyStr)) {
log.Warn().Str("node_key", nodeKeyStr).Msg("Invalid node key passed to registration url")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusUnauthorized)
_, err := writer.Write([]byte("Unauthorized"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return
}
// We need to make sure we dont open for XSS style injections, if the parameter that
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
// the template and log an error.
var nodeKey key.NodePublic
err := nodeKey.UnmarshalText(
[]byte(NodePublicKeyEnsurePrefix(nodeKeyStr)),
)
if !ok || nodeKeyStr == "" || err != nil {
log.Warn().Err(err).Msg("Failed to parse incoming nodekey")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("Wrong params"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return
}
var content bytes.Buffer
if err := registerWebAPITemplate.Execute(&content, registerWebAPITemplateConfig{
Key: nodeKeyStr,
}); err != nil {
log.Error().
Str("func", "RegisterWebAPI").
Err(err).
Msg("Could not render register web API template")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, err = writer.Write([]byte("Could not render register web API template"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return
}
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(content.Bytes())
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
}

113
hscontrol/api_common.go Normal file
View File

@@ -0,0 +1,113 @@
package headscale
import (
"time"
"github.com/rs/zerolog/log"
"tailscale.com/tailcfg"
)
func (h *Headscale) generateMapResponse(
mapRequest tailcfg.MapRequest,
machine *Machine,
) (*tailcfg.MapResponse, error) {
log.Trace().
Str("func", "generateMapResponse").
Str("machine", mapRequest.Hostinfo.Hostname).
Msg("Creating Map response")
node, err := h.toNode(*machine, h.cfg.BaseDomain, h.cfg.DNSConfig)
if err != nil {
log.Error().
Caller().
Str("func", "generateMapResponse").
Err(err).
Msg("Cannot convert to node")
return nil, err
}
peers, err := h.getValidPeers(machine)
if err != nil {
log.Error().
Caller().
Str("func", "generateMapResponse").
Err(err).
Msg("Cannot fetch peers")
return nil, err
}
profiles := h.getMapResponseUserProfiles(*machine, peers)
nodePeers, err := h.toNodes(peers, h.cfg.BaseDomain, h.cfg.DNSConfig)
if err != nil {
log.Error().
Caller().
Str("func", "generateMapResponse").
Err(err).
Msg("Failed to convert peers to Tailscale nodes")
return nil, err
}
dnsConfig := getMapResponseDNSConfig(
h.cfg.DNSConfig,
h.cfg.BaseDomain,
*machine,
peers,
)
now := time.Now()
resp := tailcfg.MapResponse{
KeepAlive: false,
Node: node,
// TODO: Only send if updated
DERPMap: h.DERPMap,
// TODO: Only send if updated
Peers: nodePeers,
// TODO(kradalby): Implement:
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L1351-L1374
// PeersChanged
// PeersRemoved
// PeersChangedPatch
// PeerSeenChange
// OnlineChange
// TODO: Only send if updated
DNSConfig: dnsConfig,
// TODO: Only send if updated
Domain: h.cfg.BaseDomain,
// Do not instruct clients to collect services, we do not
// support or do anything with them
CollectServices: "false",
// TODO: Only send if updated
PacketFilter: h.aclRules,
UserProfiles: profiles,
// TODO: Only send if updated
SSHPolicy: h.sshPolicy,
ControlTime: &now,
Debug: &tailcfg.Debug{
DisableLogTail: !h.cfg.LogTail.Enabled,
RandomizeClientPort: h.cfg.RandomizeClientPort,
},
}
log.Trace().
Str("func", "generateMapResponse").
Str("machine", mapRequest.Hostinfo.Hostname).
// Interface("payload", resp).
Msgf("Generated map response: %s", tailMapResponseToString(resp))
return &resp, nil
}

157
hscontrol/api_key.go Normal file
View File

@@ -0,0 +1,157 @@
package headscale
import (
"fmt"
"strings"
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"golang.org/x/crypto/bcrypt"
"google.golang.org/protobuf/types/known/timestamppb"
)
const (
apiPrefixLength = 7
apiKeyLength = 32
ErrAPIKeyFailedToParse = Error("Failed to parse ApiKey")
)
// APIKey describes the datamodel for API keys used to remotely authenticate with
// headscale.
type APIKey struct {
ID uint64 `gorm:"primary_key"`
Prefix string `gorm:"uniqueIndex"`
Hash []byte
CreatedAt *time.Time
Expiration *time.Time
LastSeen *time.Time
}
// CreateAPIKey creates a new ApiKey in a user, and returns it.
func (h *Headscale) CreateAPIKey(
expiration *time.Time,
) (string, *APIKey, error) {
prefix, err := GenerateRandomStringURLSafe(apiPrefixLength)
if err != nil {
return "", nil, err
}
toBeHashed, err := GenerateRandomStringURLSafe(apiKeyLength)
if err != nil {
return "", nil, err
}
// Key to return to user, this will only be visible _once_
keyStr := prefix + "." + toBeHashed
hash, err := bcrypt.GenerateFromPassword([]byte(toBeHashed), bcrypt.DefaultCost)
if err != nil {
return "", nil, err
}
key := APIKey{
Prefix: prefix,
Hash: hash,
Expiration: expiration,
}
if err := h.db.Save(&key).Error; err != nil {
return "", nil, fmt.Errorf("failed to save API key to database: %w", err)
}
return keyStr, &key, nil
}
// ListAPIKeys returns the list of ApiKeys for a user.
func (h *Headscale) ListAPIKeys() ([]APIKey, error) {
keys := []APIKey{}
if err := h.db.Find(&keys).Error; err != nil {
return nil, err
}
return keys, nil
}
// GetAPIKey returns a ApiKey for a given key.
func (h *Headscale) GetAPIKey(prefix string) (*APIKey, error) {
key := APIKey{}
if result := h.db.First(&key, "prefix = ?", prefix); result.Error != nil {
return nil, result.Error
}
return &key, nil
}
// GetAPIKeyByID returns a ApiKey for a given id.
func (h *Headscale) GetAPIKeyByID(id uint64) (*APIKey, error) {
key := APIKey{}
if result := h.db.Find(&APIKey{ID: id}).First(&key); result.Error != nil {
return nil, result.Error
}
return &key, nil
}
// DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey
// does not exist.
func (h *Headscale) DestroyAPIKey(key APIKey) error {
if result := h.db.Unscoped().Delete(key); result.Error != nil {
return result.Error
}
return nil
}
// ExpireAPIKey marks a ApiKey as expired.
func (h *Headscale) ExpireAPIKey(key *APIKey) error {
if err := h.db.Model(&key).Update("Expiration", time.Now()).Error; err != nil {
return err
}
return nil
}
func (h *Headscale) ValidateAPIKey(keyStr string) (bool, error) {
prefix, hash, found := strings.Cut(keyStr, ".")
if !found {
return false, ErrAPIKeyFailedToParse
}
key, err := h.GetAPIKey(prefix)
if err != nil {
return false, fmt.Errorf("failed to validate api key: %w", err)
}
if key.Expiration.Before(time.Now()) {
return false, nil
}
if err := bcrypt.CompareHashAndPassword(key.Hash, []byte(hash)); err != nil {
return false, err
}
return true, nil
}
func (key *APIKey) toProto() *v1.ApiKey {
protoKey := v1.ApiKey{
Id: key.ID,
Prefix: key.Prefix,
}
if key.Expiration != nil {
protoKey.Expiration = timestamppb.New(*key.Expiration)
}
if key.CreatedAt != nil {
protoKey.CreatedAt = timestamppb.New(*key.CreatedAt)
}
if key.LastSeen != nil {
protoKey.LastSeen = timestamppb.New(*key.LastSeen)
}
return &protoKey
}

89
hscontrol/api_key_test.go Normal file
View File

@@ -0,0 +1,89 @@
package headscale
import (
"time"
"gopkg.in/check.v1"
)
func (*Suite) TestCreateAPIKey(c *check.C) {
apiKeyStr, apiKey, err := app.CreateAPIKey(nil)
c.Assert(err, check.IsNil)
c.Assert(apiKey, check.NotNil)
// Did we get a valid key?
c.Assert(apiKey.Prefix, check.NotNil)
c.Assert(apiKey.Hash, check.NotNil)
c.Assert(apiKeyStr, check.Not(check.Equals), "")
_, err = app.ListAPIKeys()
c.Assert(err, check.IsNil)
keys, err := app.ListAPIKeys()
c.Assert(err, check.IsNil)
c.Assert(len(keys), check.Equals, 1)
}
func (*Suite) TestAPIKeyDoesNotExist(c *check.C) {
key, err := app.GetAPIKey("does-not-exist")
c.Assert(err, check.NotNil)
c.Assert(key, check.IsNil)
}
func (*Suite) TestValidateAPIKeyOk(c *check.C) {
nowPlus2 := time.Now().Add(2 * time.Hour)
apiKeyStr, apiKey, err := app.CreateAPIKey(&nowPlus2)
c.Assert(err, check.IsNil)
c.Assert(apiKey, check.NotNil)
valid, err := app.ValidateAPIKey(apiKeyStr)
c.Assert(err, check.IsNil)
c.Assert(valid, check.Equals, true)
}
func (*Suite) TestValidateAPIKeyNotOk(c *check.C) {
nowMinus2 := time.Now().Add(time.Duration(-2) * time.Hour)
apiKeyStr, apiKey, err := app.CreateAPIKey(&nowMinus2)
c.Assert(err, check.IsNil)
c.Assert(apiKey, check.NotNil)
valid, err := app.ValidateAPIKey(apiKeyStr)
c.Assert(err, check.IsNil)
c.Assert(valid, check.Equals, false)
now := time.Now()
apiKeyStrNow, apiKey, err := app.CreateAPIKey(&now)
c.Assert(err, check.IsNil)
c.Assert(apiKey, check.NotNil)
validNow, err := app.ValidateAPIKey(apiKeyStrNow)
c.Assert(err, check.IsNil)
c.Assert(validNow, check.Equals, false)
validSilly, err := app.ValidateAPIKey("nota.validkey")
c.Assert(err, check.NotNil)
c.Assert(validSilly, check.Equals, false)
validWithErr, err := app.ValidateAPIKey("produceerrorkey")
c.Assert(err, check.NotNil)
c.Assert(validWithErr, check.Equals, false)
}
func (*Suite) TestExpireAPIKey(c *check.C) {
nowPlus2 := time.Now().Add(2 * time.Hour)
apiKeyStr, apiKey, err := app.CreateAPIKey(&nowPlus2)
c.Assert(err, check.IsNil)
c.Assert(apiKey, check.NotNil)
valid, err := app.ValidateAPIKey(apiKeyStr)
c.Assert(err, check.IsNil)
c.Assert(valid, check.Equals, true)
err = app.ExpireAPIKey(apiKey)
c.Assert(err, check.IsNil)
c.Assert(apiKey.Expiration, check.NotNil)
notValid, err := app.ValidateAPIKey(apiKeyStr)
c.Assert(err, check.IsNil)
c.Assert(notValid, check.Equals, false)
}

1016
hscontrol/app.go Normal file

File diff suppressed because it is too large Load Diff

61
hscontrol/app_test.go Normal file
View File

@@ -0,0 +1,61 @@
package headscale
import (
"net/netip"
"os"
"testing"
"gopkg.in/check.v1"
)
func Test(t *testing.T) {
check.TestingT(t)
}
var _ = check.Suite(&Suite{})
type Suite struct{}
var (
tmpDir string
app Headscale
)
func (s *Suite) SetUpTest(c *check.C) {
s.ResetDB(c)
}
func (s *Suite) TearDownTest(c *check.C) {
os.RemoveAll(tmpDir)
}
func (s *Suite) ResetDB(c *check.C) {
if len(tmpDir) != 0 {
os.RemoveAll(tmpDir)
}
var err error
tmpDir, err = os.MkdirTemp("", "autoygg-client-test")
if err != nil {
c.Fatal(err)
}
cfg := Config{
IPPrefixes: []netip.Prefix{
netip.MustParsePrefix("10.27.0.0/23"),
},
}
app = Headscale{
cfg: &cfg,
dbType: "sqlite3",
dbString: tmpDir + "/headscale_test.db",
}
err = app.initDB()
if err != nil {
c.Fatal(err)
}
db, err := app.openDB()
if err != nil {
c.Fatal(err)
}
app.db = db
}

674
hscontrol/config.go Normal file
View File

@@ -0,0 +1,674 @@
package headscale
import (
"errors"
"fmt"
"io/fs"
"net/netip"
"net/url"
"os"
"strings"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/prometheus/common/model"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/spf13/viper"
"go4.org/netipx"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
)
const (
tlsALPN01ChallengeType = "TLS-ALPN-01"
http01ChallengeType = "HTTP-01"
JSONLogFormat = "json"
TextLogFormat = "text"
defaultOIDCExpiryTime = 180 * 24 * time.Hour // 180 Days
maxDuration time.Duration = 1<<63 - 1
)
var errOidcMutuallyExclusive = errors.New(
"oidc_client_secret and oidc_client_secret_path are mutually exclusive",
)
// Config contains the initial Headscale configuration.
type Config struct {
ServerURL string
Addr string
MetricsAddr string
GRPCAddr string
GRPCAllowInsecure bool
EphemeralNodeInactivityTimeout time.Duration
NodeUpdateCheckInterval time.Duration
IPPrefixes []netip.Prefix
PrivateKeyPath string
NoisePrivateKeyPath string
BaseDomain string
Log LogConfig
DisableUpdateCheck bool
DERP DERPConfig
DBtype string
DBpath string
DBhost string
DBport int
DBname string
DBuser string
DBpass string
DBssl string
TLS TLSConfig
ACMEURL string
ACMEEmail string
DNSConfig *tailcfg.DNSConfig
UnixSocket string
UnixSocketPermission fs.FileMode
OIDC OIDCConfig
LogTail LogTailConfig
RandomizeClientPort bool
CLI CLIConfig
ACL ACLConfig
}
type TLSConfig struct {
CertPath string
KeyPath string
LetsEncrypt LetsEncryptConfig
}
type LetsEncryptConfig struct {
Listen string
Hostname string
CacheDir string
ChallengeType string
}
type OIDCConfig struct {
OnlyStartIfOIDCIsAvailable bool
Issuer string
ClientID string
ClientSecret string
Scope []string
ExtraParams map[string]string
AllowedDomains []string
AllowedUsers []string
AllowedGroups []string
StripEmaildomain bool
Expiry time.Duration
UseExpiryFromToken bool
}
type DERPConfig struct {
ServerEnabled bool
ServerRegionID int
ServerRegionCode string
ServerRegionName string
STUNAddr string
URLs []url.URL
Paths []string
AutoUpdate bool
UpdateFrequency time.Duration
}
type LogTailConfig struct {
Enabled bool
}
type CLIConfig struct {
Address string
APIKey string
Timeout time.Duration
Insecure bool
}
type ACLConfig struct {
PolicyPath string
}
type LogConfig struct {
Format string
Level zerolog.Level
}
func LoadConfig(path string, isFile bool) error {
if isFile {
viper.SetConfigFile(path)
} else {
viper.SetConfigName("config")
if path == "" {
viper.AddConfigPath("/etc/headscale/")
viper.AddConfigPath("$HOME/.headscale")
viper.AddConfigPath(".")
} else {
// For testing
viper.AddConfigPath(path)
}
}
viper.SetEnvPrefix("headscale")
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
viper.AutomaticEnv()
viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache")
viper.SetDefault("tls_letsencrypt_challenge_type", http01ChallengeType)
viper.SetDefault("log.level", "info")
viper.SetDefault("log.format", TextLogFormat)
viper.SetDefault("dns_config", nil)
viper.SetDefault("dns_config.override_local_dns", true)
viper.SetDefault("derp.server.enabled", false)
viper.SetDefault("derp.server.stun.enabled", true)
viper.SetDefault("unix_socket", "/var/run/headscale/headscale.sock")
viper.SetDefault("unix_socket_permission", "0o770")
viper.SetDefault("grpc_listen_addr", ":50443")
viper.SetDefault("grpc_allow_insecure", false)
viper.SetDefault("cli.timeout", "5s")
viper.SetDefault("cli.insecure", false)
viper.SetDefault("db_ssl", false)
viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"})
viper.SetDefault("oidc.strip_email_domain", true)
viper.SetDefault("oidc.only_start_if_oidc_is_available", true)
viper.SetDefault("oidc.expiry", "180d")
viper.SetDefault("oidc.use_expiry_from_token", false)
viper.SetDefault("logtail.enabled", false)
viper.SetDefault("randomize_client_port", false)
viper.SetDefault("ephemeral_node_inactivity_timeout", "120s")
viper.SetDefault("node_update_check_interval", "10s")
if IsCLIConfigured() {
return nil
}
if err := viper.ReadInConfig(); err != nil {
log.Warn().Err(err).Msg("Failed to read configuration from disk")
return fmt.Errorf("fatal error reading config file: %w", err)
}
// Collect any validation errors and return them all at once
var errorText string
if (viper.GetString("tls_letsencrypt_hostname") != "") &&
((viper.GetString("tls_cert_path") != "") || (viper.GetString("tls_key_path") != "")) {
errorText += "Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both\n"
}
if !viper.IsSet("noise") || viper.GetString("noise.private_key_path") == "" {
errorText += "Fatal config error: headscale now requires a new `noise.private_key_path` field in the config file for the Tailscale v2 protocol\n"
}
if (viper.GetString("tls_letsencrypt_hostname") != "") &&
(viper.GetString("tls_letsencrypt_challenge_type") == tlsALPN01ChallengeType) &&
(!strings.HasSuffix(viper.GetString("listen_addr"), ":443")) {
// this is only a warning because there could be something sitting in front of headscale that redirects the traffic (e.g. an iptables rule)
log.Warn().
Msg("Warning: when using tls_letsencrypt_hostname with TLS-ALPN-01 as challenge type, headscale must be reachable on port 443, i.e. listen_addr should probably end in :443")
}
if (viper.GetString("tls_letsencrypt_challenge_type") != http01ChallengeType) &&
(viper.GetString("tls_letsencrypt_challenge_type") != tlsALPN01ChallengeType) {
errorText += "Fatal config error: the only supported values for tls_letsencrypt_challenge_type are HTTP-01 and TLS-ALPN-01\n"
}
if !strings.HasPrefix(viper.GetString("server_url"), "http://") &&
!strings.HasPrefix(viper.GetString("server_url"), "https://") {
errorText += "Fatal config error: server_url must start with https:// or http://\n"
}
// Minimum inactivity time out is keepalive timeout (60s) plus a few seconds
// to avoid races
minInactivityTimeout, _ := time.ParseDuration("65s")
if viper.GetDuration("ephemeral_node_inactivity_timeout") <= minInactivityTimeout {
errorText += fmt.Sprintf(
"Fatal config error: ephemeral_node_inactivity_timeout (%s) is set too low, must be more than %s",
viper.GetString("ephemeral_node_inactivity_timeout"),
minInactivityTimeout,
)
}
maxNodeUpdateCheckInterval, _ := time.ParseDuration("60s")
if viper.GetDuration("node_update_check_interval") > maxNodeUpdateCheckInterval {
errorText += fmt.Sprintf(
"Fatal config error: node_update_check_interval (%s) is set too high, must be less than %s",
viper.GetString("node_update_check_interval"),
maxNodeUpdateCheckInterval,
)
}
if errorText != "" {
//nolint
return errors.New(strings.TrimSuffix(errorText, "\n"))
} else {
return nil
}
}
func GetTLSConfig() TLSConfig {
return TLSConfig{
LetsEncrypt: LetsEncryptConfig{
Hostname: viper.GetString("tls_letsencrypt_hostname"),
Listen: viper.GetString("tls_letsencrypt_listen"),
CacheDir: AbsolutePathFromConfigPath(
viper.GetString("tls_letsencrypt_cache_dir"),
),
ChallengeType: viper.GetString("tls_letsencrypt_challenge_type"),
},
CertPath: AbsolutePathFromConfigPath(
viper.GetString("tls_cert_path"),
),
KeyPath: AbsolutePathFromConfigPath(
viper.GetString("tls_key_path"),
),
}
}
func GetDERPConfig() DERPConfig {
serverEnabled := viper.GetBool("derp.server.enabled")
serverRegionID := viper.GetInt("derp.server.region_id")
serverRegionCode := viper.GetString("derp.server.region_code")
serverRegionName := viper.GetString("derp.server.region_name")
stunAddr := viper.GetString("derp.server.stun_listen_addr")
if serverEnabled && stunAddr == "" {
log.Fatal().
Msg("derp.server.stun_listen_addr must be set if derp.server.enabled is true")
}
urlStrs := viper.GetStringSlice("derp.urls")
urls := make([]url.URL, len(urlStrs))
for index, urlStr := range urlStrs {
urlAddr, err := url.Parse(urlStr)
if err != nil {
log.Error().
Str("url", urlStr).
Err(err).
Msg("Failed to parse url, ignoring...")
}
urls[index] = *urlAddr
}
paths := viper.GetStringSlice("derp.paths")
autoUpdate := viper.GetBool("derp.auto_update_enabled")
updateFrequency := viper.GetDuration("derp.update_frequency")
return DERPConfig{
ServerEnabled: serverEnabled,
ServerRegionID: serverRegionID,
ServerRegionCode: serverRegionCode,
ServerRegionName: serverRegionName,
STUNAddr: stunAddr,
URLs: urls,
Paths: paths,
AutoUpdate: autoUpdate,
UpdateFrequency: updateFrequency,
}
}
func GetLogTailConfig() LogTailConfig {
enabled := viper.GetBool("logtail.enabled")
return LogTailConfig{
Enabled: enabled,
}
}
func GetACLConfig() ACLConfig {
policyPath := viper.GetString("acl_policy_path")
return ACLConfig{
PolicyPath: policyPath,
}
}
func GetLogConfig() LogConfig {
logLevelStr := viper.GetString("log.level")
logLevel, err := zerolog.ParseLevel(logLevelStr)
if err != nil {
logLevel = zerolog.DebugLevel
}
logFormatOpt := viper.GetString("log.format")
var logFormat string
switch logFormatOpt {
case "json":
logFormat = JSONLogFormat
case "text":
logFormat = TextLogFormat
case "":
logFormat = TextLogFormat
default:
log.Error().
Str("func", "GetLogConfig").
Msgf("Could not parse log format: %s. Valid choices are 'json' or 'text'", logFormatOpt)
}
return LogConfig{
Format: logFormat,
Level: logLevel,
}
}
func GetDNSConfig() (*tailcfg.DNSConfig, string) {
if viper.IsSet("dns_config") {
dnsConfig := &tailcfg.DNSConfig{}
overrideLocalDNS := viper.GetBool("dns_config.override_local_dns")
if viper.IsSet("dns_config.nameservers") {
nameserversStr := viper.GetStringSlice("dns_config.nameservers")
nameservers := []netip.Addr{}
resolvers := []*dnstype.Resolver{}
for _, nameserverStr := range nameserversStr {
// Search for explicit DNS-over-HTTPS resolvers
if strings.HasPrefix(nameserverStr, "https://") {
resolvers = append(resolvers, &dnstype.Resolver{
Addr: nameserverStr,
})
// This nameserver can not be parsed as an IP address
continue
}
// Parse nameserver as a regular IP
nameserver, err := netip.ParseAddr(nameserverStr)
if err != nil {
log.Error().
Str("func", "getDNSConfig").
Err(err).
Msgf("Could not parse nameserver IP: %s", nameserverStr)
}
nameservers = append(nameservers, nameserver)
resolvers = append(resolvers, &dnstype.Resolver{
Addr: nameserver.String(),
})
}
dnsConfig.Nameservers = nameservers
if overrideLocalDNS {
dnsConfig.Resolvers = resolvers
} else {
dnsConfig.FallbackResolvers = resolvers
}
}
if viper.IsSet("dns_config.restricted_nameservers") {
dnsConfig.Routes = make(map[string][]*dnstype.Resolver)
domains := []string{}
restrictedDNS := viper.GetStringMapStringSlice(
"dns_config.restricted_nameservers",
)
for domain, restrictedNameservers := range restrictedDNS {
restrictedResolvers := make(
[]*dnstype.Resolver,
len(restrictedNameservers),
)
for index, nameserverStr := range restrictedNameservers {
nameserver, err := netip.ParseAddr(nameserverStr)
if err != nil {
log.Error().
Str("func", "getDNSConfig").
Err(err).
Msgf("Could not parse restricted nameserver IP: %s", nameserverStr)
}
restrictedResolvers[index] = &dnstype.Resolver{
Addr: nameserver.String(),
}
}
dnsConfig.Routes[domain] = restrictedResolvers
domains = append(domains, domain)
}
dnsConfig.Domains = domains
}
if viper.IsSet("dns_config.domains") {
domains := viper.GetStringSlice("dns_config.domains")
if len(dnsConfig.Resolvers) > 0 {
dnsConfig.Domains = domains
} else if domains != nil {
log.Warn().
Msg("Warning: dns_config.domains is set, but no nameservers are configured. Ignoring domains.")
}
}
if viper.IsSet("dns_config.extra_records") {
var extraRecords []tailcfg.DNSRecord
err := viper.UnmarshalKey("dns_config.extra_records", &extraRecords)
if err != nil {
log.Error().
Str("func", "getDNSConfig").
Err(err).
Msgf("Could not parse dns_config.extra_records")
}
dnsConfig.ExtraRecords = extraRecords
}
if viper.IsSet("dns_config.magic_dns") {
dnsConfig.Proxied = viper.GetBool("dns_config.magic_dns")
}
var baseDomain string
if viper.IsSet("dns_config.base_domain") {
baseDomain = viper.GetString("dns_config.base_domain")
} else {
baseDomain = "headscale.net" // does not really matter when MagicDNS is not enabled
}
return dnsConfig, baseDomain
}
return nil, ""
}
func GetHeadscaleConfig() (*Config, error) {
if IsCLIConfigured() {
return &Config{
CLI: CLIConfig{
Address: viper.GetString("cli.address"),
APIKey: viper.GetString("cli.api_key"),
Timeout: viper.GetDuration("cli.timeout"),
Insecure: viper.GetBool("cli.insecure"),
},
}, nil
}
dnsConfig, baseDomain := GetDNSConfig()
derpConfig := GetDERPConfig()
logConfig := GetLogTailConfig()
randomizeClientPort := viper.GetBool("randomize_client_port")
configuredPrefixes := viper.GetStringSlice("ip_prefixes")
parsedPrefixes := make([]netip.Prefix, 0, len(configuredPrefixes)+1)
for i, prefixInConfig := range configuredPrefixes {
prefix, err := netip.ParsePrefix(prefixInConfig)
if err != nil {
panic(fmt.Errorf("failed to parse ip_prefixes[%d]: %w", i, err))
}
if prefix.Addr().Is4() {
builder := netipx.IPSetBuilder{}
builder.AddPrefix(tsaddr.CGNATRange())
ipSet, _ := builder.IPSet()
if !ipSet.ContainsPrefix(prefix) {
log.Warn().
Msgf("Prefix %s is not in the %s range. This is an unsupported configuration.",
prefixInConfig, tsaddr.CGNATRange())
}
}
if prefix.Addr().Is6() {
builder := netipx.IPSetBuilder{}
builder.AddPrefix(tsaddr.TailscaleULARange())
ipSet, _ := builder.IPSet()
if !ipSet.ContainsPrefix(prefix) {
log.Warn().
Msgf("Prefix %s is not in the %s range. This is an unsupported configuration.",
prefixInConfig, tsaddr.TailscaleULARange())
}
}
parsedPrefixes = append(parsedPrefixes, prefix)
}
prefixes := make([]netip.Prefix, 0, len(parsedPrefixes))
{
// dedup
normalizedPrefixes := make(map[string]int, len(parsedPrefixes))
for i, p := range parsedPrefixes {
normalized, _ := netipx.RangeOfPrefix(p).Prefix()
normalizedPrefixes[normalized.String()] = i
}
// convert back to list
for _, i := range normalizedPrefixes {
prefixes = append(prefixes, parsedPrefixes[i])
}
}
if len(prefixes) < 1 {
prefixes = append(prefixes, netip.MustParsePrefix("100.64.0.0/10"))
log.Warn().
Msgf("'ip_prefixes' not configured, falling back to default: %v", prefixes)
}
oidcClientSecret := viper.GetString("oidc.client_secret")
oidcClientSecretPath := viper.GetString("oidc.client_secret_path")
if oidcClientSecretPath != "" && oidcClientSecret != "" {
return nil, errOidcMutuallyExclusive
}
if oidcClientSecretPath != "" {
secretBytes, err := os.ReadFile(os.ExpandEnv(oidcClientSecretPath))
if err != nil {
return nil, err
}
oidcClientSecret = string(secretBytes)
}
return &Config{
ServerURL: viper.GetString("server_url"),
Addr: viper.GetString("listen_addr"),
MetricsAddr: viper.GetString("metrics_listen_addr"),
GRPCAddr: viper.GetString("grpc_listen_addr"),
GRPCAllowInsecure: viper.GetBool("grpc_allow_insecure"),
DisableUpdateCheck: viper.GetBool("disable_check_updates"),
IPPrefixes: prefixes,
PrivateKeyPath: AbsolutePathFromConfigPath(
viper.GetString("private_key_path"),
),
NoisePrivateKeyPath: AbsolutePathFromConfigPath(
viper.GetString("noise.private_key_path"),
),
BaseDomain: baseDomain,
DERP: derpConfig,
EphemeralNodeInactivityTimeout: viper.GetDuration(
"ephemeral_node_inactivity_timeout",
),
NodeUpdateCheckInterval: viper.GetDuration(
"node_update_check_interval",
),
DBtype: viper.GetString("db_type"),
DBpath: AbsolutePathFromConfigPath(viper.GetString("db_path")),
DBhost: viper.GetString("db_host"),
DBport: viper.GetInt("db_port"),
DBname: viper.GetString("db_name"),
DBuser: viper.GetString("db_user"),
DBpass: viper.GetString("db_pass"),
DBssl: viper.GetString("db_ssl"),
TLS: GetTLSConfig(),
DNSConfig: dnsConfig,
ACMEEmail: viper.GetString("acme_email"),
ACMEURL: viper.GetString("acme_url"),
UnixSocket: viper.GetString("unix_socket"),
UnixSocketPermission: GetFileMode("unix_socket_permission"),
OIDC: OIDCConfig{
OnlyStartIfOIDCIsAvailable: viper.GetBool(
"oidc.only_start_if_oidc_is_available",
),
Issuer: viper.GetString("oidc.issuer"),
ClientID: viper.GetString("oidc.client_id"),
ClientSecret: oidcClientSecret,
Scope: viper.GetStringSlice("oidc.scope"),
ExtraParams: viper.GetStringMapString("oidc.extra_params"),
AllowedDomains: viper.GetStringSlice("oidc.allowed_domains"),
AllowedUsers: viper.GetStringSlice("oidc.allowed_users"),
AllowedGroups: viper.GetStringSlice("oidc.allowed_groups"),
StripEmaildomain: viper.GetBool("oidc.strip_email_domain"),
Expiry: func() time.Duration {
// if set to 0, we assume no expiry
if value := viper.GetString("oidc.expiry"); value == "0" {
return maxDuration
} else {
expiry, err := model.ParseDuration(value)
if err != nil {
log.Warn().Msg("failed to parse oidc.expiry, defaulting back to 180 days")
return defaultOIDCExpiryTime
}
return time.Duration(expiry)
}
}(),
UseExpiryFromToken: viper.GetBool("oidc.use_expiry_from_token"),
},
LogTail: logConfig,
RandomizeClientPort: randomizeClientPort,
ACL: GetACLConfig(),
CLI: CLIConfig{
Address: viper.GetString("cli.address"),
APIKey: viper.GetString("cli.api_key"),
Timeout: viper.GetDuration("cli.timeout"),
Insecure: viper.GetBool("cli.insecure"),
},
Log: GetLogConfig(),
}, nil
}
func IsCLIConfigured() bool {
return viper.GetString("cli.address") != "" && viper.GetString("cli.api_key") != ""
}

404
hscontrol/db.go Normal file
View File

@@ -0,0 +1,404 @@
package headscale
import (
"context"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"net/netip"
"time"
"github.com/glebarez/sqlite"
"github.com/rs/zerolog/log"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"tailscale.com/tailcfg"
)
const (
dbVersion = "1"
errValueNotFound = Error("not found")
ErrCannotParsePrefix = Error("cannot parse prefix")
)
// KV is a key-value store in a psql table. For future use...
type KV struct {
Key string
Value string
}
func (h *Headscale) initDB() error {
db, err := h.openDB()
if err != nil {
return err
}
h.db = db
if h.dbType == Postgres {
db.Exec(`create extension if not exists "uuid-ossp";`)
}
_ = db.Migrator().RenameTable("namespaces", "users")
err = db.AutoMigrate(&User{})
if err != nil {
return err
}
_ = db.Migrator().RenameColumn(&Machine{}, "namespace_id", "user_id")
_ = db.Migrator().RenameColumn(&PreAuthKey{}, "namespace_id", "user_id")
_ = db.Migrator().RenameColumn(&Machine{}, "ip_address", "ip_addresses")
_ = db.Migrator().RenameColumn(&Machine{}, "name", "hostname")
// GivenName is used as the primary source of DNS names, make sure
// the field is populated and normalized if it was not when the
// machine was registered.
_ = db.Migrator().RenameColumn(&Machine{}, "nickname", "given_name")
// If the Machine table has a column for registered,
// find all occourences of "false" and drop them. Then
// remove the column.
if db.Migrator().HasColumn(&Machine{}, "registered") {
log.Info().
Msg(`Database has legacy "registered" column in machine, removing...`)
machines := Machines{}
if err := h.db.Not("registered").Find(&machines).Error; err != nil {
log.Error().Err(err).Msg("Error accessing db")
}
for _, machine := range machines {
log.Info().
Str("machine", machine.Hostname).
Str("machine_key", machine.MachineKey).
Msg("Deleting unregistered machine")
if err := h.db.Delete(&Machine{}, machine.ID).Error; err != nil {
log.Error().
Err(err).
Str("machine", machine.Hostname).
Str("machine_key", machine.MachineKey).
Msg("Error deleting unregistered machine")
}
}
err := db.Migrator().DropColumn(&Machine{}, "registered")
if err != nil {
log.Error().Err(err).Msg("Error dropping registered column")
}
}
err = db.AutoMigrate(&Route{})
if err != nil {
return err
}
if db.Migrator().HasColumn(&Machine{}, "enabled_routes") {
log.Info().Msgf("Database has legacy enabled_routes column in machine, migrating...")
type MachineAux struct {
ID uint64
EnabledRoutes IPPrefixes
}
machinesAux := []MachineAux{}
err := db.Table("machines").Select("id, enabled_routes").Scan(&machinesAux).Error
if err != nil {
log.Fatal().Err(err).Msg("Error accessing db")
}
for _, machine := range machinesAux {
for _, prefix := range machine.EnabledRoutes {
if err != nil {
log.Error().
Err(err).
Str("enabled_route", prefix.String()).
Msg("Error parsing enabled_route")
continue
}
err = db.Preload("Machine").
Where("machine_id = ? AND prefix = ?", machine.ID, IPPrefix(prefix)).
First(&Route{}).
Error
if err == nil {
log.Info().
Str("enabled_route", prefix.String()).
Msg("Route already migrated to new table, skipping")
continue
}
route := Route{
MachineID: machine.ID,
Advertised: true,
Enabled: true,
Prefix: IPPrefix(prefix),
}
if err := h.db.Create(&route).Error; err != nil {
log.Error().Err(err).Msg("Error creating route")
} else {
log.Info().
Uint64("machine_id", route.MachineID).
Str("prefix", prefix.String()).
Msg("Route migrated")
}
}
}
err = db.Migrator().DropColumn(&Machine{}, "enabled_routes")
if err != nil {
log.Error().Err(err).Msg("Error dropping enabled_routes column")
}
}
err = db.AutoMigrate(&Machine{})
if err != nil {
return err
}
if db.Migrator().HasColumn(&Machine{}, "given_name") {
machines := Machines{}
if err := h.db.Find(&machines).Error; err != nil {
log.Error().Err(err).Msg("Error accessing db")
}
for item, machine := range machines {
if machine.GivenName == "" {
normalizedHostname, err := NormalizeToFQDNRules(
machine.Hostname,
h.cfg.OIDC.StripEmaildomain,
)
if err != nil {
log.Error().
Caller().
Str("hostname", machine.Hostname).
Err(err).
Msg("Failed to normalize machine hostname in DB migration")
}
err = h.RenameMachine(&machines[item], normalizedHostname)
if err != nil {
log.Error().
Caller().
Str("hostname", machine.Hostname).
Err(err).
Msg("Failed to save normalized machine name in DB migration")
}
}
}
}
err = db.AutoMigrate(&KV{})
if err != nil {
return err
}
err = db.AutoMigrate(&PreAuthKey{})
if err != nil {
return err
}
err = db.AutoMigrate(&PreAuthKeyACLTag{})
if err != nil {
return err
}
_ = db.Migrator().DropTable("shared_machines")
err = db.AutoMigrate(&APIKey{})
if err != nil {
return err
}
err = h.setValue("db_version", dbVersion)
return err
}
func (h *Headscale) openDB() (*gorm.DB, error) {
var db *gorm.DB
var err error
var log logger.Interface
if h.dbDebug {
log = logger.Default
} else {
log = logger.Default.LogMode(logger.Silent)
}
switch h.dbType {
case Sqlite:
db, err = gorm.Open(
sqlite.Open(h.dbString+"?_synchronous=1&_journal_mode=WAL"),
&gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
Logger: log,
},
)
db.Exec("PRAGMA foreign_keys=ON")
// The pure Go SQLite library does not handle locking in
// the same way as the C based one and we cant use the gorm
// connection pool as of 2022/02/23.
sqlDB, _ := db.DB()
sqlDB.SetMaxIdleConns(1)
sqlDB.SetMaxOpenConns(1)
sqlDB.SetConnMaxIdleTime(time.Hour)
case Postgres:
db, err = gorm.Open(postgres.Open(h.dbString), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
Logger: log,
})
}
if err != nil {
return nil, err
}
return db, nil
}
// getValue returns the value for the given key in KV.
func (h *Headscale) getValue(key string) (string, error) {
var row KV
if result := h.db.First(&row, "key = ?", key); errors.Is(
result.Error,
gorm.ErrRecordNotFound,
) {
return "", errValueNotFound
}
return row.Value, nil
}
// setValue sets value for the given key in KV.
func (h *Headscale) setValue(key string, value string) error {
keyValue := KV{
Key: key,
Value: value,
}
if _, err := h.getValue(key); err == nil {
h.db.Model(&keyValue).Where("key = ?", key).Update("value", value)
return nil
}
if err := h.db.Create(keyValue).Error; err != nil {
return fmt.Errorf("failed to create key value pair in the database: %w", err)
}
return nil
}
func (h *Headscale) pingDB(ctx context.Context) error {
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
db, err := h.db.DB()
if err != nil {
return err
}
return db.PingContext(ctx)
}
// This is a "wrapper" type around tailscales
// Hostinfo to allow us to add database "serialization"
// methods. This allows us to use a typed values throughout
// the code and not have to marshal/unmarshal and error
// check all over the code.
type HostInfo tailcfg.Hostinfo
func (hi *HostInfo) Scan(destination interface{}) error {
switch value := destination.(type) {
case []byte:
return json.Unmarshal(value, hi)
case string:
return json.Unmarshal([]byte(value), hi)
default:
return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination)
}
}
// Value return json value, implement driver.Valuer interface.
func (hi HostInfo) Value() (driver.Value, error) {
bytes, err := json.Marshal(hi)
return string(bytes), err
}
type IPPrefix netip.Prefix
func (i *IPPrefix) Scan(destination interface{}) error {
switch value := destination.(type) {
case string:
prefix, err := netip.ParsePrefix(value)
if err != nil {
return err
}
*i = IPPrefix(prefix)
return nil
default:
return fmt.Errorf("%w: unexpected data type %T", ErrCannotParsePrefix, destination)
}
}
// Value return json value, implement driver.Valuer interface.
func (i IPPrefix) Value() (driver.Value, error) {
prefixStr := netip.Prefix(i).String()
return prefixStr, nil
}
type IPPrefixes []netip.Prefix
func (i *IPPrefixes) Scan(destination interface{}) error {
switch value := destination.(type) {
case []byte:
return json.Unmarshal(value, i)
case string:
return json.Unmarshal([]byte(value), i)
default:
return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination)
}
}
// Value return json value, implement driver.Valuer interface.
func (i IPPrefixes) Value() (driver.Value, error) {
bytes, err := json.Marshal(i)
return string(bytes), err
}
type StringList []string
func (i *StringList) Scan(destination interface{}) error {
switch value := destination.(type) {
case []byte:
return json.Unmarshal(value, i)
case string:
return json.Unmarshal([]byte(value), i)
default:
return fmt.Errorf("%w: unexpected data type %T", ErrMachineAddressesInvalid, destination)
}
}
// Value return json value, implement driver.Valuer interface.
func (i StringList) Value() (driver.Value, error) {
bytes, err := json.Marshal(i)
return string(bytes), err
}

157
hscontrol/derp.go Normal file
View File

@@ -0,0 +1,157 @@
package headscale
import (
"context"
"encoding/json"
"io"
"net/http"
"net/url"
"os"
"time"
"github.com/rs/zerolog/log"
"gopkg.in/yaml.v3"
"tailscale.com/tailcfg"
)
func loadDERPMapFromPath(path string) (*tailcfg.DERPMap, error) {
derpFile, err := os.Open(path)
if err != nil {
return nil, err
}
defer derpFile.Close()
var derpMap tailcfg.DERPMap
b, err := io.ReadAll(derpFile)
if err != nil {
return nil, err
}
err = yaml.Unmarshal(b, &derpMap)
return &derpMap, err
}
func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) {
ctx, cancel := context.WithTimeout(context.Background(), HTTPReadTimeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, addr.String(), nil)
if err != nil {
return nil, err
}
client := http.Client{
Timeout: HTTPReadTimeout,
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var derpMap tailcfg.DERPMap
err = json.Unmarshal(body, &derpMap)
return &derpMap, err
}
// mergeDERPMaps naively merges a list of DERPMaps into a single
// DERPMap, it will _only_ look at the Regions, an integer.
// If a region exists in two of the given DERPMaps, the region
// form the _last_ DERPMap will be preserved.
// An empty DERPMap list will result in a DERPMap with no regions.
func mergeDERPMaps(derpMaps []*tailcfg.DERPMap) *tailcfg.DERPMap {
result := tailcfg.DERPMap{
OmitDefaultRegions: false,
Regions: map[int]*tailcfg.DERPRegion{},
}
for _, derpMap := range derpMaps {
for id, region := range derpMap.Regions {
result.Regions[id] = region
}
}
return &result
}
func GetDERPMap(cfg DERPConfig) *tailcfg.DERPMap {
derpMaps := make([]*tailcfg.DERPMap, 0)
for _, path := range cfg.Paths {
log.Debug().
Str("func", "GetDERPMap").
Str("path", path).
Msg("Loading DERPMap from path")
derpMap, err := loadDERPMapFromPath(path)
if err != nil {
log.Error().
Str("func", "GetDERPMap").
Str("path", path).
Err(err).
Msg("Could not load DERP map from path")
break
}
derpMaps = append(derpMaps, derpMap)
}
for _, addr := range cfg.URLs {
derpMap, err := loadDERPMapFromURL(addr)
log.Debug().
Str("func", "GetDERPMap").
Str("url", addr.String()).
Msg("Loading DERPMap from path")
if err != nil {
log.Error().
Str("func", "GetDERPMap").
Str("url", addr.String()).
Err(err).
Msg("Could not load DERP map from path")
break
}
derpMaps = append(derpMaps, derpMap)
}
derpMap := mergeDERPMaps(derpMaps)
log.Trace().Interface("derpMap", derpMap).Msg("DERPMap loaded")
if len(derpMap.Regions) == 0 {
log.Warn().
Msg("DERP map is empty, not a single DERP map datasource was loaded correctly or contained a region")
}
return derpMap
}
func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) {
log.Info().
Dur("frequency", h.cfg.DERP.UpdateFrequency).
Msg("Setting up a DERPMap update worker")
ticker := time.NewTicker(h.cfg.DERP.UpdateFrequency)
for {
select {
case <-cancelChan:
return
case <-ticker.C:
log.Info().Msg("Fetching DERPMap updates")
h.DERPMap = GetDERPMap(h.cfg.DERP)
if h.cfg.DERP.ServerEnabled {
h.DERPMap.Regions[h.DERPServer.region.RegionID] = &h.DERPServer.region
}
h.setLastStateChangeToNow()
}
}
}

292
hscontrol/derp_server.go Normal file
View File

@@ -0,0 +1,292 @@
package headscale
import (
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"net/netip"
"net/url"
"strconv"
"strings"
"time"
"github.com/rs/zerolog/log"
"tailscale.com/derp"
"tailscale.com/net/stun"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
// fastStartHeader is the header (with value "1") that signals to the HTTP
// server that the DERP HTTP client does not want the HTTP 101 response
// headers and it will begin writing & reading the DERP protocol immediately
// following its HTTP request.
const fastStartHeader = "Derp-Fast-Start"
type DERPServer struct {
tailscaleDERP *derp.Server
region tailcfg.DERPRegion
}
func (h *Headscale) NewDERPServer() (*DERPServer, error) {
log.Trace().Caller().Msg("Creating new embedded DERP server")
server := derp.NewServer(key.NodePrivate(*h.privateKey), log.Info().Msgf)
region, err := h.generateRegionLocalDERP()
if err != nil {
return nil, err
}
return &DERPServer{server, region}, nil
}
func (h *Headscale) generateRegionLocalDERP() (tailcfg.DERPRegion, error) {
serverURL, err := url.Parse(h.cfg.ServerURL)
if err != nil {
return tailcfg.DERPRegion{}, err
}
var host string
var port int
host, portStr, err := net.SplitHostPort(serverURL.Host)
if err != nil {
if serverURL.Scheme == "https" {
host = serverURL.Host
port = 443
} else {
host = serverURL.Host
port = 80
}
} else {
port, err = strconv.Atoi(portStr)
if err != nil {
return tailcfg.DERPRegion{}, err
}
}
localDERPregion := tailcfg.DERPRegion{
RegionID: h.cfg.DERP.ServerRegionID,
RegionCode: h.cfg.DERP.ServerRegionCode,
RegionName: h.cfg.DERP.ServerRegionName,
Avoid: false,
Nodes: []*tailcfg.DERPNode{
{
Name: fmt.Sprintf("%d", h.cfg.DERP.ServerRegionID),
RegionID: h.cfg.DERP.ServerRegionID,
HostName: host,
DERPPort: port,
},
},
}
_, portSTUNStr, err := net.SplitHostPort(h.cfg.DERP.STUNAddr)
if err != nil {
return tailcfg.DERPRegion{}, err
}
portSTUN, err := strconv.Atoi(portSTUNStr)
if err != nil {
return tailcfg.DERPRegion{}, err
}
localDERPregion.Nodes[0].STUNPort = portSTUN
log.Info().Caller().Msgf("DERP region: %+v", localDERPregion)
return localDERPregion, nil
}
func (h *Headscale) DERPHandler(
writer http.ResponseWriter,
req *http.Request,
) {
log.Trace().Caller().Msgf("/derp request from %v", req.RemoteAddr)
upgrade := strings.ToLower(req.Header.Get("Upgrade"))
if upgrade != "websocket" && upgrade != "derp" {
if upgrade != "" {
log.Warn().
Caller().
Msg("No Upgrade header in DERP server request. If headscale is behind a reverse proxy, make sure it is configured to pass WebSockets through.")
}
writer.Header().Set("Content-Type", "text/plain")
writer.WriteHeader(http.StatusUpgradeRequired)
_, err := writer.Write([]byte("DERP requires connection upgrade"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return
}
fastStart := req.Header.Get(fastStartHeader) == "1"
hijacker, ok := writer.(http.Hijacker)
if !ok {
log.Error().Caller().Msg("DERP requires Hijacker interface from Gin")
writer.Header().Set("Content-Type", "text/plain")
writer.WriteHeader(http.StatusInternalServerError)
_, err := writer.Write([]byte("HTTP does not support general TCP support"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return
}
netConn, conn, err := hijacker.Hijack()
if err != nil {
log.Error().Caller().Err(err).Msgf("Hijack failed")
writer.Header().Set("Content-Type", "text/plain")
writer.WriteHeader(http.StatusInternalServerError)
_, err = writer.Write([]byte("HTTP does not support general TCP support"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return
}
log.Trace().Caller().Msgf("Hijacked connection from %v", req.RemoteAddr)
if !fastStart {
pubKey := h.privateKey.Public()
pubKeyStr, _ := pubKey.MarshalText() //nolint
fmt.Fprintf(conn, "HTTP/1.1 101 Switching Protocols\r\n"+
"Upgrade: DERP\r\n"+
"Connection: Upgrade\r\n"+
"Derp-Version: %v\r\n"+
"Derp-Public-Key: %s\r\n\r\n",
derp.ProtocolVersion,
string(pubKeyStr))
}
h.DERPServer.tailscaleDERP.Accept(req.Context(), netConn, conn, netConn.RemoteAddr().String())
}
// DERPProbeHandler is the endpoint that js/wasm clients hit to measure
// DERP latency, since they can't do UDP STUN queries.
func (h *Headscale) DERPProbeHandler(
writer http.ResponseWriter,
req *http.Request,
) {
switch req.Method {
case http.MethodHead, http.MethodGet:
writer.Header().Set("Access-Control-Allow-Origin", "*")
writer.WriteHeader(http.StatusOK)
default:
writer.WriteHeader(http.StatusMethodNotAllowed)
_, err := writer.Write([]byte("bogus probe method"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
}
}
// DERPBootstrapDNSHandler implements the /bootsrap-dns endpoint
// Described in https://github.com/tailscale/tailscale/issues/1405,
// this endpoint provides a way to help a client when it fails to start up
// because its DNS are broken.
// The initial implementation is here https://github.com/tailscale/tailscale/pull/1406
// They have a cache, but not clear if that is really necessary at Headscale, uh, scale.
// An example implementation is found here https://derp.tailscale.com/bootstrap-dns
func (h *Headscale) DERPBootstrapDNSHandler(
writer http.ResponseWriter,
req *http.Request,
) {
dnsEntries := make(map[string][]net.IP)
resolvCtx, cancel := context.WithTimeout(req.Context(), time.Minute)
defer cancel()
var resolver net.Resolver
for _, region := range h.DERPMap.Regions {
for _, node := range region.Nodes { // we don't care if we override some nodes
addrs, err := resolver.LookupIP(resolvCtx, "ip", node.HostName)
if err != nil {
log.Trace().
Caller().
Err(err).
Msgf("bootstrap DNS lookup failed %q", node.HostName)
continue
}
dnsEntries[node.HostName] = addrs
}
}
writer.Header().Set("Content-Type", "application/json")
writer.WriteHeader(http.StatusOK)
err := json.NewEncoder(writer).Encode(dnsEntries)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
}
// ServeSTUN starts a STUN server on the configured addr.
func (h *Headscale) ServeSTUN() {
packetConn, err := net.ListenPacket("udp", h.cfg.DERP.STUNAddr)
if err != nil {
log.Fatal().Msgf("failed to open STUN listener: %v", err)
}
log.Info().Msgf("STUN server started at %s", packetConn.LocalAddr())
udpConn, ok := packetConn.(*net.UDPConn)
if !ok {
log.Fatal().Msg("STUN listener is not a UDP listener")
}
serverSTUNListener(context.Background(), udpConn)
}
func serverSTUNListener(ctx context.Context, packetConn *net.UDPConn) {
var buf [64 << 10]byte
var (
bytesRead int
udpAddr *net.UDPAddr
err error
)
for {
bytesRead, udpAddr, err = packetConn.ReadFromUDP(buf[:])
if err != nil {
if ctx.Err() != nil {
return
}
log.Error().Caller().Err(err).Msgf("STUN ReadFrom")
time.Sleep(time.Second)
continue
}
log.Trace().Caller().Msgf("STUN request from %v", udpAddr)
pkt := buf[:bytesRead]
if !stun.Is(pkt) {
log.Trace().Caller().Msgf("UDP packet is not STUN")
continue
}
txid, err := stun.ParseBindingRequest(pkt)
if err != nil {
log.Trace().Caller().Err(err).Msgf("STUN parse error")
continue
}
addr, _ := netip.AddrFromSlice(udpAddr.IP)
res := stun.Response(txid, netip.AddrPortFrom(addr, uint16(udpAddr.Port)))
_, err = packetConn.WriteTo(res, udpAddr)
if err != nil {
log.Trace().Caller().Err(err).Msgf("Issue writing to UDP")
continue
}
}
}

219
hscontrol/dns.go Normal file
View File

@@ -0,0 +1,219 @@
package headscale
import (
"fmt"
"net/netip"
"net/url"
"strings"
mapset "github.com/deckarep/golang-set/v2"
"go4.org/netipx"
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
"tailscale.com/util/dnsname"
)
const (
ByteSize = 8
)
const (
ipv4AddressLength = 32
ipv6AddressLength = 128
)
const (
nextDNSDoHPrefix = "https://dns.nextdns.io"
)
// generateMagicDNSRootDomains generates a list of DNS entries to be included in `Routes` in `MapResponse`.
// This list of reverse DNS entries instructs the OS on what subnets and domains the Tailscale embedded DNS
// server (listening in 100.100.100.100 udp/53) should be used for.
//
// Tailscale.com includes in the list:
// - the `BaseDomain` of the user
// - the reverse DNS entry for IPv6 (0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa., see below more on IPv6)
// - the reverse DNS entries for the IPv4 subnets covered by the user's `IPPrefix`.
// In the public SaaS this is [64-127].100.in-addr.arpa.
//
// The main purpose of this function is then generating the list of IPv4 entries. For the 100.64.0.0/10, this
// is clear, and could be hardcoded. But we are allowing any range as `IPPrefix`, so we need to find out the
// subnets when we have 172.16.0.0/16 (i.e., [0-255].16.172.in-addr.arpa.), or any other subnet.
//
// How IN-ADDR.ARPA domains work is defined in RFC1035 (section 3.5). Tailscale.com seems to adhere to this,
// and do not make use of RFC2317 ("Classless IN-ADDR.ARPA delegation") - hence generating the entries for the next
// class block only.
// From the netmask we can find out the wildcard bits (the bits that are not set in the netmask).
// This allows us to then calculate the subnets included in the subsequent class block and generate the entries.
func generateMagicDNSRootDomains(ipPrefixes []netip.Prefix) []dnsname.FQDN {
fqdns := make([]dnsname.FQDN, 0, len(ipPrefixes))
for _, ipPrefix := range ipPrefixes {
var generateDNSRoot func(netip.Prefix) []dnsname.FQDN
switch ipPrefix.Addr().BitLen() {
case ipv4AddressLength:
generateDNSRoot = generateIPv4DNSRootDomain
case ipv6AddressLength:
generateDNSRoot = generateIPv6DNSRootDomain
default:
panic(
fmt.Sprintf(
"unsupported IP version with address length %d",
ipPrefix.Addr().BitLen(),
),
)
}
fqdns = append(fqdns, generateDNSRoot(ipPrefix)...)
}
return fqdns
}
func generateIPv4DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
// Conversion to the std lib net.IPnet, a bit easier to operate
netRange := netipx.PrefixIPNet(ipPrefix)
maskBits, _ := netRange.Mask.Size()
// lastOctet is the last IP byte covered by the mask
lastOctet := maskBits / ByteSize
// wildcardBits is the number of bits not under the mask in the lastOctet
wildcardBits := ByteSize - maskBits%ByteSize
// min is the value in the lastOctet byte of the IP
// max is basically 2^wildcardBits - i.e., the value when all the wildcardBits are set to 1
min := uint(netRange.IP[lastOctet])
max := (min + 1<<uint(wildcardBits)) - 1
// here we generate the base domain (e.g., 100.in-addr.arpa., 16.172.in-addr.arpa., etc.)
rdnsSlice := []string{}
for i := lastOctet - 1; i >= 0; i-- {
rdnsSlice = append(rdnsSlice, fmt.Sprintf("%d", netRange.IP[i]))
}
rdnsSlice = append(rdnsSlice, "in-addr.arpa.")
rdnsBase := strings.Join(rdnsSlice, ".")
fqdns := make([]dnsname.FQDN, 0, max-min+1)
for i := min; i <= max; i++ {
fqdn, err := dnsname.ToFQDN(fmt.Sprintf("%d.%s", i, rdnsBase))
if err != nil {
continue
}
fqdns = append(fqdns, fqdn)
}
return fqdns
}
func generateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN {
const nibbleLen = 4
maskBits, _ := netipx.PrefixIPNet(ipPrefix).Mask.Size()
expanded := ipPrefix.Addr().StringExpanded()
nibbleStr := strings.Map(func(r rune) rune {
if r == ':' {
return -1
}
return r
}, expanded)
// TODO?: that does not look the most efficient implementation,
// but the inputs are not so long as to cause problems,
// and from what I can see, the generateMagicDNSRootDomains
// function is called only once over the lifetime of a server process.
prefixConstantParts := []string{}
for i := 0; i < maskBits/nibbleLen; i++ {
prefixConstantParts = append(
[]string{string(nibbleStr[i])},
prefixConstantParts...)
}
makeDomain := func(variablePrefix ...string) (dnsname.FQDN, error) {
prefix := strings.Join(append(variablePrefix, prefixConstantParts...), ".")
return dnsname.ToFQDN(fmt.Sprintf("%s.ip6.arpa", prefix))
}
var fqdns []dnsname.FQDN
if maskBits%4 == 0 {
dom, _ := makeDomain()
fqdns = append(fqdns, dom)
} else {
domCount := 1 << (maskBits % nibbleLen)
fqdns = make([]dnsname.FQDN, 0, domCount)
for i := 0; i < domCount; i++ {
varNibble := fmt.Sprintf("%x", i)
dom, err := makeDomain(varNibble)
if err != nil {
continue
}
fqdns = append(fqdns, dom)
}
}
return fqdns
}
// If any nextdns DoH resolvers are present in the list of resolvers it will
// take metadata from the machine metadata and instruct tailscale to add it
// to the requests. This makes it possible to identify from which device the
// requests come in the NextDNS dashboard.
//
// This will produce a resolver like:
// `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1`
func addNextDNSMetadata(resolvers []*dnstype.Resolver, machine Machine) {
for _, resolver := range resolvers {
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {
attrs := url.Values{
"device_name": []string{machine.Hostname},
"device_model": []string{machine.HostInfo.OS},
}
if len(machine.IPAddresses) > 0 {
attrs.Add("device_ip", machine.IPAddresses[0].String())
}
resolver.Addr = fmt.Sprintf("%s?%s", resolver.Addr, attrs.Encode())
}
}
}
func getMapResponseDNSConfig(
dnsConfigOrig *tailcfg.DNSConfig,
baseDomain string,
machine Machine,
peers Machines,
) *tailcfg.DNSConfig {
var dnsConfig *tailcfg.DNSConfig = dnsConfigOrig.Clone()
if dnsConfigOrig != nil && dnsConfigOrig.Proxied { // if MagicDNS is enabled
// Only inject the Search Domain of the current user - shared nodes should use their full FQDN
dnsConfig.Domains = append(
dnsConfig.Domains,
fmt.Sprintf(
"%s.%s",
machine.User.Name,
baseDomain,
),
)
userSet := mapset.NewSet[User]()
userSet.Add(machine.User)
for _, p := range peers {
userSet.Add(p.User)
}
for _, user := range userSet.ToSlice() {
dnsRoute := fmt.Sprintf("%v.%v", user.Name, baseDomain)
dnsConfig.Routes[dnsRoute] = nil
}
} else {
dnsConfig = dnsConfigOrig
}
addNextDNSMetadata(dnsConfig.Resolvers, machine)
return dnsConfig
}

394
hscontrol/dns_test.go Normal file
View File

@@ -0,0 +1,394 @@
package headscale
import (
"fmt"
"net/netip"
"gopkg.in/check.v1"
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
)
func (s *Suite) TestMagicDNSRootDomains100(c *check.C) {
prefixes := []netip.Prefix{
netip.MustParsePrefix("100.64.0.0/10"),
}
domains := generateMagicDNSRootDomains(prefixes)
found := false
for _, domain := range domains {
if domain == "64.100.in-addr.arpa." {
found = true
break
}
}
c.Assert(found, check.Equals, true)
found = false
for _, domain := range domains {
if domain == "100.100.in-addr.arpa." {
found = true
break
}
}
c.Assert(found, check.Equals, true)
found = false
for _, domain := range domains {
if domain == "127.100.in-addr.arpa." {
found = true
break
}
}
c.Assert(found, check.Equals, true)
}
func (s *Suite) TestMagicDNSRootDomains172(c *check.C) {
prefixes := []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"),
}
domains := generateMagicDNSRootDomains(prefixes)
found := false
for _, domain := range domains {
if domain == "0.16.172.in-addr.arpa." {
found = true
break
}
}
c.Assert(found, check.Equals, true)
found = false
for _, domain := range domains {
if domain == "255.16.172.in-addr.arpa." {
found = true
break
}
}
c.Assert(found, check.Equals, true)
}
// Happens when netmask is a multiple of 4 bits (sounds likely).
func (s *Suite) TestMagicDNSRootDomainsIPv6Single(c *check.C) {
prefixes := []netip.Prefix{
netip.MustParsePrefix("fd7a:115c:a1e0::/48"),
}
domains := generateMagicDNSRootDomains(prefixes)
c.Assert(len(domains), check.Equals, 1)
c.Assert(
domains[0].WithTrailingDot(),
check.Equals,
"0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa.",
)
}
func (s *Suite) TestMagicDNSRootDomainsIPv6SingleMultiple(c *check.C) {
prefixes := []netip.Prefix{
netip.MustParsePrefix("fd7a:115c:a1e0::/50"),
}
domains := generateMagicDNSRootDomains(prefixes)
yieldsRoot := func(dom string) bool {
for _, candidate := range domains {
if candidate.WithTrailingDot() == dom {
return true
}
}
return false
}
c.Assert(len(domains), check.Equals, 4)
c.Assert(yieldsRoot("0.0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa."), check.Equals, true)
c.Assert(yieldsRoot("1.0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa."), check.Equals, true)
c.Assert(yieldsRoot("2.0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa."), check.Equals, true)
c.Assert(yieldsRoot("3.0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa."), check.Equals, true)
}
func (s *Suite) TestDNSConfigMapResponseWithMagicDNS(c *check.C) {
userShared1, err := app.CreateUser("shared1")
c.Assert(err, check.IsNil)
userShared2, err := app.CreateUser("shared2")
c.Assert(err, check.IsNil)
userShared3, err := app.CreateUser("shared3")
c.Assert(err, check.IsNil)
preAuthKeyInShared1, err := app.CreatePreAuthKey(
userShared1.Name,
false,
false,
nil,
nil,
)
c.Assert(err, check.IsNil)
preAuthKeyInShared2, err := app.CreatePreAuthKey(
userShared2.Name,
false,
false,
nil,
nil,
)
c.Assert(err, check.IsNil)
preAuthKeyInShared3, err := app.CreatePreAuthKey(
userShared3.Name,
false,
false,
nil,
nil,
)
c.Assert(err, check.IsNil)
PreAuthKey2InShared1, err := app.CreatePreAuthKey(
userShared1.Name,
false,
false,
nil,
nil,
)
c.Assert(err, check.IsNil)
_, err = app.GetMachine(userShared1.Name, "test_get_shared_nodes_1")
c.Assert(err, check.NotNil)
machineInShared1 := &Machine{
ID: 1,
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
Hostname: "test_get_shared_nodes_1",
UserID: userShared1.ID,
User: *userShared1,
RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
AuthKeyID: uint(preAuthKeyInShared1.ID),
}
app.db.Save(machineInShared1)
_, err = app.GetMachine(userShared1.Name, machineInShared1.Hostname)
c.Assert(err, check.IsNil)
machineInShared2 := &Machine{
ID: 2,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Hostname: "test_get_shared_nodes_2",
UserID: userShared2.ID,
User: *userShared2,
RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
AuthKeyID: uint(preAuthKeyInShared2.ID),
}
app.db.Save(machineInShared2)
_, err = app.GetMachine(userShared2.Name, machineInShared2.Hostname)
c.Assert(err, check.IsNil)
machineInShared3 := &Machine{
ID: 3,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Hostname: "test_get_shared_nodes_3",
UserID: userShared3.ID,
User: *userShared3,
RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
AuthKeyID: uint(preAuthKeyInShared3.ID),
}
app.db.Save(machineInShared3)
_, err = app.GetMachine(userShared3.Name, machineInShared3.Hostname)
c.Assert(err, check.IsNil)
machine2InShared1 := &Machine{
ID: 4,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Hostname: "test_get_shared_nodes_4",
UserID: userShared1.ID,
User: *userShared1,
RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
AuthKeyID: uint(PreAuthKey2InShared1.ID),
}
app.db.Save(machine2InShared1)
baseDomain := "foobar.headscale.net"
dnsConfigOrig := tailcfg.DNSConfig{
Routes: make(map[string][]*dnstype.Resolver),
Domains: []string{baseDomain},
Proxied: true,
}
peersOfMachineInShared1, err := app.getPeers(machineInShared1)
c.Assert(err, check.IsNil)
dnsConfig := getMapResponseDNSConfig(
&dnsConfigOrig,
baseDomain,
*machineInShared1,
peersOfMachineInShared1,
)
c.Assert(dnsConfig, check.NotNil)
c.Assert(len(dnsConfig.Routes), check.Equals, 3)
domainRouteShared1 := fmt.Sprintf("%s.%s", userShared1.Name, baseDomain)
_, ok := dnsConfig.Routes[domainRouteShared1]
c.Assert(ok, check.Equals, true)
domainRouteShared2 := fmt.Sprintf("%s.%s", userShared2.Name, baseDomain)
_, ok = dnsConfig.Routes[domainRouteShared2]
c.Assert(ok, check.Equals, true)
domainRouteShared3 := fmt.Sprintf("%s.%s", userShared3.Name, baseDomain)
_, ok = dnsConfig.Routes[domainRouteShared3]
c.Assert(ok, check.Equals, true)
}
func (s *Suite) TestDNSConfigMapResponseWithoutMagicDNS(c *check.C) {
userShared1, err := app.CreateUser("shared1")
c.Assert(err, check.IsNil)
userShared2, err := app.CreateUser("shared2")
c.Assert(err, check.IsNil)
userShared3, err := app.CreateUser("shared3")
c.Assert(err, check.IsNil)
preAuthKeyInShared1, err := app.CreatePreAuthKey(
userShared1.Name,
false,
false,
nil,
nil,
)
c.Assert(err, check.IsNil)
preAuthKeyInShared2, err := app.CreatePreAuthKey(
userShared2.Name,
false,
false,
nil,
nil,
)
c.Assert(err, check.IsNil)
preAuthKeyInShared3, err := app.CreatePreAuthKey(
userShared3.Name,
false,
false,
nil,
nil,
)
c.Assert(err, check.IsNil)
preAuthKey2InShared1, err := app.CreatePreAuthKey(
userShared1.Name,
false,
false,
nil,
nil,
)
c.Assert(err, check.IsNil)
_, err = app.GetMachine(userShared1.Name, "test_get_shared_nodes_1")
c.Assert(err, check.NotNil)
machineInShared1 := &Machine{
ID: 1,
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
Hostname: "test_get_shared_nodes_1",
UserID: userShared1.ID,
User: *userShared1,
RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
AuthKeyID: uint(preAuthKeyInShared1.ID),
}
app.db.Save(machineInShared1)
_, err = app.GetMachine(userShared1.Name, machineInShared1.Hostname)
c.Assert(err, check.IsNil)
machineInShared2 := &Machine{
ID: 2,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Hostname: "test_get_shared_nodes_2",
UserID: userShared2.ID,
User: *userShared2,
RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
AuthKeyID: uint(preAuthKeyInShared2.ID),
}
app.db.Save(machineInShared2)
_, err = app.GetMachine(userShared2.Name, machineInShared2.Hostname)
c.Assert(err, check.IsNil)
machineInShared3 := &Machine{
ID: 3,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Hostname: "test_get_shared_nodes_3",
UserID: userShared3.ID,
User: *userShared3,
RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
AuthKeyID: uint(preAuthKeyInShared3.ID),
}
app.db.Save(machineInShared3)
_, err = app.GetMachine(userShared3.Name, machineInShared3.Hostname)
c.Assert(err, check.IsNil)
machine2InShared1 := &Machine{
ID: 4,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Hostname: "test_get_shared_nodes_4",
UserID: userShared1.ID,
User: *userShared1,
RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
AuthKeyID: uint(preAuthKey2InShared1.ID),
}
app.db.Save(machine2InShared1)
baseDomain := "foobar.headscale.net"
dnsConfigOrig := tailcfg.DNSConfig{
Routes: make(map[string][]*dnstype.Resolver),
Domains: []string{baseDomain},
Proxied: false,
}
peersOfMachine1Shared1, err := app.getPeers(machineInShared1)
c.Assert(err, check.IsNil)
dnsConfig := getMapResponseDNSConfig(
&dnsConfigOrig,
baseDomain,
*machineInShared1,
peersOfMachine1Shared1,
)
c.Assert(dnsConfig, check.NotNil)
c.Assert(len(dnsConfig.Routes), check.Equals, 0)
c.Assert(len(dnsConfig.Domains), check.Equals, 1)
}

553
hscontrol/grpcv1.go Normal file
View File

@@ -0,0 +1,553 @@
// nolint
package headscale
import (
"context"
"fmt"
"strings"
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/rs/zerolog/log"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
type headscaleV1APIServer struct { // v1.HeadscaleServiceServer
v1.UnimplementedHeadscaleServiceServer
h *Headscale
}
func newHeadscaleV1APIServer(h *Headscale) v1.HeadscaleServiceServer {
return headscaleV1APIServer{
h: h,
}
}
func (api headscaleV1APIServer) GetUser(
ctx context.Context,
request *v1.GetUserRequest,
) (*v1.GetUserResponse, error) {
user, err := api.h.GetUser(request.GetName())
if err != nil {
return nil, err
}
return &v1.GetUserResponse{User: user.toProto()}, nil
}
func (api headscaleV1APIServer) CreateUser(
ctx context.Context,
request *v1.CreateUserRequest,
) (*v1.CreateUserResponse, error) {
user, err := api.h.CreateUser(request.GetName())
if err != nil {
return nil, err
}
return &v1.CreateUserResponse{User: user.toProto()}, nil
}
func (api headscaleV1APIServer) RenameUser(
ctx context.Context,
request *v1.RenameUserRequest,
) (*v1.RenameUserResponse, error) {
err := api.h.RenameUser(request.GetOldName(), request.GetNewName())
if err != nil {
return nil, err
}
user, err := api.h.GetUser(request.GetNewName())
if err != nil {
return nil, err
}
return &v1.RenameUserResponse{User: user.toProto()}, nil
}
func (api headscaleV1APIServer) DeleteUser(
ctx context.Context,
request *v1.DeleteUserRequest,
) (*v1.DeleteUserResponse, error) {
err := api.h.DestroyUser(request.GetName())
if err != nil {
return nil, err
}
return &v1.DeleteUserResponse{}, nil
}
func (api headscaleV1APIServer) ListUsers(
ctx context.Context,
request *v1.ListUsersRequest,
) (*v1.ListUsersResponse, error) {
users, err := api.h.ListUsers()
if err != nil {
return nil, err
}
response := make([]*v1.User, len(users))
for index, user := range users {
response[index] = user.toProto()
}
log.Trace().Caller().Interface("users", response).Msg("")
return &v1.ListUsersResponse{Users: response}, nil
}
func (api headscaleV1APIServer) CreatePreAuthKey(
ctx context.Context,
request *v1.CreatePreAuthKeyRequest,
) (*v1.CreatePreAuthKeyResponse, error) {
var expiration time.Time
if request.GetExpiration() != nil {
expiration = request.GetExpiration().AsTime()
}
for _, tag := range request.AclTags {
err := validateTag(tag)
if err != nil {
return &v1.CreatePreAuthKeyResponse{
PreAuthKey: nil,
}, status.Error(codes.InvalidArgument, err.Error())
}
}
preAuthKey, err := api.h.CreatePreAuthKey(
request.GetUser(),
request.GetReusable(),
request.GetEphemeral(),
&expiration,
request.AclTags,
)
if err != nil {
return nil, err
}
return &v1.CreatePreAuthKeyResponse{PreAuthKey: preAuthKey.toProto()}, nil
}
func (api headscaleV1APIServer) ExpirePreAuthKey(
ctx context.Context,
request *v1.ExpirePreAuthKeyRequest,
) (*v1.ExpirePreAuthKeyResponse, error) {
preAuthKey, err := api.h.GetPreAuthKey(request.GetUser(), request.Key)
if err != nil {
return nil, err
}
err = api.h.ExpirePreAuthKey(preAuthKey)
if err != nil {
return nil, err
}
return &v1.ExpirePreAuthKeyResponse{}, nil
}
func (api headscaleV1APIServer) ListPreAuthKeys(
ctx context.Context,
request *v1.ListPreAuthKeysRequest,
) (*v1.ListPreAuthKeysResponse, error) {
preAuthKeys, err := api.h.ListPreAuthKeys(request.GetUser())
if err != nil {
return nil, err
}
response := make([]*v1.PreAuthKey, len(preAuthKeys))
for index, key := range preAuthKeys {
response[index] = key.toProto()
}
return &v1.ListPreAuthKeysResponse{PreAuthKeys: response}, nil
}
func (api headscaleV1APIServer) RegisterMachine(
ctx context.Context,
request *v1.RegisterMachineRequest,
) (*v1.RegisterMachineResponse, error) {
log.Trace().
Str("user", request.GetUser()).
Str("node_key", request.GetKey()).
Msg("Registering machine")
machine, err := api.h.RegisterMachineFromAuthCallback(
request.GetKey(),
request.GetUser(),
nil,
RegisterMethodCLI,
)
if err != nil {
return nil, err
}
return &v1.RegisterMachineResponse{Machine: machine.toProto()}, nil
}
func (api headscaleV1APIServer) GetMachine(
ctx context.Context,
request *v1.GetMachineRequest,
) (*v1.GetMachineResponse, error) {
machine, err := api.h.GetMachineByID(request.GetMachineId())
if err != nil {
return nil, err
}
return &v1.GetMachineResponse{Machine: machine.toProto()}, nil
}
func (api headscaleV1APIServer) SetTags(
ctx context.Context,
request *v1.SetTagsRequest,
) (*v1.SetTagsResponse, error) {
machine, err := api.h.GetMachineByID(request.GetMachineId())
if err != nil {
return nil, err
}
for _, tag := range request.GetTags() {
err := validateTag(tag)
if err != nil {
return &v1.SetTagsResponse{
Machine: nil,
}, status.Error(codes.InvalidArgument, err.Error())
}
}
err = api.h.SetTags(machine, request.GetTags())
if err != nil {
return &v1.SetTagsResponse{
Machine: nil,
}, status.Error(codes.Internal, err.Error())
}
log.Trace().
Str("machine", machine.Hostname).
Strs("tags", request.GetTags()).
Msg("Changing tags of machine")
return &v1.SetTagsResponse{Machine: machine.toProto()}, nil
}
func validateTag(tag string) error {
if strings.Index(tag, "tag:") != 0 {
return fmt.Errorf("tag must start with the string 'tag:'")
}
if strings.ToLower(tag) != tag {
return fmt.Errorf("tag should be lowercase")
}
if len(strings.Fields(tag)) > 1 {
return fmt.Errorf("tag should not contains space")
}
return nil
}
func (api headscaleV1APIServer) DeleteMachine(
ctx context.Context,
request *v1.DeleteMachineRequest,
) (*v1.DeleteMachineResponse, error) {
machine, err := api.h.GetMachineByID(request.GetMachineId())
if err != nil {
return nil, err
}
err = api.h.DeleteMachine(
machine,
)
if err != nil {
return nil, err
}
return &v1.DeleteMachineResponse{}, nil
}
func (api headscaleV1APIServer) ExpireMachine(
ctx context.Context,
request *v1.ExpireMachineRequest,
) (*v1.ExpireMachineResponse, error) {
machine, err := api.h.GetMachineByID(request.GetMachineId())
if err != nil {
return nil, err
}
api.h.ExpireMachine(
machine,
)
log.Trace().
Str("machine", machine.Hostname).
Time("expiry", *machine.Expiry).
Msg("machine expired")
return &v1.ExpireMachineResponse{Machine: machine.toProto()}, nil
}
func (api headscaleV1APIServer) RenameMachine(
ctx context.Context,
request *v1.RenameMachineRequest,
) (*v1.RenameMachineResponse, error) {
machine, err := api.h.GetMachineByID(request.GetMachineId())
if err != nil {
return nil, err
}
err = api.h.RenameMachine(
machine,
request.GetNewName(),
)
if err != nil {
return nil, err
}
log.Trace().
Str("machine", machine.Hostname).
Str("new_name", request.GetNewName()).
Msg("machine renamed")
return &v1.RenameMachineResponse{Machine: machine.toProto()}, nil
}
func (api headscaleV1APIServer) ListMachines(
ctx context.Context,
request *v1.ListMachinesRequest,
) (*v1.ListMachinesResponse, error) {
if request.GetUser() != "" {
machines, err := api.h.ListMachinesByUser(request.GetUser())
if err != nil {
return nil, err
}
response := make([]*v1.Machine, len(machines))
for index, machine := range machines {
response[index] = machine.toProto()
}
return &v1.ListMachinesResponse{Machines: response}, nil
}
machines, err := api.h.ListMachines()
if err != nil {
return nil, err
}
response := make([]*v1.Machine, len(machines))
for index, machine := range machines {
m := machine.toProto()
validTags, invalidTags := getTags(
api.h.aclPolicy,
machine,
api.h.cfg.OIDC.StripEmaildomain,
)
m.InvalidTags = invalidTags
m.ValidTags = validTags
response[index] = m
}
return &v1.ListMachinesResponse{Machines: response}, nil
}
func (api headscaleV1APIServer) MoveMachine(
ctx context.Context,
request *v1.MoveMachineRequest,
) (*v1.MoveMachineResponse, error) {
machine, err := api.h.GetMachineByID(request.GetMachineId())
if err != nil {
return nil, err
}
err = api.h.SetMachineUser(machine, request.GetUser())
if err != nil {
return nil, err
}
return &v1.MoveMachineResponse{Machine: machine.toProto()}, nil
}
func (api headscaleV1APIServer) GetRoutes(
ctx context.Context,
request *v1.GetRoutesRequest,
) (*v1.GetRoutesResponse, error) {
routes, err := api.h.GetRoutes()
if err != nil {
return nil, err
}
return &v1.GetRoutesResponse{
Routes: Routes(routes).toProto(),
}, nil
}
func (api headscaleV1APIServer) EnableRoute(
ctx context.Context,
request *v1.EnableRouteRequest,
) (*v1.EnableRouteResponse, error) {
err := api.h.EnableRoute(request.GetRouteId())
if err != nil {
return nil, err
}
return &v1.EnableRouteResponse{}, nil
}
func (api headscaleV1APIServer) DisableRoute(
ctx context.Context,
request *v1.DisableRouteRequest,
) (*v1.DisableRouteResponse, error) {
err := api.h.DisableRoute(request.GetRouteId())
if err != nil {
return nil, err
}
return &v1.DisableRouteResponse{}, nil
}
func (api headscaleV1APIServer) GetMachineRoutes(
ctx context.Context,
request *v1.GetMachineRoutesRequest,
) (*v1.GetMachineRoutesResponse, error) {
machine, err := api.h.GetMachineByID(request.GetMachineId())
if err != nil {
return nil, err
}
routes, err := api.h.GetMachineRoutes(machine)
if err != nil {
return nil, err
}
return &v1.GetMachineRoutesResponse{
Routes: Routes(routes).toProto(),
}, nil
}
func (api headscaleV1APIServer) DeleteRoute(
ctx context.Context,
request *v1.DeleteRouteRequest,
) (*v1.DeleteRouteResponse, error) {
err := api.h.DeleteRoute(request.GetRouteId())
if err != nil {
return nil, err
}
return &v1.DeleteRouteResponse{}, nil
}
func (api headscaleV1APIServer) CreateApiKey(
ctx context.Context,
request *v1.CreateApiKeyRequest,
) (*v1.CreateApiKeyResponse, error) {
var expiration time.Time
if request.GetExpiration() != nil {
expiration = request.GetExpiration().AsTime()
}
apiKey, _, err := api.h.CreateAPIKey(
&expiration,
)
if err != nil {
return nil, err
}
return &v1.CreateApiKeyResponse{ApiKey: apiKey}, nil
}
func (api headscaleV1APIServer) ExpireApiKey(
ctx context.Context,
request *v1.ExpireApiKeyRequest,
) (*v1.ExpireApiKeyResponse, error) {
var apiKey *APIKey
var err error
apiKey, err = api.h.GetAPIKey(request.Prefix)
if err != nil {
return nil, err
}
err = api.h.ExpireAPIKey(apiKey)
if err != nil {
return nil, err
}
return &v1.ExpireApiKeyResponse{}, nil
}
func (api headscaleV1APIServer) ListApiKeys(
ctx context.Context,
request *v1.ListApiKeysRequest,
) (*v1.ListApiKeysResponse, error) {
apiKeys, err := api.h.ListAPIKeys()
if err != nil {
return nil, err
}
response := make([]*v1.ApiKey, len(apiKeys))
for index, key := range apiKeys {
response[index] = key.toProto()
}
return &v1.ListApiKeysResponse{ApiKeys: response}, nil
}
// The following service calls are for testing and debugging
func (api headscaleV1APIServer) DebugCreateMachine(
ctx context.Context,
request *v1.DebugCreateMachineRequest,
) (*v1.DebugCreateMachineResponse, error) {
user, err := api.h.GetUser(request.GetUser())
if err != nil {
return nil, err
}
routes, err := stringToIPPrefix(request.GetRoutes())
if err != nil {
return nil, err
}
log.Trace().
Caller().
Interface("route-prefix", routes).
Interface("route-str", request.GetRoutes()).
Msg("")
hostinfo := tailcfg.Hostinfo{
RoutableIPs: routes,
OS: "TestOS",
Hostname: "DebugTestMachine",
}
givenName, err := api.h.GenerateGivenName(request.GetKey(), request.GetName())
if err != nil {
return nil, err
}
newMachine := Machine{
MachineKey: request.GetKey(),
Hostname: request.GetName(),
GivenName: givenName,
User: *user,
Expiry: &time.Time{},
LastSeen: &time.Time{},
LastSuccessfulUpdate: &time.Time{},
HostInfo: HostInfo(hostinfo),
}
nodeKey := key.NodePublic{}
err = nodeKey.UnmarshalText([]byte(request.GetKey()))
if err != nil {
log.Panic().Msg("can not add machine for debug. invalid node key")
}
api.h.registrationCache.Set(
NodePublicKeyStripPrefix(nodeKey),
newMachine,
registerCacheExpiration,
)
return &v1.DebugCreateMachineResponse{Machine: newMachine.toProto()}, nil
}
func (api headscaleV1APIServer) mustEmbedUnimplementedHeadscaleServiceServer() {}

42
hscontrol/grpcv1_test.go Normal file
View File

@@ -0,0 +1,42 @@
package headscale
import "testing"
func Test_validateTag(t *testing.T) {
type args struct {
tag string
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "valid tag",
args: args{tag: "tag:test"},
wantErr: false,
},
{
name: "tag without tag prefix",
args: args{tag: "test"},
wantErr: true,
},
{
name: "uppercase tag",
args: args{tag: "tag:tEST"},
wantErr: true,
},
{
name: "tag that contains space",
args: args{tag: "tag:this is a spaced tag"},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := validateTag(tt.args.tag); (err != nil) != tt.wantErr {
t.Errorf("validateTag() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View File

@@ -0,0 +1,15 @@
//go:build ts2019
package headscale
import (
"net/http"
"github.com/gorilla/mux"
)
func (h *Headscale) addLegacyHandlers(router *mux.Router) {
router.HandleFunc("/machine/{mkey}/map", h.PollNetMapHandler).
Methods(http.MethodPost)
router.HandleFunc("/machine/{mkey}", h.RegistrationHandler).Methods(http.MethodPost)
}

View File

@@ -0,0 +1,8 @@
//go:build !ts2019
package headscale
import "github.com/gorilla/mux"
func (h *Headscale) addLegacyHandlers(router *mux.Router) {
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,271 @@
// nolint
package headscale
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"net/netip"
"os"
"strconv"
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
)
const (
headscaleNetwork = "headscale-test"
headscaleHostname = "headscale"
DOCKER_EXECUTE_TIMEOUT = 10 * time.Second
)
var (
errEnvVarEmpty = errors.New("getenv: environment variable empty")
IpPrefix4 = netip.MustParsePrefix("100.64.0.0/10")
IpPrefix6 = netip.MustParsePrefix("fd7a:115c:a1e0::/48")
tailscaleVersions = []string{
"head",
"unstable",
"1.38.4",
"1.36.2",
"1.34.2",
"1.32.3",
"1.30.2",
"1.28.0",
"1.26.2",
"1.24.2",
"1.22.2",
"1.20.4",
"1.18.2",
"1.16.2",
"1.14.3",
"1.12.3",
}
)
type ExecuteCommandConfig struct {
timeout time.Duration
}
type ExecuteCommandOption func(*ExecuteCommandConfig) error
func ExecuteCommandTimeout(timeout time.Duration) ExecuteCommandOption {
return ExecuteCommandOption(func(conf *ExecuteCommandConfig) error {
conf.timeout = timeout
return nil
})
}
func ExecuteCommand(
resource *dockertest.Resource,
cmd []string,
env []string,
options ...ExecuteCommandOption,
) (string, string, error) {
var stdout bytes.Buffer
var stderr bytes.Buffer
execConfig := ExecuteCommandConfig{
timeout: DOCKER_EXECUTE_TIMEOUT,
}
for _, opt := range options {
if err := opt(&execConfig); err != nil {
return "", "", fmt.Errorf("execute-command/options: %w", err)
}
}
type result struct {
exitCode int
err error
}
resultChan := make(chan result, 1)
// Run your long running function in it's own goroutine and pass back it's
// response into our channel.
go func() {
exitCode, err := resource.Exec(
cmd,
dockertest.ExecOptions{
Env: append(env, "HEADSCALE_LOG_LEVEL=disabled"),
StdOut: &stdout,
StdErr: &stderr,
},
)
resultChan <- result{exitCode, err}
}()
// Listen on our channel AND a timeout channel - which ever happens first.
select {
case res := <-resultChan:
if res.err != nil {
return stdout.String(), stderr.String(), res.err
}
if res.exitCode != 0 {
fmt.Println("Command: ", cmd)
fmt.Println("stdout: ", stdout.String())
fmt.Println("stderr: ", stderr.String())
return stdout.String(), stderr.String(), fmt.Errorf(
"command failed with: %s",
stderr.String(),
)
}
return stdout.String(), stderr.String(), nil
case <-time.After(execConfig.timeout):
return stdout.String(), stderr.String(), fmt.Errorf(
"command timed out after %s",
execConfig.timeout,
)
}
}
func DockerRestartPolicy(config *docker.HostConfig) {
// set AutoRemove to true so that stopped container goes away by itself on error *immediately*.
// when set to false, containers remain until the end of the integration test.
config.AutoRemove = false
config.RestartPolicy = docker.RestartPolicy{
Name: "no",
}
}
func DockerAllowLocalIPv6(config *docker.HostConfig) {
if config.Sysctls == nil {
config.Sysctls = make(map[string]string, 1)
}
config.Sysctls["net.ipv6.conf.all.disable_ipv6"] = "0"
}
func DockerAllowNetworkAdministration(config *docker.HostConfig) {
config.CapAdd = append(config.CapAdd, "NET_ADMIN")
config.Mounts = append(config.Mounts, docker.HostMount{
Type: "bind",
Source: "/dev/net/tun",
Target: "/dev/net/tun",
})
}
func getDockerBuildOptions(version string) *dockertest.BuildOptions {
var tailscaleBuildOptions *dockertest.BuildOptions
switch version {
case "head":
tailscaleBuildOptions = &dockertest.BuildOptions{
Dockerfile: "Dockerfile.tailscale-HEAD",
ContextDir: ".",
BuildArgs: []docker.BuildArg{},
}
case "unstable":
tailscaleBuildOptions = &dockertest.BuildOptions{
Dockerfile: "Dockerfile.tailscale",
ContextDir: ".",
BuildArgs: []docker.BuildArg{
{
Name: "TAILSCALE_VERSION",
Value: "*", // Installs the latest version https://askubuntu.com/a/824926
},
{
Name: "TAILSCALE_CHANNEL",
Value: "unstable",
},
},
}
default:
tailscaleBuildOptions = &dockertest.BuildOptions{
Dockerfile: "Dockerfile.tailscale",
ContextDir: ".",
BuildArgs: []docker.BuildArg{
{
Name: "TAILSCALE_VERSION",
Value: version,
},
{
Name: "TAILSCALE_CHANNEL",
Value: "stable",
},
},
}
}
return tailscaleBuildOptions
}
func getDNSNames(
headscale *dockertest.Resource,
) ([]string, error) {
listAllResult, _, err := ExecuteCommand(
headscale,
[]string{
"headscale",
"nodes",
"list",
"--output",
"json",
},
[]string{},
)
if err != nil {
return nil, err
}
var listAll []v1.Machine
err = json.Unmarshal([]byte(listAllResult), &listAll)
if err != nil {
return nil, err
}
hostnames := make([]string, len(listAll))
for index := range listAll {
hostnames[index] = listAll[index].GetGivenName()
}
return hostnames, nil
}
func GetEnvStr(key string) (string, error) {
v := os.Getenv(key)
if v == "" {
return v, errEnvVarEmpty
}
return v, nil
}
func GetEnvBool(key string) (bool, error) {
s, err := GetEnvStr(key)
if err != nil {
return false, err
}
v, err := strconv.ParseBool(s)
if err != nil {
return false, err
}
return v, nil
}
func GetFirstOrCreateNetwork(pool *dockertest.Pool, name string) (dockertest.Network, error) {
networks, err := pool.NetworksByName(name)
if err != nil || len(networks) == 0 {
if _, err := pool.CreateNetwork(name); err == nil {
// Create does not give us an updated version of the resource, so we need to
// get it again.
networks, err := pool.NetworksByName(name)
if err != nil {
return dockertest.Network{}, err
}
return networks[0], nil
}
}
return networks[0], nil
}

1214
hscontrol/machine.go Normal file

File diff suppressed because it is too large Load Diff

1389
hscontrol/machine_test.go Normal file

File diff suppressed because it is too large Load Diff

142
hscontrol/matcher.go Normal file
View File

@@ -0,0 +1,142 @@
package headscale
import (
"fmt"
"net/netip"
"strings"
"go4.org/netipx"
"tailscale.com/tailcfg"
)
// This is borrowed from, and updated to use IPSet
// https://github.com/tailscale/tailscale/blob/71029cea2ddf82007b80f465b256d027eab0f02d/wgengine/filter/tailcfg.go#L97-L162
// TODO(kradalby): contribute upstream and make public.
var (
zeroIP4 = netip.AddrFrom4([4]byte{})
zeroIP6 = netip.AddrFrom16([16]byte{})
)
// parseIPSet parses arg as one:
//
// - an IP address (IPv4 or IPv6)
// - the string "*" to match everything (both IPv4 & IPv6)
// - a CIDR (e.g. "192.168.0.0/16")
// - a range of two IPs, inclusive, separated by hyphen ("2eff::1-2eff::0800")
//
// bits, if non-nil, is the legacy SrcBits CIDR length to make a IP
// address (without a slash) treated as a CIDR of *bits length.
// nolint
func parseIPSet(arg string, bits *int) (*netipx.IPSet, error) {
var ipSet netipx.IPSetBuilder
if arg == "*" {
ipSet.AddPrefix(netip.PrefixFrom(zeroIP4, 0))
ipSet.AddPrefix(netip.PrefixFrom(zeroIP6, 0))
return ipSet.IPSet()
}
if strings.Contains(arg, "/") {
pfx, err := netip.ParsePrefix(arg)
if err != nil {
return nil, err
}
if pfx != pfx.Masked() {
return nil, fmt.Errorf("%v contains non-network bits set", pfx)
}
ipSet.AddPrefix(pfx)
return ipSet.IPSet()
}
if strings.Count(arg, "-") == 1 {
ip1s, ip2s, _ := strings.Cut(arg, "-")
ip1, err := netip.ParseAddr(ip1s)
if err != nil {
return nil, err
}
ip2, err := netip.ParseAddr(ip2s)
if err != nil {
return nil, err
}
r := netipx.IPRangeFrom(ip1, ip2)
if !r.IsValid() {
return nil, fmt.Errorf("invalid IP range %q", arg)
}
for _, prefix := range r.Prefixes() {
ipSet.AddPrefix(prefix)
}
return ipSet.IPSet()
}
ip, err := netip.ParseAddr(arg)
if err != nil {
return nil, fmt.Errorf("invalid IP address %q", arg)
}
bits8 := uint8(ip.BitLen())
if bits != nil {
if *bits < 0 || *bits > int(bits8) {
return nil, fmt.Errorf("invalid CIDR size %d for IP %q", *bits, arg)
}
bits8 = uint8(*bits)
}
ipSet.AddPrefix(netip.PrefixFrom(ip, int(bits8)))
return ipSet.IPSet()
}
type Match struct {
Srcs *netipx.IPSet
Dests *netipx.IPSet
}
func MatchFromFilterRule(rule tailcfg.FilterRule) Match {
srcs := new(netipx.IPSetBuilder)
dests := new(netipx.IPSetBuilder)
for _, srcIP := range rule.SrcIPs {
set, _ := parseIPSet(srcIP, nil)
srcs.AddSet(set)
}
for _, dest := range rule.DstPorts {
set, _ := parseIPSet(dest.IP, nil)
dests.AddSet(set)
}
srcsSet, _ := srcs.IPSet()
destsSet, _ := dests.IPSet()
match := Match{
Srcs: srcsSet,
Dests: destsSet,
}
return match
}
func (m *Match) SrcsContainsIPs(ips []netip.Addr) bool {
for _, ip := range ips {
if m.Srcs.Contains(ip) {
return true
}
}
return false
}
func (m *Match) DestsContainsIP(ips []netip.Addr) bool {
for _, ip := range ips {
if m.Dests.Contains(ip) {
return true
}
}
return false
}

119
hscontrol/matcher_test.go Normal file
View File

@@ -0,0 +1,119 @@
package headscale
import (
"net/netip"
"reflect"
"testing"
"go4.org/netipx"
)
func Test_parseIPSet(t *testing.T) {
set := func(ips []string, prefixes []string) *netipx.IPSet {
var builder netipx.IPSetBuilder
for _, ip := range ips {
builder.Add(netip.MustParseAddr(ip))
}
for _, pre := range prefixes {
builder.AddPrefix(netip.MustParsePrefix(pre))
}
s, _ := builder.IPSet()
return s
}
type args struct {
arg string
bits *int
}
tests := []struct {
name string
args args
want *netipx.IPSet
wantErr bool
}{
{
name: "simple ip4",
args: args{
arg: "10.0.0.1",
bits: nil,
},
want: set([]string{
"10.0.0.1",
}, []string{}),
wantErr: false,
},
{
name: "simple ip6",
args: args{
arg: "2001:db8:abcd:1234::2",
bits: nil,
},
want: set([]string{
"2001:db8:abcd:1234::2",
}, []string{}),
wantErr: false,
},
{
name: "wildcard",
args: args{
arg: "*",
bits: nil,
},
want: set([]string{}, []string{
"0.0.0.0/0",
"::/0",
}),
wantErr: false,
},
{
name: "prefix4",
args: args{
arg: "192.168.0.0/16",
bits: nil,
},
want: set([]string{}, []string{
"192.168.0.0/16",
}),
wantErr: false,
},
{
name: "prefix6",
args: args{
arg: "2001:db8:abcd:1234::/64",
bits: nil,
},
want: set([]string{}, []string{
"2001:db8:abcd:1234::/64",
}),
wantErr: false,
},
{
name: "range4",
args: args{
arg: "192.168.0.0-192.168.255.255",
bits: nil,
},
want: set([]string{}, []string{
"192.168.0.0/16",
}),
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := parseIPSet(tt.args.arg, tt.args.bits)
if (err != nil) != tt.wantErr {
t.Errorf("parseIPSet() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("parseIPSet() = %v, want %v", got, tt.want)
}
})
}
}

41
hscontrol/metrics.go Normal file
View File

@@ -0,0 +1,41 @@
package headscale
import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
const prometheusNamespace = "headscale"
var (
// This is a high cardinality metric (user x machines), we might want to make this
// configurable/opt-in in the future.
lastStateUpdate = promauto.NewGaugeVec(prometheus.GaugeOpts{
Namespace: prometheusNamespace,
Name: "last_update_seconds",
Help: "Time stamp in unix time when a machine or headscale was updated",
}, []string{"user", "machine"})
machineRegistrations = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace,
Name: "machine_registrations_total",
Help: "The total amount of registered machine attempts",
}, []string{"action", "auth", "status", "user"})
updateRequestsFromNode = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace,
Name: "update_request_from_node_total",
Help: "The number of updates requested by a node/update function",
}, []string{"user", "machine", "state"})
updateRequestsSentToNode = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace,
Name: "update_request_sent_to_node_total",
Help: "The number of calls/messages issued on a specific nodes update channel",
}, []string{"user", "machine", "status"})
// TODO(kradalby): This is very debugging, we might want to remove it.
updateRequestsReceivedOnChannel = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace,
Name: "update_request_received_on_channel_total",
Help: "The number of update requests received on an update channel",
}, []string{"user", "machine"})
)

164
hscontrol/noise.go Normal file
View File

@@ -0,0 +1,164 @@
package headscale
import (
"encoding/binary"
"encoding/json"
"io"
"net/http"
"github.com/gorilla/mux"
"github.com/rs/zerolog/log"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"tailscale.com/control/controlbase"
"tailscale.com/control/controlhttp"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
const (
// ts2021UpgradePath is the path that the server listens on for the WebSockets upgrade.
ts2021UpgradePath = "/ts2021"
// The first 9 bytes from the server to client over Noise are either an HTTP/2
// settings frame (a normal HTTP/2 setup) or, as Tailscale added later, an "early payload"
// header that's also 9 bytes long: 5 bytes (earlyPayloadMagic) followed by 4 bytes
// of length. Then that many bytes of JSON-encoded tailcfg.EarlyNoise.
// The early payload is optional. Some servers may not send it... But we do!
earlyPayloadMagic = "\xff\xff\xffTS"
// EarlyNoise was added in protocol version 49.
earlyNoiseCapabilityVersion = 49
)
type noiseServer struct {
headscale *Headscale
httpBaseConfig *http.Server
http2Server *http2.Server
conn *controlbase.Conn
machineKey key.MachinePublic
nodeKey key.NodePublic
// EarlyNoise-related stuff
challenge key.ChallengePrivate
protocolVersion int
}
// NoiseUpgradeHandler is to upgrade the connection and hijack the net.Conn
// in order to use the Noise-based TS2021 protocol. Listens in /ts2021.
func (h *Headscale) NoiseUpgradeHandler(
writer http.ResponseWriter,
req *http.Request,
) {
log.Trace().Caller().Msgf("Noise upgrade handler for client %s", req.RemoteAddr)
upgrade := req.Header.Get("Upgrade")
if upgrade == "" {
// This probably means that the user is running Headscale behind an
// improperly configured reverse proxy. TS2021 requires WebSockets to
// be passed to Headscale. Let's give them a hint.
log.Warn().
Caller().
Msg("No Upgrade header in TS2021 request. If headscale is behind a reverse proxy, make sure it is configured to pass WebSockets through.")
http.Error(writer, "Internal error", http.StatusInternalServerError)
return
}
noiseServer := noiseServer{
headscale: h,
challenge: key.NewChallenge(),
}
noiseConn, err := controlhttp.AcceptHTTP(
req.Context(),
writer,
req,
*h.noisePrivateKey,
noiseServer.earlyNoise,
)
if err != nil {
log.Error().Err(err).Msg("noise upgrade failed")
http.Error(writer, err.Error(), http.StatusInternalServerError)
return
}
noiseServer.conn = noiseConn
noiseServer.machineKey = noiseServer.conn.Peer()
noiseServer.protocolVersion = noiseServer.conn.ProtocolVersion()
// This router is served only over the Noise connection, and exposes only the new API.
//
// The HTTP2 server that exposes this router is created for
// a single hijacked connection from /ts2021, using netutil.NewOneConnListener
router := mux.NewRouter()
router.HandleFunc("/machine/register", noiseServer.NoiseRegistrationHandler).
Methods(http.MethodPost)
router.HandleFunc("/machine/map", noiseServer.NoisePollNetMapHandler)
server := http.Server{
ReadTimeout: HTTPReadTimeout,
}
noiseServer.httpBaseConfig = &http.Server{
Handler: router,
ReadHeaderTimeout: HTTPReadTimeout,
}
noiseServer.http2Server = &http2.Server{}
server.Handler = h2c.NewHandler(router, noiseServer.http2Server)
noiseServer.http2Server.ServeConn(
noiseConn,
&http2.ServeConnOpts{
BaseConfig: noiseServer.httpBaseConfig,
},
)
}
func (ns *noiseServer) earlyNoise(protocolVersion int, writer io.Writer) error {
log.Trace().
Caller().
Int("protocol_version", protocolVersion).
Str("challenge", ns.challenge.Public().String()).
Msg("earlyNoise called")
if protocolVersion < earlyNoiseCapabilityVersion {
log.Trace().
Caller().
Msgf("protocol version %d does not support early noise", protocolVersion)
return nil
}
earlyJSON, err := json.Marshal(&tailcfg.EarlyNoise{
NodeKeyChallenge: ns.challenge.Public(),
})
if err != nil {
return err
}
// 5 bytes that won't be mistaken for an HTTP/2 frame:
// https://httpwg.org/specs/rfc7540.html#rfc.section.4.1 (Especially not
// an HTTP/2 settings frame, which isn't of type 'T')
var notH2Frame [5]byte
copy(notH2Frame[:], earlyPayloadMagic)
var lenBuf [4]byte
binary.BigEndian.PutUint32(lenBuf[:], uint32(len(earlyJSON)))
// These writes are all buffered by caller, so fine to do them
// separately:
if _, err := writer.Write(notH2Frame[:]); err != nil {
return err
}
if _, err := writer.Write(lenBuf[:]); err != nil {
return err
}
if _, err := writer.Write(earlyJSON); err != nil {
return err
}
return nil
}

760
hscontrol/oidc.go Normal file
View File

@@ -0,0 +1,760 @@
package headscale
import (
"bytes"
"context"
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"html/template"
"net/http"
"strings"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/mux"
"github.com/rs/zerolog/log"
"golang.org/x/oauth2"
"tailscale.com/types/key"
)
const (
randomByteSize = 16
errEmptyOIDCCallbackParams = Error("empty OIDC callback params")
errNoOIDCIDToken = Error("could not extract ID Token for OIDC callback")
errOIDCAllowedDomains = Error("authenticated principal does not match any allowed domain")
errOIDCAllowedGroups = Error("authenticated principal is not in any allowed group")
errOIDCAllowedUsers = Error("authenticated principal does not match any allowed user")
errOIDCInvalidMachineState = Error(
"requested machine state key expired before authorisation completed",
)
errOIDCNodeKeyMissing = Error("could not get node key from cache")
)
type IDTokenClaims struct {
Name string `json:"name,omitempty"`
Groups []string `json:"groups,omitempty"`
Email string `json:"email"`
Username string `json:"preferred_username,omitempty"`
}
func (h *Headscale) initOIDC() error {
var err error
// grab oidc config if it hasn't been already
if h.oauth2Config == nil {
h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDC.Issuer)
if err != nil {
log.Error().
Err(err).
Caller().
Msgf("Could not retrieve OIDC Config: %s", err.Error())
return err
}
h.oauth2Config = &oauth2.Config{
ClientID: h.cfg.OIDC.ClientID,
ClientSecret: h.cfg.OIDC.ClientSecret,
Endpoint: h.oidcProvider.Endpoint(),
RedirectURL: fmt.Sprintf(
"%s/oidc/callback",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
),
Scopes: h.cfg.OIDC.Scope,
}
}
return nil
}
func (h *Headscale) determineTokenExpiration(idTokenExpiration time.Time) time.Time {
if h.cfg.OIDC.UseExpiryFromToken {
return idTokenExpiration
}
return time.Now().Add(h.cfg.OIDC.Expiry)
}
// RegisterOIDC redirects to the OIDC provider for authentication
// Puts NodeKey in cache so the callback can retrieve it using the oidc state param
// Listens in /oidc/register/:nKey.
func (h *Headscale) RegisterOIDC(
writer http.ResponseWriter,
req *http.Request,
) {
vars := mux.Vars(req)
nodeKeyStr, ok := vars["nkey"]
log.Debug().
Caller().
Str("node_key", nodeKeyStr).
Bool("ok", ok).
Msg("Received oidc register call")
if !NodePublicKeyRegex.Match([]byte(nodeKeyStr)) {
log.Warn().Str("node_key", nodeKeyStr).Msg("Invalid node key passed to registration url")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusUnauthorized)
_, err := writer.Write([]byte("Unauthorized"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return
}
// We need to make sure we dont open for XSS style injections, if the parameter that
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
// the template and log an error.
var nodeKey key.NodePublic
err := nodeKey.UnmarshalText(
[]byte(NodePublicKeyEnsurePrefix(nodeKeyStr)),
)
if !ok || nodeKeyStr == "" || err != nil {
log.Warn().
Err(err).
Msg("Failed to parse incoming nodekey in OIDC registration")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("Wrong params"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return
}
randomBlob := make([]byte, randomByteSize)
if _, err := rand.Read(randomBlob); err != nil {
log.Error().
Caller().
Msg("could not read 16 bytes from rand")
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
stateStr := hex.EncodeToString(randomBlob)[:32]
// place the node key into the state cache, so it can be retrieved later
h.registrationCache.Set(stateStr, NodePublicKeyStripPrefix(nodeKey), registerCacheExpiration)
// Add any extra parameter provided in the configuration to the Authorize Endpoint request
extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams))
for k, v := range h.cfg.OIDC.ExtraParams {
extras = append(extras, oauth2.SetAuthURLParam(k, v))
}
authURL := h.oauth2Config.AuthCodeURL(stateStr, extras...)
log.Debug().Msgf("Redirecting to %s for authentication", authURL)
http.Redirect(writer, req, authURL, http.StatusFound)
}
type oidcCallbackTemplateConfig struct {
User string
Verb string
}
var oidcCallbackTemplate = template.Must(
template.New("oidccallback").Parse(`<html>
<body>
<h1>headscale</h1>
<p>
{{.Verb}} as {{.User}}, you can now close this window.
</p>
</body>
</html>`),
)
// OIDCCallback handles the callback from the OIDC endpoint
// Retrieves the nkey from the state cache and adds the machine to the users email user
// TODO: A confirmation page for new machines should be added to avoid phishing vulnerabilities
// TODO: Add groups information from OIDC tokens into machine HostInfo
// Listens in /oidc/callback.
func (h *Headscale) OIDCCallback(
writer http.ResponseWriter,
req *http.Request,
) {
code, state, err := validateOIDCCallbackParams(writer, req)
if err != nil {
return
}
rawIDToken, err := h.getIDTokenForOIDCCallback(req.Context(), writer, code, state)
if err != nil {
return
}
idToken, err := h.verifyIDTokenForOIDCCallback(req.Context(), writer, rawIDToken)
if err != nil {
return
}
idTokenExpiry := h.determineTokenExpiration(idToken.Expiry)
// TODO: we can use userinfo at some point to grab additional information about the user (groups membership, etc)
// userInfo, err := oidcProvider.UserInfo(context.Background(), oauth2.StaticTokenSource(oauth2Token))
// if err != nil {
// c.String(http.StatusBadRequest, fmt.Sprintf("Failed to retrieve userinfo"))
// return
// }
claims, err := extractIDTokenClaims(writer, idToken)
if err != nil {
return
}
if err := validateOIDCAllowedDomains(writer, h.cfg.OIDC.AllowedDomains, claims); err != nil {
return
}
if err := validateOIDCAllowedGroups(writer, h.cfg.OIDC.AllowedGroups, claims); err != nil {
return
}
if err := validateOIDCAllowedUsers(writer, h.cfg.OIDC.AllowedUsers, claims); err != nil {
return
}
nodeKey, machineExists, err := h.validateMachineForOIDCCallback(
writer,
state,
claims,
idTokenExpiry,
)
if err != nil || machineExists {
return
}
userName, err := getUserName(writer, claims, h.cfg.OIDC.StripEmaildomain)
if err != nil {
return
}
// register the machine if it's new
log.Debug().Msg("Registering new machine after successful callback")
user, err := h.findOrCreateNewUserForOIDCCallback(writer, userName)
if err != nil {
return
}
if err := h.registerMachineForOIDCCallback(writer, user, nodeKey, idTokenExpiry); err != nil {
return
}
content, err := renderOIDCCallbackTemplate(writer, claims)
if err != nil {
return
}
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
if _, err := writer.Write(content.Bytes()); err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
}
func validateOIDCCallbackParams(
writer http.ResponseWriter,
req *http.Request,
) (string, string, error) {
code := req.URL.Query().Get("code")
state := req.URL.Query().Get("state")
if code == "" || state == "" {
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("Wrong params"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return "", "", errEmptyOIDCCallbackParams
}
return code, state, nil
}
func (h *Headscale) getIDTokenForOIDCCallback(
ctx context.Context,
writer http.ResponseWriter,
code, state string,
) (string, error) {
oauth2Token, err := h.oauth2Config.Exchange(ctx, code)
if err != nil {
log.Error().
Err(err).
Caller().
Msg("Could not exchange code for token")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("Could not exchange code for token"))
if werr != nil {
log.Error().
Caller().
Err(werr).
Msg("Failed to write response")
}
return "", err
}
log.Trace().
Caller().
Str("code", code).
Str("state", state).
Msg("Got oidc callback")
rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string)
if !rawIDTokenOK {
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("Could not extract ID Token"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return "", errNoOIDCIDToken
}
return rawIDToken, nil
}
func (h *Headscale) verifyIDTokenForOIDCCallback(
ctx context.Context,
writer http.ResponseWriter,
rawIDToken string,
) (*oidc.IDToken, error) {
verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDC.ClientID})
idToken, err := verifier.Verify(ctx, rawIDToken)
if err != nil {
log.Error().
Err(err).
Caller().
Msg("failed to verify id token")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("Failed to verify id token"))
if werr != nil {
log.Error().
Caller().
Err(werr).
Msg("Failed to write response")
}
return nil, err
}
return idToken, nil
}
func extractIDTokenClaims(
writer http.ResponseWriter,
idToken *oidc.IDToken,
) (*IDTokenClaims, error) {
var claims IDTokenClaims
if err := idToken.Claims(&claims); err != nil {
log.Error().
Err(err).
Caller().
Msg("Failed to decode id token claims")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("Failed to decode id token claims"))
if werr != nil {
log.Error().
Caller().
Err(werr).
Msg("Failed to write response")
}
return nil, err
}
return &claims, nil
}
// validateOIDCAllowedDomains checks that if AllowedDomains is provided,
// that the authenticated principal ends with @<alloweddomain>.
func validateOIDCAllowedDomains(
writer http.ResponseWriter,
allowedDomains []string,
claims *IDTokenClaims,
) error {
if len(allowedDomains) > 0 {
if at := strings.LastIndex(claims.Email, "@"); at < 0 ||
!IsStringInSlice(allowedDomains, claims.Email[at+1:]) {
log.Error().Msg("authenticated principal does not match any allowed domain")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("unauthorized principal (domain mismatch)"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return errOIDCAllowedDomains
}
}
return nil
}
// validateOIDCAllowedGroups checks if AllowedGroups is provided,
// and that the user has one group in the list.
// claims.Groups can be populated by adding a client scope named
// 'groups' that contains group membership.
func validateOIDCAllowedGroups(
writer http.ResponseWriter,
allowedGroups []string,
claims *IDTokenClaims,
) error {
if len(allowedGroups) > 0 {
for _, group := range allowedGroups {
if IsStringInSlice(claims.Groups, group) {
return nil
}
}
log.Error().Msg("authenticated principal not in any allowed groups")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("unauthorized principal (allowed groups)"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return errOIDCAllowedGroups
}
return nil
}
// validateOIDCAllowedUsers checks that if AllowedUsers is provided,
// that the authenticated principal is part of that list.
func validateOIDCAllowedUsers(
writer http.ResponseWriter,
allowedUsers []string,
claims *IDTokenClaims,
) error {
if len(allowedUsers) > 0 &&
!IsStringInSlice(allowedUsers, claims.Email) {
log.Error().Msg("authenticated principal does not match any allowed user")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("unauthorized principal (user mismatch)"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return errOIDCAllowedUsers
}
return nil
}
// validateMachine retrieves machine information if it exist
// The error is not important, because if it does not
// exist, then this is a new machine and we will move
// on to registration.
func (h *Headscale) validateMachineForOIDCCallback(
writer http.ResponseWriter,
state string,
claims *IDTokenClaims,
expiry time.Time,
) (*key.NodePublic, bool, error) {
// retrieve machinekey from state cache
nodeKeyIf, nodeKeyFound := h.registrationCache.Get(state)
if !nodeKeyFound {
log.Error().
Msg("requested machine state key expired before authorisation completed")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("state has expired"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return nil, false, errOIDCNodeKeyMissing
}
var nodeKey key.NodePublic
nodeKeyFromCache, nodeKeyOK := nodeKeyIf.(string)
if !nodeKeyOK {
log.Error().
Msg("requested machine state key is not a string")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("state is invalid"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return nil, false, errOIDCInvalidMachineState
}
err := nodeKey.UnmarshalText(
[]byte(NodePublicKeyEnsurePrefix(nodeKeyFromCache)),
)
if err != nil {
log.Error().
Str("nodeKey", nodeKeyFromCache).
Bool("nodeKeyOK", nodeKeyOK).
Msg("could not parse node public key")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, werr := writer.Write([]byte("could not parse node public key"))
if werr != nil {
log.Error().
Caller().
Err(werr).
Msg("Failed to write response")
}
return nil, false, err
}
// retrieve machine information if it exist
// The error is not important, because if it does not
// exist, then this is a new machine and we will move
// on to registration.
machine, _ := h.GetMachineByNodeKey(nodeKey)
if machine != nil {
log.Trace().
Caller().
Str("machine", machine.Hostname).
Msg("machine already registered, reauthenticating")
err := h.RefreshMachine(machine, expiry)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to refresh machine")
http.Error(
writer,
"Failed to refresh machine",
http.StatusInternalServerError,
)
return nil, true, err
}
log.Debug().
Str("machine", machine.Hostname).
Str("expiresAt", fmt.Sprintf("%v", expiry)).
Msg("successfully refreshed machine")
var content bytes.Buffer
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
User: claims.Email,
Verb: "Reauthenticated",
}); err != nil {
log.Error().
Str("func", "OIDCCallback").
Str("type", "reauthenticate").
Err(err).
Msg("Could not render OIDC callback template")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("Could not render OIDC callback template"))
if werr != nil {
log.Error().
Caller().
Err(werr).
Msg("Failed to write response")
}
return nil, true, err
}
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(content.Bytes())
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return nil, true, nil
}
return &nodeKey, false, nil
}
func getUserName(
writer http.ResponseWriter,
claims *IDTokenClaims,
stripEmaildomain bool,
) (string, error) {
userName, err := NormalizeToFQDNRules(
claims.Email,
stripEmaildomain,
)
if err != nil {
log.Error().Err(err).Caller().Msgf("couldn't normalize email")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("couldn't normalize email"))
if werr != nil {
log.Error().
Caller().
Err(werr).
Msg("Failed to write response")
}
return "", err
}
return userName, nil
}
func (h *Headscale) findOrCreateNewUserForOIDCCallback(
writer http.ResponseWriter,
userName string,
) (*User, error) {
user, err := h.GetUser(userName)
if errors.Is(err, ErrUserNotFound) {
user, err = h.CreateUser(userName)
if err != nil {
log.Error().
Err(err).
Caller().
Msgf("could not create new user '%s'", userName)
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("could not create user"))
if werr != nil {
log.Error().
Caller().
Err(werr).
Msg("Failed to write response")
}
return nil, err
}
} else if err != nil {
log.Error().
Caller().
Err(err).
Str("user", userName).
Msg("could not find or create user")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("could not find or create user"))
if werr != nil {
log.Error().
Caller().
Err(werr).
Msg("Failed to write response")
}
return nil, err
}
return user, nil
}
func (h *Headscale) registerMachineForOIDCCallback(
writer http.ResponseWriter,
user *User,
nodeKey *key.NodePublic,
expiry time.Time,
) error {
if _, err := h.RegisterMachineFromAuthCallback(
nodeKey.String(),
user.Name,
&expiry,
RegisterMethodOIDC,
); err != nil {
log.Error().
Caller().
Err(err).
Msg("could not register machine")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("could not register machine"))
if werr != nil {
log.Error().
Caller().
Err(werr).
Msg("Failed to write response")
}
return err
}
return nil
}
func renderOIDCCallbackTemplate(
writer http.ResponseWriter,
claims *IDTokenClaims,
) (*bytes.Buffer, error) {
var content bytes.Buffer
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
User: claims.Email,
Verb: "Authenticated",
}); err != nil {
log.Error().
Str("func", "OIDCCallback").
Str("type", "authenticate").
Err(err).
Msg("Could not render OIDC callback template")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("Could not render OIDC callback template"))
if werr != nil {
log.Error().
Caller().
Err(werr).
Msg("Failed to write response")
}
return nil, err
}
return &content, nil
}

View File

@@ -0,0 +1,408 @@
package headscale
import (
"bytes"
_ "embed"
"html/template"
"net/http"
textTemplate "text/template"
"github.com/gofrs/uuid/v5"
"github.com/gorilla/mux"
"github.com/rs/zerolog/log"
)
//go:embed templates/apple.html
var appleTemplate string
//go:embed templates/windows.html
var windowsTemplate string
// WindowsConfigMessage shows a simple message in the browser for how to configure the Windows Tailscale client.
func (h *Headscale) WindowsConfigMessage(
writer http.ResponseWriter,
req *http.Request,
) {
winTemplate := template.Must(template.New("windows").Parse(windowsTemplate))
config := map[string]interface{}{
"URL": h.cfg.ServerURL,
}
var payload bytes.Buffer
if err := winTemplate.Execute(&payload, config); err != nil {
log.Error().
Str("handler", "WindowsRegConfig").
Err(err).
Msg("Could not render Windows index template")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, err := writer.Write([]byte("Could not render Windows index template"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return
}
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err := writer.Write(payload.Bytes())
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
}
// WindowsRegConfig generates and serves a .reg file configured with the Headscale server address.
func (h *Headscale) WindowsRegConfig(
writer http.ResponseWriter,
req *http.Request,
) {
config := WindowsRegistryConfig{
URL: h.cfg.ServerURL,
}
var content bytes.Buffer
if err := windowsRegTemplate.Execute(&content, config); err != nil {
log.Error().
Str("handler", "WindowsRegConfig").
Err(err).
Msg("Could not render Apple macOS template")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, err := writer.Write([]byte("Could not render Windows registry template"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return
}
writer.Header().Set("Content-Type", "text/x-ms-regedit; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err := writer.Write(content.Bytes())
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
}
// AppleConfigMessage shows a simple message in the browser to point the user to the iOS/MacOS profile and instructions for how to install it.
func (h *Headscale) AppleConfigMessage(
writer http.ResponseWriter,
req *http.Request,
) {
appleTemplate := template.Must(template.New("apple").Parse(appleTemplate))
config := map[string]interface{}{
"URL": h.cfg.ServerURL,
}
var payload bytes.Buffer
if err := appleTemplate.Execute(&payload, config); err != nil {
log.Error().
Str("handler", "AppleMobileConfig").
Err(err).
Msg("Could not render Apple index template")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, err := writer.Write([]byte("Could not render Apple index template"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return
}
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err := writer.Write(payload.Bytes())
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
}
func (h *Headscale) ApplePlatformConfig(
writer http.ResponseWriter,
req *http.Request,
) {
vars := mux.Vars(req)
platform, ok := vars["platform"]
if !ok {
log.Error().
Str("handler", "ApplePlatformConfig").
Msg("No platform specified")
http.Error(writer, "No platform specified", http.StatusBadRequest)
return
}
id, err := uuid.NewV4()
if err != nil {
log.Error().
Str("handler", "ApplePlatformConfig").
Err(err).
Msg("Failed not create UUID")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, err := writer.Write([]byte("Failed to create UUID"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return
}
contentID, err := uuid.NewV4()
if err != nil {
log.Error().
Str("handler", "ApplePlatformConfig").
Err(err).
Msg("Failed not create UUID")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, err := writer.Write([]byte("Failed to create content UUID"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return
}
platformConfig := AppleMobilePlatformConfig{
UUID: contentID,
URL: h.cfg.ServerURL,
}
var payload bytes.Buffer
handleMacError := func(ierr error) {
log.Error().
Str("handler", "ApplePlatformConfig").
Err(ierr).
Msg("Could not render Apple macOS template")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, err := writer.Write([]byte("Could not render Apple macOS template"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
}
switch platform {
case "macos-standalone":
if err := macosStandaloneTemplate.Execute(&payload, platformConfig); err != nil {
handleMacError(err)
return
}
case "macos-app-store":
if err := macosAppStoreTemplate.Execute(&payload, platformConfig); err != nil {
handleMacError(err)
return
}
case "ios":
if err := iosTemplate.Execute(&payload, platformConfig); err != nil {
log.Error().
Str("handler", "ApplePlatformConfig").
Err(err).
Msg("Could not render Apple iOS template")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, err := writer.Write([]byte("Could not render Apple iOS template"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return
}
default:
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write(
[]byte("Invalid platform. Only ios, macos-app-store and macos-standalone are supported"),
)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return
}
config := AppleMobileConfig{
UUID: id,
URL: h.cfg.ServerURL,
Payload: payload.String(),
}
var content bytes.Buffer
if err := commonTemplate.Execute(&content, config); err != nil {
log.Error().
Str("handler", "ApplePlatformConfig").
Err(err).
Msg("Could not render Apple platform template")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, err := writer.Write([]byte("Could not render Apple platform template"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return
}
writer.Header().
Set("Content-Type", "application/x-apple-aspen-config; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(content.Bytes())
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
}
type WindowsRegistryConfig struct {
URL string
}
type AppleMobileConfig struct {
UUID uuid.UUID
URL string
Payload string
}
type AppleMobilePlatformConfig struct {
UUID uuid.UUID
URL string
}
var windowsRegTemplate = textTemplate.Must(
textTemplate.New("windowsconfig").Parse(`Windows Registry Editor Version 5.00
[HKEY_LOCAL_MACHINE\SOFTWARE\Tailscale IPN]
"UnattendedMode"="always"
"LoginURL"="{{.URL}}"
`))
var commonTemplate = textTemplate.Must(
textTemplate.New("mobileconfig").Parse(`<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>PayloadUUID</key>
<string>{{.UUID}}</string>
<key>PayloadDisplayName</key>
<string>Headscale</string>
<key>PayloadDescription</key>
<string>Configure Tailscale login server to: {{.URL}}</string>
<key>PayloadIdentifier</key>
<string>com.github.juanfont.headscale</string>
<key>PayloadRemovalDisallowed</key>
<false/>
<key>PayloadType</key>
<string>Configuration</string>
<key>PayloadVersion</key>
<integer>1</integer>
<key>PayloadContent</key>
<array>
{{.Payload}}
</array>
</dict>
</plist>`),
)
var iosTemplate = textTemplate.Must(textTemplate.New("iosTemplate").Parse(`
<dict>
<key>PayloadType</key>
<string>io.tailscale.ipn.ios</string>
<key>PayloadUUID</key>
<string>{{.UUID}}</string>
<key>PayloadIdentifier</key>
<string>com.github.juanfont.headscale</string>
<key>PayloadVersion</key>
<integer>1</integer>
<key>PayloadEnabled</key>
<true/>
<key>ControlURL</key>
<string>{{.URL}}</string>
</dict>
`))
var macosAppStoreTemplate = template.Must(template.New("macosTemplate").Parse(`
<dict>
<key>PayloadType</key>
<string>io.tailscale.ipn.macos</string>
<key>PayloadUUID</key>
<string>{{.UUID}}</string>
<key>PayloadIdentifier</key>
<string>com.github.juanfont.headscale</string>
<key>PayloadVersion</key>
<integer>1</integer>
<key>PayloadEnabled</key>
<true/>
<key>ControlURL</key>
<string>{{.URL}}</string>
</dict>
`))
var macosStandaloneTemplate = template.Must(template.New("macosStandaloneTemplate").Parse(`
<dict>
<key>PayloadType</key>
<string>io.tailscale.ipn.macsys</string>
<key>PayloadUUID</key>
<string>{{.UUID}}</string>
<key>PayloadIdentifier</key>
<string>com.github.juanfont.headscale</string>
<key>PayloadVersion</key>
<integer>1</integer>
<key>PayloadEnabled</key>
<true/>
<key>ControlURL</key>
<string>{{.URL}}</string>
</dict>
`))

242
hscontrol/preauth_keys.go Normal file
View File

@@ -0,0 +1,242 @@
package headscale
import (
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"strconv"
"strings"
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"google.golang.org/protobuf/types/known/timestamppb"
"gorm.io/gorm"
)
const (
ErrPreAuthKeyNotFound = Error("AuthKey not found")
ErrPreAuthKeyExpired = Error("AuthKey expired")
ErrSingleUseAuthKeyHasBeenUsed = Error("AuthKey has already been used")
ErrUserMismatch = Error("user mismatch")
ErrPreAuthKeyACLTagInvalid = Error("AuthKey tag is invalid")
)
// PreAuthKey describes a pre-authorization key usable in a particular user.
type PreAuthKey struct {
ID uint64 `gorm:"primary_key"`
Key string
UserID uint
User User
Reusable bool
Ephemeral bool `gorm:"default:false"`
Used bool `gorm:"default:false"`
ACLTags []PreAuthKeyACLTag
CreatedAt *time.Time
Expiration *time.Time
}
// PreAuthKeyACLTag describes an autmatic tag applied to a node when registered with the associated PreAuthKey.
type PreAuthKeyACLTag struct {
ID uint64 `gorm:"primary_key"`
PreAuthKeyID uint64
Tag string
}
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
func (h *Headscale) CreatePreAuthKey(
userName string,
reusable bool,
ephemeral bool,
expiration *time.Time,
aclTags []string,
) (*PreAuthKey, error) {
user, err := h.GetUser(userName)
if err != nil {
return nil, err
}
for _, tag := range aclTags {
if !strings.HasPrefix(tag, "tag:") {
return nil, fmt.Errorf("%w: '%s' did not begin with 'tag:'", ErrPreAuthKeyACLTagInvalid, tag)
}
}
now := time.Now().UTC()
kstr, err := h.generateKey()
if err != nil {
return nil, err
}
key := PreAuthKey{
Key: kstr,
UserID: user.ID,
User: *user,
Reusable: reusable,
Ephemeral: ephemeral,
CreatedAt: &now,
Expiration: expiration,
}
err = h.db.Transaction(func(db *gorm.DB) error {
if err := db.Save(&key).Error; err != nil {
return fmt.Errorf("failed to create key in the database: %w", err)
}
if len(aclTags) > 0 {
seenTags := map[string]bool{}
for _, tag := range aclTags {
if !seenTags[tag] {
if err := db.Save(&PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil {
return fmt.Errorf(
"failed to ceate key tag in the database: %w",
err,
)
}
seenTags[tag] = true
}
}
}
return nil
})
if err != nil {
return nil, err
}
return &key, nil
}
// ListPreAuthKeys returns the list of PreAuthKeys for a user.
func (h *Headscale) ListPreAuthKeys(userName string) ([]PreAuthKey, error) {
user, err := h.GetUser(userName)
if err != nil {
return nil, err
}
keys := []PreAuthKey{}
if err := h.db.Preload("User").Preload("ACLTags").Where(&PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil {
return nil, err
}
return keys, nil
}
// GetPreAuthKey returns a PreAuthKey for a given key.
func (h *Headscale) GetPreAuthKey(user string, key string) (*PreAuthKey, error) {
pak, err := h.checkKeyValidity(key)
if err != nil {
return nil, err
}
if pak.User.Name != user {
return nil, ErrUserMismatch
}
return pak, nil
}
// DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey
// does not exist.
func (h *Headscale) DestroyPreAuthKey(pak PreAuthKey) error {
return h.db.Transaction(func(db *gorm.DB) error {
if result := db.Unscoped().Where(PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&PreAuthKeyACLTag{}); result.Error != nil {
return result.Error
}
if result := db.Unscoped().Delete(pak); result.Error != nil {
return result.Error
}
return nil
})
}
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
func (h *Headscale) ExpirePreAuthKey(k *PreAuthKey) error {
if err := h.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil {
return err
}
return nil
}
// UsePreAuthKey marks a PreAuthKey as used.
func (h *Headscale) UsePreAuthKey(k *PreAuthKey) error {
k.Used = true
if err := h.db.Save(k).Error; err != nil {
return fmt.Errorf("failed to update key used status in the database: %w", err)
}
return nil
}
// checkKeyValidity does the heavy lifting for validation of the PreAuthKey coming from a node
// If returns no error and a PreAuthKey, it can be used.
func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) {
pak := PreAuthKey{}
if result := h.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is(
result.Error,
gorm.ErrRecordNotFound,
) {
return nil, ErrPreAuthKeyNotFound
}
if pak.Expiration != nil && pak.Expiration.Before(time.Now()) {
return nil, ErrPreAuthKeyExpired
}
if pak.Reusable || pak.Ephemeral { // we don't need to check if has been used before
return &pak, nil
}
machines := []Machine{}
if err := h.db.Preload("AuthKey").Where(&Machine{AuthKeyID: uint(pak.ID)}).Find(&machines).Error; err != nil {
return nil, err
}
if len(machines) != 0 || pak.Used {
return nil, ErrSingleUseAuthKeyHasBeenUsed
}
return &pak, nil
}
func (h *Headscale) generateKey() (string, error) {
size := 24
bytes := make([]byte, size)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}
func (key *PreAuthKey) toProto() *v1.PreAuthKey {
protoKey := v1.PreAuthKey{
User: key.User.Name,
Id: strconv.FormatUint(key.ID, Base10),
Key: key.Key,
Ephemeral: key.Ephemeral,
Reusable: key.Reusable,
Used: key.Used,
AclTags: make([]string, len(key.ACLTags)),
}
if key.Expiration != nil {
protoKey.Expiration = timestamppb.New(*key.Expiration)
}
if key.CreatedAt != nil {
protoKey.CreatedAt = timestamppb.New(*key.CreatedAt)
}
for idx := range key.ACLTags {
protoKey.AclTags[idx] = key.ACLTags[idx].Tag
}
return &protoKey
}

View File

@@ -0,0 +1,209 @@
package headscale
import (
"time"
"gopkg.in/check.v1"
)
func (*Suite) TestCreatePreAuthKey(c *check.C) {
_, err := app.CreatePreAuthKey("bogus", true, false, nil, nil)
c.Assert(err, check.NotNil)
user, err := app.CreateUser("test")
c.Assert(err, check.IsNil)
key, err := app.CreatePreAuthKey(user.Name, true, false, nil, nil)
c.Assert(err, check.IsNil)
// Did we get a valid key?
c.Assert(key.Key, check.NotNil)
c.Assert(len(key.Key), check.Equals, 48)
// Make sure the User association is populated
c.Assert(key.User.Name, check.Equals, user.Name)
_, err = app.ListPreAuthKeys("bogus")
c.Assert(err, check.NotNil)
keys, err := app.ListPreAuthKeys(user.Name)
c.Assert(err, check.IsNil)
c.Assert(len(keys), check.Equals, 1)
// Make sure the User association is populated
c.Assert((keys)[0].User.Name, check.Equals, user.Name)
}
func (*Suite) TestExpiredPreAuthKey(c *check.C) {
user, err := app.CreateUser("test2")
c.Assert(err, check.IsNil)
now := time.Now()
pak, err := app.CreatePreAuthKey(user.Name, true, false, &now, nil)
c.Assert(err, check.IsNil)
key, err := app.checkKeyValidity(pak.Key)
c.Assert(err, check.Equals, ErrPreAuthKeyExpired)
c.Assert(key, check.IsNil)
}
func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) {
key, err := app.checkKeyValidity("potatoKey")
c.Assert(err, check.Equals, ErrPreAuthKeyNotFound)
c.Assert(key, check.IsNil)
}
func (*Suite) TestValidateKeyOk(c *check.C) {
user, err := app.CreateUser("test3")
c.Assert(err, check.IsNil)
pak, err := app.CreatePreAuthKey(user.Name, true, false, nil, nil)
c.Assert(err, check.IsNil)
key, err := app.checkKeyValidity(pak.Key)
c.Assert(err, check.IsNil)
c.Assert(key.ID, check.Equals, pak.ID)
}
func (*Suite) TestAlreadyUsedKey(c *check.C) {
user, err := app.CreateUser("test4")
c.Assert(err, check.IsNil)
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testest",
UserID: user.ID,
RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
app.db.Save(&machine)
key, err := app.checkKeyValidity(pak.Key)
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
c.Assert(key, check.IsNil)
}
func (*Suite) TestReusableBeingUsedKey(c *check.C) {
user, err := app.CreateUser("test5")
c.Assert(err, check.IsNil)
pak, err := app.CreatePreAuthKey(user.Name, true, false, nil, nil)
c.Assert(err, check.IsNil)
machine := Machine{
ID: 1,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testest",
UserID: user.ID,
RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
app.db.Save(&machine)
key, err := app.checkKeyValidity(pak.Key)
c.Assert(err, check.IsNil)
c.Assert(key.ID, check.Equals, pak.ID)
}
func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) {
user, err := app.CreateUser("test6")
c.Assert(err, check.IsNil)
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
key, err := app.checkKeyValidity(pak.Key)
c.Assert(err, check.IsNil)
c.Assert(key.ID, check.Equals, pak.ID)
}
func (*Suite) TestEphemeralKey(c *check.C) {
user, err := app.CreateUser("test7")
c.Assert(err, check.IsNil)
pak, err := app.CreatePreAuthKey(user.Name, false, true, nil, nil)
c.Assert(err, check.IsNil)
now := time.Now()
machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testest",
UserID: user.ID,
RegisterMethod: RegisterMethodAuthKey,
LastSeen: &now,
AuthKeyID: uint(pak.ID),
}
app.db.Save(&machine)
_, err = app.checkKeyValidity(pak.Key)
// Ephemeral keys are by definition reusable
c.Assert(err, check.IsNil)
_, err = app.GetMachine("test7", "testest")
c.Assert(err, check.IsNil)
app.expireEphemeralNodesWorker()
// The machine record should have been deleted
_, err = app.GetMachine("test7", "testest")
c.Assert(err, check.NotNil)
}
func (*Suite) TestExpirePreauthKey(c *check.C) {
user, err := app.CreateUser("test3")
c.Assert(err, check.IsNil)
pak, err := app.CreatePreAuthKey(user.Name, true, false, nil, nil)
c.Assert(err, check.IsNil)
c.Assert(pak.Expiration, check.IsNil)
err = app.ExpirePreAuthKey(pak)
c.Assert(err, check.IsNil)
c.Assert(pak.Expiration, check.NotNil)
key, err := app.checkKeyValidity(pak.Key)
c.Assert(err, check.Equals, ErrPreAuthKeyExpired)
c.Assert(key, check.IsNil)
}
func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) {
user, err := app.CreateUser("test6")
c.Assert(err, check.IsNil)
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
pak.Used = true
app.db.Save(&pak)
_, err = app.checkKeyValidity(pak.Key)
c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed)
}
func (*Suite) TestPreAuthKeyACLTags(c *check.C) {
user, err := app.CreateUser("test8")
c.Assert(err, check.IsNil)
_, err = app.CreatePreAuthKey(user.Name, false, false, nil, []string{"badtag"})
c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected
tags := []string{"tag:test1", "tag:test2"}
tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"}
_, err = app.CreatePreAuthKey(user.Name, false, false, nil, tagsWithDuplicate)
c.Assert(err, check.IsNil)
listedPaks, err := app.ListPreAuthKeys("test8")
c.Assert(err, check.IsNil)
c.Assert(listedPaks[0].toProto().AclTags, check.DeepEquals, tags)
}

View File

@@ -0,0 +1,839 @@
package headscale
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"time"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
const (
// The CapabilityVersion is used by Tailscale clients to indicate
// their codebase version. Tailscale clients can communicate over TS2021
// from CapabilityVersion 28, but we only have good support for it
// since https://github.com/tailscale/tailscale/pull/4323 (Noise in any HTTPS port).
//
// Related to this change, there is https://github.com/tailscale/tailscale/pull/5379,
// where CapabilityVersion 39 is introduced to indicate #4323 was merged.
//
// See also https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go
NoiseCapabilityVersion = 39
)
// KeyHandler provides the Headscale pub key
// Listens in /key.
func (h *Headscale) KeyHandler(
writer http.ResponseWriter,
req *http.Request,
) {
// New Tailscale clients send a 'v' parameter to indicate the CurrentCapabilityVersion
clientCapabilityStr := req.URL.Query().Get("v")
if clientCapabilityStr != "" {
log.Debug().
Str("handler", "/key").
Str("v", clientCapabilityStr).
Msg("New noise client")
clientCapabilityVersion, err := strconv.Atoi(clientCapabilityStr)
if err != nil {
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("Wrong params"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return
}
// TS2021 (Tailscale v2 protocol) requires to have a different key
if clientCapabilityVersion >= NoiseCapabilityVersion {
resp := tailcfg.OverTLSPublicKeyResponse{
LegacyPublicKey: h.privateKey.Public(),
PublicKey: h.noisePrivateKey.Public(),
}
writer.Header().Set("Content-Type", "application/json")
writer.WriteHeader(http.StatusOK)
err = json.NewEncoder(writer).Encode(resp)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return
}
}
log.Debug().
Str("handler", "/key").
Msg("New legacy client")
// Old clients don't send a 'v' parameter, so we send the legacy public key
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err := writer.Write([]byte(MachinePublicKeyStripPrefix(h.privateKey.Public())))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
}
// handleRegisterCommon is the common logic for registering a client in the legacy and Noise protocols
//
// When using Noise, the machineKey is Zero.
func (h *Headscale) handleRegisterCommon(
writer http.ResponseWriter,
req *http.Request,
registerRequest tailcfg.RegisterRequest,
machineKey key.MachinePublic,
isNoise bool,
) {
now := time.Now().UTC()
machine, err := h.GetMachineByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey)
if errors.Is(err, gorm.ErrRecordNotFound) {
// If the machine has AuthKey set, handle registration via PreAuthKeys
if registerRequest.Auth.AuthKey != "" {
h.handleAuthKeyCommon(writer, registerRequest, machineKey, isNoise)
return
}
// Check if the node is waiting for interactive login.
//
// TODO(juan): We could use this field to improve our protocol implementation,
// and hold the request until the client closes it, or the interactive
// login is completed (i.e., the user registers the machine).
// This is not implemented yet, as it is no strictly required. The only side-effect
// is that the client will hammer headscale with requests until it gets a
// successful RegisterResponse.
if registerRequest.Followup != "" {
if _, ok := h.registrationCache.Get(NodePublicKeyStripPrefix(registerRequest.NodeKey)); ok {
log.Debug().
Caller().
Str("machine", registerRequest.Hostinfo.Hostname).
Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Str("follow_up", registerRequest.Followup).
Bool("noise", isNoise).
Msg("Machine is waiting for interactive login")
select {
case <-req.Context().Done():
return
case <-time.After(registrationHoldoff):
h.handleNewMachineCommon(writer, registerRequest, machineKey, isNoise)
return
}
}
}
log.Info().
Caller().
Str("machine", registerRequest.Hostinfo.Hostname).
Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Str("follow_up", registerRequest.Followup).
Bool("noise", isNoise).
Msg("New machine not yet in the database")
givenName, err := h.GenerateGivenName(
machineKey.String(),
registerRequest.Hostinfo.Hostname,
)
if err != nil {
log.Error().
Caller().
Str("func", "RegistrationHandler").
Str("hostinfo.name", registerRequest.Hostinfo.Hostname).
Err(err)
return
}
// The machine did not have a key to authenticate, which means
// that we rely on a method that calls back some how (OpenID or CLI)
// We create the machine and then keep it around until a callback
// happens
newMachine := Machine{
MachineKey: MachinePublicKeyStripPrefix(machineKey),
Hostname: registerRequest.Hostinfo.Hostname,
GivenName: givenName,
NodeKey: NodePublicKeyStripPrefix(registerRequest.NodeKey),
LastSeen: &now,
Expiry: &time.Time{},
}
if !registerRequest.Expiry.IsZero() {
log.Trace().
Caller().
Bool("noise", isNoise).
Str("machine", registerRequest.Hostinfo.Hostname).
Time("expiry", registerRequest.Expiry).
Msg("Non-zero expiry time requested")
newMachine.Expiry = &registerRequest.Expiry
}
h.registrationCache.Set(
newMachine.NodeKey,
newMachine,
registerCacheExpiration,
)
h.handleNewMachineCommon(writer, registerRequest, machineKey, isNoise)
return
}
// The machine is already in the DB. This could mean one of the following:
// - The machine is authenticated and ready to /map
// - We are doing a key refresh
// - The machine is logged out (or expired) and pending to be authorized. TODO(juan): We need to keep alive the connection here
if machine != nil {
// (juan): For a while we had a bug where we were not storing the MachineKey for the nodes using the TS2021,
// due to a misunderstanding of the protocol https://github.com/juanfont/headscale/issues/1054
// So if we have a not valid MachineKey (but we were able to fetch the machine with the NodeKeys), we update it.
var storedMachineKey key.MachinePublic
err = storedMachineKey.UnmarshalText(
[]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)),
)
if err != nil || storedMachineKey.IsZero() {
machine.MachineKey = MachinePublicKeyStripPrefix(machineKey)
if err := h.db.Save(&machine).Error; err != nil {
log.Error().
Caller().
Str("func", "RegistrationHandler").
Str("machine", machine.Hostname).
Err(err).
Msg("Error saving machine key to database")
return
}
}
// If the NodeKey stored in headscale is the same as the key presented in a registration
// request, then we have a node that is either:
// - Trying to log out (sending a expiry in the past)
// - A valid, registered machine, looking for /map
// - Expired machine wanting to reauthenticate
if machine.NodeKey == NodePublicKeyStripPrefix(registerRequest.NodeKey) {
// The client sends an Expiry in the past if the client is requesting to expire the key (aka logout)
// https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648
if !registerRequest.Expiry.IsZero() &&
registerRequest.Expiry.UTC().Before(now) {
h.handleMachineLogOutCommon(writer, *machine, machineKey, isNoise)
return
}
// If machine is not expired, and it is register, we have a already accepted this machine,
// let it proceed with a valid registration
if !machine.isExpired() {
h.handleMachineValidRegistrationCommon(writer, *machine, machineKey, isNoise)
return
}
}
// The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration
if machine.NodeKey == NodePublicKeyStripPrefix(registerRequest.OldNodeKey) &&
!machine.isExpired() {
h.handleMachineRefreshKeyCommon(
writer,
registerRequest,
*machine,
machineKey,
isNoise,
)
return
}
if registerRequest.Followup != "" {
select {
case <-req.Context().Done():
return
case <-time.After(registrationHoldoff):
}
}
// The machine has expired or it is logged out
h.handleMachineExpiredOrLoggedOutCommon(writer, registerRequest, *machine, machineKey, isNoise)
// TODO(juan): RegisterRequest includes an Expiry time, that we could optionally use
machine.Expiry = &time.Time{}
// If we are here it means the client needs to be reauthorized,
// we need to make sure the NodeKey matches the one in the request
// TODO(juan): What happens when using fast user switching between two
// headscale-managed tailnets?
machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey)
h.registrationCache.Set(
NodePublicKeyStripPrefix(registerRequest.NodeKey),
*machine,
registerCacheExpiration,
)
return
}
}
// handleAuthKeyCommon contains the logic to manage auth key client registration
// It is used both by the legacy and the new Noise protocol.
// When using Noise, the machineKey is Zero.
//
// TODO: check if any locks are needed around IP allocation.
func (h *Headscale) handleAuthKeyCommon(
writer http.ResponseWriter,
registerRequest tailcfg.RegisterRequest,
machineKey key.MachinePublic,
isNoise bool,
) {
log.Debug().
Str("func", "handleAuthKeyCommon").
Str("machine", registerRequest.Hostinfo.Hostname).
Bool("noise", isNoise).
Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname)
resp := tailcfg.RegisterResponse{}
pak, err := h.checkKeyValidity(registerRequest.Auth.AuthKey)
if err != nil {
log.Error().
Caller().
Str("func", "handleAuthKeyCommon").
Bool("noise", isNoise).
Str("machine", registerRequest.Hostinfo.Hostname).
Err(err).
Msg("Failed authentication via AuthKey")
resp.MachineAuthorized = false
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
if err != nil {
log.Error().
Caller().
Str("func", "handleAuthKeyCommon").
Bool("noise", isNoise).
Str("machine", registerRequest.Hostinfo.Hostname).
Err(err).
Msg("Cannot encode message")
http.Error(writer, "Internal server error", http.StatusInternalServerError)
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name).
Inc()
return
}
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusUnauthorized)
_, err = writer.Write(respBody)
if err != nil {
log.Error().
Caller().
Bool("noise", isNoise).
Err(err).
Msg("Failed to write response")
}
log.Error().
Caller().
Str("func", "handleAuthKeyCommon").
Bool("noise", isNoise).
Str("machine", registerRequest.Hostinfo.Hostname).
Msg("Failed authentication via AuthKey")
if pak != nil {
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name).
Inc()
} else {
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", "unknown").Inc()
}
return
}
log.Debug().
Str("func", "handleAuthKeyCommon").
Bool("noise", isNoise).
Str("machine", registerRequest.Hostinfo.Hostname).
Msg("Authentication key was valid, proceeding to acquire IP addresses")
nodeKey := NodePublicKeyStripPrefix(registerRequest.NodeKey)
// retrieve machine information if it exist
// The error is not important, because if it does not
// exist, then this is a new machine and we will move
// on to registration.
machine, _ := h.GetMachineByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey)
if machine != nil {
log.Trace().
Caller().
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msg("machine was already registered before, refreshing with new auth key")
machine.NodeKey = nodeKey
machine.AuthKeyID = uint(pak.ID)
err := h.RefreshMachine(machine, registerRequest.Expiry)
if err != nil {
log.Error().
Caller().
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Err(err).
Msg("Failed to refresh machine")
return
}
aclTags := pak.toProto().AclTags
if len(aclTags) > 0 {
// This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login
err = h.SetTags(machine, aclTags)
if err != nil {
log.Error().
Caller().
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Strs("aclTags", aclTags).
Err(err).
Msg("Failed to set tags after refreshing machine")
return
}
}
} else {
now := time.Now().UTC()
givenName, err := h.GenerateGivenName(MachinePublicKeyStripPrefix(machineKey), registerRequest.Hostinfo.Hostname)
if err != nil {
log.Error().
Caller().
Bool("noise", isNoise).
Str("func", "RegistrationHandler").
Str("hostinfo.name", registerRequest.Hostinfo.Hostname).
Err(err)
return
}
machineToRegister := Machine{
Hostname: registerRequest.Hostinfo.Hostname,
GivenName: givenName,
UserID: pak.User.ID,
MachineKey: MachinePublicKeyStripPrefix(machineKey),
RegisterMethod: RegisterMethodAuthKey,
Expiry: &registerRequest.Expiry,
NodeKey: nodeKey,
LastSeen: &now,
AuthKeyID: uint(pak.ID),
ForcedTags: pak.toProto().AclTags,
}
machine, err = h.RegisterMachine(
machineToRegister,
)
if err != nil {
log.Error().
Caller().
Bool("noise", isNoise).
Err(err).
Msg("could not register machine")
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name).
Inc()
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
}
err = h.UsePreAuthKey(pak)
if err != nil {
log.Error().
Caller().
Bool("noise", isNoise).
Err(err).
Msg("Failed to use pre-auth key")
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name).
Inc()
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
resp.MachineAuthorized = true
resp.User = *pak.User.toTailscaleUser()
// Provide LoginName when registering with pre-auth key
// Otherwise it will need to exec `tailscale up` twice to fetch the *LoginName*
resp.Login = *pak.User.toTailscaleLogin()
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
if err != nil {
log.Error().
Caller().
Bool("noise", isNoise).
Str("func", "handleAuthKeyCommon").
Str("machine", registerRequest.Hostinfo.Hostname).
Err(err).
Msg("Cannot encode message")
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "error", pak.User.Name).
Inc()
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
machineRegistrations.WithLabelValues("new", RegisterMethodAuthKey, "success", pak.User.Name).
Inc()
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(respBody)
if err != nil {
log.Error().
Caller().
Bool("noise", isNoise).
Err(err).
Msg("Failed to write response")
}
log.Info().
Str("func", "handleAuthKeyCommon").
Bool("noise", isNoise).
Str("machine", registerRequest.Hostinfo.Hostname).
Str("ips", strings.Join(machine.IPAddresses.ToStringSlice(), ", ")).
Msg("Successfully authenticated via AuthKey")
}
// handleNewMachineCommon exposes for both legacy and Noise the functionality to get a URL
// for authorizing the machine. This url is then showed to the user by the local Tailscale client.
func (h *Headscale) handleNewMachineCommon(
writer http.ResponseWriter,
registerRequest tailcfg.RegisterRequest,
machineKey key.MachinePublic,
isNoise bool,
) {
resp := tailcfg.RegisterResponse{}
// The machine registration is new, redirect the client to the registration URL
log.Debug().
Caller().
Bool("noise", isNoise).
Str("machine", registerRequest.Hostinfo.Hostname).
Msg("The node seems to be new, sending auth url")
if h.oauth2Config != nil {
resp.AuthURL = fmt.Sprintf(
"%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
registerRequest.NodeKey,
)
} else {
resp.AuthURL = fmt.Sprintf("%s/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
registerRequest.NodeKey)
}
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
if err != nil {
log.Error().
Caller().
Bool("noise", isNoise).
Err(err).
Msg("Cannot encode message")
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(respBody)
if err != nil {
log.Error().
Bool("noise", isNoise).
Caller().
Err(err).
Msg("Failed to write response")
}
log.Info().
Caller().
Bool("noise", isNoise).
Str("AuthURL", resp.AuthURL).
Str("machine", registerRequest.Hostinfo.Hostname).
Msg("Successfully sent auth url")
}
func (h *Headscale) handleMachineLogOutCommon(
writer http.ResponseWriter,
machine Machine,
machineKey key.MachinePublic,
isNoise bool,
) {
resp := tailcfg.RegisterResponse{}
log.Info().
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msg("Client requested logout")
err := h.ExpireMachine(&machine)
if err != nil {
log.Error().
Caller().
Bool("noise", isNoise).
Str("func", "handleMachineLogOutCommon").
Err(err).
Msg("Failed to expire machine")
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
resp.AuthURL = ""
resp.MachineAuthorized = false
resp.NodeKeyExpired = true
resp.User = *machine.User.toTailscaleUser()
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
if err != nil {
log.Error().
Caller().
Bool("noise", isNoise).
Err(err).
Msg("Cannot encode message")
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(respBody)
if err != nil {
log.Error().
Bool("noise", isNoise).
Caller().
Err(err).
Msg("Failed to write response")
return
}
if machine.isEphemeral() {
err = h.HardDeleteMachine(&machine)
if err != nil {
log.Error().
Err(err).
Str("machine", machine.Hostname).
Msg("Cannot delete ephemeral machine from the database")
}
return
}
log.Info().
Caller().
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msg("Successfully logged out")
}
func (h *Headscale) handleMachineValidRegistrationCommon(
writer http.ResponseWriter,
machine Machine,
machineKey key.MachinePublic,
isNoise bool,
) {
resp := tailcfg.RegisterResponse{}
// The machine registration is valid, respond with redirect to /map
log.Debug().
Caller().
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msg("Client is registered and we have the current NodeKey. All clear to /map")
resp.AuthURL = ""
resp.MachineAuthorized = true
resp.User = *machine.User.toTailscaleUser()
resp.Login = *machine.User.toTailscaleLogin()
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
if err != nil {
log.Error().
Caller().
Bool("noise", isNoise).
Err(err).
Msg("Cannot encode message")
machineRegistrations.WithLabelValues("update", "web", "error", machine.User.Name).
Inc()
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
machineRegistrations.WithLabelValues("update", "web", "success", machine.User.Name).
Inc()
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(respBody)
if err != nil {
log.Error().
Caller().
Bool("noise", isNoise).
Err(err).
Msg("Failed to write response")
}
log.Info().
Caller().
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msg("Machine successfully authorized")
}
func (h *Headscale) handleMachineRefreshKeyCommon(
writer http.ResponseWriter,
registerRequest tailcfg.RegisterRequest,
machine Machine,
machineKey key.MachinePublic,
isNoise bool,
) {
resp := tailcfg.RegisterResponse{}
log.Info().
Caller().
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msg("We have the OldNodeKey in the database. This is a key refresh")
machine.NodeKey = NodePublicKeyStripPrefix(registerRequest.NodeKey)
if err := h.db.Save(&machine).Error; err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to update machine key in the database")
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
resp.AuthURL = ""
resp.User = *machine.User.toTailscaleUser()
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
if err != nil {
log.Error().
Caller().
Bool("noise", isNoise).
Err(err).
Msg("Cannot encode message")
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(respBody)
if err != nil {
log.Error().
Caller().
Bool("noise", isNoise).
Err(err).
Msg("Failed to write response")
}
log.Info().
Caller().
Bool("noise", isNoise).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("old_node_key", registerRequest.OldNodeKey.ShortString()).
Str("machine", machine.Hostname).
Msg("Node key successfully refreshed")
}
func (h *Headscale) handleMachineExpiredOrLoggedOutCommon(
writer http.ResponseWriter,
registerRequest tailcfg.RegisterRequest,
machine Machine,
machineKey key.MachinePublic,
isNoise bool,
) {
resp := tailcfg.RegisterResponse{}
if registerRequest.Auth.AuthKey != "" {
h.handleAuthKeyCommon(writer, registerRequest, machineKey, isNoise)
return
}
// The client has registered before, but has expired or logged out
log.Trace().
Caller().
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Msg("Machine registration has expired or logged out. Sending a auth url to register")
if h.oauth2Config != nil {
resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
registerRequest.NodeKey)
} else {
resp.AuthURL = fmt.Sprintf("%s/register/%s",
strings.TrimSuffix(h.cfg.ServerURL, "/"),
registerRequest.NodeKey)
}
respBody, err := h.marshalResponse(resp, machineKey, isNoise)
if err != nil {
log.Error().
Caller().
Bool("noise", isNoise).
Err(err).
Msg("Cannot encode message")
machineRegistrations.WithLabelValues("reauth", "web", "error", machine.User.Name).
Inc()
http.Error(writer, "Internal server error", http.StatusInternalServerError)
return
}
machineRegistrations.WithLabelValues("reauth", "web", "success", machine.User.Name).
Inc()
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err = writer.Write(respBody)
if err != nil {
log.Error().
Caller().
Bool("noise", isNoise).
Err(err).
Msg("Failed to write response")
}
log.Trace().
Caller().
Bool("noise", isNoise).
Str("machine_key", machineKey.ShortString()).
Str("node_key", registerRequest.NodeKey.ShortString()).
Str("node_key_old", registerRequest.OldNodeKey.ShortString()).
Str("machine", machine.Hostname).
Msg("Machine logged out. Sent AuthURL for reauthentication")
}

View File

@@ -0,0 +1,698 @@
package headscale
import (
"context"
"fmt"
"net/http"
"time"
"github.com/rs/zerolog/log"
"tailscale.com/tailcfg"
)
const (
keepAliveInterval = 60 * time.Second
)
type contextKey string
const machineNameContextKey = contextKey("machineName")
// handlePollCommon is the common code for the legacy and Noise protocols to
// managed the poll loop.
func (h *Headscale) handlePollCommon(
writer http.ResponseWriter,
ctx context.Context,
machine *Machine,
mapRequest tailcfg.MapRequest,
isNoise bool,
) {
machine.Hostname = mapRequest.Hostinfo.Hostname
machine.HostInfo = HostInfo(*mapRequest.Hostinfo)
machine.DiscoKey = DiscoPublicKeyStripPrefix(mapRequest.DiscoKey)
now := time.Now().UTC()
err := h.processMachineRoutes(machine)
if err != nil {
log.Error().
Caller().
Err(err).
Str("machine", machine.Hostname).
Msg("Error processing machine routes")
}
// update ACLRules with peer informations (to update server tags if necessary)
if h.aclPolicy != nil {
err := h.UpdateACLRules()
if err != nil {
log.Error().
Caller().
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Err(err)
}
// update routes with peer information
err = h.EnableAutoApprovedRoutes(machine)
if err != nil {
log.Error().
Caller().
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Err(err).
Msg("Error running auto approved routes")
}
}
// From Tailscale client:
//
// ReadOnly is whether the client just wants to fetch the MapResponse,
// without updating their Endpoints. The Endpoints field will be ignored and
// LastSeen will not be updated and peers will not be notified of changes.
//
// The intended use is for clients to discover the DERP map at start-up
// before their first real endpoint update.
if !mapRequest.ReadOnly {
machine.Endpoints = mapRequest.Endpoints
machine.LastSeen = &now
}
if err := h.db.Updates(machine).Error; err != nil {
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Bool("noise", isNoise).
Str("node_key", machine.NodeKey).
Str("machine", machine.Hostname).
Err(err).
Msg("Failed to persist/update machine in the database")
http.Error(writer, "", http.StatusInternalServerError)
return
}
}
mapResp, err := h.getMapResponseData(mapRequest, machine, isNoise)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Bool("noise", isNoise).
Str("node_key", machine.NodeKey).
Str("machine", machine.Hostname).
Err(err).
Msg("Failed to get Map response")
http.Error(writer, "", http.StatusInternalServerError)
return
}
// We update our peers if the client is not sending ReadOnly in the MapRequest
// so we don't distribute its initial request (it comes with
// empty endpoints to peers)
// Details on the protocol can be found in https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L696
log.Debug().
Str("handler", "PollNetMap").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Bool("readOnly", mapRequest.ReadOnly).
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Msg("Client map request processed")
if mapRequest.ReadOnly {
log.Info().
Str("handler", "PollNetMap").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msg("Client is starting up. Probably interested in a DERP map")
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err := writer.Write(mapResp)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
if f, ok := writer.(http.Flusher); ok {
f.Flush()
}
return
}
// There has been an update to _any_ of the nodes that the other nodes would
// need to know about
h.setLastStateChangeToNow()
// The request is not ReadOnly, so we need to set up channels for updating
// peers via longpoll
// Only create update channel if it has not been created
log.Trace().
Caller().
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msg("Loading or creating update channel")
const chanSize = 8
updateChan := make(chan struct{}, chanSize)
pollDataChan := make(chan []byte, chanSize)
defer closeChanWithLog(pollDataChan, machine.Hostname, "pollDataChan")
keepAliveChan := make(chan []byte)
if mapRequest.OmitPeers && !mapRequest.Stream {
log.Info().
Str("handler", "PollNetMap").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msg("Client sent endpoint update and is ok with a response without peer list")
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err := writer.Write(mapResp)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
// It sounds like we should update the nodes when we have received a endpoint update
// even tho the comments in the tailscale code dont explicitly say so.
updateRequestsFromNode.WithLabelValues(machine.User.Name, machine.Hostname, "endpoint-update").
Inc()
updateChan <- struct{}{}
return
} else if mapRequest.OmitPeers && mapRequest.Stream {
log.Warn().
Str("handler", "PollNetMap").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msg("Ignoring request, don't know how to handle it")
http.Error(writer, "", http.StatusBadRequest)
return
}
log.Info().
Str("handler", "PollNetMap").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msg("Client is ready to access the tailnet")
log.Info().
Str("handler", "PollNetMap").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msg("Sending initial map")
pollDataChan <- mapResp
log.Info().
Str("handler", "PollNetMap").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msg("Notifying peers")
updateRequestsFromNode.WithLabelValues(machine.User.Name, machine.Hostname, "full-update").
Inc()
updateChan <- struct{}{}
h.pollNetMapStream(
writer,
ctx,
machine,
mapRequest,
pollDataChan,
keepAliveChan,
updateChan,
isNoise,
)
log.Trace().
Str("handler", "PollNetMap").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msg("Finished stream, closing PollNetMap session")
}
// pollNetMapStream stream logic for /machine/map,
// ensuring we communicate updates and data to the connected clients.
func (h *Headscale) pollNetMapStream(
writer http.ResponseWriter,
ctxReq context.Context,
machine *Machine,
mapRequest tailcfg.MapRequest,
pollDataChan chan []byte,
keepAliveChan chan []byte,
updateChan chan struct{},
isNoise bool,
) {
h.pollNetMapStreamWG.Add(1)
defer h.pollNetMapStreamWG.Done()
ctx := context.WithValue(ctxReq, machineNameContextKey, machine.Hostname)
ctx, cancel := context.WithCancel(ctx)
defer cancel()
go h.scheduledPollWorker(
ctx,
updateChan,
keepAliveChan,
mapRequest,
machine,
isNoise,
)
log.Trace().
Str("handler", "pollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msg("Waiting for data to stream...")
log.Trace().
Str("handler", "pollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msgf("pollData is %#v, keepAliveChan is %#v, updateChan is %#v", pollDataChan, keepAliveChan, updateChan)
for {
select {
case data := <-pollDataChan:
log.Trace().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "pollData").
Int("bytes", len(data)).
Msg("Sending data received via pollData channel")
_, err := writer.Write(data)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "pollData").
Err(err).
Msg("Cannot write data")
return
}
flusher, ok := writer.(http.Flusher)
if !ok {
log.Error().
Caller().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "pollData").
Msg("Cannot cast writer to http.Flusher")
} else {
flusher.Flush()
}
log.Trace().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "pollData").
Int("bytes", len(data)).
Msg("Data from pollData channel written successfully")
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err = h.UpdateMachineFromDatabase(machine)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "pollData").
Err(err).
Msg("Cannot update machine from database")
// client has been removed from database
// since the stream opened, terminate connection.
return
}
now := time.Now().UTC()
machine.LastSeen = &now
lastStateUpdate.WithLabelValues(machine.User.Name, machine.Hostname).
Set(float64(now.Unix()))
machine.LastSuccessfulUpdate = &now
err = h.TouchMachine(machine)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "pollData").
Err(err).
Msg("Cannot update machine LastSuccessfulUpdate")
return
}
log.Trace().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "pollData").
Int("bytes", len(data)).
Msg("Machine entry in database updated successfully after sending data")
case data := <-keepAliveChan:
log.Trace().
Str("handler", "PollNetMapStream").
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Sending keep alive message")
_, err := writer.Write(data)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Err(err).
Msg("Cannot write keep alive message")
return
}
flusher, ok := writer.(http.Flusher)
if !ok {
log.Error().
Caller().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Msg("Cannot cast writer to http.Flusher")
} else {
flusher.Flush()
}
log.Trace().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Keep alive sent successfully")
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err = h.UpdateMachineFromDatabase(machine)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Err(err).
Msg("Cannot update machine from database")
// client has been removed from database
// since the stream opened, terminate connection.
return
}
now := time.Now().UTC()
machine.LastSeen = &now
err = h.TouchMachine(machine)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Err(err).
Msg("Cannot update machine LastSeen")
return
}
log.Trace().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "keepAlive").
Int("bytes", len(data)).
Msg("Machine updated successfully after sending keep alive")
case <-updateChan:
log.Trace().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "update").
Msg("Received a request for update")
updateRequestsReceivedOnChannel.WithLabelValues(machine.User.Name, machine.Hostname).
Inc()
if h.isOutdated(machine) {
var lastUpdate time.Time
if machine.LastSuccessfulUpdate != nil {
lastUpdate = *machine.LastSuccessfulUpdate
}
log.Debug().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Time("last_successful_update", lastUpdate).
Time("last_state_change", h.getLastStateChange(machine.User)).
Msgf("There has been updates since the last successful update to %s", machine.Hostname)
data, err := h.getMapResponseData(mapRequest, machine, isNoise)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "update").
Err(err).
Msg("Could not get the map update")
return
}
_, err = writer.Write(data)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "update").
Err(err).
Msg("Could not write the map response")
updateRequestsSentToNode.WithLabelValues(machine.User.Name, machine.Hostname, "failed").
Inc()
return
}
flusher, ok := writer.(http.Flusher)
if !ok {
log.Error().
Caller().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "update").
Msg("Cannot cast writer to http.Flusher")
} else {
flusher.Flush()
}
log.Trace().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "update").
Msg("Updated Map has been sent")
updateRequestsSentToNode.WithLabelValues(machine.User.Name, machine.Hostname, "success").
Inc()
// Keep track of the last successful update,
// we sometimes end in a state were the update
// is not picked up by a client and we use this
// to determine if we should "force" an update.
// TODO(kradalby): Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err = h.UpdateMachineFromDatabase(machine)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "update").
Err(err).
Msg("Cannot update machine from database")
// client has been removed from database
// since the stream opened, terminate connection.
return
}
now := time.Now().UTC()
lastStateUpdate.WithLabelValues(machine.User.Name, machine.Hostname).
Set(float64(now.Unix()))
machine.LastSuccessfulUpdate = &now
err = h.TouchMachine(machine)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "update").
Err(err).
Msg("Cannot update machine LastSuccessfulUpdate")
return
}
} else {
var lastUpdate time.Time
if machine.LastSuccessfulUpdate != nil {
lastUpdate = *machine.LastSuccessfulUpdate
}
log.Trace().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Time("last_successful_update", lastUpdate).
Time("last_state_change", h.getLastStateChange(machine.User)).
Msgf("%s is up to date", machine.Hostname)
}
case <-ctx.Done():
log.Info().
Str("handler", "PollNetMapStream").
Str("machine", machine.Hostname).
Msg("The client has closed the connection")
// TODO: Abstract away all the database calls, this can cause race conditions
// when an outdated machine object is kept alive, e.g. db is update from
// command line, but then overwritten.
err := h.UpdateMachineFromDatabase(machine)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "Done").
Err(err).
Msg("Cannot update machine from database")
// client has been removed from database
// since the stream opened, terminate connection.
return
}
now := time.Now().UTC()
machine.LastSeen = &now
err = h.TouchMachine(machine)
if err != nil {
log.Error().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Str("channel", "Done").
Err(err).
Msg("Cannot update machine LastSeen")
}
// The connection has been closed, so we can stop polling.
return
case <-h.shutdownChan:
log.Info().
Str("handler", "PollNetMapStream").
Bool("noise", isNoise).
Str("machine", machine.Hostname).
Msg("The long-poll handler is shutting down")
return
}
}
}
func (h *Headscale) scheduledPollWorker(
ctx context.Context,
updateChan chan struct{},
keepAliveChan chan []byte,
mapRequest tailcfg.MapRequest,
machine *Machine,
isNoise bool,
) {
keepAliveTicker := time.NewTicker(keepAliveInterval)
updateCheckerTicker := time.NewTicker(h.cfg.NodeUpdateCheckInterval)
defer closeChanWithLog(
updateChan,
fmt.Sprint(ctx.Value(machineNameContextKey)),
"updateChan",
)
defer closeChanWithLog(
keepAliveChan,
fmt.Sprint(ctx.Value(machineNameContextKey)),
"keepAliveChan",
)
for {
select {
case <-ctx.Done():
return
case <-keepAliveTicker.C:
data, err := h.getMapKeepAliveResponseData(mapRequest, machine, isNoise)
if err != nil {
log.Error().
Str("func", "keepAlive").
Bool("noise", isNoise).
Err(err).
Msg("Error generating the keep alive msg")
return
}
log.Debug().
Str("func", "keepAlive").
Str("machine", machine.Hostname).
Bool("noise", isNoise).
Msg("Sending keepalive")
select {
case keepAliveChan <- data:
case <-ctx.Done():
return
}
case <-updateCheckerTicker.C:
log.Debug().
Str("func", "scheduledPollWorker").
Str("machine", machine.Hostname).
Bool("noise", isNoise).
Msg("Sending update request")
updateRequestsFromNode.WithLabelValues(machine.User.Name, machine.Hostname, "scheduled-update").
Inc()
select {
case updateChan <- struct{}{}:
case <-ctx.Done():
return
}
}
}
}
func closeChanWithLog[C chan []byte | chan struct{}](channel C, machine, name string) {
log.Trace().
Str("handler", "PollNetMap").
Str("machine", machine).
Str("channel", "Done").
Msg(fmt.Sprintf("Closing %s channel", name))
close(channel)
}

View File

@@ -0,0 +1,150 @@
package headscale
import (
"encoding/binary"
"encoding/json"
"sync"
"github.com/klauspost/compress/zstd"
"github.com/rs/zerolog/log"
"tailscale.com/smallzstd"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
func (h *Headscale) getMapResponseData(
mapRequest tailcfg.MapRequest,
machine *Machine,
isNoise bool,
) ([]byte, error) {
mapResponse, err := h.generateMapResponse(mapRequest, machine)
if err != nil {
return nil, err
}
if isNoise {
return h.marshalMapResponse(mapResponse, key.MachinePublic{}, mapRequest.Compress, isNoise)
}
var machineKey key.MachinePublic
err = machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot parse client key")
return nil, err
}
return h.marshalMapResponse(mapResponse, machineKey, mapRequest.Compress, isNoise)
}
func (h *Headscale) getMapKeepAliveResponseData(
mapRequest tailcfg.MapRequest,
machine *Machine,
isNoise bool,
) ([]byte, error) {
keepAliveResponse := tailcfg.MapResponse{
KeepAlive: true,
}
if isNoise {
return h.marshalMapResponse(keepAliveResponse, key.MachinePublic{}, mapRequest.Compress, isNoise)
}
var machineKey key.MachinePublic
err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machine.MachineKey)))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot parse client key")
return nil, err
}
return h.marshalMapResponse(keepAliveResponse, machineKey, mapRequest.Compress, isNoise)
}
func (h *Headscale) marshalResponse(
resp interface{},
machineKey key.MachinePublic,
isNoise bool,
) ([]byte, error) {
jsonBody, err := json.Marshal(resp)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot marshal response")
return nil, err
}
if isNoise {
return jsonBody, nil
}
return h.privateKey.SealTo(machineKey, jsonBody), nil
}
func (h *Headscale) marshalMapResponse(
resp interface{},
machineKey key.MachinePublic,
compression string,
isNoise bool,
) ([]byte, error) {
jsonBody, err := json.Marshal(resp)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot marshal map response")
}
var respBody []byte
if compression == ZstdCompression {
respBody = zstdEncode(jsonBody)
if !isNoise { // if legacy protocol
respBody = h.privateKey.SealTo(machineKey, respBody)
}
} else {
if !isNoise { // if legacy protocol
respBody = h.privateKey.SealTo(machineKey, jsonBody)
} else {
respBody = jsonBody
}
}
data := make([]byte, reservedResponseHeaderSize)
binary.LittleEndian.PutUint32(data, uint32(len(respBody)))
data = append(data, respBody...)
return data, nil
}
func zstdEncode(in []byte) []byte {
encoder, ok := zstdEncoderPool.Get().(*zstd.Encoder)
if !ok {
panic("invalid type in sync pool")
}
out := encoder.EncodeAll(in, nil)
_ = encoder.Close()
zstdEncoderPool.Put(encoder)
return out
}
var zstdEncoderPool = &sync.Pool{
New: func() any {
encoder, err := smallzstd.NewEncoder(
nil,
zstd.WithEncoderLevel(zstd.SpeedFastest))
if err != nil {
panic(err)
}
return encoder
},
}

View File

@@ -0,0 +1,60 @@
//go:build ts2019
package headscale
import (
"io"
"net/http"
"github.com/gorilla/mux"
"github.com/rs/zerolog/log"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
// RegistrationHandler handles the actual registration process of a machine
// Endpoint /machine/:mkey.
func (h *Headscale) RegistrationHandler(
writer http.ResponseWriter,
req *http.Request,
) {
vars := mux.Vars(req)
machineKeyStr, ok := vars["mkey"]
if !ok || machineKeyStr == "" {
log.Error().
Str("handler", "RegistrationHandler").
Msg("No machine ID in request")
http.Error(writer, "No machine ID in request", http.StatusBadRequest)
return
}
body, _ := io.ReadAll(req.Body)
var machineKey key.MachinePublic
err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr)))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot parse machine key")
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
http.Error(writer, "Cannot parse machine key", http.StatusBadRequest)
return
}
registerRequest := tailcfg.RegisterRequest{}
err = decode(body, &registerRequest, &machineKey, h.privateKey)
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot decode message")
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
http.Error(writer, "Cannot decode message", http.StatusBadRequest)
return
}
h.handleRegisterCommon(writer, req, registerRequest, machineKey, false)
}

View File

@@ -0,0 +1,96 @@
//go:build ts2019
package headscale
import (
"errors"
"io"
"net/http"
"github.com/gorilla/mux"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
// PollNetMapHandler takes care of /machine/:id/map
//
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
// the clients when something in the network changes.
//
// The clients POST stuff like HostInfo and their Endpoints here, but
// only after their first request (marked with the ReadOnly field).
//
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
func (h *Headscale) PollNetMapHandler(
writer http.ResponseWriter,
req *http.Request,
) {
vars := mux.Vars(req)
machineKeyStr, ok := vars["mkey"]
if !ok || machineKeyStr == "" {
log.Error().
Str("handler", "PollNetMap").
Msg("No machine key in request")
http.Error(writer, "No machine key in request", http.StatusBadRequest)
return
}
log.Trace().
Str("handler", "PollNetMap").
Str("id", machineKeyStr).
Msg("PollNetMapHandler called")
body, _ := io.ReadAll(req.Body)
var machineKey key.MachinePublic
err := machineKey.UnmarshalText([]byte(MachinePublicKeyEnsurePrefix(machineKeyStr)))
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Err(err).
Msg("Cannot parse client key")
http.Error(writer, "Cannot parse client key", http.StatusBadRequest)
return
}
mapRequest := tailcfg.MapRequest{}
err = decode(body, &mapRequest, &machineKey, h.privateKey)
if err != nil {
log.Error().
Str("handler", "PollNetMap").
Err(err).
Msg("Cannot decode message")
http.Error(writer, "Cannot decode message", http.StatusBadRequest)
return
}
machine, err := h.GetMachineByMachineKey(machineKey)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn().
Str("handler", "PollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", machineKey.String())
http.Error(writer, "", http.StatusUnauthorized)
return
}
log.Error().
Str("handler", "PollNetMap").
Msgf("Failed to fetch machine from the database with Machine key: %s", machineKey.String())
http.Error(writer, "", http.StatusInternalServerError)
return
}
log.Trace().
Str("handler", "PollNetMap").
Str("id", machineKeyStr).
Str("machine", machine.Hostname).
Msg("A machine is entering polling via the legacy protocol")
h.handlePollCommon(writer, req.Context(), machine, mapRequest, false)
}

View File

@@ -0,0 +1,44 @@
package headscale
import (
"encoding/json"
"io"
"net/http"
"github.com/rs/zerolog/log"
"tailscale.com/tailcfg"
)
// // NoiseRegistrationHandler handles the actual registration process of a machine.
func (ns *noiseServer) NoiseRegistrationHandler(
writer http.ResponseWriter,
req *http.Request,
) {
log.Trace().Caller().Msgf("Noise registration handler for client %s", req.RemoteAddr)
if req.Method != http.MethodPost {
http.Error(writer, "Wrong method", http.StatusMethodNotAllowed)
return
}
log.Trace().
Any("headers", req.Header).
Msg("Headers")
body, _ := io.ReadAll(req.Body)
registerRequest := tailcfg.RegisterRequest{}
if err := json.Unmarshal(body, &registerRequest); err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot parse RegisterRequest")
machineRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc()
http.Error(writer, "Internal error", http.StatusInternalServerError)
return
}
ns.nodeKey = registerRequest.NodeKey
ns.headscale.handleRegisterCommon(writer, req, registerRequest, ns.conn.Peer(), true)
}

View File

@@ -0,0 +1,74 @@
package headscale
import (
"encoding/json"
"errors"
"io"
"net/http"
"github.com/rs/zerolog/log"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
// NoisePollNetMapHandler takes care of /machine/:id/map using the Noise protocol
//
// This is the busiest endpoint, as it keeps the HTTP long poll that updates
// the clients when something in the network changes.
//
// The clients POST stuff like HostInfo and their Endpoints here, but
// only after their first request (marked with the ReadOnly field).
//
// At this moment the updates are sent in a quite horrendous way, but they kinda work.
func (ns *noiseServer) NoisePollNetMapHandler(
writer http.ResponseWriter,
req *http.Request,
) {
log.Trace().
Str("handler", "NoisePollNetMap").
Msg("PollNetMapHandler called")
log.Trace().
Any("headers", req.Header).
Msg("Headers")
body, _ := io.ReadAll(req.Body)
mapRequest := tailcfg.MapRequest{}
if err := json.Unmarshal(body, &mapRequest); err != nil {
log.Error().
Caller().
Err(err).
Msg("Cannot parse MapRequest")
http.Error(writer, "Internal error", http.StatusInternalServerError)
return
}
ns.nodeKey = mapRequest.NodeKey
machine, err := ns.headscale.GetMachineByAnyKey(ns.conn.Peer(), mapRequest.NodeKey, key.NodePublic{})
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn().
Str("handler", "NoisePollNetMap").
Msgf("Ignoring request, cannot find machine with key %s", mapRequest.NodeKey.String())
http.Error(writer, "Internal error", http.StatusNotFound)
return
}
log.Error().
Str("handler", "NoisePollNetMap").
Msgf("Failed to fetch machine from the database with node key: %s", mapRequest.NodeKey.String())
http.Error(writer, "Internal error", http.StatusInternalServerError)
return
}
log.Debug().
Str("handler", "NoisePollNetMap").
Str("machine", machine.Hostname).
Msg("A machine is entering polling via the Noise protocol")
ns.headscale.handlePollCommon(writer, req.Context(), machine, mapRequest, true)
}

428
hscontrol/routes.go Normal file
View File

@@ -0,0 +1,428 @@
package headscale
import (
"errors"
"fmt"
"net/netip"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/rs/zerolog/log"
"google.golang.org/protobuf/types/known/timestamppb"
"gorm.io/gorm"
)
const (
ErrRouteIsNotAvailable = Error("route is not available")
)
var (
ExitRouteV4 = netip.MustParsePrefix("0.0.0.0/0")
ExitRouteV6 = netip.MustParsePrefix("::/0")
)
type Route struct {
gorm.Model
MachineID uint64
Machine Machine
Prefix IPPrefix
Advertised bool
Enabled bool
IsPrimary bool
}
type Routes []Route
func (r *Route) String() string {
return fmt.Sprintf("%s:%s", r.Machine, netip.Prefix(r.Prefix).String())
}
func (r *Route) isExitRoute() bool {
return netip.Prefix(r.Prefix) == ExitRouteV4 || netip.Prefix(r.Prefix) == ExitRouteV6
}
func (rs Routes) toPrefixes() []netip.Prefix {
prefixes := make([]netip.Prefix, len(rs))
for i, r := range rs {
prefixes[i] = netip.Prefix(r.Prefix)
}
return prefixes
}
func (h *Headscale) GetRoutes() ([]Route, error) {
var routes []Route
err := h.db.Preload("Machine").Find(&routes).Error
if err != nil {
return nil, err
}
return routes, nil
}
func (h *Headscale) GetMachineRoutes(m *Machine) ([]Route, error) {
var routes []Route
err := h.db.
Preload("Machine").
Where("machine_id = ?", m.ID).
Find(&routes).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
return routes, nil
}
func (h *Headscale) GetRoute(id uint64) (*Route, error) {
var route Route
err := h.db.Preload("Machine").First(&route, id).Error
if err != nil {
return nil, err
}
return &route, nil
}
func (h *Headscale) EnableRoute(id uint64) error {
route, err := h.GetRoute(id)
if err != nil {
return err
}
// Tailscale requires both IPv4 and IPv6 exit routes to
// be enabled at the same time, as per
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
if route.isExitRoute() {
return h.enableRoutes(&route.Machine, ExitRouteV4.String(), ExitRouteV6.String())
}
return h.enableRoutes(&route.Machine, netip.Prefix(route.Prefix).String())
}
func (h *Headscale) DisableRoute(id uint64) error {
route, err := h.GetRoute(id)
if err != nil {
return err
}
// Tailscale requires both IPv4 and IPv6 exit routes to
// be enabled at the same time, as per
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
if !route.isExitRoute() {
route.Enabled = false
route.IsPrimary = false
err = h.db.Save(route).Error
if err != nil {
return err
}
return h.handlePrimarySubnetFailover()
}
routes, err := h.GetMachineRoutes(&route.Machine)
if err != nil {
return err
}
for i := range routes {
if routes[i].isExitRoute() {
routes[i].Enabled = false
routes[i].IsPrimary = false
err = h.db.Save(&routes[i]).Error
if err != nil {
return err
}
}
}
return h.handlePrimarySubnetFailover()
}
func (h *Headscale) DeleteRoute(id uint64) error {
route, err := h.GetRoute(id)
if err != nil {
return err
}
// Tailscale requires both IPv4 and IPv6 exit routes to
// be enabled at the same time, as per
// https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002
if !route.isExitRoute() {
if err := h.db.Unscoped().Delete(&route).Error; err != nil {
return err
}
return h.handlePrimarySubnetFailover()
}
routes, err := h.GetMachineRoutes(&route.Machine)
if err != nil {
return err
}
routesToDelete := []Route{}
for _, r := range routes {
if r.isExitRoute() {
routesToDelete = append(routesToDelete, r)
}
}
if err := h.db.Unscoped().Delete(&routesToDelete).Error; err != nil {
return err
}
return h.handlePrimarySubnetFailover()
}
func (h *Headscale) DeleteMachineRoutes(m *Machine) error {
routes, err := h.GetMachineRoutes(m)
if err != nil {
return err
}
for i := range routes {
if err := h.db.Unscoped().Delete(&routes[i]).Error; err != nil {
return err
}
}
return h.handlePrimarySubnetFailover()
}
// isUniquePrefix returns if there is another machine providing the same route already.
func (h *Headscale) isUniquePrefix(route Route) bool {
var count int64
h.db.
Model(&Route{}).
Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?",
route.Prefix,
route.MachineID,
true, true).Count(&count)
return count == 0
}
func (h *Headscale) getPrimaryRoute(prefix netip.Prefix) (*Route, error) {
var route Route
err := h.db.
Preload("Machine").
Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", IPPrefix(prefix), true, true, true).
First(&route).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, err
}
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, gorm.ErrRecordNotFound
}
return &route, nil
}
// getMachinePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover)
// Exit nodes are not considered for this, as they are never marked as Primary.
func (h *Headscale) getMachinePrimaryRoutes(m *Machine) ([]Route, error) {
var routes []Route
err := h.db.
Preload("Machine").
Where("machine_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", m.ID, true, true, true).
Find(&routes).Error
if err != nil {
return nil, err
}
return routes, nil
}
func (h *Headscale) processMachineRoutes(machine *Machine) error {
currentRoutes := []Route{}
err := h.db.Where("machine_id = ?", machine.ID).Find(&currentRoutes).Error
if err != nil {
return err
}
advertisedRoutes := map[netip.Prefix]bool{}
for _, prefix := range machine.HostInfo.RoutableIPs {
advertisedRoutes[prefix] = false
}
for pos, route := range currentRoutes {
if _, ok := advertisedRoutes[netip.Prefix(route.Prefix)]; ok {
if !route.Advertised {
currentRoutes[pos].Advertised = true
err := h.db.Save(&currentRoutes[pos]).Error
if err != nil {
return err
}
}
advertisedRoutes[netip.Prefix(route.Prefix)] = true
} else if route.Advertised {
currentRoutes[pos].Advertised = false
currentRoutes[pos].Enabled = false
err := h.db.Save(&currentRoutes[pos]).Error
if err != nil {
return err
}
}
}
for prefix, exists := range advertisedRoutes {
if !exists {
route := Route{
MachineID: machine.ID,
Prefix: IPPrefix(prefix),
Advertised: true,
Enabled: false,
}
err := h.db.Create(&route).Error
if err != nil {
return err
}
}
}
return nil
}
func (h *Headscale) handlePrimarySubnetFailover() error {
// first, get all the enabled routes
var routes []Route
err := h.db.
Preload("Machine").
Where("advertised = ? AND enabled = ?", true, true).
Find(&routes).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
log.Error().Err(err).Msg("error getting routes")
}
routesChanged := false
for pos, route := range routes {
if route.isExitRoute() {
continue
}
if !route.IsPrimary {
_, err := h.getPrimaryRoute(netip.Prefix(route.Prefix))
if h.isUniquePrefix(route) || errors.Is(err, gorm.ErrRecordNotFound) {
log.Info().
Str("prefix", netip.Prefix(route.Prefix).String()).
Str("machine", route.Machine.GivenName).
Msg("Setting primary route")
routes[pos].IsPrimary = true
err := h.db.Save(&routes[pos]).Error
if err != nil {
log.Error().Err(err).Msg("error marking route as primary")
return err
}
routesChanged = true
continue
}
}
if route.IsPrimary {
if route.Machine.isOnline() {
continue
}
// machine offline, find a new primary
log.Info().
Str("machine", route.Machine.Hostname).
Str("prefix", netip.Prefix(route.Prefix).String()).
Msgf("machine offline, finding a new primary subnet")
// find a new primary route
var newPrimaryRoutes []Route
err := h.db.
Preload("Machine").
Where("prefix = ? AND machine_id != ? AND advertised = ? AND enabled = ?",
route.Prefix,
route.MachineID,
true, true).
Find(&newPrimaryRoutes).Error
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
log.Error().Err(err).Msg("error finding new primary route")
return err
}
var newPrimaryRoute *Route
for pos, r := range newPrimaryRoutes {
if r.Machine.isOnline() {
newPrimaryRoute = &newPrimaryRoutes[pos]
break
}
}
if newPrimaryRoute == nil {
log.Warn().
Str("machine", route.Machine.Hostname).
Str("prefix", netip.Prefix(route.Prefix).String()).
Msgf("no alternative primary route found")
continue
}
log.Info().
Str("old_machine", route.Machine.Hostname).
Str("prefix", netip.Prefix(route.Prefix).String()).
Str("new_machine", newPrimaryRoute.Machine.Hostname).
Msgf("found new primary route")
// disable the old primary route
routes[pos].IsPrimary = false
err = h.db.Save(&routes[pos]).Error
if err != nil {
log.Error().Err(err).Msg("error disabling old primary route")
return err
}
// enable the new primary route
newPrimaryRoute.IsPrimary = true
err = h.db.Save(&newPrimaryRoute).Error
if err != nil {
log.Error().Err(err).Msg("error enabling new primary route")
return err
}
routesChanged = true
}
}
if routesChanged {
h.setLastStateChangeToNow()
}
return nil
}
func (rs Routes) toProto() []*v1.Route {
protoRoutes := []*v1.Route{}
for _, route := range rs {
protoRoute := v1.Route{
Id: uint64(route.ID),
Machine: route.Machine.toProto(),
Prefix: netip.Prefix(route.Prefix).String(),
Advertised: route.Advertised,
Enabled: route.Enabled,
IsPrimary: route.IsPrimary,
CreatedAt: timestamppb.New(route.CreatedAt),
UpdatedAt: timestamppb.New(route.UpdatedAt),
}
if route.DeletedAt.Valid {
protoRoute.DeletedAt = timestamppb.New(route.DeletedAt.Time)
}
protoRoutes = append(protoRoutes, &protoRoute)
}
return protoRoutes
}

550
hscontrol/routes_test.go Normal file
View File

@@ -0,0 +1,550 @@
package headscale
import (
"net/netip"
"time"
"gopkg.in/check.v1"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
func (s *Suite) TestGetRoutes(c *check.C) {
user, err := app.CreateUser("test")
c.Assert(err, check.IsNil)
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = app.GetMachine("test", "test_get_route_machine")
c.Assert(err, check.NotNil)
route, err := netip.ParsePrefix("10.0.0.0/24")
c.Assert(err, check.IsNil)
hostInfo := tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{route},
}
machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "test_get_route_machine",
UserID: user.ID,
RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: HostInfo(hostInfo),
}
app.db.Save(&machine)
err = app.processMachineRoutes(&machine)
c.Assert(err, check.IsNil)
advertisedRoutes, err := app.GetAdvertisedRoutes(&machine)
c.Assert(err, check.IsNil)
c.Assert(len(advertisedRoutes), check.Equals, 1)
err = app.enableRoutes(&machine, "192.168.0.0/24")
c.Assert(err, check.NotNil)
err = app.enableRoutes(&machine, "10.0.0.0/24")
c.Assert(err, check.IsNil)
}
func (s *Suite) TestGetEnableRoutes(c *check.C) {
user, err := app.CreateUser("test")
c.Assert(err, check.IsNil)
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = app.GetMachine("test", "test_enable_route_machine")
c.Assert(err, check.NotNil)
route, err := netip.ParsePrefix(
"10.0.0.0/24",
)
c.Assert(err, check.IsNil)
route2, err := netip.ParsePrefix(
"150.0.10.0/25",
)
c.Assert(err, check.IsNil)
hostInfo := tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{route, route2},
}
machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "test_enable_route_machine",
UserID: user.ID,
RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: HostInfo(hostInfo),
}
app.db.Save(&machine)
err = app.processMachineRoutes(&machine)
c.Assert(err, check.IsNil)
availableRoutes, err := app.GetAdvertisedRoutes(&machine)
c.Assert(err, check.IsNil)
c.Assert(err, check.IsNil)
c.Assert(len(availableRoutes), check.Equals, 2)
noEnabledRoutes, err := app.GetEnabledRoutes(&machine)
c.Assert(err, check.IsNil)
c.Assert(len(noEnabledRoutes), check.Equals, 0)
err = app.enableRoutes(&machine, "192.168.0.0/24")
c.Assert(err, check.NotNil)
err = app.enableRoutes(&machine, "10.0.0.0/24")
c.Assert(err, check.IsNil)
enabledRoutes, err := app.GetEnabledRoutes(&machine)
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes), check.Equals, 1)
// Adding it twice will just let it pass through
err = app.enableRoutes(&machine, "10.0.0.0/24")
c.Assert(err, check.IsNil)
enableRoutesAfterDoubleApply, err := app.GetEnabledRoutes(&machine)
c.Assert(err, check.IsNil)
c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1)
err = app.enableRoutes(&machine, "150.0.10.0/25")
c.Assert(err, check.IsNil)
enabledRoutesWithAdditionalRoute, err := app.GetEnabledRoutes(&machine)
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutesWithAdditionalRoute), check.Equals, 2)
}
func (s *Suite) TestIsUniquePrefix(c *check.C) {
user, err := app.CreateUser("test")
c.Assert(err, check.IsNil)
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = app.GetMachine("test", "test_enable_route_machine")
c.Assert(err, check.NotNil)
route, err := netip.ParsePrefix(
"10.0.0.0/24",
)
c.Assert(err, check.IsNil)
route2, err := netip.ParsePrefix(
"150.0.10.0/25",
)
c.Assert(err, check.IsNil)
hostInfo1 := tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{route, route2},
}
machine1 := Machine{
ID: 1,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "test_enable_route_machine",
UserID: user.ID,
RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: HostInfo(hostInfo1),
}
app.db.Save(&machine1)
err = app.processMachineRoutes(&machine1)
c.Assert(err, check.IsNil)
err = app.enableRoutes(&machine1, route.String())
c.Assert(err, check.IsNil)
err = app.enableRoutes(&machine1, route2.String())
c.Assert(err, check.IsNil)
hostInfo2 := tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{route2},
}
machine2 := Machine{
ID: 2,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "test_enable_route_machine",
UserID: user.ID,
RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: HostInfo(hostInfo2),
}
app.db.Save(&machine2)
err = app.processMachineRoutes(&machine2)
c.Assert(err, check.IsNil)
err = app.enableRoutes(&machine2, route2.String())
c.Assert(err, check.IsNil)
enabledRoutes1, err := app.GetEnabledRoutes(&machine1)
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes1), check.Equals, 2)
enabledRoutes2, err := app.GetEnabledRoutes(&machine2)
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes2), check.Equals, 1)
routes, err := app.getMachinePrimaryRoutes(&machine1)
c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 2)
routes, err = app.getMachinePrimaryRoutes(&machine2)
c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 0)
}
func (s *Suite) TestSubnetFailover(c *check.C) {
user, err := app.CreateUser("test")
c.Assert(err, check.IsNil)
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = app.GetMachine("test", "test_enable_route_machine")
c.Assert(err, check.NotNil)
prefix, err := netip.ParsePrefix(
"10.0.0.0/24",
)
c.Assert(err, check.IsNil)
prefix2, err := netip.ParsePrefix(
"150.0.10.0/25",
)
c.Assert(err, check.IsNil)
hostInfo1 := tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{prefix, prefix2},
}
now := time.Now()
machine1 := Machine{
ID: 1,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "test_enable_route_machine",
UserID: user.ID,
RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: HostInfo(hostInfo1),
LastSeen: &now,
}
app.db.Save(&machine1)
err = app.processMachineRoutes(&machine1)
c.Assert(err, check.IsNil)
err = app.enableRoutes(&machine1, prefix.String())
c.Assert(err, check.IsNil)
err = app.enableRoutes(&machine1, prefix2.String())
c.Assert(err, check.IsNil)
err = app.handlePrimarySubnetFailover()
c.Assert(err, check.IsNil)
enabledRoutes1, err := app.GetEnabledRoutes(&machine1)
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes1), check.Equals, 2)
route, err := app.getPrimaryRoute(prefix)
c.Assert(err, check.IsNil)
c.Assert(route.MachineID, check.Equals, machine1.ID)
hostInfo2 := tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{prefix2},
}
machine2 := Machine{
ID: 2,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "test_enable_route_machine",
UserID: user.ID,
RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: HostInfo(hostInfo2),
LastSeen: &now,
}
app.db.Save(&machine2)
err = app.processMachineRoutes(&machine2)
c.Assert(err, check.IsNil)
err = app.enableRoutes(&machine2, prefix2.String())
c.Assert(err, check.IsNil)
err = app.handlePrimarySubnetFailover()
c.Assert(err, check.IsNil)
enabledRoutes1, err = app.GetEnabledRoutes(&machine1)
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes1), check.Equals, 2)
enabledRoutes2, err := app.GetEnabledRoutes(&machine2)
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes2), check.Equals, 1)
routes, err := app.getMachinePrimaryRoutes(&machine1)
c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 2)
routes, err = app.getMachinePrimaryRoutes(&machine2)
c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 0)
// lets make machine1 lastseen 10 mins ago
before := now.Add(-10 * time.Minute)
machine1.LastSeen = &before
err = app.db.Save(&machine1).Error
c.Assert(err, check.IsNil)
err = app.handlePrimarySubnetFailover()
c.Assert(err, check.IsNil)
routes, err = app.getMachinePrimaryRoutes(&machine1)
c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 1)
routes, err = app.getMachinePrimaryRoutes(&machine2)
c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 1)
machine2.HostInfo = HostInfo(tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{prefix, prefix2},
})
err = app.db.Save(&machine2).Error
c.Assert(err, check.IsNil)
err = app.processMachineRoutes(&machine2)
c.Assert(err, check.IsNil)
err = app.enableRoutes(&machine2, prefix.String())
c.Assert(err, check.IsNil)
err = app.handlePrimarySubnetFailover()
c.Assert(err, check.IsNil)
routes, err = app.getMachinePrimaryRoutes(&machine1)
c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 0)
routes, err = app.getMachinePrimaryRoutes(&machine2)
c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 2)
}
// TestAllowedIPRoutes tests that the AllowedIPs are correctly set for a node,
// including both the primary routes the node is responsible for, and the
// exit node routes if enabled.
func (s *Suite) TestAllowedIPRoutes(c *check.C) {
user, err := app.CreateUser("test")
c.Assert(err, check.IsNil)
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = app.GetMachine("test", "test_enable_route_machine")
c.Assert(err, check.NotNil)
prefix, err := netip.ParsePrefix(
"10.0.0.0/24",
)
c.Assert(err, check.IsNil)
prefix2, err := netip.ParsePrefix(
"150.0.10.0/25",
)
c.Assert(err, check.IsNil)
prefixExitNodeV4, err := netip.ParsePrefix(
"0.0.0.0/0",
)
c.Assert(err, check.IsNil)
prefixExitNodeV6, err := netip.ParsePrefix(
"::/0",
)
c.Assert(err, check.IsNil)
hostInfo1 := tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{prefix, prefix2, prefixExitNodeV4, prefixExitNodeV6},
}
nodeKey := key.NewNode()
discoKey := key.NewDisco()
machineKey := key.NewMachine()
now := time.Now()
machine1 := Machine{
ID: 1,
MachineKey: MachinePublicKeyStripPrefix(machineKey.Public()),
NodeKey: NodePublicKeyStripPrefix(nodeKey.Public()),
DiscoKey: DiscoPublicKeyStripPrefix(discoKey.Public()),
Hostname: "test_enable_route_machine",
UserID: user.ID,
RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: HostInfo(hostInfo1),
LastSeen: &now,
}
app.db.Save(&machine1)
err = app.processMachineRoutes(&machine1)
c.Assert(err, check.IsNil)
err = app.enableRoutes(&machine1, prefix.String())
c.Assert(err, check.IsNil)
// We do not enable this one on purpose to test that it is not enabled
// err = app.enableRoutes(&machine1, prefix2.String())
// c.Assert(err, check.IsNil)
routes, err := app.GetMachineRoutes(&machine1)
c.Assert(err, check.IsNil)
for _, route := range routes {
if route.isExitRoute() {
err = app.EnableRoute(uint64(route.ID))
c.Assert(err, check.IsNil)
// We only enable one exit route, so we can test that both are enabled
break
}
}
err = app.handlePrimarySubnetFailover()
c.Assert(err, check.IsNil)
enabledRoutes1, err := app.GetEnabledRoutes(&machine1)
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes1), check.Equals, 3)
peer, err := app.toNode(machine1, "headscale.net", nil)
c.Assert(err, check.IsNil)
c.Assert(len(peer.AllowedIPs), check.Equals, 3)
foundExitNodeV4 := false
foundExitNodeV6 := false
for _, allowedIP := range peer.AllowedIPs {
if allowedIP == prefixExitNodeV4 {
foundExitNodeV4 = true
}
if allowedIP == prefixExitNodeV6 {
foundExitNodeV6 = true
}
}
c.Assert(foundExitNodeV4, check.Equals, true)
c.Assert(foundExitNodeV6, check.Equals, true)
// Now we disable only one of the exit routes
// and we see if both are disabled
var exitRouteV4 Route
for _, route := range routes {
if route.isExitRoute() && netip.Prefix(route.Prefix) == prefixExitNodeV4 {
exitRouteV4 = route
break
}
}
err = app.DisableRoute(uint64(exitRouteV4.ID))
c.Assert(err, check.IsNil)
enabledRoutes1, err = app.GetEnabledRoutes(&machine1)
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes1), check.Equals, 1)
// and now we delete only one of the exit routes
// and we check if both are deleted
routes, err = app.GetMachineRoutes(&machine1)
c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 4)
err = app.DeleteRoute(uint64(exitRouteV4.ID))
c.Assert(err, check.IsNil)
routes, err = app.GetMachineRoutes(&machine1)
c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 2)
}
func (s *Suite) TestDeleteRoutes(c *check.C) {
user, err := app.CreateUser("test")
c.Assert(err, check.IsNil)
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = app.GetMachine("test", "test_enable_route_machine")
c.Assert(err, check.NotNil)
prefix, err := netip.ParsePrefix(
"10.0.0.0/24",
)
c.Assert(err, check.IsNil)
prefix2, err := netip.ParsePrefix(
"150.0.10.0/25",
)
c.Assert(err, check.IsNil)
hostInfo1 := tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{prefix, prefix2},
}
now := time.Now()
machine1 := Machine{
ID: 1,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "test_enable_route_machine",
UserID: user.ID,
RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: HostInfo(hostInfo1),
LastSeen: &now,
}
app.db.Save(&machine1)
err = app.processMachineRoutes(&machine1)
c.Assert(err, check.IsNil)
err = app.enableRoutes(&machine1, prefix.String())
c.Assert(err, check.IsNil)
err = app.enableRoutes(&machine1, prefix2.String())
c.Assert(err, check.IsNil)
routes, err := app.GetMachineRoutes(&machine1)
c.Assert(err, check.IsNil)
err = app.DeleteRoute(uint64(routes[0].ID))
c.Assert(err, check.IsNil)
enabledRoutes1, err := app.GetEnabledRoutes(&machine1)
c.Assert(err, check.IsNil)
c.Assert(len(enabledRoutes1), check.Equals, 1)
}

94
hscontrol/swagger.go Normal file
View File

@@ -0,0 +1,94 @@
package headscale
import (
"bytes"
_ "embed"
"html/template"
"net/http"
"github.com/rs/zerolog/log"
)
//go:embed gen/openapiv2/headscale/v1/headscale.swagger.json
var apiV1JSON []byte
func SwaggerUI(
writer http.ResponseWriter,
req *http.Request,
) {
swaggerTemplate := template.Must(template.New("swagger").Parse(`
<html>
<head>
<link rel="stylesheet" type="text/css" href="https://unpkg.com/swagger-ui-dist@3/swagger-ui.css">
<script src="https://unpkg.com/swagger-ui-dist@3/swagger-ui-standalone-preset.js"></script>
<script src="https://unpkg.com/swagger-ui-dist@3/swagger-ui-bundle.js" charset="UTF-8"></script>
</head>
<body>
<div id="swagger-ui"></div>
<script>
window.addEventListener('load', (event) => {
const ui = SwaggerUIBundle({
url: "/swagger/v1/openapiv2.json",
dom_id: '#swagger-ui',
presets: [
SwaggerUIBundle.presets.apis,
SwaggerUIBundle.SwaggerUIStandalonePreset
],
plugins: [
SwaggerUIBundle.plugins.DownloadUrl
],
deepLinking: true,
// TODO(kradalby): Figure out why this does not work
// layout: "StandaloneLayout",
})
window.ui = ui
});
</script>
</body>
</html>`))
var payload bytes.Buffer
if err := swaggerTemplate.Execute(&payload, struct{}{}); err != nil {
log.Error().
Caller().
Err(err).
Msg("Could not render Swagger")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, err := writer.Write([]byte("Could not render Swagger"))
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
return
}
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
_, err := writer.Write(payload.Bytes())
if err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
}
func SwaggerAPIv1(
writer http.ResponseWriter,
req *http.Request,
) {
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
if _, err := writer.Write(apiV1JSON); err != nil {
log.Error().
Caller().
Err(err).
Msg("Failed to write response")
}
}

301
hscontrol/users.go Normal file
View File

@@ -0,0 +1,301 @@
package headscale
import (
"errors"
"fmt"
"regexp"
"strconv"
"strings"
"time"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/rs/zerolog/log"
"google.golang.org/protobuf/types/known/timestamppb"
"gorm.io/gorm"
"tailscale.com/tailcfg"
)
const (
ErrUserExists = Error("User already exists")
ErrUserNotFound = Error("User not found")
ErrUserStillHasNodes = Error("User not empty: node(s) found")
ErrInvalidUserName = Error("Invalid user name")
)
const (
// value related to RFC 1123 and 952.
labelHostnameLength = 63
)
var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+")
// User is the way Headscale implements the concept of users in Tailscale
//
// At the end of the day, users in Tailscale are some kind of 'bubbles' or users
// that contain our machines.
type User struct {
gorm.Model
Name string `gorm:"unique"`
}
// CreateUser creates a new User. Returns error if could not be created
// or another user already exists.
func (h *Headscale) CreateUser(name string) (*User, error) {
err := CheckForFQDNRules(name)
if err != nil {
return nil, err
}
user := User{}
if err := h.db.Where("name = ?", name).First(&user).Error; err == nil {
return nil, ErrUserExists
}
user.Name = name
if err := h.db.Create(&user).Error; err != nil {
log.Error().
Str("func", "CreateUser").
Err(err).
Msg("Could not create row")
return nil, err
}
return &user, nil
}
// DestroyUser destroys a User. Returns error if the User does
// not exist or if there are machines associated with it.
func (h *Headscale) DestroyUser(name string) error {
user, err := h.GetUser(name)
if err != nil {
return ErrUserNotFound
}
machines, err := h.ListMachinesByUser(name)
if err != nil {
return err
}
if len(machines) > 0 {
return ErrUserStillHasNodes
}
keys, err := h.ListPreAuthKeys(name)
if err != nil {
return err
}
for _, key := range keys {
err = h.DestroyPreAuthKey(key)
if err != nil {
return err
}
}
if result := h.db.Unscoped().Delete(&user); result.Error != nil {
return result.Error
}
return nil
}
// RenameUser renames a User. Returns error if the User does
// not exist or if another User exists with the new name.
func (h *Headscale) RenameUser(oldName, newName string) error {
var err error
oldUser, err := h.GetUser(oldName)
if err != nil {
return err
}
err = CheckForFQDNRules(newName)
if err != nil {
return err
}
_, err = h.GetUser(newName)
if err == nil {
return ErrUserExists
}
if !errors.Is(err, ErrUserNotFound) {
return err
}
oldUser.Name = newName
if result := h.db.Save(&oldUser); result.Error != nil {
return result.Error
}
return nil
}
// GetUser fetches a user by name.
func (h *Headscale) GetUser(name string) (*User, error) {
user := User{}
if result := h.db.First(&user, "name = ?", name); errors.Is(
result.Error,
gorm.ErrRecordNotFound,
) {
return nil, ErrUserNotFound
}
return &user, nil
}
// ListUsers gets all the existing users.
func (h *Headscale) ListUsers() ([]User, error) {
users := []User{}
if err := h.db.Find(&users).Error; err != nil {
return nil, err
}
return users, nil
}
// ListMachinesByUser gets all the nodes in a given user.
func (h *Headscale) ListMachinesByUser(name string) ([]Machine, error) {
err := CheckForFQDNRules(name)
if err != nil {
return nil, err
}
user, err := h.GetUser(name)
if err != nil {
return nil, err
}
machines := []Machine{}
if err := h.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&Machine{UserID: user.ID}).Find(&machines).Error; err != nil {
return nil, err
}
return machines, nil
}
// SetMachineUser assigns a Machine to a user.
func (h *Headscale) SetMachineUser(machine *Machine, username string) error {
err := CheckForFQDNRules(username)
if err != nil {
return err
}
user, err := h.GetUser(username)
if err != nil {
return err
}
machine.User = *user
if result := h.db.Save(&machine); result.Error != nil {
return result.Error
}
return nil
}
func (n *User) toTailscaleUser() *tailcfg.User {
user := tailcfg.User{
ID: tailcfg.UserID(n.ID),
LoginName: n.Name,
DisplayName: n.Name,
ProfilePicURL: "",
Domain: "headscale.net",
Logins: []tailcfg.LoginID{},
Created: time.Time{},
}
return &user
}
func (n *User) toTailscaleLogin() *tailcfg.Login {
login := tailcfg.Login{
ID: tailcfg.LoginID(n.ID),
LoginName: n.Name,
DisplayName: n.Name,
ProfilePicURL: "",
Domain: "headscale.net",
}
return &login
}
func (h *Headscale) getMapResponseUserProfiles(
machine Machine,
peers Machines,
) []tailcfg.UserProfile {
userMap := make(map[string]User)
userMap[machine.User.Name] = machine.User
for _, peer := range peers {
userMap[peer.User.Name] = peer.User // not worth checking if already is there
}
profiles := []tailcfg.UserProfile{}
for _, user := range userMap {
displayName := user.Name
if h.cfg.BaseDomain != "" {
displayName = fmt.Sprintf("%s@%s", user.Name, h.cfg.BaseDomain)
}
profiles = append(profiles,
tailcfg.UserProfile{
ID: tailcfg.UserID(user.ID),
LoginName: user.Name,
DisplayName: displayName,
})
}
return profiles
}
func (n *User) toProto() *v1.User {
return &v1.User{
Id: strconv.FormatUint(uint64(n.ID), Base10),
Name: n.Name,
CreatedAt: timestamppb.New(n.CreatedAt),
}
}
// NormalizeToFQDNRules will replace forbidden chars in user
// it can also return an error if the user doesn't respect RFC 952 and 1123.
func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) {
name = strings.ToLower(name)
name = strings.ReplaceAll(name, "'", "")
atIdx := strings.Index(name, "@")
if stripEmailDomain && atIdx > 0 {
name = name[:atIdx]
} else {
name = strings.ReplaceAll(name, "@", ".")
}
name = invalidCharsInUserRegex.ReplaceAllString(name, "-")
for _, elt := range strings.Split(name, ".") {
if len(elt) > labelHostnameLength {
return "", fmt.Errorf(
"label %v is more than 63 chars: %w",
elt,
ErrInvalidUserName,
)
}
}
return name, nil
}
func CheckForFQDNRules(name string) error {
if len(name) > labelHostnameLength {
return fmt.Errorf(
"DNS segment must not be over 63 chars. %v doesn't comply with this rule: %w",
name,
ErrInvalidUserName,
)
}
if strings.ToLower(name) != name {
return fmt.Errorf(
"DNS segment should be lowercase. %v doesn't comply with this rule: %w",
name,
ErrInvalidUserName,
)
}
if invalidCharsInUserRegex.MatchString(name) {
return fmt.Errorf(
"DNS segment should only be composed of lowercase ASCII letters numbers, hyphen and dots. %v doesn't comply with theses rules: %w",
name,
ErrInvalidUserName,
)
}
return nil
}

415
hscontrol/users_test.go Normal file
View File

@@ -0,0 +1,415 @@
package headscale
import (
"net/netip"
"testing"
"gopkg.in/check.v1"
"gorm.io/gorm"
)
func (s *Suite) TestCreateAndDestroyUser(c *check.C) {
user, err := app.CreateUser("test")
c.Assert(err, check.IsNil)
c.Assert(user.Name, check.Equals, "test")
users, err := app.ListUsers()
c.Assert(err, check.IsNil)
c.Assert(len(users), check.Equals, 1)
err = app.DestroyUser("test")
c.Assert(err, check.IsNil)
_, err = app.GetUser("test")
c.Assert(err, check.NotNil)
}
func (s *Suite) TestDestroyUserErrors(c *check.C) {
err := app.DestroyUser("test")
c.Assert(err, check.Equals, ErrUserNotFound)
user, err := app.CreateUser("test")
c.Assert(err, check.IsNil)
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
err = app.DestroyUser("test")
c.Assert(err, check.IsNil)
result := app.db.Preload("User").First(&pak, "key = ?", pak.Key)
// destroying a user also deletes all associated preauthkeys
c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound)
user, err = app.CreateUser("test")
c.Assert(err, check.IsNil)
pak, err = app.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testmachine",
UserID: user.ID,
RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
app.db.Save(&machine)
err = app.DestroyUser("test")
c.Assert(err, check.Equals, ErrUserStillHasNodes)
}
func (s *Suite) TestRenameUser(c *check.C) {
userTest, err := app.CreateUser("test")
c.Assert(err, check.IsNil)
c.Assert(userTest.Name, check.Equals, "test")
users, err := app.ListUsers()
c.Assert(err, check.IsNil)
c.Assert(len(users), check.Equals, 1)
err = app.RenameUser("test", "test-renamed")
c.Assert(err, check.IsNil)
_, err = app.GetUser("test")
c.Assert(err, check.Equals, ErrUserNotFound)
_, err = app.GetUser("test-renamed")
c.Assert(err, check.IsNil)
err = app.RenameUser("test-does-not-exit", "test")
c.Assert(err, check.Equals, ErrUserNotFound)
userTest2, err := app.CreateUser("test2")
c.Assert(err, check.IsNil)
c.Assert(userTest2.Name, check.Equals, "test2")
err = app.RenameUser("test2", "test-renamed")
c.Assert(err, check.Equals, ErrUserExists)
}
func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) {
userShared1, err := app.CreateUser("shared1")
c.Assert(err, check.IsNil)
userShared2, err := app.CreateUser("shared2")
c.Assert(err, check.IsNil)
userShared3, err := app.CreateUser("shared3")
c.Assert(err, check.IsNil)
preAuthKeyShared1, err := app.CreatePreAuthKey(
userShared1.Name,
false,
false,
nil,
nil,
)
c.Assert(err, check.IsNil)
preAuthKeyShared2, err := app.CreatePreAuthKey(
userShared2.Name,
false,
false,
nil,
nil,
)
c.Assert(err, check.IsNil)
preAuthKeyShared3, err := app.CreatePreAuthKey(
userShared3.Name,
false,
false,
nil,
nil,
)
c.Assert(err, check.IsNil)
preAuthKey2Shared1, err := app.CreatePreAuthKey(
userShared1.Name,
false,
false,
nil,
nil,
)
c.Assert(err, check.IsNil)
_, err = app.GetMachine(userShared1.Name, "test_get_shared_nodes_1")
c.Assert(err, check.NotNil)
machineInShared1 := &Machine{
ID: 1,
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
Hostname: "test_get_shared_nodes_1",
UserID: userShared1.ID,
User: *userShared1,
RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
AuthKeyID: uint(preAuthKeyShared1.ID),
}
app.db.Save(machineInShared1)
_, err = app.GetMachine(userShared1.Name, machineInShared1.Hostname)
c.Assert(err, check.IsNil)
machineInShared2 := &Machine{
ID: 2,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Hostname: "test_get_shared_nodes_2",
UserID: userShared2.ID,
User: *userShared2,
RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
AuthKeyID: uint(preAuthKeyShared2.ID),
}
app.db.Save(machineInShared2)
_, err = app.GetMachine(userShared2.Name, machineInShared2.Hostname)
c.Assert(err, check.IsNil)
machineInShared3 := &Machine{
ID: 3,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Hostname: "test_get_shared_nodes_3",
UserID: userShared3.ID,
User: *userShared3,
RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
AuthKeyID: uint(preAuthKeyShared3.ID),
}
app.db.Save(machineInShared3)
_, err = app.GetMachine(userShared3.Name, machineInShared3.Hostname)
c.Assert(err, check.IsNil)
machine2InShared1 := &Machine{
ID: 4,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Hostname: "test_get_shared_nodes_4",
UserID: userShared1.ID,
User: *userShared1,
RegisterMethod: RegisterMethodAuthKey,
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")},
AuthKeyID: uint(preAuthKey2Shared1.ID),
}
app.db.Save(machine2InShared1)
peersOfMachine1InShared1, err := app.getPeers(machineInShared1)
c.Assert(err, check.IsNil)
userProfiles := app.getMapResponseUserProfiles(
*machineInShared1,
peersOfMachine1InShared1,
)
c.Assert(len(userProfiles), check.Equals, 3)
found := false
for _, userProfiles := range userProfiles {
if userProfiles.DisplayName == userShared1.Name {
found = true
break
}
}
c.Assert(found, check.Equals, true)
found = false
for _, userProfile := range userProfiles {
if userProfile.DisplayName == userShared2.Name {
found = true
break
}
}
c.Assert(found, check.Equals, true)
}
func TestNormalizeToFQDNRules(t *testing.T) {
type args struct {
name string
stripEmailDomain bool
}
tests := []struct {
name string
args args
want string
wantErr bool
}{
{
name: "normalize simple name",
args: args{
name: "normalize-simple.name",
stripEmailDomain: false,
},
want: "normalize-simple.name",
wantErr: false,
},
{
name: "normalize an email",
args: args{
name: "foo.bar@example.com",
stripEmailDomain: false,
},
want: "foo.bar.example.com",
wantErr: false,
},
{
name: "normalize an email domain should be removed",
args: args{
name: "foo.bar@example.com",
stripEmailDomain: true,
},
want: "foo.bar",
wantErr: false,
},
{
name: "strip enabled no email passed as argument",
args: args{
name: "not-email-and-strip-enabled",
stripEmailDomain: true,
},
want: "not-email-and-strip-enabled",
wantErr: false,
},
{
name: "normalize complex email",
args: args{
name: "foo.bar+complex-email@example.com",
stripEmailDomain: false,
},
want: "foo.bar-complex-email.example.com",
wantErr: false,
},
{
name: "user name with space",
args: args{
name: "name space",
stripEmailDomain: false,
},
want: "name-space",
wantErr: false,
},
{
name: "user with quote",
args: args{
name: "Jamie's iPhone 5",
stripEmailDomain: false,
},
want: "jamies-iphone-5",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NormalizeToFQDNRules(tt.args.name, tt.args.stripEmailDomain)
if (err != nil) != tt.wantErr {
t.Errorf(
"NormalizeToFQDNRules() error = %v, wantErr %v",
err,
tt.wantErr,
)
return
}
if got != tt.want {
t.Errorf("NormalizeToFQDNRules() = %v, want %v", got, tt.want)
}
})
}
}
func TestCheckForFQDNRules(t *testing.T) {
type args struct {
name string
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "valid: user",
args: args{name: "valid-user"},
wantErr: false,
},
{
name: "invalid: capitalized user",
args: args{name: "Invalid-CapItaLIzed-user"},
wantErr: true,
},
{
name: "invalid: email as user",
args: args{name: "foo.bar@example.com"},
wantErr: true,
},
{
name: "invalid: chars in user name",
args: args{name: "super-user+name"},
wantErr: true,
},
{
name: "invalid: too long name for user",
args: args{
name: "super-long-useruseruser-name-that-should-be-a-little-more-than-63-chars",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := CheckForFQDNRules(tt.args.name); (err != nil) != tt.wantErr {
t.Errorf("CheckForFQDNRules() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func (s *Suite) TestSetMachineUser(c *check.C) {
oldUser, err := app.CreateUser("old")
c.Assert(err, check.IsNil)
newUser, err := app.CreateUser("new")
c.Assert(err, check.IsNil)
pak, err := app.CreatePreAuthKey(oldUser.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testmachine",
UserID: oldUser.ID,
RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
app.db.Save(&machine)
c.Assert(machine.UserID, check.Equals, oldUser.ID)
err = app.SetMachineUser(&machine, newUser.Name)
c.Assert(err, check.IsNil)
c.Assert(machine.UserID, check.Equals, newUser.ID)
c.Assert(machine.User.Name, check.Equals, newUser.Name)
err = app.SetMachineUser(&machine, "non-existing-user")
c.Assert(err, check.Equals, ErrUserNotFound)
err = app.SetMachineUser(&machine, newUser.Name)
c.Assert(err, check.IsNil)
c.Assert(machine.UserID, check.Equals, newUser.ID)
c.Assert(machine.User.Name, check.Equals, newUser.Name)
}

361
hscontrol/utils.go Normal file
View File

@@ -0,0 +1,361 @@
// Codehere is mostly taken from github.com/tailscale/tailscale
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package headscale
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"io/fs"
"net"
"net/netip"
"os"
"path/filepath"
"reflect"
"regexp"
"strconv"
"strings"
"github.com/rs/zerolog/log"
"github.com/spf13/viper"
"go4.org/netipx"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
const (
ErrCannotDecryptResponse = Error("cannot decrypt response")
ErrCouldNotAllocateIP = Error("could not find any suitable IP")
// These constants are copied from the upstream tailscale.com/types/key
// library, because they are not exported.
// https://github.com/tailscale/tailscale/tree/main/types/key
// nodePublicHexPrefix is the prefix used to identify a
// hex-encoded node public key.
//
// This prefix is used in the control protocol, so cannot be
// changed.
nodePublicHexPrefix = "nodekey:"
// machinePublicHexPrefix is the prefix used to identify a
// hex-encoded machine public key.
//
// This prefix is used in the control protocol, so cannot be
// changed.
machinePublicHexPrefix = "mkey:"
// discoPublicHexPrefix is the prefix used to identify a
// hex-encoded disco public key.
//
// This prefix is used in the control protocol, so cannot be
// changed.
discoPublicHexPrefix = "discokey:"
// privateKey prefix.
privateHexPrefix = "privkey:"
PermissionFallback = 0o700
ZstdCompression = "zstd"
)
var NodePublicKeyRegex = regexp.MustCompile("nodekey:[a-fA-F0-9]+")
func MachinePublicKeyStripPrefix(machineKey key.MachinePublic) string {
return strings.TrimPrefix(machineKey.String(), machinePublicHexPrefix)
}
func NodePublicKeyStripPrefix(nodeKey key.NodePublic) string {
return strings.TrimPrefix(nodeKey.String(), nodePublicHexPrefix)
}
func DiscoPublicKeyStripPrefix(discoKey key.DiscoPublic) string {
return strings.TrimPrefix(discoKey.String(), discoPublicHexPrefix)
}
func MachinePublicKeyEnsurePrefix(machineKey string) string {
if !strings.HasPrefix(machineKey, machinePublicHexPrefix) {
return machinePublicHexPrefix + machineKey
}
return machineKey
}
func NodePublicKeyEnsurePrefix(nodeKey string) string {
if !strings.HasPrefix(nodeKey, nodePublicHexPrefix) {
return nodePublicHexPrefix + nodeKey
}
return nodeKey
}
func DiscoPublicKeyEnsurePrefix(discoKey string) string {
if !strings.HasPrefix(discoKey, discoPublicHexPrefix) {
return discoPublicHexPrefix + discoKey
}
return discoKey
}
func PrivateKeyEnsurePrefix(privateKey string) string {
if !strings.HasPrefix(privateKey, privateHexPrefix) {
return privateHexPrefix + privateKey
}
return privateKey
}
// Error is used to compare errors as per https://dave.cheney.net/2016/04/07/constant-errors
type Error string
func (e Error) Error() string { return string(e) }
func decode(
msg []byte,
output interface{},
pubKey *key.MachinePublic,
privKey *key.MachinePrivate,
) error {
log.Trace().
Str("pubkey", pubKey.ShortString()).
Int("length", len(msg)).
Msg("Trying to decrypt")
decrypted, ok := privKey.OpenFrom(*pubKey, msg)
if !ok {
return ErrCannotDecryptResponse
}
if err := json.Unmarshal(decrypted, output); err != nil {
return err
}
return nil
}
func (h *Headscale) getAvailableIPs() (MachineAddresses, error) {
var ips MachineAddresses
var err error
ipPrefixes := h.cfg.IPPrefixes
for _, ipPrefix := range ipPrefixes {
var ip *netip.Addr
ip, err = h.getAvailableIP(ipPrefix)
if err != nil {
return ips, err
}
ips = append(ips, *ip)
}
return ips, err
}
func GetIPPrefixEndpoints(na netip.Prefix) (netip.Addr, netip.Addr) {
var network, broadcast netip.Addr
ipRange := netipx.RangeOfPrefix(na)
network = ipRange.From()
broadcast = ipRange.To()
return network, broadcast
}
func (h *Headscale) getAvailableIP(ipPrefix netip.Prefix) (*netip.Addr, error) {
usedIps, err := h.getUsedIPs()
if err != nil {
return nil, err
}
ipPrefixNetworkAddress, ipPrefixBroadcastAddress := GetIPPrefixEndpoints(ipPrefix)
// Get the first IP in our prefix
ip := ipPrefixNetworkAddress.Next()
for {
if !ipPrefix.Contains(ip) {
return nil, ErrCouldNotAllocateIP
}
switch {
case ip.Compare(ipPrefixBroadcastAddress) == 0:
fallthrough
case usedIps.Contains(ip):
fallthrough
case ip == netip.Addr{} || ip.IsLoopback():
ip = ip.Next()
continue
default:
return &ip, nil
}
}
}
func (h *Headscale) getUsedIPs() (*netipx.IPSet, error) {
// FIXME: This really deserves a better data model,
// but this was quick to get running and it should be enough
// to begin experimenting with a dual stack tailnet.
var addressesSlices []string
h.db.Model(&Machine{}).Pluck("ip_addresses", &addressesSlices)
var ips netipx.IPSetBuilder
for _, slice := range addressesSlices {
var machineAddresses MachineAddresses
err := machineAddresses.Scan(slice)
if err != nil {
return &netipx.IPSet{}, fmt.Errorf(
"failed to read ip from database: %w",
err,
)
}
for _, ip := range machineAddresses {
ips.Add(ip)
}
}
ipSet, err := ips.IPSet()
if err != nil {
return &netipx.IPSet{}, fmt.Errorf(
"failed to build IP Set: %w",
err,
)
}
return ipSet, nil
}
func tailNodesToString(nodes []*tailcfg.Node) string {
temp := make([]string, len(nodes))
for index, node := range nodes {
temp[index] = node.Name
}
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
}
func tailMapResponseToString(resp tailcfg.MapResponse) string {
return fmt.Sprintf(
"{ Node: %s, Peers: %s }",
resp.Node.Name,
tailNodesToString(resp.Peers),
)
}
func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, "unix", addr)
}
func stringToIPPrefix(prefixes []string) ([]netip.Prefix, error) {
result := make([]netip.Prefix, len(prefixes))
for index, prefixStr := range prefixes {
prefix, err := netip.ParsePrefix(prefixStr)
if err != nil {
return []netip.Prefix{}, err
}
result[index] = prefix
}
return result, nil
}
func contains[T string | netip.Prefix](ts []T, t T) bool {
for _, v := range ts {
if reflect.DeepEqual(v, t) {
return true
}
}
return false
}
// GenerateRandomBytes returns securely generated random bytes.
// It will return an error if the system's secure random
// number generator fails to function correctly, in which
// case the caller should not continue.
func GenerateRandomBytes(n int) ([]byte, error) {
bytes := make([]byte, n)
// Note that err == nil only if we read len(b) bytes.
if _, err := rand.Read(bytes); err != nil {
return nil, err
}
return bytes, nil
}
// GenerateRandomStringURLSafe returns a URL-safe, base64 encoded
// securely generated random string.
// It will return an error if the system's secure random
// number generator fails to function correctly, in which
// case the caller should not continue.
func GenerateRandomStringURLSafe(n int) (string, error) {
b, err := GenerateRandomBytes(n)
return base64.RawURLEncoding.EncodeToString(b), err
}
// GenerateRandomStringDNSSafe returns a DNS-safe
// securely generated random string.
// It will return an error if the system's secure random
// number generator fails to function correctly, in which
// case the caller should not continue.
func GenerateRandomStringDNSSafe(size int) (string, error) {
var str string
var err error
for len(str) < size {
str, err = GenerateRandomStringURLSafe(size)
if err != nil {
return "", err
}
str = strings.ToLower(
strings.ReplaceAll(strings.ReplaceAll(str, "_", ""), "-", ""),
)
}
return str[:size], nil
}
func IsStringInSlice(slice []string, str string) bool {
for _, s := range slice {
if s == str {
return true
}
}
return false
}
func AbsolutePathFromConfigPath(path string) string {
// If a relative path is provided, prefix it with the directory where
// the config file was found.
if (path != "") && !strings.HasPrefix(path, string(os.PathSeparator)) {
dir, _ := filepath.Split(viper.ConfigFileUsed())
if dir != "" {
path = filepath.Join(dir, path)
}
}
return path
}
func GetFileMode(key string) fs.FileMode {
modeStr := viper.GetString(key)
mode, err := strconv.ParseUint(modeStr, Base8, BitSize64)
if err != nil {
return PermissionFallback
}
return fs.FileMode(mode)
}

201
hscontrol/utils_test.go Normal file
View File

@@ -0,0 +1,201 @@
package headscale
import (
"net/netip"
"go4.org/netipx"
"gopkg.in/check.v1"
)
func (s *Suite) TestGetAvailableIp(c *check.C) {
ips, err := app.getAvailableIPs()
c.Assert(err, check.IsNil)
expected := netip.MustParseAddr("10.27.0.1")
c.Assert(len(ips), check.Equals, 1)
c.Assert(ips[0].String(), check.Equals, expected.String())
}
func (s *Suite) TestGetUsedIps(c *check.C) {
ips, err := app.getAvailableIPs()
c.Assert(err, check.IsNil)
user, err := app.CreateUser("test-ip")
c.Assert(err, check.IsNil)
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = app.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil)
machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testmachine",
UserID: user.ID,
RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
IPAddresses: ips,
}
app.db.Save(&machine)
usedIps, err := app.getUsedIPs()
c.Assert(err, check.IsNil)
expected := netip.MustParseAddr("10.27.0.1")
expectedIPSetBuilder := netipx.IPSetBuilder{}
expectedIPSetBuilder.Add(expected)
expectedIPSet, _ := expectedIPSetBuilder.IPSet()
c.Assert(usedIps.Equal(expectedIPSet), check.Equals, true)
c.Assert(usedIps.Contains(expected), check.Equals, true)
machine1, err := app.GetMachineByID(0)
c.Assert(err, check.IsNil)
c.Assert(len(machine1.IPAddresses), check.Equals, 1)
c.Assert(machine1.IPAddresses[0], check.Equals, expected)
}
func (s *Suite) TestGetMultiIp(c *check.C) {
user, err := app.CreateUser("test-ip-multi")
c.Assert(err, check.IsNil)
for index := 1; index <= 350; index++ {
app.ipAllocationMutex.Lock()
ips, err := app.getAvailableIPs()
c.Assert(err, check.IsNil)
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = app.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil)
machine := Machine{
ID: uint64(index),
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testmachine",
UserID: user.ID,
RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
IPAddresses: ips,
}
app.db.Save(&machine)
app.ipAllocationMutex.Unlock()
}
usedIps, err := app.getUsedIPs()
c.Assert(err, check.IsNil)
expected0 := netip.MustParseAddr("10.27.0.1")
expected9 := netip.MustParseAddr("10.27.0.10")
expected300 := netip.MustParseAddr("10.27.0.45")
notExpectedIPSetBuilder := netipx.IPSetBuilder{}
notExpectedIPSetBuilder.Add(expected0)
notExpectedIPSetBuilder.Add(expected9)
notExpectedIPSetBuilder.Add(expected300)
notExpectedIPSet, err := notExpectedIPSetBuilder.IPSet()
c.Assert(err, check.IsNil)
// We actually expect it to be a lot larger
c.Assert(usedIps.Equal(notExpectedIPSet), check.Equals, false)
c.Assert(usedIps.Contains(expected0), check.Equals, true)
c.Assert(usedIps.Contains(expected9), check.Equals, true)
c.Assert(usedIps.Contains(expected300), check.Equals, true)
// Check that we can read back the IPs
machine1, err := app.GetMachineByID(1)
c.Assert(err, check.IsNil)
c.Assert(len(machine1.IPAddresses), check.Equals, 1)
c.Assert(
machine1.IPAddresses[0],
check.Equals,
netip.MustParseAddr("10.27.0.1"),
)
machine50, err := app.GetMachineByID(50)
c.Assert(err, check.IsNil)
c.Assert(len(machine50.IPAddresses), check.Equals, 1)
c.Assert(
machine50.IPAddresses[0],
check.Equals,
netip.MustParseAddr("10.27.0.50"),
)
expectedNextIP := netip.MustParseAddr("10.27.1.95")
nextIP, err := app.getAvailableIPs()
c.Assert(err, check.IsNil)
c.Assert(len(nextIP), check.Equals, 1)
c.Assert(nextIP[0].String(), check.Equals, expectedNextIP.String())
// If we call get Available again, we should receive
// the same IP, as it has not been reserved.
nextIP2, err := app.getAvailableIPs()
c.Assert(err, check.IsNil)
c.Assert(len(nextIP2), check.Equals, 1)
c.Assert(nextIP2[0].String(), check.Equals, expectedNextIP.String())
}
func (s *Suite) TestGetAvailableIpMachineWithoutIP(c *check.C) {
ips, err := app.getAvailableIPs()
c.Assert(err, check.IsNil)
expected := netip.MustParseAddr("10.27.0.1")
c.Assert(len(ips), check.Equals, 1)
c.Assert(ips[0].String(), check.Equals, expected.String())
user, err := app.CreateUser("test-ip")
c.Assert(err, check.IsNil)
pak, err := app.CreatePreAuthKey(user.Name, false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = app.GetMachine("test", "testmachine")
c.Assert(err, check.NotNil)
machine := Machine{
ID: 0,
MachineKey: "foo",
NodeKey: "bar",
DiscoKey: "faa",
Hostname: "testmachine",
UserID: user.ID,
RegisterMethod: RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
}
app.db.Save(&machine)
ips2, err := app.getAvailableIPs()
c.Assert(err, check.IsNil)
c.Assert(len(ips2), check.Equals, 1)
c.Assert(ips2[0].String(), check.Equals, expected.String())
}
func (s *Suite) TestGenerateRandomStringDNSSafe(c *check.C) {
for i := 0; i < 100000; i++ {
str, err := GenerateRandomStringDNSSafe(8)
if err != nil {
c.Error(err)
}
if len(str) != 8 {
c.Error("invalid length", len(str), str)
}
}
}