Merge branch 'main' into main

This commit is contained in:
unreality 2021-10-16 22:31:37 +08:00 committed by GitHub
commit afbfc1d370
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 621 additions and 99 deletions

View File

@ -20,6 +20,7 @@ builds:
- -mod=readonly - -mod=readonly
ldflags: ldflags:
- -s -w -X github.com/juanfont/headscale/cmd/headscale/cli.Version=v{{.Version}} - -s -w -X github.com/juanfont/headscale/cmd/headscale/cli.Version=v{{.Version}}
- id: linux-armhf - id: linux-armhf
main: ./cmd/headscale/headscale.go main: ./cmd/headscale/headscale.go
mod_timestamp: '{{ .CommitTimestamp }}' mod_timestamp: '{{ .CommitTimestamp }}'
@ -49,9 +50,16 @@ builds:
- linux - linux
goarch: goarch:
- amd64 - amd64
goarm: main: ./cmd/headscale/headscale.go
- 6 mod_timestamp: '{{ .CommitTimestamp }}'
- 7 ldflags:
- -s -w -X github.com/juanfont/headscale/cmd/headscale/cli.Version=v{{.Version}}
- id: linux-arm64
goos:
- linux
goarch:
- arm64
main: ./cmd/headscale/headscale.go main: ./cmd/headscale/headscale.go
mod_timestamp: '{{ .CommitTimestamp }}' mod_timestamp: '{{ .CommitTimestamp }}'
ldflags: ldflags:
@ -63,6 +71,7 @@ archives:
- darwin-amd64 - darwin-amd64
- linux-armhf - linux-armhf
- linux-amd64 - linux-amd64
- linux-arm64
name_template: "{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}" name_template: "{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
format: binary format: binary

View File

