diff --git a/.goreleaser.yml b/.goreleaser.yml index 7b1ea605..f7355104 100644 --- a/.goreleaser.yml +++ b/.goreleaser.yml @@ -20,6 +20,7 @@ builds: - -mod=readonly ldflags: - -s -w -X github.com/juanfont/headscale/cmd/headscale/cli.Version=v{{.Version}} + - id: linux-armhf main: ./cmd/headscale/headscale.go mod_timestamp: '{{ .CommitTimestamp }}' @@ -49,9 +50,16 @@ builds: - linux goarch: - amd64 - goarm: - - 6 - - 7 + main: ./cmd/headscale/headscale.go + mod_timestamp: '{{ .CommitTimestamp }}' + ldflags: + - -s -w -X github.com/juanfont/headscale/cmd/headscale/cli.Version=v{{.Version}} + + - id: linux-arm64 + goos: + - linux + goarch: + - arm64 main: ./cmd/headscale/headscale.go mod_timestamp: '{{ .CommitTimestamp }}' ldflags: @@ -63,6 +71,7 @@ archives: - darwin-amd64 - linux-armhf - linux-amd64 + - linux-arm64 name_template: "{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}" format: binary diff --git a/README.md b/README.md index 9d2ec159..0cb41bcb 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,11 @@ -# Headscale +# headscale -[![Join the chat at https://gitter.im/headscale-dev/community](https://badges.gitter.im/headscale-dev/community.svg)](https://gitter.im/headscale-dev/community?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) ![ci](https://github.com/juanfont/headscale/actions/workflows/test.yml/badge.svg) +![ci](https://github.com/juanfont/headscale/actions/workflows/test.yml/badge.svg) An open source, self-hosted implementation of the Tailscale coordination server. +Join our [Discord](https://discord.gg/XcQxk2VHjx) server for a chat. + ## Overview Tailscale is [a modern VPN](https://tailscale.com/) built on top of [Wireguard](https://www.wireguard.com/). It [works like an overlay network](https://tailscale.com/blog/how-tailscale-works/) between the computers of your networks - using all kinds of [NAT traversal sorcery](https://tailscale.com/blog/how-nat-traversal-works/). @@ -29,18 +31,18 @@ Headscale implements this coordination server. - [x] DNS (passing DNS servers to nodes) - [x] Share nodes between ~~users~~ namespaces - [x] SSO (via OIDC) -- [ ] MagicDNS / Smart DNS +- [x] MagicDNS (see `docs/`) ## Client OS support -| OS | Supports headscale | -| --- | --- | -| Linux | Yes | -| OpenBSD | Yes | -| macOS | Yes (see `/apple` on your headscale for more information) | -| Windows | Yes | +| OS | Supports headscale | +| ------- | ----------------------------------------------------------------------------------------------------------------- | +| Linux | Yes | +| OpenBSD | Yes | +| macOS | Yes (see `/apple` on your headscale for more information) | +| Windows | Yes | | Android | [You need to compile the client yourself](https://github.com/juanfont/headscale/issues/58#issuecomment-885255270) | -| iOS | Not yet | +| iOS | Not yet | ## Roadmap 🤷 diff --git a/api.go b/api.go index d85221bc..cbe48072 100644 --- a/api.go +++ b/api.go @@ -253,7 +253,7 @@ func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Ma Str("func", "getMapResponse"). Str("machine", req.Hostinfo.Hostname). Msg("Creating Map response") - node, err := m.toNode(true) + node, err := m.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true) if err != nil { log.Error(). Str("func", "getMapResponse"). @@ -277,7 +277,7 @@ func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Ma DisplayName: m.Namespace.Name, } - nodePeers, err := peers.toNodes(true) + nodePeers, err := peers.toNodes(h.cfg.BaseDomain, h.cfg.DNSConfig, true) if err != nil { log.Error(). Str("func", "getMapResponse"). @@ -286,20 +286,25 @@ func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Ma return nil, err } + var dnsConfig *tailcfg.DNSConfig + if h.cfg.DNSConfig != nil && h.cfg.DNSConfig.Proxied { // if MagicDNS is enabled + // Only inject the Search Domain of the current namespace - shared nodes should use their full FQDN + dnsConfig = h.cfg.DNSConfig.Clone() + dnsConfig.Domains = append(dnsConfig.Domains, fmt.Sprintf("%s.%s", m.Namespace.Name, h.cfg.BaseDomain)) + } else { + dnsConfig = h.cfg.DNSConfig + } + resp := tailcfg.MapResponse{ - KeepAlive: false, - Node: node, - Peers: nodePeers, - // TODO(kradalby): As per tailscale docs, if DNSConfig is nil, - // it means its not updated, maybe we can have some logic - // to check and only pass updates when its updates. - // This is probably more relevant if we try to implement - // "MagicDNS" - DNSConfig: h.cfg.DNSConfig, - SearchPaths: []string{}, - Domain: "headscale.net", + KeepAlive: false, + Node: node, + Peers: nodePeers, + DNSConfig: dnsConfig, + Domain: h.cfg.BaseDomain, PacketFilter: *h.aclRules, DERPMap: h.cfg.DerpMap, + + // TODO(juanfont): We should send the profiles of all the peers (this own namespace + those from the shared peers) UserProfiles: []tailcfg.UserProfile{profile}, } log.Trace(). @@ -365,6 +370,11 @@ func (h *Headscale) handleAuthKey(c *gin.Context, db *gorm.DB, idKey wgkey.Key, resp := tailcfg.RegisterResponse{} pak, err := h.checkKeyValidity(req.Auth.AuthKey) if err != nil { + log.Error(). + Str("func", "handleAuthKey"). + Str("machine", m.Name). + Err(err). + Msg("Failed authentication via AuthKey") resp.MachineAuthorized = false respBody, err := encode(resp, &idKey, h.privateKey) if err != nil { @@ -414,6 +424,9 @@ func (h *Headscale) handleAuthKey(c *gin.Context, db *gorm.DB, idKey wgkey.Key, db.Save(&m) h.updateMachineExpiry(&m) // TODO: do we want to do different expiry times for AuthKeys? + + pak.Used = true + db.Save(&pak) resp.MachineAuthorized = true resp.User = *pak.Namespace.toUser() diff --git a/app.go b/app.go index 239998c2..89c43589 100644 --- a/app.go +++ b/app.go @@ -16,12 +16,13 @@ import ( "github.com/rs/zerolog/log" "github.com/gin-gonic/gin" - "github.com/zsais/go-gin-prometheus" + ginprometheus "github.com/zsais/go-gin-prometheus" "golang.org/x/crypto/acme" "golang.org/x/crypto/acme/autocert" "gorm.io/gorm" "inet.af/netaddr" "tailscale.com/tailcfg" + "tailscale.com/types/dnstype" "tailscale.com/types/wgkey" ) @@ -33,6 +34,7 @@ type Config struct { DerpMap *tailcfg.DERPMap EphemeralNodeInactivityTimeout time.Duration IPPrefix netaddr.IPPrefix + BaseDomain string DBtype string DBpath string @@ -125,6 +127,17 @@ func NewHeadscale(cfg Config) (*Headscale, error) { if err != nil { return nil, err } + } + + if h.cfg.DNSConfig != nil && h.cfg.DNSConfig.Proxied { // if MagicDNS + magicDNSDomains, err := generateMagicDNSRootDomains(h.cfg.IPPrefix, h.cfg.BaseDomain) + if err != nil { + return nil, err + } + h.cfg.DNSConfig.Routes = make(map[string][]dnstype.Resolver) + for _, d := range magicDNSDomains { + h.cfg.DNSConfig.Routes[d.WithoutTrailingDot()] = nil + } } return &h, nil diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index 5f30dc1b..12461924 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -129,6 +129,7 @@ var deleteNodeCmd = &cobra.Command{ return nil }, Run: func(cmd *cobra.Command, args []string) { + output, _ := cmd.Flags().GetString("output") h, err := getHeadscaleApp() if err != nil { log.Fatalf("Error initializing: %s", err) @@ -143,21 +144,32 @@ var deleteNodeCmd = &cobra.Command{ } confirm := false - prompt := &survey.Confirm{ - Message: fmt.Sprintf("Do you want to remove the node %s?", m.Name), - } - err = survey.AskOne(prompt, &confirm) - if err != nil { - return + force, _ := cmd.Flags().GetBool("force") + if !force { + prompt := &survey.Confirm{ + Message: fmt.Sprintf("Do you want to remove the node %s?", m.Name), + } + err = survey.AskOne(prompt, &confirm) + if err != nil { + return + } } - if confirm { + if confirm || force { err = h.DeleteMachine(m) + if strings.HasPrefix(output, "json") { + JsonOutput(map[string]string{"Result": "Node deleted"}, err, output) + return + } if err != nil { log.Fatalf("Error deleting node: %s", err) } fmt.Printf("Node deleted\n") } else { + if strings.HasPrefix(output, "json") { + JsonOutput(map[string]string{"Result": "Node not deleted"}, err, output) + return + } fmt.Printf("Node not deleted\n") } }, diff --git a/cmd/headscale/cli/preauthkeys.go b/cmd/headscale/cli/preauthkeys.go index 1340267e..d7cebec1 100644 --- a/cmd/headscale/cli/preauthkeys.go +++ b/cmd/headscale/cli/preauthkeys.go @@ -57,7 +57,7 @@ var listPreAuthKeys = &cobra.Command{ return } - d := pterm.TableData{{"ID", "Key", "Reusable", "Ephemeral", "Expiration", "Created"}} + d := pterm.TableData{{"ID", "Key", "Reusable", "Ephemeral", "Used", "Expiration", "Created"}} for _, k := range *keys { expiration := "-" if k.Expiration != nil { @@ -76,6 +76,7 @@ var listPreAuthKeys = &cobra.Command{ k.Key, reusable, strconv.FormatBool(k.Ephemeral), + fmt.Sprintf("%v", k.Used), expiration, k.CreatedAt.Format("2006-01-02 15:04:05"), }) @@ -130,7 +131,7 @@ var createPreAuthKeyCmd = &cobra.Command{ } var expirePreAuthKeyCmd = &cobra.Command{ - Use: "expire", + Use: "expire KEY", Short: "Expire a preauthkey", Args: func(cmd *cobra.Command, args []string) error { if len(args) < 1 { @@ -152,6 +153,10 @@ var expirePreAuthKeyCmd = &cobra.Command{ k, err := h.GetPreAuthKey(n, args[0]) if err != nil { + if strings.HasPrefix(o, "json") { + JsonOutput(k, err, o) + return + } log.Fatalf("Error getting the key: %s", err) } diff --git a/cmd/headscale/cli/root.go b/cmd/headscale/cli/root.go index 21857d85..794cd0d0 100644 --- a/cmd/headscale/cli/root.go +++ b/cmd/headscale/cli/root.go @@ -9,6 +9,7 @@ import ( func init() { rootCmd.PersistentFlags().StringP("output", "o", "", "Output format. Empty for human-readable, 'json' or 'json-line'") + rootCmd.PersistentFlags().Bool("force", false, "Disable prompts and forces the execution") } var rootCmd = &cobra.Command{ diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index 366e9597..ba8d34ad 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -76,7 +76,7 @@ func LoadConfig(path string) error { } -func GetDNSConfig() *tailcfg.DNSConfig { +func GetDNSConfig() (*tailcfg.DNSConfig, string) { if viper.IsSet("dns_config") { dnsConfig := &tailcfg.DNSConfig{} @@ -108,10 +108,27 @@ func GetDNSConfig() *tailcfg.DNSConfig { dnsConfig.Domains = viper.GetStringSlice("dns_config.domains") } - return dnsConfig + if viper.IsSet("dns_config.magic_dns") { + magicDNS := viper.GetBool("dns_config.magic_dns") + if len(dnsConfig.Nameservers) > 0 { + dnsConfig.Proxied = magicDNS + } else if magicDNS { + log.Warn(). + Msg("Warning: dns_config.magic_dns is set, but no nameservers are configured. Ignoring magic_dns.") + } + } + + var baseDomain string + if viper.IsSet("dns_config.base_domain") { + baseDomain = viper.GetString("dns_config.base_domain") + } else { + baseDomain = "headscale.net" // does not really matter when MagicDNS is not enabled + } + + return dnsConfig, baseDomain } - return nil + return nil, "" } func absPath(path string) string { @@ -144,7 +161,8 @@ func getHeadscaleApp() (*headscale.Headscale, error) { return nil, err } - // maxMachineRegistrationDuration is the maximum time a client can request for a client registration + + // maxMachineRegistrationDuration is the maximum time a client can request for a client registration maxMachineRegistrationDuration, _ := time.ParseDuration("10h") if viper.GetDuration("max_machine_registration_duration") >= time.Second { maxMachineRegistrationDuration = viper.GetDuration("max_machine_registration_duration") @@ -156,12 +174,15 @@ func getHeadscaleApp() (*headscale.Headscale, error) { defaultMachineRegistrationDuration = viper.GetDuration("default_machine_registration_duration") } + dnsConfig, baseDomain := GetDNSConfig() + cfg := headscale.Config{ ServerURL: viper.GetString("server_url"), Addr: viper.GetString("listen_addr"), PrivateKeyPath: absPath(viper.GetString("private_key_path")), DerpMap: derpMap, IPPrefix: netaddr.MustParseIPPrefix(viper.GetString("ip_prefix")), + BaseDomain: baseDomain, EphemeralNodeInactivityTimeout: viper.GetDuration("ephemeral_node_inactivity_timeout"), @@ -181,6 +202,8 @@ func getHeadscaleApp() (*headscale.Headscale, error) { TLSCertPath: absPath(viper.GetString("tls_cert_path")), TLSKeyPath: absPath(viper.GetString("tls_key_path")), + DNSConfig: dnsConfig, + ACMEEmail: viper.GetString("acme_email"), ACMEURL: viper.GetString("acme_url"), @@ -192,6 +215,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) { MaxMachineRegistrationDuration: maxMachineRegistrationDuration, // the maximum duration a client may request for expiry time DefaultMachineRegistrationDuration: defaultMachineRegistrationDuration, // if a client does not request a specific expiry time, use this duration + } h, err := headscale.NewHeadscale(cfg) @@ -261,3 +285,12 @@ func JsonOutput(result interface{}, errResult error, outputFormat string) { } fmt.Println(string(j)) } + +func HasJsonOutputFlag() bool { + for _, arg := range os.Args { + if arg == "json" || arg == "json-line" { + return true + } + } + return false +} diff --git a/cmd/headscale/headscale.go b/cmd/headscale/headscale.go index f815001e..6b1a8437 100644 --- a/cmd/headscale/headscale.go +++ b/cmd/headscale/headscale.go @@ -62,7 +62,8 @@ func main() { zerolog.SetGlobalLevel(zerolog.DebugLevel) } - if !viper.GetBool("disable_check_updates") { + jsonOutput := cli.HasJsonOutputFlag() + if !viper.GetBool("disable_check_updates") && !jsonOutput { if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") && cli.Version != "dev" { githubTag := &latest.GithubTag{ Owner: "juanfont", diff --git a/cmd/headscale/headscale_test.go b/cmd/headscale/headscale_test.go index 58bf5899..bddea94c 100644 --- a/cmd/headscale/headscale_test.go +++ b/cmd/headscale/headscale_test.go @@ -117,12 +117,12 @@ func (*Suite) TestDNSConfigLoading(c *check.C) { err = cli.LoadConfig(tmpDir) c.Assert(err, check.IsNil) - dnsConfig := cli.GetDNSConfig() - fmt.Println(dnsConfig) + dnsConfig, baseDomain := cli.GetDNSConfig() c.Assert(dnsConfig.Nameservers[0].String(), check.Equals, "1.1.1.1") - c.Assert(dnsConfig.Resolvers[0].Addr, check.Equals, "1.1.1.1") + c.Assert(dnsConfig.Proxied, check.Equals, true) + c.Assert(baseDomain, check.Equals, "example.com") } func writeConfig(c *check.C, tmpDir string, configYaml []byte) { diff --git a/config.json.postgres.example b/config.json.postgres.example index e9118204..9b6f737f 100644 --- a/config.json.postgres.example +++ b/config.json.postgres.example @@ -22,6 +22,9 @@ "dns_config": { "nameservers": [ "1.1.1.1" - ] + ], + "domains": [], + "magic_dns": true, + "base_domain": "example.com" } } diff --git a/config.json.sqlite.example b/config.json.sqlite.example index 5afa450f..74e15902 100644 --- a/config.json.sqlite.example +++ b/config.json.sqlite.example @@ -18,6 +18,9 @@ "dns_config": { "nameservers": [ "1.1.1.1" - ] + ], + "domains": [], + "magic_dns": true, + "base_domain": "example.com" } } diff --git a/dns.go b/dns.go new file mode 100644 index 00000000..353e10be --- /dev/null +++ b/dns.go @@ -0,0 +1,73 @@ +package headscale + +import ( + "fmt" + "strings" + + "inet.af/netaddr" + "tailscale.com/util/dnsname" +) + +// generateMagicDNSRootDomains generates a list of DNS entries to be included in `Routes` in `MapResponse`. +// This list of reverse DNS entries instructs the OS on what subnets and domains the Tailscale embedded DNS +// server (listening in 100.100.100.100 udp/53) should be used for. +// +// Tailscale.com includes in the list: +// - the `BaseDomain` of the user +// - the reverse DNS entry for IPv6 (0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa., see below more on IPv6) +// - the reverse DNS entries for the IPv4 subnets covered by the user's `IPPrefix`. +// In the public SaaS this is [64-127].100.in-addr.arpa. +// +// The main purpose of this function is then generating the list of IPv4 entries. For the 100.64.0.0/10, this +// is clear, and could be hardcoded. But we are allowing any range as `IPPrefix`, so we need to find out the +// subnets when we have 172.16.0.0/16 (i.e., [0-255].16.172.in-addr.arpa.), or any other subnet. +// +// How IN-ADDR.ARPA domains work is defined in RFC1035 (section 3.5). Tailscale.com seems to adhere to this, +// and do not make use of RFC2317 ("Classless IN-ADDR.ARPA delegation") - hence generating the entries for the next +// class block only. + +// From the netmask we can find out the wildcard bits (the bits that are not set in the netmask). +// This allows us to then calculate the subnets included in the subsequent class block and generate the entries. +func generateMagicDNSRootDomains(ipPrefix netaddr.IPPrefix, baseDomain string) ([]dnsname.FQDN, error) { + base, err := dnsname.ToFQDN(baseDomain) + if err != nil { + return nil, err + } + + // TODO(juanfont): we are not handing out IPv6 addresses yet + // and in fact this is Tailscale.com's range (note the fd7a:115c:a1e0: range in the fc00::/7 network) + ipv6base := dnsname.FQDN("0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa.") + fqdns := []dnsname.FQDN{base, ipv6base} + + // Conversion to the std lib net.IPnet, a bit easier to operate + netRange := ipPrefix.IPNet() + maskBits, _ := netRange.Mask.Size() + + // lastOctet is the last IP byte covered by the mask + lastOctet := maskBits / 8 + + // wildcardBits is the number of bits not under the mask in the lastOctet + wildcardBits := 8 - maskBits%8 + + // min is the value in the lastOctet byte of the IP + // max is basically 2^wildcardBits - i.e., the value when all the wildcardBits are set to 1 + min := uint(netRange.IP[lastOctet]) + max := uint((min + 1<= 0; i-- { + rdnsSlice = append(rdnsSlice, fmt.Sprintf("%d", netRange.IP[i])) + } + rdnsSlice = append(rdnsSlice, "in-addr.arpa.") + rdnsBase := strings.Join(rdnsSlice, ".") + + for i := min; i <= max; i++ { + fqdn, err := dnsname.ToFQDN(fmt.Sprintf("%d.%s", i, rdnsBase)) + if err != nil { + continue + } + fqdns = append(fqdns, fqdn) + } + return fqdns, nil +} diff --git a/dns_test.go b/dns_test.go new file mode 100644 index 00000000..87813203 --- /dev/null +++ b/dns_test.go @@ -0,0 +1,63 @@ +package headscale + +import ( + "gopkg.in/check.v1" + "inet.af/netaddr" +) + +func (s *Suite) TestMagicDNSRootDomains100(c *check.C) { + prefix := netaddr.MustParseIPPrefix("100.64.0.0/10") + domains, err := generateMagicDNSRootDomains(prefix, "headscale.net") + c.Assert(err, check.IsNil) + + found := false + for _, domain := range domains { + if domain == "64.100.in-addr.arpa." { + found = true + break + } + } + c.Assert(found, check.Equals, true) + + found = false + for _, domain := range domains { + if domain == "100.100.in-addr.arpa." { + found = true + break + } + } + c.Assert(found, check.Equals, true) + + found = false + for _, domain := range domains { + if domain == "127.100.in-addr.arpa." { + found = true + break + } + } + c.Assert(found, check.Equals, true) +} + +func (s *Suite) TestMagicDNSRootDomains172(c *check.C) { + prefix := netaddr.MustParseIPPrefix("172.16.0.0/16") + domains, err := generateMagicDNSRootDomains(prefix, "headscale.net") + c.Assert(err, check.IsNil) + + found := false + for _, domain := range domains { + if domain == "0.16.172.in-addr.arpa." { + found = true + break + } + } + c.Assert(found, check.Equals, true) + + found = false + for _, domain := range domains { + if domain == "255.16.172.in-addr.arpa." { + found = true + break + } + } + c.Assert(found, check.Equals, true) +} diff --git a/docs/DNS.md b/docs/DNS.md new file mode 100644 index 00000000..85bf9f44 --- /dev/null +++ b/docs/DNS.md @@ -0,0 +1,33 @@ +# DNS in Headscale + +Headscale supports Tailscale's DNS configuration and MagicDNS. Please have a look to their KB to better understand what this means: + +- https://tailscale.com/kb/1054/dns/ +- https://tailscale.com/kb/1081/magicdns/ +- https://tailscale.com/blog/2021-09-private-dns-with-magicdns/ + +Long story short, you can define the DNS servers you want to use in your tailnets, activate MagicDNS (so you don't have to remember the IP addresses of your nodes), define search domains, as well as predefined hosts. Headscale will inject that settings into your nodes. + + +## Configuration reference + +The setup is done via the `config.json` file, under the `dns_config` key. + +```json +{ + "server_url": "http://127.0.0.1:8001", + "listen_addr": "0.0.0.0:8001", + "private_key_path": "private.key", + //... + "dns_config": { + "nameservers": ["1.1.1.1", "8.8.8.8"], + "domains": [], + "magic_dns": true, + "base_domain": "example.com" + } +} +``` +- `nameservers`: The list of DNS servers to use. +- `domains`: Search domains to inject. +- `magic_dns`: Whether to use [MagicDNS](https://tailscale.com/kb/1081/magicdns/). Only works if there is at least a nameserver defined. +- `base_domain`: Defines the base domain to create the hostnames for MagicDNS. `base_domain` must be a FQDNs, without the trailing dot. The FQDN of the hosts will be `hostname.namespace.base_domain` (e.g., _myhost.mynamespace.example.com_). \ No newline at end of file diff --git a/integration_test.go b/integration_test.go index a7805646..3c51215d 100644 --- a/integration_test.go +++ b/integration_test.go @@ -504,43 +504,43 @@ func (s *IntegrationTestSuite) TestSharedNodes() { assert.Contains(s.T(), result, hostname) } - // TODO(kradalby): Figure out why these connections are not set up - // // TODO: See if we can have a more deterministic wait here. - // time.Sleep(100 * time.Second) + // TODO(juanfont): We have to find out why do we need to wait + time.Sleep(100 * time.Second) // Wait for the nodes to receive updates - // mainIps, err := getIPs(main.tailscales) - // assert.Nil(s.T(), err) + mainIps, err := getIPs(main.tailscales) + assert.Nil(s.T(), err) - // sharedIps, err := getIPs(shared.tailscales) - // assert.Nil(s.T(), err) + sharedIps, err := getIPs(shared.tailscales) + assert.Nil(s.T(), err) - // for hostname, tailscale := range main.tailscales { - // for peername, ip := range sharedIps { - // s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) { - // // We currently cant ping ourselves, so skip that. - // if peername != hostname { - // // We are only interested in "direct ping" which means what we - // // might need a couple of more attempts before reaching the node. - // command := []string{ - // "tailscale", "ping", - // "--timeout=1s", - // "--c=20", - // "--until-direct=true", - // ip.String(), - // } + for hostname, tailscale := range main.tailscales { + for peername, ip := range sharedIps { + s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) { + // We currently cant ping ourselves, so skip that. + if peername != hostname { + // We are only interested in "direct ping" which means what we + // might need a couple of more attempts before reaching the node. + command := []string{ + "tailscale", "ping", + "--timeout=15s", + "--c=20", + "--until-direct=true", + ip.String(), + } - // fmt.Printf("Pinging from %s (%s) to %s (%s)\n", hostname, mainIps[hostname], peername, ip) - // result, err := executeCommand( - // &tailscale, - // command, - // ) - // assert.Nil(t, err) - // fmt.Printf("Result for %s: %s\n", hostname, result) - // assert.Contains(t, result, "pong") - // } - // }) - // } - // } + fmt.Printf("Pinging from %s (%s) to %s (%s)\n", hostname, mainIps[hostname], peername, ip) + result, err := executeCommand( + &tailscale, + command, + []string{}, + ) + assert.Nil(t, err) + fmt.Printf("Result for %s: %s\n", hostname, result) + assert.Contains(t, result, "pong") + } + }) + } + } } func (s *IntegrationTestSuite) TestTailDrop() { @@ -592,7 +592,7 @@ func (s *IntegrationTestSuite) TestTailDrop() { _, err = executeCommand( &tailscale, command, - []string{"ALL_PROXY=socks5://localhost:1055/"}, + []string{"ALL_PROXY=socks5://localhost:1055"}, ) if err == nil { break @@ -645,6 +645,38 @@ func (s *IntegrationTestSuite) TestTailDrop() { } } +func (s *IntegrationTestSuite) TestMagicDNS() { + for namespace, scales := range s.namespaces { + ips, err := getIPs(scales.tailscales) + assert.Nil(s.T(), err) + for hostname, tailscale := range scales.tailscales { + for peername, ip := range ips { + s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) { + if peername != hostname { + command := []string{ + "tailscale", "ping", + "--timeout=10s", + "--c=20", + "--until-direct=true", + fmt.Sprintf("%s.%s.headscale.net", peername, namespace), + } + + fmt.Printf("Pinging using Hostname (magicdns) from %s (%s) to %s (%s)\n", hostname, ips[hostname], peername, ip) + result, err := executeCommand( + &tailscale, + command, + []string{}, + ) + assert.Nil(t, err) + fmt.Printf("Result for %s: %s\n", hostname, result) + assert.Contains(t, result, "pong") + } + }) + } + } + } +} + func getIPs(tailscales map[string]dockertest.Resource) (map[string]netaddr.IP, error) { ips := make(map[string]netaddr.IP) for hostname, tailscale := range tailscales { diff --git a/integration_test/etc/config.json b/integration_test/etc/config.json index 5454f2f7..dc23652d 100644 --- a/integration_test/etc/config.json +++ b/integration_test/etc/config.json @@ -7,5 +7,13 @@ "db_type": "sqlite3", "db_path": "/tmp/integration_test_db.sqlite3", "acl_policy_path": "", - "log_level": "trace" -} + "log_level": "trace", + "dns_config": { + "nameservers": [ + "1.1.1.1" + ], + "domains": [], + "magic_dns": true, + "base_domain": "headscale.net" + } +} \ No newline at end of file diff --git a/machine.go b/machine.go index 6eecbc6f..0057d3f1 100644 --- a/machine.go +++ b/machine.go @@ -94,7 +94,7 @@ func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) { Msg("Finding direct peers") machines := Machines{} - if err := h.db.Where("namespace_id = ? AND machine_key <> ? AND registered", + if err := h.db.Preload("Namespace").Where("namespace_id = ? AND machine_key <> ? AND registered", m.NamespaceID, m.MachineKey).Find(&machines).Error; err != nil { log.Error().Err(err).Msg("Error accessing db") return Machines{}, err @@ -109,13 +109,13 @@ func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) { return machines, nil } +// getShared fetches machines that are shared to the `Namespace` of the machine we are getting peers for func (h *Headscale) getShared(m *Machine) (Machines, error) { log.Trace(). Str("func", "getShared"). Str("machine", m.Name). Msg("Finding shared peers") - // We fetch here machines that are shared to the `Namespace` of the machine we are getting peers for sharedMachines := []SharedMachine{} if err := h.db.Preload("Namespace").Preload("Machine").Where("namespace_id = ?", m.NamespaceID).Find(&sharedMachines).Error; err != nil { @@ -136,6 +136,37 @@ func (h *Headscale) getShared(m *Machine) (Machines, error) { return peers, nil } +// getSharedTo fetches the machines of the namespaces this machine is shared in +func (h *Headscale) getSharedTo(m *Machine) (Machines, error) { + log.Trace(). + Str("func", "getSharedTo"). + Str("machine", m.Name). + Msg("Finding peers in namespaces this machine is shared with") + + sharedMachines := []SharedMachine{} + if err := h.db.Preload("Namespace").Preload("Machine").Where("machine_id = ?", + m.ID).Find(&sharedMachines).Error; err != nil { + return Machines{}, err + } + + peers := make(Machines, 0) + for _, sharedMachine := range sharedMachines { + namespaceMachines, err := h.ListMachinesInNamespace(sharedMachine.Namespace.Name) + if err != nil { + return Machines{}, err + } + peers = append(peers, *namespaceMachines...) + } + + sort.Slice(peers, func(i, j int) bool { return peers[i].ID < peers[j].ID }) + + log.Trace(). + Str("func", "getSharedTo"). + Str("machine", m.Name). + Msgf("Found peers we are shared with: %s", peers.String()) + return peers, nil +} + func (h *Headscale) getPeers(m *Machine) (Machines, error) { direct, err := h.getDirectPeers(m) if err != nil { @@ -149,13 +180,24 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) { shared, err := h.getShared(m) if err != nil { log.Error(). - Str("func", "getDirectPeers"). + Str("func", "getShared"). + Err(err). + Msg("Cannot fetch peers") + return Machines{}, err + } + + sharedTo, err := h.getSharedTo(m) + if err != nil { + log.Error(). + Str("func", "sharedTo"). Err(err). Msg("Cannot fetch peers") return Machines{}, err } peers := append(direct, shared...) + peers = append(peers, sharedTo...) + sort.Slice(peers, func(i, j int) bool { return peers[i].ID < peers[j].ID }) log.Trace(). @@ -210,6 +252,11 @@ func (h *Headscale) UpdateMachine(m *Machine) error { // DeleteMachine softs deletes a Machine from the database func (h *Headscale) DeleteMachine(m *Machine) error { + err := h.RemoveSharedMachineFromAllNamespaces(m) + if err != nil && err != errorMachineNotShared { + return err + } + m.Registered = false namespaceID := m.NamespaceID h.db.Save(&m) // we mark it as unregistered, just in case @@ -222,10 +269,16 @@ func (h *Headscale) DeleteMachine(m *Machine) error { // HardDeleteMachine hard deletes a Machine from the database func (h *Headscale) HardDeleteMachine(m *Machine) error { + err := h.RemoveSharedMachineFromAllNamespaces(m) + if err != nil && err != errorMachineNotShared { + return err + } + namespaceID := m.NamespaceID if err := h.db.Unscoped().Delete(&m).Error; err != nil { return err } + return h.RequestMapUpdates(namespaceID) } @@ -304,11 +357,11 @@ func (ms MachinesP) String() string { return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) } -func (ms Machines) toNodes(includeRoutes bool) ([]*tailcfg.Node, error) { +func (ms Machines) toNodes(baseDomain string, dnsConfig *tailcfg.DNSConfig, includeRoutes bool) ([]*tailcfg.Node, error) { nodes := make([]*tailcfg.Node, len(ms)) for index, machine := range ms { - node, err := machine.toNode(includeRoutes) + node, err := machine.toNode(baseDomain, dnsConfig, includeRoutes) if err != nil { return nil, err } @@ -321,7 +374,7 @@ func (ms Machines) toNodes(includeRoutes bool) ([]*tailcfg.Node, error) { // toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes // as per the expected behaviour in the official SaaS -func (m Machine) toNode(includeRoutes bool) (*tailcfg.Node, error) { +func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, includeRoutes bool) (*tailcfg.Node, error) { nKey, err := wgkey.ParseHex(m.NodeKey) if err != nil { return nil, err @@ -416,10 +469,17 @@ func (m Machine) toNode(includeRoutes bool) (*tailcfg.Node, error) { keyExpiry = time.Time{} } + var hostname string + if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS + hostname = fmt.Sprintf("%s.%s.%s", m.Name, m.Namespace.Name, baseDomain) + } else { + hostname = m.Name + } + n := tailcfg.Node{ ID: tailcfg.NodeID(m.ID), // this is the actual ID StableID: tailcfg.StableNodeID(strconv.FormatUint(m.ID, 10)), // in headscale, unlike tailcontrol server, IDs are permanent - Name: hostinfo.Hostname, + Name: hostname, User: tailcfg.UserID(m.NamespaceID), Key: tailcfg.NodeKey(nKey), KeyExpiry: keyExpiry, diff --git a/namespaces.go b/namespaces.go index 212df9a6..75ebe81b 100644 --- a/namespaces.go +++ b/namespaces.go @@ -91,7 +91,7 @@ func (h *Headscale) ListMachinesInNamespace(name string) (*[]Machine, error) { } machines := []Machine{} - if err := h.db.Preload("AuthKey").Preload("Namespace").Where(&Machine{NamespaceID: n.ID}).Find(&machines).Error; err != nil { + if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where(&Machine{NamespaceID: n.ID}).Find(&machines).Error; err != nil { return nil, err } return &machines, nil diff --git a/preauth_keys.go b/preauth_keys.go index cc849fc0..de10cdb7 100644 --- a/preauth_keys.go +++ b/preauth_keys.go @@ -11,7 +11,7 @@ import ( const errorAuthKeyNotFound = Error("AuthKey not found") const errorAuthKeyExpired = Error("AuthKey expired") -const errorAuthKeyNotReusableAlreadyUsed = Error("AuthKey not reusable already used") +const errSingleUseAuthKeyHasBeenUsed = Error("AuthKey has already been used") // PreAuthKey describes a pre-authorization key usable in a particular namespace type PreAuthKey struct { @@ -21,6 +21,7 @@ type PreAuthKey struct { Namespace Namespace Reusable bool Ephemeral bool `gorm:"default:false"` + Used bool `gorm:"default:false"` CreatedAt *time.Time Expiration *time.Time @@ -110,11 +111,10 @@ func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) { return nil, err } - if len(machines) != 0 { - return nil, errorAuthKeyNotReusableAlreadyUsed + if len(machines) != 0 || pak.Used { + return nil, errSingleUseAuthKeyHasBeenUsed } - // missing here validation on current usage return &pak, nil } diff --git a/preauth_keys_test.go b/preauth_keys_test.go index 37f2e4dd..f8973eaf 100644 --- a/preauth_keys_test.go +++ b/preauth_keys_test.go @@ -87,7 +87,7 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) { h.db.Save(&m) p, err := h.checkKeyValidity(pak.Key) - c.Assert(err, check.Equals, errorAuthKeyNotReusableAlreadyUsed) + c.Assert(err, check.Equals, errSingleUseAuthKeyHasBeenUsed) c.Assert(p, check.IsNil) } @@ -180,3 +180,16 @@ func (*Suite) TestExpirePreauthKey(c *check.C) { c.Assert(err, check.Equals, errorAuthKeyExpired) c.Assert(p, check.IsNil) } + +func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) { + n, err := h.CreateNamespace("test6") + c.Assert(err, check.IsNil) + + pak, err := h.CreatePreAuthKey(n.Name, false, false, nil) + c.Assert(err, check.IsNil) + pak.Used = true + h.db.Save(&pak) + + _, err = h.checkKeyValidity(pak.Key) + c.Assert(err, check.Equals, errSingleUseAuthKeyHasBeenUsed) +} diff --git a/sharing.go b/sharing.go index 93c299c7..83ce5260 100644 --- a/sharing.go +++ b/sharing.go @@ -4,6 +4,7 @@ import "gorm.io/gorm" const errorSameNamespace = Error("Destination namespace same as origin") const errorMachineAlreadyShared = Error("Node already shared to this namespace") +const errorMachineNotShared = Error("Machine not shared to this namespace") // SharedMachine is a join table to support sharing nodes between namespaces type SharedMachine struct { @@ -35,3 +36,13 @@ func (h *Headscale) AddSharedMachineToNamespace(m *Machine, ns *Namespace) error return nil } + +// RemoveSharedMachineFromAllNamespaces removes a machine as a shared node from all namespaces +func (h *Headscale) RemoveSharedMachineFromAllNamespaces(m *Machine) error { + sharedMachine := SharedMachine{} + if result := h.db.Where("machine_id = ?", m.ID).Unscoped().Delete(&sharedMachine); result.Error != nil { + return result.Error + } + + return nil +} diff --git a/sharing_test.go b/sharing_test.go index baa90d0c..d8cd8029 100644 --- a/sharing_test.go +++ b/sharing_test.go @@ -274,7 +274,7 @@ func (s *Suite) TestComplexSharingAcrossNamespaces(c *check.C) { c.Assert(err, check.NotNil) m1 := &Machine{ - ID: 0, + ID: 1, MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", @@ -291,7 +291,7 @@ func (s *Suite) TestComplexSharingAcrossNamespaces(c *check.C) { c.Assert(err, check.IsNil) m2 := &Machine{ - ID: 1, + ID: 2, MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", @@ -308,7 +308,7 @@ func (s *Suite) TestComplexSharingAcrossNamespaces(c *check.C) { c.Assert(err, check.IsNil) m3 := &Machine{ - ID: 2, + ID: 3, MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", @@ -325,7 +325,7 @@ func (s *Suite) TestComplexSharingAcrossNamespaces(c *check.C) { c.Assert(err, check.IsNil) m4 := &Machine{ - ID: 3, + ID: 4, MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", @@ -341,6 +341,129 @@ func (s *Suite) TestComplexSharingAcrossNamespaces(c *check.C) { _, err = h.GetMachine(n1.Name, m4.Name) c.Assert(err, check.IsNil) + p1s, err := h.getPeers(m1) + c.Assert(err, check.IsNil) + c.Assert(len(p1s), check.Equals, 1) // node1 can see node4 + c.Assert(p1s[0].Name, check.Equals, "test_get_shared_nodes_4") + + err = h.AddSharedMachineToNamespace(m2, n1) + c.Assert(err, check.IsNil) + + p1sAfter, err := h.getPeers(m1) + c.Assert(err, check.IsNil) + c.Assert(len(p1sAfter), check.Equals, 2) // node1 can see node2 (shared) and node4 (same namespace) + c.Assert(p1sAfter[0].Name, check.Equals, "test_get_shared_nodes_2") + c.Assert(p1sAfter[1].Name, check.Equals, "test_get_shared_nodes_4") + + node1shared, err := h.getShared(m1) + c.Assert(err, check.IsNil) + c.Assert(len(node1shared), check.Equals, 1) // node1 can see node2 as shared + c.Assert(node1shared[0].Name, check.Equals, "test_get_shared_nodes_2") + + pAlone, err := h.getPeers(m3) + c.Assert(err, check.IsNil) + c.Assert(len(pAlone), check.Equals, 0) // node3 is alone + + pSharedTo, err := h.getPeers(m2) + c.Assert(err, check.IsNil) + c.Assert(len(pSharedTo), check.Equals, 2) // node2 should see node1 (sharedTo) and node4 (sharedTo), as is shared in namespace1 + c.Assert(pSharedTo[0].Name, check.Equals, "test_get_shared_nodes_1") + c.Assert(pSharedTo[1].Name, check.Equals, "test_get_shared_nodes_4") +} + +func (s *Suite) TestDeleteSharedMachine(c *check.C) { + n1, err := h.CreateNamespace("shared1") + c.Assert(err, check.IsNil) + + n2, err := h.CreateNamespace("shared2") + c.Assert(err, check.IsNil) + + n3, err := h.CreateNamespace("shared3") + c.Assert(err, check.IsNil) + + pak1n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil) + c.Assert(err, check.IsNil) + + pak2n2, err := h.CreatePreAuthKey(n2.Name, false, false, nil) + c.Assert(err, check.IsNil) + + pak3n3, err := h.CreatePreAuthKey(n3.Name, false, false, nil) + c.Assert(err, check.IsNil) + + pak4n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil) + c.Assert(err, check.IsNil) + + _, err = h.GetMachine(n1.Name, "test_get_shared_nodes_1") + c.Assert(err, check.NotNil) + + m1 := &Machine{ + ID: 1, + MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", + NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", + DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", + Name: "test_get_shared_nodes_1", + NamespaceID: n1.ID, + Registered: true, + RegisterMethod: "authKey", + IPAddress: "100.64.0.1", + AuthKeyID: uint(pak1n1.ID), + } + h.db.Save(m1) + + _, err = h.GetMachine(n1.Name, m1.Name) + c.Assert(err, check.IsNil) + + m2 := &Machine{ + ID: 2, + MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", + NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", + DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", + Name: "test_get_shared_nodes_2", + NamespaceID: n2.ID, + Registered: true, + RegisterMethod: "authKey", + IPAddress: "100.64.0.2", + AuthKeyID: uint(pak2n2.ID), + } + h.db.Save(m2) + + _, err = h.GetMachine(n2.Name, m2.Name) + c.Assert(err, check.IsNil) + + m3 := &Machine{ + ID: 3, + MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", + NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", + DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", + Name: "test_get_shared_nodes_3", + NamespaceID: n3.ID, + Registered: true, + RegisterMethod: "authKey", + IPAddress: "100.64.0.3", + AuthKeyID: uint(pak3n3.ID), + } + h.db.Save(m3) + + _, err = h.GetMachine(n3.Name, m3.Name) + c.Assert(err, check.IsNil) + + m4 := &Machine{ + ID: 4, + MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", + NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", + DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", + Name: "test_get_shared_nodes_4", + NamespaceID: n1.ID, + Registered: true, + RegisterMethod: "authKey", + IPAddress: "100.64.0.4", + AuthKeyID: uint(pak4n1.ID), + } + h.db.Save(m4) + + _, err = h.GetMachine(n1.Name, m4.Name) + c.Assert(err, check.IsNil) + p1s, err := h.getPeers(m1) c.Assert(err, check.IsNil) c.Assert(len(p1s), check.Equals, 1) // nodes 1 and 4 @@ -363,4 +486,15 @@ func (s *Suite) TestComplexSharingAcrossNamespaces(c *check.C) { pAlone, err := h.getPeers(m3) c.Assert(err, check.IsNil) c.Assert(len(pAlone), check.Equals, 0) // node 3 is alone + + sharedMachines, err := h.ListSharedMachinesInNamespace(n1.Name) + c.Assert(err, check.IsNil) + c.Assert(len(*sharedMachines), check.Equals, 1) + + err = h.DeleteMachine(m2) + c.Assert(err, check.IsNil) + + sharedMachines, err = h.ListSharedMachinesInNamespace(n1.Name) + c.Assert(err, check.IsNil) + c.Assert(len(*sharedMachines), check.Equals, 0) }