@ -1,9 +1,11 @@
# Headscale # headscale
[![Join the chat at https://gitter.im/headscale-dev/community](https://badges.gitter.im/headscale-dev/community.svg)](https://gitter.im/headscale-dev/community?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) ![ci](https://github.com/juanfont/headscale/actions/workflows/test.yml/badge.svg) ![ci](https://github.com/juanfont/headscale/actions/workflows/test.yml/badge.svg)
An open source, self-hosted implementation of the Tailscale coordination server. An open source, self-hosted implementation of the Tailscale coordination server.
Join our [Discord](https://discord.gg/XcQxk2VHjx) server for a chat.
## Overview ## Overview
Tailscale is [a modern VPN](https://tailscale.com/) built on top of [Wireguard](https://www.wireguard.com/). It [works like an overlay network](https://tailscale.com/blog/how-tailscale-works/) between the computers of your networks - using all kinds of [NAT traversal sorcery](https://tailscale.com/blog/how-nat-traversal-works/). Tailscale is [a modern VPN](https://tailscale.com/) built on top of [Wireguard](https://www.wireguard.com/). It [works like an overlay network](https://tailscale.com/blog/how-tailscale-works/) between the computers of your networks - using all kinds of [NAT traversal sorcery](https://tailscale.com/blog/how-nat-traversal-works/).
@ -29,12 +31,12 @@ Headscale implements this coordination server.
- [x] DNS (passing DNS servers to nodes) - [x] DNS (passing DNS servers to nodes)
- [x] Share nodes between ~~users~~ namespaces - [x] Share nodes between ~~users~~ namespaces
- [x] SSO (via OIDC) - [x] SSO (via OIDC)
- [ ] MagicDNS / Smart DNS - [x] MagicDNS (see `docs/`)
## Client OS support ## Client OS support
| OS | Supports headscale | | OS | Supports headscale |
| --- | --- | | ------- | ----------------------------------------------------------------------------------------------------------------- |
| Linux | Yes | | Linux | Yes |
| OpenBSD | Yes | | OpenBSD | Yes |
| macOS | Yes (see `/apple` on your headscale for more information) | | macOS | Yes (see `/apple` on your headscale for more information) |

33
api.go
View File

@ -253,7 +253,7 @@ func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Ma
Str("func", "getMapResponse"). Str("func", "getMapResponse").
Str("machine", req.Hostinfo.Hostname). Str("machine", req.Hostinfo.Hostname).
Msg("Creating Map response") Msg("Creating Map response")
node, err := m.toNode(true) node, err := m.toNode(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("func", "getMapResponse"). Str("func", "getMapResponse").
@ -277,7 +277,7 @@ func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Ma
DisplayName: m.Namespace.Name, DisplayName: m.Namespace.Name,
} }
nodePeers, err := peers.toNodes(true) nodePeers, err := peers.toNodes(h.cfg.BaseDomain, h.cfg.DNSConfig, true)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("func", "getMapResponse"). Str("func", "getMapResponse").
@ -286,20 +286,25 @@ func (h *Headscale) getMapResponse(mKey wgkey.Key, req tailcfg.MapRequest, m *Ma
return nil, err return nil, err
} }
var dnsConfig *tailcfg.DNSConfig
if h.cfg.DNSConfig != nil && h.cfg.DNSConfig.Proxied { // if MagicDNS is enabled
// Only inject the Search Domain of the current namespace - shared nodes should use their full FQDN
dnsConfig = h.cfg.DNSConfig.Clone()
dnsConfig.Domains = append(dnsConfig.Domains, fmt.Sprintf("%s.%s", m.Namespace.Name, h.cfg.BaseDomain))
} else {
dnsConfig = h.cfg.DNSConfig
}
resp := tailcfg.MapResponse{ resp := tailcfg.MapResponse{
KeepAlive: false, KeepAlive: false,
Node: node, Node: node,
Peers: nodePeers, Peers: nodePeers,
// TODO(kradalby): As per tailscale docs, if DNSConfig is nil, DNSConfig: dnsConfig,
// it means its not updated, maybe we can have some logic Domain: h.cfg.BaseDomain,
// to check and only pass updates when its updates.
// This is probably more relevant if we try to implement
// "MagicDNS"
DNSConfig: h.cfg.DNSConfig,
SearchPaths: []string{},
Domain: "headscale.net",
PacketFilter: *h.aclRules, PacketFilter: *h.aclRules,
DERPMap: h.cfg.DerpMap, DERPMap: h.cfg.DerpMap,
// TODO(juanfont): We should send the profiles of all the peers (this own namespace + those from the shared peers)
UserProfiles: []tailcfg.UserProfile{profile}, UserProfiles: []tailcfg.UserProfile{profile},
} }
log.Trace(). log.Trace().
@ -365,6 +370,11 @@ func (h *Headscale) handleAuthKey(c *gin.Context, db *gorm.DB, idKey wgkey.Key,
resp := tailcfg.RegisterResponse{} resp := tailcfg.RegisterResponse{}
pak, err := h.checkKeyValidity(req.Auth.AuthKey) pak, err := h.checkKeyValidity(req.Auth.AuthKey)
if err != nil { if err != nil {
log.Error().
Str("func", "handleAuthKey").
Str("machine", m.Name).
Err(err).
Msg("Failed authentication via AuthKey")
resp.MachineAuthorized = false resp.MachineAuthorized = false
respBody, err := encode(resp, &idKey, h.privateKey) respBody, err := encode(resp, &idKey, h.privateKey)
if err != nil { if err != nil {
@ -415,6 +425,9 @@ func (h *Headscale) handleAuthKey(c *gin.Context, db *gorm.DB, idKey wgkey.Key,
h.updateMachineExpiry(&m) // TODO: do we want to do different expiry times for AuthKeys? h.updateMachineExpiry(&m) // TODO: do we want to do different expiry times for AuthKeys?
pak.Used = true
db.Save(&pak)
resp.MachineAuthorized = true resp.MachineAuthorized = true
resp.User = *pak.Namespace.toUser() resp.User = *pak.Namespace.toUser()
respBody, err := encode(resp, &idKey, h.privateKey) respBody, err := encode(resp, &idKey, h.privateKey)

15
app.go
View File

@ -16,12 +16,13 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/zsais/go-gin-prometheus" ginprometheus "github.com/zsais/go-gin-prometheus"
"golang.org/x/crypto/acme" "golang.org/x/crypto/acme"
"golang.org/x/crypto/acme/autocert" "golang.org/x/crypto/acme/autocert"
"gorm.io/gorm" "gorm.io/gorm"
"inet.af/netaddr" "inet.af/netaddr"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
"tailscale.com/types/wgkey" "tailscale.com/types/wgkey"
) )
@ -33,6 +34,7 @@ type Config struct {
DerpMap *tailcfg.DERPMap DerpMap *tailcfg.DERPMap
EphemeralNodeInactivityTimeout time.Duration EphemeralNodeInactivityTimeout time.Duration
IPPrefix netaddr.IPPrefix IPPrefix netaddr.IPPrefix
BaseDomain string
DBtype string DBtype string
DBpath string DBpath string
@ -127,6 +129,17 @@ func NewHeadscale(cfg Config) (*Headscale, error) {
} }
} }
if h.cfg.DNSConfig != nil && h.cfg.DNSConfig.Proxied { // if MagicDNS
magicDNSDomains, err := generateMagicDNSRootDomains(h.cfg.IPPrefix, h.cfg.BaseDomain)
if err != nil {
return nil, err
}
h.cfg.DNSConfig.Routes = make(map[string][]dnstype.Resolver)
for _, d := range magicDNSDomains {
h.cfg.DNSConfig.Routes[d.WithoutTrailingDot()] = nil
}
}
return &h, nil return &h, nil
} }

View File

@ -129,6 +129,7 @@ var deleteNodeCmd = &cobra.Command{
return nil return nil
}, },
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
output, _ := cmd.Flags().GetString("output")
h, err := getHeadscaleApp() h, err := getHeadscaleApp()
if err != nil { if err != nil {
log.Fatalf("Error initializing: %s", err) log.Fatalf("Error initializing: %s", err)
@ -143,6 +144,8 @@ var deleteNodeCmd = &cobra.Command{
} }
confirm := false confirm := false
force, _ := cmd.Flags().GetBool("force")
if !force {
prompt := &survey.Confirm{ prompt := &survey.Confirm{
Message: fmt.Sprintf("Do you want to remove the node %s?", m.Name), Message: fmt.Sprintf("Do you want to remove the node %s?", m.Name),
} }
@ -150,14 +153,23 @@ var deleteNodeCmd = &cobra.Command{
if err != nil { if err != nil {
return return
} }
}
if confirm { if confirm || force {
err = h.DeleteMachine(m) err = h.DeleteMachine(m)
if strings.HasPrefix(output, "json") {
JsonOutput(map[string]string{"Result": "Node deleted"}, err, output)
return
}
if err != nil { if err != nil {
log.Fatalf("Error deleting node: %s", err) log.Fatalf("Error deleting node: %s", err)
} }
fmt.Printf("Node deleted\n") fmt.Printf("Node deleted\n")
} else { } else {
if strings.HasPrefix(output, "json") {
JsonOutput(map[string]string{"Result": "Node not deleted"}, err, output)
return
}
fmt.Printf("Node not deleted\n") fmt.Printf("Node not deleted\n")
} }
}, },

View File

@ -57,7 +57,7 @@ var listPreAuthKeys = &cobra.Command{
return return
} }
d := pterm.TableData{{"ID", "Key", "Reusable", "Ephemeral", "Expiration", "Created"}} d := pterm.TableData{{"ID", "Key", "Reusable", "Ephemeral", "Used", "Expiration", "Created"}}
for _, k := range *keys { for _, k := range *keys {
expiration := "-" expiration := "-"
if k.Expiration != nil { if k.Expiration != nil {
@ -76,6 +76,7 @@ var listPreAuthKeys = &cobra.Command{
k.Key, k.Key,
reusable, reusable,
strconv.FormatBool(k.Ephemeral), strconv.FormatBool(k.Ephemeral),
fmt.Sprintf("%v", k.Used),
expiration, expiration,
k.CreatedAt.Format("2006-01-02 15:04:05"), k.CreatedAt.Format("2006-01-02 15:04:05"),
}) })
@ -130,7 +131,7 @@ var createPreAuthKeyCmd = &cobra.Command{
} }
var expirePreAuthKeyCmd = &cobra.Command{ var expirePreAuthKeyCmd = &cobra.Command{
Use: "expire", Use: "expire KEY",
Short: "Expire a preauthkey", Short: "Expire a preauthkey",
Args: func(cmd *cobra.Command, args []string) error { Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 1 { if len(args) < 1 {
@ -152,6 +153,10 @@ var expirePreAuthKeyCmd = &cobra.Command{
k, err := h.GetPreAuthKey(n, args[0]) k, err := h.GetPreAuthKey(n, args[0])
if err != nil { if err != nil {
if strings.HasPrefix(o, "json") {
JsonOutput(k, err, o)
return
}
log.Fatalf("Error getting the key: %s", err) log.Fatalf("Error getting the key: %s", err)
} }

View File

@ -9,6 +9,7 @@ import (
func init() { func init() {
rootCmd.PersistentFlags().StringP("output", "o", "", "Output format. Empty for human-readable, 'json' or 'json-line'") rootCmd.PersistentFlags().StringP("output", "o", "", "Output format. Empty for human-readable, 'json' or 'json-line'")
rootCmd.PersistentFlags().Bool("force", false, "Disable prompts and forces the execution")
} }
var rootCmd = &cobra.Command{ var rootCmd = &cobra.Command{

View File

@ -76,7 +76,7 @@ func LoadConfig(path string) error {
} }
func GetDNSConfig() *tailcfg.DNSConfig { func GetDNSConfig() (*tailcfg.DNSConfig, string) {
if viper.IsSet("dns_config") { if viper.IsSet("dns_config") {
dnsConfig := &tailcfg.DNSConfig{} dnsConfig := &tailcfg.DNSConfig{}
@ -108,10 +108,27 @@ func GetDNSConfig() *tailcfg.DNSConfig {
dnsConfig.Domains = viper.GetStringSlice("dns_config.domains") dnsConfig.Domains = viper.GetStringSlice("dns_config.domains")
} }
return dnsConfig if viper.IsSet("dns_config.magic_dns") {
magicDNS := viper.GetBool("dns_config.magic_dns")
if len(dnsConfig.Nameservers) > 0 {
dnsConfig.Proxied = magicDNS
} else if magicDNS {
log.Warn().
Msg("Warning: dns_config.magic_dns is set, but no nameservers are configured. Ignoring magic_dns.")
}
} }
return nil var baseDomain string
if viper.IsSet("dns_config.base_domain") {
baseDomain = viper.GetString("dns_config.base_domain")
} else {
baseDomain = "headscale.net" // does not really matter when MagicDNS is not enabled
}
return dnsConfig, baseDomain
}
return nil, ""
} }
func absPath(path string) string { func absPath(path string) string {
@ -144,6 +161,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
return nil, err return nil, err
} }
// maxMachineRegistrationDuration is the maximum time a client can request for a client registration // maxMachineRegistrationDuration is the maximum time a client can request for a client registration
maxMachineRegistrationDuration, _ := time.ParseDuration("10h") maxMachineRegistrationDuration, _ := time.ParseDuration("10h")
if viper.GetDuration("max_machine_registration_duration") >= time.Second { if viper.GetDuration("max_machine_registration_duration") >= time.Second {
@ -156,12 +174,15 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
defaultMachineRegistrationDuration = viper.GetDuration("default_machine_registration_duration") defaultMachineRegistrationDuration = viper.GetDuration("default_machine_registration_duration")
} }
dnsConfig, baseDomain := GetDNSConfig()
cfg := headscale.Config{ cfg := headscale.Config{
ServerURL: viper.GetString("server_url"), ServerURL: viper.GetString("server_url"),
Addr: viper.GetString("listen_addr"), Addr: viper.GetString("listen_addr"),
PrivateKeyPath: absPath(viper.GetString("private_key_path")), PrivateKeyPath: absPath(viper.GetString("private_key_path")),
DerpMap: derpMap, DerpMap: derpMap,
IPPrefix: netaddr.MustParseIPPrefix(viper.GetString("ip_prefix")), IPPrefix: netaddr.MustParseIPPrefix(viper.GetString("ip_prefix")),
BaseDomain: baseDomain,
EphemeralNodeInactivityTimeout: viper.GetDuration("ephemeral_node_inactivity_timeout"), EphemeralNodeInactivityTimeout: viper.GetDuration("ephemeral_node_inactivity_timeout"),
@ -181,6 +202,8 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
TLSCertPath: absPath(viper.GetString("tls_cert_path")), TLSCertPath: absPath(viper.GetString("tls_cert_path")),
TLSKeyPath: absPath(viper.GetString("tls_key_path")), TLSKeyPath: absPath(viper.GetString("tls_key_path")),
DNSConfig: dnsConfig,
ACMEEmail: viper.GetString("acme_email"), ACMEEmail: viper.GetString("acme_email"),
ACMEURL: viper.GetString("acme_url"), ACMEURL: viper.GetString("acme_url"),
@ -192,6 +215,7 @@ func getHeadscaleApp() (*headscale.Headscale, error) {
MaxMachineRegistrationDuration: maxMachineRegistrationDuration, // the maximum duration a client may request for expiry time MaxMachineRegistrationDuration: maxMachineRegistrationDuration, // the maximum duration a client may request for expiry time
DefaultMachineRegistrationDuration: defaultMachineRegistrationDuration, // if a client does not request a specific expiry time, use this duration DefaultMachineRegistrationDuration: defaultMachineRegistrationDuration, // if a client does not request a specific expiry time, use this duration
} }
h, err := headscale.NewHeadscale(cfg) h, err := headscale.NewHeadscale(cfg)
@ -261,3 +285,12 @@ func JsonOutput(result interface{}, errResult error, outputFormat string) {
} }
fmt.Println(string(j)) fmt.Println(string(j))
} }
func HasJsonOutputFlag() bool {
for _, arg := range os.Args {
if arg == "json" || arg == "json-line" {
return true
}
}
return false
}

View File

@ -62,7 +62,8 @@ func main() {
zerolog.SetGlobalLevel(zerolog.DebugLevel) zerolog.SetGlobalLevel(zerolog.DebugLevel)
} }
if !viper.GetBool("disable_check_updates") { jsonOutput := cli.HasJsonOutputFlag()
if !viper.GetBool("disable_check_updates") && !jsonOutput {
if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") && cli.Version != "dev" { if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") && cli.Version != "dev" {
githubTag := &latest.GithubTag{ githubTag := &latest.GithubTag{
Owner: "juanfont", Owner: "juanfont",

View File

@ -117,12 +117,12 @@ func (*Suite) TestDNSConfigLoading(c *check.C) {
err = cli.LoadConfig(tmpDir) err = cli.LoadConfig(tmpDir)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
dnsConfig := cli.GetDNSConfig() dnsConfig, baseDomain := cli.GetDNSConfig()
fmt.Println(dnsConfig)
c.Assert(dnsConfig.Nameservers[0].String(), check.Equals, "1.1.1.1") c.Assert(dnsConfig.Nameservers[0].String(), check.Equals, "1.1.1.1")
c.Assert(dnsConfig.Resolvers[0].Addr, check.Equals, "1.1.1.1") c.Assert(dnsConfig.Resolvers[0].Addr, check.Equals, "1.1.1.1")
c.Assert(dnsConfig.Proxied, check.Equals, true)
c.Assert(baseDomain, check.Equals, "example.com")
} }
func writeConfig(c *check.C, tmpDir string, configYaml []byte) { func writeConfig(c *check.C, tmpDir string, configYaml []byte) {

View File

@ -22,6 +22,9 @@
"dns_config": { "dns_config": {
"nameservers": [ "nameservers": [
"1.1.1.1" "1.1.1.1"
] ],
"domains": [],
"magic_dns": true,
"base_domain": "example.com"
} }
} }

View File

@ -18,6 +18,9 @@
"dns_config": { "dns_config": {
"nameservers": [ "nameservers": [
"1.1.1.1" "1.1.1.1"
] ],
"domains": [],
"magic_dns": true,
"base_domain": "example.com"
} }
} }

73
dns.go Normal file
View File

@ -0,0 +1,73 @@
package headscale
import (
"fmt"
"strings"
"inet.af/netaddr"
"tailscale.com/util/dnsname"
)
// generateMagicDNSRootDomains generates a list of DNS entries to be included in `Routes` in `MapResponse`.
// This list of reverse DNS entries instructs the OS on what subnets and domains the Tailscale embedded DNS
// server (listening in 100.100.100.100 udp/53) should be used for.
//
// Tailscale.com includes in the list:
// - the `BaseDomain` of the user
// - the reverse DNS entry for IPv6 (0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa., see below more on IPv6)
// - the reverse DNS entries for the IPv4 subnets covered by the user's `IPPrefix`.
// In the public SaaS this is [64-127].100.in-addr.arpa.
//
// The main purpose of this function is then generating the list of IPv4 entries. For the 100.64.0.0/10, this
// is clear, and could be hardcoded. But we are allowing any range as `IPPrefix`, so we need to find out the
// subnets when we have 172.16.0.0/16 (i.e., [0-255].16.172.in-addr.arpa.), or any other subnet.
//
// How IN-ADDR.ARPA domains work is defined in RFC1035 (section 3.5). Tailscale.com seems to adhere to this,
// and do not make use of RFC2317 ("Classless IN-ADDR.ARPA delegation") - hence generating the entries for the next
// class block only.
// From the netmask we can find out the wildcard bits (the bits that are not set in the netmask).
// This allows us to then calculate the subnets included in the subsequent class block and generate the entries.
func generateMagicDNSRootDomains(ipPrefix netaddr.IPPrefix, baseDomain string) ([]dnsname.FQDN, error) {
base, err := dnsname.ToFQDN(baseDomain)
if err != nil {
return nil, err
}
// TODO(juanfont): we are not handing out IPv6 addresses yet
// and in fact this is Tailscale.com's range (note the fd7a:115c:a1e0: range in the fc00::/7 network)
ipv6base := dnsname.FQDN("0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa.")
fqdns := []dnsname.FQDN{base, ipv6base}
// Conversion to the std lib net.IPnet, a bit easier to operate
netRange := ipPrefix.IPNet()
maskBits, _ := netRange.Mask.Size()
// lastOctet is the last IP byte covered by the mask
lastOctet := maskBits / 8
// wildcardBits is the number of bits not under the mask in the lastOctet
wildcardBits := 8 - maskBits%8
// min is the value in the lastOctet byte of the IP
// max is basically 2^wildcardBits - i.e., the value when all the wildcardBits are set to 1
min := uint(netRange.IP[lastOctet])
max := uint((min + 1<<uint(wildcardBits)) - 1)
// here we generate the base domain (e.g., 100.in-addr.arpa., 16.172.in-addr.arpa., etc.)
rdnsSlice := []string{}
for i := lastOctet - 1; i >= 0; i-- {
rdnsSlice = append(rdnsSlice, fmt.Sprintf("%d", netRange.IP[i]))
}
rdnsSlice = append(rdnsSlice, "in-addr.arpa.")
rdnsBase := strings.Join(rdnsSlice, ".")
for i := min; i <= max; i++ {
fqdn, err := dnsname.ToFQDN(fmt.Sprintf("%d.%s", i, rdnsBase))
if err != nil {
continue
}
fqdns = append(fqdns, fqdn)
}
return fqdns, nil
}

63
dns_test.go Normal file
View File

@ -0,0 +1,63 @@
package headscale
import (
"gopkg.in/check.v1"
"inet.af/netaddr"
)
func (s *Suite) TestMagicDNSRootDomains100(c *check.C) {
prefix := netaddr.MustParseIPPrefix("100.64.0.0/10")
domains, err := generateMagicDNSRootDomains(prefix, "headscale.net")
c.Assert(err, check.IsNil)
found := false
for _, domain := range domains {
if domain == "64.100.in-addr.arpa." {
found = true
break
}
}
c.Assert(found, check.Equals, true)
found = false
for _, domain := range domains {
if domain == "100.100.in-addr.arpa." {
found = true
break
}
}
c.Assert(found, check.Equals, true)
found = false
for _, domain := range domains {
if domain == "127.100.in-addr.arpa." {
found = true
break
}
}
c.Assert(found, check.Equals, true)
}
func (s *Suite) TestMagicDNSRootDomains172(c *check.C) {
prefix := netaddr.MustParseIPPrefix("172.16.0.0/16")
domains, err := generateMagicDNSRootDomains(prefix, "headscale.net")
c.Assert(err, check.IsNil)
found := false
for _, domain := range domains {
if domain == "0.16.172.in-addr.arpa." {
found = true
break
}
}
c.Assert(found, check.Equals, true)
found = false
for _, domain := range domains {
if domain == "255.16.172.in-addr.arpa." {
found = true
break
}
}
c.Assert(found, check.Equals, true)
}

33
docs/DNS.md Normal file
View File

@ -0,0 +1,33 @@
# DNS in Headscale
Headscale supports Tailscale's DNS configuration and MagicDNS. Please have a look to their KB to better understand what this means:
- https://tailscale.com/kb/1054/dns/
- https://tailscale.com/kb/1081/magicdns/
- https://tailscale.com/blog/2021-09-private-dns-with-magicdns/
Long story short, you can define the DNS servers you want to use in your tailnets, activate MagicDNS (so you don't have to remember the IP addresses of your nodes), define search domains, as well as predefined hosts. Headscale will inject that settings into your nodes.
## Configuration reference
The setup is done via the `config.json` file, under the `dns_config` key.
```json
{
"server_url": "http://127.0.0.1:8001",
"listen_addr": "0.0.0.0:8001",
"private_key_path": "private.key",
//...
"dns_config": {
"nameservers": ["1.1.1.1", "8.8.8.8"],
"domains": [],
"magic_dns": true,
"base_domain": "example.com"
}
}
```
- `nameservers`: The list of DNS servers to use.
- `domains`: Search domains to inject.
- `magic_dns`: Whether to use [MagicDNS](https://tailscale.com/kb/1081/magicdns/). Only works if there is at least a nameserver defined.
- `base_domain`: Defines the base domain to create the hostnames for MagicDNS. `base_domain` must be a FQDNs, without the trailing dot. The FQDN of the hosts will be `hostname.namespace.base_domain` (e.g., _myhost.mynamespace.example.com_).

View File

@ -504,43 +504,43 @@ func (s *IntegrationTestSuite) TestSharedNodes() {
assert.Contains(s.T(), result, hostname) assert.Contains(s.T(), result, hostname)
} }
// TODO(kradalby): Figure out why these connections are not set up // TODO(juanfont): We have to find out why do we need to wait
// // TODO: See if we can have a more deterministic wait here. time.Sleep(100 * time.Second) // Wait for the nodes to receive updates
// time.Sleep(100 * time.Second)
// mainIps, err := getIPs(main.tailscales) mainIps, err := getIPs(main.tailscales)
// assert.Nil(s.T(), err) assert.Nil(s.T(), err)
// sharedIps, err := getIPs(shared.tailscales) sharedIps, err := getIPs(shared.tailscales)
// assert.Nil(s.T(), err) assert.Nil(s.T(), err)
// for hostname, tailscale := range main.tailscales { for hostname, tailscale := range main.tailscales {
// for peername, ip := range sharedIps { for peername, ip := range sharedIps {
// s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) { s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) {
// // We currently cant ping ourselves, so skip that. // We currently cant ping ourselves, so skip that.
// if peername != hostname { if peername != hostname {
// // We are only interested in "direct ping" which means what we // We are only interested in "direct ping" which means what we
// // might need a couple of more attempts before reaching the node. // might need a couple of more attempts before reaching the node.
// command := []string{ command := []string{
// "tailscale", "ping", "tailscale", "ping",
// "--timeout=1s", "--timeout=15s",
// "--c=20", "--c=20",
// "--until-direct=true", "--until-direct=true",
// ip.String(), ip.String(),
// } }
// fmt.Printf("Pinging from %s (%s) to %s (%s)\n", hostname, mainIps[hostname], peername, ip) fmt.Printf("Pinging from %s (%s) to %s (%s)\n", hostname, mainIps[hostname], peername, ip)
// result, err := executeCommand( result, err := executeCommand(
// &tailscale, &tailscale,
// command, command,
// ) []string{},
// assert.Nil(t, err) )
// fmt.Printf("Result for %s: %s\n", hostname, result) assert.Nil(t, err)
// assert.Contains(t, result, "pong") fmt.Printf("Result for %s: %s\n", hostname, result)
// } assert.Contains(t, result, "pong")
// }) }
// } })
// } }
}
} }
func (s *IntegrationTestSuite) TestTailDrop() { func (s *IntegrationTestSuite) TestTailDrop() {
@ -592,7 +592,7 @@ func (s *IntegrationTestSuite) TestTailDrop() {
_, err = executeCommand( _, err = executeCommand(
&tailscale, &tailscale,
command, command,
[]string{"ALL_PROXY=socks5://localhost:1055/"}, []string{"ALL_PROXY=socks5://localhost:1055"},
) )
if err == nil { if err == nil {
break break
@ -645,6 +645,38 @@ func (s *IntegrationTestSuite) TestTailDrop() {
} }
} }
func (s *IntegrationTestSuite) TestMagicDNS() {
for namespace, scales := range s.namespaces {
ips, err := getIPs(scales.tailscales)
assert.Nil(s.T(), err)
for hostname, tailscale := range scales.tailscales {
for peername, ip := range ips {
s.T().Run(fmt.Sprintf("%s-%s", hostname, peername), func(t *testing.T) {
if peername != hostname {
command := []string{
"tailscale", "ping",
"--timeout=10s",
"--c=20",
"--until-direct=true",
fmt.Sprintf("%s.%s.headscale.net", peername, namespace),
}
fmt.Printf("Pinging using Hostname (magicdns) from %s (%s) to %s (%s)\n", hostname, ips[hostname], peername, ip)
result, err := executeCommand(
&tailscale,
command,
[]string{},
)
assert.Nil(t, err)
fmt.Printf("Result for %s: %s\n", hostname, result)
assert.Contains(t, result, "pong")
}
})
}
}
}
}
func getIPs(tailscales map[string]dockertest.Resource) (map[string]netaddr.IP, error) { func getIPs(tailscales map[string]dockertest.Resource) (map[string]netaddr.IP, error) {
ips := make(map[string]netaddr.IP) ips := make(map[string]netaddr.IP)
for hostname, tailscale := range tailscales { for hostname, tailscale := range tailscales {

View File

@ -7,5 +7,13 @@
"db_type": "sqlite3", "db_type": "sqlite3",
"db_path": "/tmp/integration_test_db.sqlite3", "db_path": "/tmp/integration_test_db.sqlite3",
"acl_policy_path": "", "acl_policy_path": "",
"log_level": "trace" "log_level": "trace",
"dns_config": {
"nameservers": [
"1.1.1.1"
],
"domains": [],
"magic_dns": true,
"base_domain": "headscale.net"
}
} }

View File

@ -94,7 +94,7 @@ func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) {
Msg("Finding direct peers") Msg("Finding direct peers")
machines := Machines{} machines := Machines{}
if err := h.db.Where("namespace_id = ? AND machine_key <> ? AND registered", if err := h.db.Preload("Namespace").Where("namespace_id = ? AND machine_key <> ? AND registered",
m.NamespaceID, m.MachineKey).Find(&machines).Error; err != nil { m.NamespaceID, m.MachineKey).Find(&machines).Error; err != nil {
log.Error().Err(err).Msg("Error accessing db") log.Error().Err(err).Msg("Error accessing db")
return Machines{}, err return Machines{}, err
@ -109,13 +109,13 @@ func (h *Headscale) getDirectPeers(m *Machine) (Machines, error) {
return machines, nil return machines, nil
} }
// getShared fetches machines that are shared to the `Namespace` of the machine we are getting peers for
func (h *Headscale) getShared(m *Machine) (Machines, error) { func (h *Headscale) getShared(m *Machine) (Machines, error) {
log.Trace(). log.Trace().
Str("func", "getShared"). Str("func", "getShared").
Str("machine", m.Name). Str("machine", m.Name).
Msg("Finding shared peers") Msg("Finding shared peers")
// We fetch here machines that are shared to the `Namespace` of the machine we are getting peers for
sharedMachines := []SharedMachine{} sharedMachines := []SharedMachine{}
if err := h.db.Preload("Namespace").Preload("Machine").Where("namespace_id = ?", if err := h.db.Preload("Namespace").Preload("Machine").Where("namespace_id = ?",
m.NamespaceID).Find(&sharedMachines).Error; err != nil { m.NamespaceID).Find(&sharedMachines).Error; err != nil {
@ -136,6 +136,37 @@ func (h *Headscale) getShared(m *Machine) (Machines, error) {
return peers, nil return peers, nil
} }
// getSharedTo fetches the machines of the namespaces this machine is shared in
func (h *Headscale) getSharedTo(m *Machine) (Machines, error) {
log.Trace().
Str("func", "getSharedTo").
Str("machine", m.Name).
Msg("Finding peers in namespaces this machine is shared with")
sharedMachines := []SharedMachine{}
if err := h.db.Preload("Namespace").Preload("Machine").Where("machine_id = ?",
m.ID).Find(&sharedMachines).Error; err != nil {
return Machines{}, err
}
peers := make(Machines, 0)
for _, sharedMachine := range sharedMachines {
namespaceMachines, err := h.ListMachinesInNamespace(sharedMachine.Namespace.Name)
if err != nil {
return Machines{}, err
}
peers = append(peers, *namespaceMachines...)
}
sort.Slice(peers, func(i, j int) bool { return peers[i].ID < peers[j].ID })
log.Trace().
Str("func", "getSharedTo").
Str("machine", m.Name).
Msgf("Found peers we are shared with: %s", peers.String())
return peers, nil
}
func (h *Headscale) getPeers(m *Machine) (Machines, error) { func (h *Headscale) getPeers(m *Machine) (Machines, error) {
direct, err := h.getDirectPeers(m) direct, err := h.getDirectPeers(m)
if err != nil { if err != nil {
@ -149,13 +180,24 @@ func (h *Headscale) getPeers(m *Machine) (Machines, error) {
shared, err := h.getShared(m) shared, err := h.getShared(m)
if err != nil { if err != nil {
log.Error(). log.Error().
Str("func", "getDirectPeers"). Str("func", "getShared").
Err(err).
Msg("Cannot fetch peers")
return Machines{}, err
}
sharedTo, err := h.getSharedTo(m)
if err != nil {
log.Error().
Str("func", "sharedTo").
Err(err). Err(err).
Msg("Cannot fetch peers") Msg("Cannot fetch peers")
return Machines{}, err return Machines{}, err
} }
peers := append(direct, shared...) peers := append(direct, shared...)
peers = append(peers, sharedTo...)
sort.Slice(peers, func(i, j int) bool { return peers[i].ID < peers[j].ID }) sort.Slice(peers, func(i, j int) bool { return peers[i].ID < peers[j].ID })
log.Trace(). log.Trace().
@ -210,6 +252,11 @@ func (h *Headscale) UpdateMachine(m *Machine) error {
// DeleteMachine softs deletes a Machine from the database // DeleteMachine softs deletes a Machine from the database
func (h *Headscale) DeleteMachine(m *Machine) error { func (h *Headscale) DeleteMachine(m *Machine) error {
err := h.RemoveSharedMachineFromAllNamespaces(m)
if err != nil && err != errorMachineNotShared {
return err
}
m.Registered = false m.Registered = false
namespaceID := m.NamespaceID namespaceID := m.NamespaceID
h.db.Save(&m) // we mark it as unregistered, just in case h.db.Save(&m) // we mark it as unregistered, just in case
@ -222,10 +269,16 @@ func (h *Headscale) DeleteMachine(m *Machine) error {
// HardDeleteMachine hard deletes a Machine from the database // HardDeleteMachine hard deletes a Machine from the database
func (h *Headscale) HardDeleteMachine(m *Machine) error { func (h *Headscale) HardDeleteMachine(m *Machine) error {
err := h.RemoveSharedMachineFromAllNamespaces(m)
if err != nil && err != errorMachineNotShared {
return err
}
namespaceID := m.NamespaceID namespaceID := m.NamespaceID
if err := h.db.Unscoped().Delete(&m).Error; err != nil { if err := h.db.Unscoped().Delete(&m).Error; err != nil {
return err return err
} }
return h.RequestMapUpdates(namespaceID) return h.RequestMapUpdates(namespaceID)
} }
@ -304,11 +357,11 @@ func (ms MachinesP) String() string {
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
} }
func (ms Machines) toNodes(includeRoutes bool) ([]*tailcfg.Node, error) { func (ms Machines) toNodes(baseDomain string, dnsConfig *tailcfg.DNSConfig, includeRoutes bool) ([]*tailcfg.Node, error) {
nodes := make([]*tailcfg.Node, len(ms)) nodes := make([]*tailcfg.Node, len(ms))
for index, machine := range ms { for index, machine := range ms {
node, err := machine.toNode(includeRoutes) node, err := machine.toNode(baseDomain, dnsConfig, includeRoutes)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -321,7 +374,7 @@ func (ms Machines) toNodes(includeRoutes bool) ([]*tailcfg.Node, error) {
// toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes // toNode converts a Machine into a Tailscale Node. includeRoutes is false for shared nodes
// as per the expected behaviour in the official SaaS // as per the expected behaviour in the official SaaS
func (m Machine) toNode(includeRoutes bool) (*tailcfg.Node, error) { func (m Machine) toNode(baseDomain string, dnsConfig *tailcfg.DNSConfig, includeRoutes bool) (*tailcfg.Node, error) {
nKey, err := wgkey.ParseHex(m.NodeKey) nKey, err := wgkey.ParseHex(m.NodeKey)
if err != nil { if err != nil {
return nil, err return nil, err
@ -416,10 +469,17 @@ func (m Machine) toNode(includeRoutes bool) (*tailcfg.Node, error) {
keyExpiry = time.Time{} keyExpiry = time.Time{}
} }
var hostname string
if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS
hostname = fmt.Sprintf("%s.%s.%s", m.Name, m.Namespace.Name, baseDomain)
} else {
hostname = m.Name
}
n := tailcfg.Node{ n := tailcfg.Node{
ID: tailcfg.NodeID(m.ID), // this is the actual ID ID: tailcfg.NodeID(m.ID), // this is the actual ID
StableID: tailcfg.StableNodeID(strconv.FormatUint(m.ID, 10)), // in headscale, unlike tailcontrol server, IDs are permanent StableID: tailcfg.StableNodeID(strconv.FormatUint(m.ID, 10)), // in headscale, unlike tailcontrol server, IDs are permanent
Name: hostinfo.Hostname, Name: hostname,
User: tailcfg.UserID(m.NamespaceID), User: tailcfg.UserID(m.NamespaceID),
Key: tailcfg.NodeKey(nKey), Key: tailcfg.NodeKey(nKey),
KeyExpiry: keyExpiry, KeyExpiry: keyExpiry,

View File

@ -91,7 +91,7 @@ func (h *Headscale) ListMachinesInNamespace(name string) (*[]Machine, error) {
} }
machines := []Machine{} machines := []Machine{}
if err := h.db.Preload("AuthKey").Preload("Namespace").Where(&Machine{NamespaceID: n.ID}).Find(&machines).Error; err != nil { if err := h.db.Preload("AuthKey").Preload("AuthKey.Namespace").Preload("Namespace").Where(&Machine{NamespaceID: n.ID}).Find(&machines).Error; err != nil {
return nil, err return nil, err
} }
return &machines, nil return &machines, nil

View File

@ -11,7 +11,7 @@ import (
const errorAuthKeyNotFound = Error("AuthKey not found") const errorAuthKeyNotFound = Error("AuthKey not found")
const errorAuthKeyExpired = Error("AuthKey expired") const errorAuthKeyExpired = Error("AuthKey expired")
const errorAuthKeyNotReusableAlreadyUsed = Error("AuthKey not reusable already used") const errSingleUseAuthKeyHasBeenUsed = Error("AuthKey has already been used")
// PreAuthKey describes a pre-authorization key usable in a particular namespace // PreAuthKey describes a pre-authorization key usable in a particular namespace
type PreAuthKey struct { type PreAuthKey struct {
@ -21,6 +21,7 @@ type PreAuthKey struct {
Namespace Namespace Namespace Namespace
Reusable bool Reusable bool
Ephemeral bool `gorm:"default:false"` Ephemeral bool `gorm:"default:false"`
Used bool `gorm:"default:false"`
CreatedAt *time.Time CreatedAt *time.Time
Expiration *time.Time Expiration *time.Time
@ -110,11 +111,10 @@ func (h *Headscale) checkKeyValidity(k string) (*PreAuthKey, error) {
return nil, err return nil, err
} }
if len(machines) != 0 { if len(machines) != 0 || pak.Used {
return nil, errorAuthKeyNotReusableAlreadyUsed return nil, errSingleUseAuthKeyHasBeenUsed
} }
// missing here validation on current usage
return &pak, nil return &pak, nil
} }

View File

@ -87,7 +87,7 @@ func (*Suite) TestAlreadyUsedKey(c *check.C) {
h.db.Save(&m) h.db.Save(&m)
p, err := h.checkKeyValidity(pak.Key) p, err := h.checkKeyValidity(pak.Key)
c.Assert(err, check.Equals, errorAuthKeyNotReusableAlreadyUsed) c.Assert(err, check.Equals, errSingleUseAuthKeyHasBeenUsed)
c.Assert(p, check.IsNil) c.Assert(p, check.IsNil)
} }
@ -180,3 +180,16 @@ func (*Suite) TestExpirePreauthKey(c *check.C) {
c.Assert(err, check.Equals, errorAuthKeyExpired) c.Assert(err, check.Equals, errorAuthKeyExpired)
c.Assert(p, check.IsNil) c.Assert(p, check.IsNil)
} }
func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) {
n, err := h.CreateNamespace("test6")
c.Assert(err, check.IsNil)
pak, err := h.CreatePreAuthKey(n.Name, false, false, nil)
c.Assert(err, check.IsNil)
pak.Used = true
h.db.Save(&pak)
_, err = h.checkKeyValidity(pak.Key)
c.Assert(err, check.Equals, errSingleUseAuthKeyHasBeenUsed)
}

View File

@ -4,6 +4,7 @@ import "gorm.io/gorm"
const errorSameNamespace = Error("Destination namespace same as origin") const errorSameNamespace = Error("Destination namespace same as origin")
const errorMachineAlreadyShared = Error("Node already shared to this namespace") const errorMachineAlreadyShared = Error("Node already shared to this namespace")
const errorMachineNotShared = Error("Machine not shared to this namespace")
// SharedMachine is a join table to support sharing nodes between namespaces // SharedMachine is a join table to support sharing nodes between namespaces
type SharedMachine struct { type SharedMachine struct {
@ -35,3 +36,13 @@ func (h *Headscale) AddSharedMachineToNamespace(m *Machine, ns *Namespace) error
return nil return nil
} }
// RemoveSharedMachineFromAllNamespaces removes a machine as a shared node from all namespaces
func (h *Headscale) RemoveSharedMachineFromAllNamespaces(m *Machine) error {
sharedMachine := SharedMachine{}
if result := h.db.Where("machine_id = ?", m.ID).Unscoped().Delete(&sharedMachine); result.Error != nil {
return result.Error
}
return nil
}

View File

@ -274,7 +274,7 @@ func (s *Suite) TestComplexSharingAcrossNamespaces(c *check.C) {
c.Assert(err, check.NotNil) c.Assert(err, check.NotNil)
m1 := &Machine{ m1 := &Machine{
ID: 0, ID: 1,
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66", DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
@ -291,7 +291,7 @@ func (s *Suite) TestComplexSharingAcrossNamespaces(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
m2 := &Machine{ m2 := &Machine{
ID: 1, ID: 2,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
@ -308,7 +308,7 @@ func (s *Suite) TestComplexSharingAcrossNamespaces(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
m3 := &Machine{ m3 := &Machine{
ID: 2, ID: 3,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
@ -325,7 +325,7 @@ func (s *Suite) TestComplexSharingAcrossNamespaces(c *check.C) {
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
m4 := &Machine{ m4 := &Machine{
ID: 3, ID: 4,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863", DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
@ -341,6 +341,129 @@ func (s *Suite) TestComplexSharingAcrossNamespaces(c *check.C) {
_, err = h.GetMachine(n1.Name, m4.Name) _, err = h.GetMachine(n1.Name, m4.Name)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
p1s, err := h.getPeers(m1)
c.Assert(err, check.IsNil)
c.Assert(len(p1s), check.Equals, 1) // node1 can see node4
c.Assert(p1s[0].Name, check.Equals, "test_get_shared_nodes_4")
err = h.AddSharedMachineToNamespace(m2, n1)
c.Assert(err, check.IsNil)
p1sAfter, err := h.getPeers(m1)
c.Assert(err, check.IsNil)
c.Assert(len(p1sAfter), check.Equals, 2) // node1 can see node2 (shared) and node4 (same namespace)
c.Assert(p1sAfter[0].Name, check.Equals, "test_get_shared_nodes_2")
c.Assert(p1sAfter[1].Name, check.Equals, "test_get_shared_nodes_4")
node1shared, err := h.getShared(m1)
c.Assert(err, check.IsNil)
c.Assert(len(node1shared), check.Equals, 1) // node1 can see node2 as shared
c.Assert(node1shared[0].Name, check.Equals, "test_get_shared_nodes_2")
pAlone, err := h.getPeers(m3)
c.Assert(err, check.IsNil)
c.Assert(len(pAlone), check.Equals, 0) // node3 is alone
pSharedTo, err := h.getPeers(m2)
c.Assert(err, check.IsNil)
c.Assert(len(pSharedTo), check.Equals, 2) // node2 should see node1 (sharedTo) and node4 (sharedTo), as is shared in namespace1
c.Assert(pSharedTo[0].Name, check.Equals, "test_get_shared_nodes_1")
c.Assert(pSharedTo[1].Name, check.Equals, "test_get_shared_nodes_4")
}
func (s *Suite) TestDeleteSharedMachine(c *check.C) {
n1, err := h.CreateNamespace("shared1")
c.Assert(err, check.IsNil)
n2, err := h.CreateNamespace("shared2")
c.Assert(err, check.IsNil)
n3, err := h.CreateNamespace("shared3")
c.Assert(err, check.IsNil)
pak1n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
c.Assert(err, check.IsNil)
pak2n2, err := h.CreatePreAuthKey(n2.Name, false, false, nil)
c.Assert(err, check.IsNil)
pak3n3, err := h.CreatePreAuthKey(n3.Name, false, false, nil)
c.Assert(err, check.IsNil)
pak4n1, err := h.CreatePreAuthKey(n1.Name, false, false, nil)
c.Assert(err, check.IsNil)
_, err = h.GetMachine(n1.Name, "test_get_shared_nodes_1")
c.Assert(err, check.NotNil)
m1 := &Machine{
ID: 1,
MachineKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
NodeKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
DiscoKey: "686824e749f3b7f2a5927ee6c1e422aee5292592d9179a271ed7b3e659b44a66",
Name: "test_get_shared_nodes_1",
NamespaceID: n1.ID,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.1",
AuthKeyID: uint(pak1n1.ID),
}
h.db.Save(m1)
_, err = h.GetMachine(n1.Name, m1.Name)
c.Assert(err, check.IsNil)
m2 := &Machine{
ID: 2,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Name: "test_get_shared_nodes_2",
NamespaceID: n2.ID,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.2",
AuthKeyID: uint(pak2n2.ID),
}
h.db.Save(m2)
_, err = h.GetMachine(n2.Name, m2.Name)
c.Assert(err, check.IsNil)
m3 := &Machine{
ID: 3,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Name: "test_get_shared_nodes_3",
NamespaceID: n3.ID,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.3",
AuthKeyID: uint(pak3n3.ID),
}
h.db.Save(m3)
_, err = h.GetMachine(n3.Name, m3.Name)
c.Assert(err, check.IsNil)
m4 := &Machine{
ID: 4,
MachineKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
NodeKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
DiscoKey: "dec46ef9dc45c7d2f03bfcd5a640d9e24e3cc68ce3d9da223867c9bc6d5e9863",
Name: "test_get_shared_nodes_4",
NamespaceID: n1.ID,
Registered: true,
RegisterMethod: "authKey",
IPAddress: "100.64.0.4",
AuthKeyID: uint(pak4n1.ID),
}
h.db.Save(m4)
_, err = h.GetMachine(n1.Name, m4.Name)
c.Assert(err, check.IsNil)
p1s, err := h.getPeers(m1) p1s, err := h.getPeers(m1)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(p1s), check.Equals, 1) // nodes 1 and 4 c.Assert(len(p1s), check.Equals, 1) // nodes 1 and 4
@ -363,4 +486,15 @@ func (s *Suite) TestComplexSharingAcrossNamespaces(c *check.C) {
pAlone, err := h.getPeers(m3) pAlone, err := h.getPeers(m3)
c.Assert(err, check.IsNil) c.Assert(err, check.IsNil)
c.Assert(len(pAlone), check.Equals, 0) // node 3 is alone c.Assert(len(pAlone), check.Equals, 0) // node 3 is alone
sharedMachines, err := h.ListSharedMachinesInNamespace(n1.Name)
c.Assert(err, check.IsNil)
c.Assert(len(*sharedMachines), check.Equals, 1)
err = h.DeleteMachine(m2)
c.Assert(err, check.IsNil)
sharedMachines, err = h.ListSharedMachinesInNamespace(n1.Name)
c.Assert(err, check.IsNil)
c.Assert(len(*sharedMachines), check.Equals, 0)
} }