mirror of
https://github.com/juanfont/headscale.git
synced 2025-08-16 15:38:07 +00:00
Compare commits
25 Commits
web-auth-f
...
oidc-clean
Author | SHA1 | Date | |
---|---|---|---|
![]() |
613d29478d | ||
![]() |
4034fbc6e9 | ||
![]() |
46df219ed3 | ||
![]() |
835288d864 | ||
![]() |
93d56362af | ||
![]() |
4799859be0 | ||
![]() |
8e44596171 | ||
![]() |
d479234058 | ||
![]() |
3fc5866de0 | ||
![]() |
f3c40086ac | ||
![]() |
09ed21edd8 | ||
![]() |
456479eaa1 | ||
![]() |
cb87852825 | ||
![]() |
69440058bb | ||
![]() |
9bc6ac0f35 | ||
![]() |
89ff5c83d2 | ||
![]() |
0a47d694be | ||
![]() |
73c84d4f6a | ||
![]() |
a9251d6652 | ||
![]() |
f9c44f11d6 | ||
![]() |
1f8bd24a0d | ||
![]() |
7bf2eb3d71 | ||
![]() |
f5a5437917 | ||
![]() |
9989657c0f | ||
![]() |
cb2790984f |
@@ -36,6 +36,9 @@ linters:
|
||||
- makezero
|
||||
- maintidx
|
||||
|
||||
# Limits the methods of an interface to 10. We have more in integration tests
|
||||
- interfacebloat
|
||||
|
||||
# We might want to enable this, but it might be a lot of work
|
||||
- cyclop
|
||||
- nestif
|
||||
|
@@ -134,7 +134,9 @@ var registerNodeCmd = &cobra.Command{
|
||||
return
|
||||
}
|
||||
|
||||
SuccessOutput(response.Machine, "Machine register", output)
|
||||
SuccessOutput(
|
||||
response.Machine,
|
||||
fmt.Sprintf("Machine %s registered", response.Machine.GivenName), output)
|
||||
},
|
||||
}
|
||||
|
||||
|
204
integration/auth_web_flow_test.go
Normal file
204
integration/auth_web_flow_test.go
Normal file
@@ -0,0 +1,204 @@
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/juanfont/headscale/integration/hsic"
|
||||
)
|
||||
|
||||
var errParseAuthPage = errors.New("failed to parse auth page")
|
||||
|
||||
type AuthWebFlowScenario struct {
|
||||
*Scenario
|
||||
}
|
||||
|
||||
func TestAuthWebFlowAuthenticationPingAll(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
|
||||
baseScenario, err := NewScenario()
|
||||
if err != nil {
|
||||
t.Errorf("failed to create scenario: %s", err)
|
||||
}
|
||||
|
||||
scenario := AuthWebFlowScenario{
|
||||
Scenario: baseScenario,
|
||||
}
|
||||
|
||||
spec := map[string]int{
|
||||
"namespace1": len(TailscaleVersions),
|
||||
"namespace2": len(TailscaleVersions),
|
||||
}
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(spec, hsic.WithTestName("webauthping"))
|
||||
if err != nil {
|
||||
t.Errorf("failed to create headscale environment: %s", err)
|
||||
}
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
if err != nil {
|
||||
t.Errorf("failed to get clients: %s", err)
|
||||
}
|
||||
|
||||
allIps, err := scenario.ListTailscaleClientsIPs()
|
||||
if err != nil {
|
||||
t.Errorf("failed to get clients: %s", err)
|
||||
}
|
||||
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
if err != nil {
|
||||
t.Errorf("failed wait for tailscale clients to be in sync: %s", err)
|
||||
}
|
||||
|
||||
success := 0
|
||||
|
||||
for _, client := range allClients {
|
||||
for _, ip := range allIps {
|
||||
err := client.Ping(ip.String())
|
||||
if err != nil {
|
||||
t.Errorf("failed to ping %s from %s: %s", ip, client.Hostname(), err)
|
||||
} else {
|
||||
success++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps))
|
||||
|
||||
err = scenario.Shutdown()
|
||||
if err != nil {
|
||||
t.Errorf("failed to tear down scenario: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthWebFlowScenario) CreateHeadscaleEnv(
|
||||
namespaces map[string]int,
|
||||
opts ...hsic.Option,
|
||||
) error {
|
||||
headscale, err := s.Headscale(opts...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = headscale.WaitForReady()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for namespaceName, clientCount := range namespaces {
|
||||
log.Printf("creating namespace %s with %d clients", namespaceName, clientCount)
|
||||
err = s.CreateNamespace(namespaceName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.CreateTailscaleNodesInNamespace(namespaceName, "all", clientCount)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.runTailscaleUp(namespaceName, headscale.GetEndpoint())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AuthWebFlowScenario) runTailscaleUp(
|
||||
namespaceStr, loginServer string,
|
||||
) error {
|
||||
log.Printf("running tailscale up for namespace %s", namespaceStr)
|
||||
if namespace, ok := s.namespaces[namespaceStr]; ok {
|
||||
for _, client := range namespace.Clients {
|
||||
namespace.joinWaitGroup.Add(1)
|
||||
|
||||
go func(c TailscaleClient) {
|
||||
defer namespace.joinWaitGroup.Done()
|
||||
|
||||
// TODO(juanfont): error handle this
|
||||
loginURL, err := c.UpWithLoginURL(loginServer)
|
||||
if err != nil {
|
||||
log.Printf("failed to run tailscale up: %s", err)
|
||||
}
|
||||
|
||||
err = s.runHeadscaleRegister(namespaceStr, loginURL)
|
||||
if err != nil {
|
||||
log.Printf("failed to register client: %s", err)
|
||||
}
|
||||
|
||||
err = c.WaitForReady()
|
||||
if err != nil {
|
||||
log.Printf("error waiting for client %s to be ready: %s", c.Hostname(), err)
|
||||
}
|
||||
}(client)
|
||||
}
|
||||
namespace.joinWaitGroup.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to up tailscale node: %w", errNoNamespaceAvailable)
|
||||
}
|
||||
|
||||
func (s *AuthWebFlowScenario) runHeadscaleRegister(namespaceStr string, loginURL *url.URL) error {
|
||||
headscale, err := s.Headscale()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Printf("loginURL: %s", loginURL)
|
||||
loginURL.Host = fmt.Sprintf("%s:8080", headscale.GetIP())
|
||||
loginURL.Scheme = "http"
|
||||
|
||||
httpClient := &http.Client{}
|
||||
ctx := context.Background()
|
||||
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil)
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
// see api.go HTML template
|
||||
codeSep := strings.Split(string(body), "</code>")
|
||||
if len(codeSep) != 2 {
|
||||
return errParseAuthPage
|
||||
}
|
||||
|
||||
keySep := strings.Split(codeSep[0], "key ")
|
||||
if len(keySep) != 2 {
|
||||
return errParseAuthPage
|
||||
}
|
||||
key := keySep[1]
|
||||
log.Printf("registering node %s", key)
|
||||
|
||||
if headscale, err := s.Headscale(); err == nil {
|
||||
_, err = headscale.Execute(
|
||||
[]string{"headscale", "-n", namespaceStr, "nodes", "register", "--key", key},
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("failed to register node: %s", err)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to find headscale: %w", errNoHeadscaleAvailable)
|
||||
}
|
@@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/integration/hsic"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@@ -36,11 +37,14 @@ func TestNamespaceCommand(t *testing.T) {
|
||||
"namespace2": 0,
|
||||
}
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(spec)
|
||||
err = scenario.CreateHeadscaleEnv(spec, hsic.WithTestName("clins"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assert.NoError(t, err)
|
||||
|
||||
var listNamespaces []v1.Namespace
|
||||
err = executeAndUnmarshal(scenario.Headscale(),
|
||||
err = executeAndUnmarshal(headscale,
|
||||
[]string{
|
||||
"headscale",
|
||||
"namespaces",
|
||||
@@ -61,7 +65,7 @@ func TestNamespaceCommand(t *testing.T) {
|
||||
result,
|
||||
)
|
||||
|
||||
_, err = scenario.Headscale().Execute(
|
||||
_, err = headscale.Execute(
|
||||
[]string{
|
||||
"headscale",
|
||||
"namespaces",
|
||||
@@ -75,7 +79,7 @@ func TestNamespaceCommand(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
var listAfterRenameNamespaces []v1.Namespace
|
||||
err = executeAndUnmarshal(scenario.Headscale(),
|
||||
err = executeAndUnmarshal(headscale,
|
||||
[]string{
|
||||
"headscale",
|
||||
"namespaces",
|
||||
@@ -114,7 +118,10 @@ func TestPreAuthKeyCommand(t *testing.T) {
|
||||
namespace: 0,
|
||||
}
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(spec)
|
||||
err = scenario.CreateHeadscaleEnv(spec, hsic.WithTestName("clipak"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assert.NoError(t, err)
|
||||
|
||||
keys := make([]*v1.PreAuthKey, count)
|
||||
@@ -123,7 +130,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
|
||||
for index := 0; index < count; index++ {
|
||||
var preAuthKey v1.PreAuthKey
|
||||
err := executeAndUnmarshal(
|
||||
scenario.Headscale(),
|
||||
headscale,
|
||||
[]string{
|
||||
"headscale",
|
||||
"preauthkeys",
|
||||
@@ -149,7 +156,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
|
||||
|
||||
var listedPreAuthKeys []v1.PreAuthKey
|
||||
err = executeAndUnmarshal(
|
||||
scenario.Headscale(),
|
||||
headscale,
|
||||
[]string{
|
||||
"headscale",
|
||||
"preauthkeys",
|
||||
@@ -202,7 +209,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test key expiry
|
||||
_, err = scenario.Headscale().Execute(
|
||||
_, err = headscale.Execute(
|
||||
[]string{
|
||||
"headscale",
|
||||
"preauthkeys",
|
||||
@@ -216,7 +223,7 @@ func TestPreAuthKeyCommand(t *testing.T) {
|
||||
|
||||
var listedPreAuthKeysAfterExpire []v1.PreAuthKey
|
||||
err = executeAndUnmarshal(
|
||||
scenario.Headscale(),
|
||||
headscale,
|
||||
[]string{
|
||||
"headscale",
|
||||
"preauthkeys",
|
||||
@@ -251,12 +258,15 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) {
|
||||
namespace: 0,
|
||||
}
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(spec)
|
||||
err = scenario.CreateHeadscaleEnv(spec, hsic.WithTestName("clipaknaexp"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assert.NoError(t, err)
|
||||
|
||||
var preAuthKey v1.PreAuthKey
|
||||
err = executeAndUnmarshal(
|
||||
scenario.Headscale(),
|
||||
headscale,
|
||||
[]string{
|
||||
"headscale",
|
||||
"preauthkeys",
|
||||
@@ -273,7 +283,7 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) {
|
||||
|
||||
var listedPreAuthKeys []v1.PreAuthKey
|
||||
err = executeAndUnmarshal(
|
||||
scenario.Headscale(),
|
||||
headscale,
|
||||
[]string{
|
||||
"headscale",
|
||||
"preauthkeys",
|
||||
@@ -313,12 +323,15 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
|
||||
namespace: 0,
|
||||
}
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(spec)
|
||||
err = scenario.CreateHeadscaleEnv(spec, hsic.WithTestName("clipakresueeph"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assert.NoError(t, err)
|
||||
|
||||
var preAuthReusableKey v1.PreAuthKey
|
||||
err = executeAndUnmarshal(
|
||||
scenario.Headscale(),
|
||||
headscale,
|
||||
[]string{
|
||||
"headscale",
|
||||
"preauthkeys",
|
||||
@@ -335,7 +348,7 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
|
||||
|
||||
var preAuthEphemeralKey v1.PreAuthKey
|
||||
err = executeAndUnmarshal(
|
||||
scenario.Headscale(),
|
||||
headscale,
|
||||
[]string{
|
||||
"headscale",
|
||||
"preauthkeys",
|
||||
@@ -355,7 +368,7 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
|
||||
|
||||
var listedPreAuthKeys []v1.PreAuthKey
|
||||
err = executeAndUnmarshal(
|
||||
scenario.Headscale(),
|
||||
headscale,
|
||||
[]string{
|
||||
"headscale",
|
||||
"preauthkeys",
|
||||
|
@@ -13,4 +13,7 @@ type ControlServer interface {
|
||||
CreateNamespace(namespace string) error
|
||||
CreateAuthKey(namespace string) (*v1.PreAuthKey, error)
|
||||
ListMachinesInNamespace(namespace string) ([]*v1.Machine, error)
|
||||
GetCert() []byte
|
||||
GetHostname() string
|
||||
GetIP() string
|
||||
}
|
||||
|
@@ -6,6 +6,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/integration/hsic"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
@@ -22,7 +23,7 @@ func TestPingAllByIP(t *testing.T) {
|
||||
"namespace2": len(TailscaleVersions),
|
||||
}
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(spec)
|
||||
err = scenario.CreateHeadscaleEnv(spec, hsic.WithTestName("pingallbyip"))
|
||||
if err != nil {
|
||||
t.Errorf("failed to create headscale environment: %s", err)
|
||||
}
|
||||
@@ -77,7 +78,7 @@ func TestPingAllByHostname(t *testing.T) {
|
||||
"namespace4": len(TailscaleVersions) - 1,
|
||||
}
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(spec)
|
||||
err = scenario.CreateHeadscaleEnv(spec, hsic.WithTestName("pingallbyname"))
|
||||
if err != nil {
|
||||
t.Errorf("failed to create headscale environment: %s", err)
|
||||
}
|
||||
@@ -144,7 +145,7 @@ func TestTaildrop(t *testing.T) {
|
||||
"taildrop": len(TailscaleVersions) - 1,
|
||||
}
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(spec)
|
||||
err = scenario.CreateHeadscaleEnv(spec, hsic.WithTestName("taildrop"))
|
||||
if err != nil {
|
||||
t.Errorf("failed to create headscale environment: %s", err)
|
||||
}
|
||||
@@ -168,7 +169,7 @@ func TestTaildrop(t *testing.T) {
|
||||
for _, client := range allClients {
|
||||
command := []string{"touch", fmt.Sprintf("/tmp/file_from_%s", client.Hostname())}
|
||||
|
||||
if _, err := client.Execute(command); err != nil {
|
||||
if _, _, err := client.Execute(command); err != nil {
|
||||
t.Errorf("failed to create taildrop file on %s, err: %s", client.Hostname(), err)
|
||||
}
|
||||
|
||||
@@ -193,7 +194,7 @@ func TestTaildrop(t *testing.T) {
|
||||
client.Hostname(),
|
||||
peer.Hostname(),
|
||||
)
|
||||
_, err := client.Execute(command)
|
||||
_, _, err := client.Execute(command)
|
||||
|
||||
return err
|
||||
})
|
||||
@@ -214,7 +215,7 @@ func TestTaildrop(t *testing.T) {
|
||||
"get",
|
||||
"/tmp/",
|
||||
}
|
||||
if _, err := client.Execute(command); err != nil {
|
||||
if _, _, err := client.Execute(command); err != nil {
|
||||
t.Errorf("failed to get taildrop file on %s, err: %s", client.Hostname(), err)
|
||||
}
|
||||
|
||||
@@ -234,7 +235,7 @@ func TestTaildrop(t *testing.T) {
|
||||
peer.Hostname(),
|
||||
)
|
||||
|
||||
result, err := client.Execute(command)
|
||||
result, _, err := client.Execute(command)
|
||||
if err != nil {
|
||||
t.Errorf("failed to execute command to ls taildrop: %s", err)
|
||||
}
|
||||
@@ -271,7 +272,7 @@ func TestResolveMagicDNS(t *testing.T) {
|
||||
"magicdns2": len(TailscaleVersions) - 1,
|
||||
}
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(spec)
|
||||
err = scenario.CreateHeadscaleEnv(spec, hsic.WithTestName("magicdns"))
|
||||
if err != nil {
|
||||
t.Errorf("failed to create headscale environment: %s", err)
|
||||
}
|
||||
@@ -306,7 +307,7 @@ func TestResolveMagicDNS(t *testing.T) {
|
||||
"tailscale",
|
||||
"ip", peerFQDN,
|
||||
}
|
||||
result, err := client.Execute(command)
|
||||
result, _, err := client.Execute(command)
|
||||
if err != nil {
|
||||
t.Errorf(
|
||||
"failed to execute resolve/ip command %s from %s: %s",
|
||||
|
@@ -1,52 +1,83 @@
|
||||
package hsic
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale"
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/integration/dockertestutil"
|
||||
"github.com/juanfont/headscale/integration/integrationutil"
|
||||
"github.com/ory/dockertest/v3"
|
||||
"github.com/ory/dockertest/v3/docker"
|
||||
)
|
||||
|
||||
const (
|
||||
hsicHashLength = 6
|
||||
dockerContextPath = "../."
|
||||
aclPolicyPath = "/etc/headscale/acl.hujson"
|
||||
hsicHashLength = 6
|
||||
dockerContextPath = "../."
|
||||
aclPolicyPath = "/etc/headscale/acl.hujson"
|
||||
tlsCertPath = "/etc/headscale/tls.cert"
|
||||
tlsKeyPath = "/etc/headscale/tls.key"
|
||||
headscaleDefaultPort = 8080
|
||||
)
|
||||
|
||||
var errHeadscaleStatusCodeNotOk = errors.New("headscale status code not ok")
|
||||
|
||||
type HeadscaleInContainer struct {
|
||||
hostname string
|
||||
port int
|
||||
|
||||
pool *dockertest.Pool
|
||||
container *dockertest.Resource
|
||||
network *dockertest.Network
|
||||
|
||||
// optional config
|
||||
port int
|
||||
aclPolicy *headscale.ACLPolicy
|
||||
env []string
|
||||
tlsCert []byte
|
||||
tlsKey []byte
|
||||
}
|
||||
|
||||
type Option = func(c *HeadscaleInContainer)
|
||||
|
||||
func WithACLPolicy(acl *headscale.ACLPolicy) Option {
|
||||
return func(hsic *HeadscaleInContainer) {
|
||||
// TODO(kradalby): Move somewhere appropriate
|
||||
hsic.env = append(hsic.env, fmt.Sprintf("HEADSCALE_ACL_POLICY_PATH=%s", aclPolicyPath))
|
||||
|
||||
hsic.aclPolicy = acl
|
||||
}
|
||||
}
|
||||
|
||||
func WithTLS() Option {
|
||||
return func(hsic *HeadscaleInContainer) {
|
||||
cert, key, err := createCertificate()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create certificates for headscale test: %s", err)
|
||||
}
|
||||
|
||||
// TODO(kradalby): Move somewhere appropriate
|
||||
hsic.env = append(hsic.env, fmt.Sprintf("HEADSCALE_TLS_CERT_PATH=%s", tlsCertPath))
|
||||
hsic.env = append(hsic.env, fmt.Sprintf("HEADSCALE_TLS_KEY_PATH=%s", tlsKeyPath))
|
||||
hsic.env = append(hsic.env, "HEADSCALE_TLS_CLIENT_AUTH_MODE=disabled")
|
||||
|
||||
hsic.tlsCert = cert
|
||||
hsic.tlsKey = key
|
||||
}
|
||||
}
|
||||
|
||||
func WithConfigEnv(configEnv map[string]string) Option {
|
||||
return func(hsic *HeadscaleInContainer) {
|
||||
env := []string{}
|
||||
@@ -59,9 +90,23 @@ func WithConfigEnv(configEnv map[string]string) Option {
|
||||
}
|
||||
}
|
||||
|
||||
func WithPort(port int) Option {
|
||||
return func(hsic *HeadscaleInContainer) {
|
||||
hsic.port = port
|
||||
}
|
||||
}
|
||||
|
||||
func WithTestName(testName string) Option {
|
||||
return func(hsic *HeadscaleInContainer) {
|
||||
hash, _ := headscale.GenerateRandomStringDNSSafe(hsicHashLength)
|
||||
|
||||
hostname := fmt.Sprintf("hs-%s-%s", testName, hash)
|
||||
hsic.hostname = hostname
|
||||
}
|
||||
}
|
||||
|
||||
func New(
|
||||
pool *dockertest.Pool,
|
||||
port int,
|
||||
network *dockertest.Network,
|
||||
opts ...Option,
|
||||
) (*HeadscaleInContainer, error) {
|
||||
@@ -71,11 +116,10 @@ func New(
|
||||
}
|
||||
|
||||
hostname := fmt.Sprintf("hs-%s", hash)
|
||||
portProto := fmt.Sprintf("%d/tcp", port)
|
||||
|
||||
hsic := &HeadscaleInContainer{
|
||||
hostname: hostname,
|
||||
port: port,
|
||||
port: headscaleDefaultPort,
|
||||
|
||||
pool: pool,
|
||||
network: network,
|
||||
@@ -85,9 +129,9 @@ func New(
|
||||
opt(hsic)
|
||||
}
|
||||
|
||||
if hsic.aclPolicy != nil {
|
||||
hsic.env = append(hsic.env, fmt.Sprintf("HEADSCALE_ACL_POLICY_PATH=%s", aclPolicyPath))
|
||||
}
|
||||
log.Println("NAME: ", hsic.hostname)
|
||||
|
||||
portProto := fmt.Sprintf("%d/tcp", hsic.port)
|
||||
|
||||
headscaleBuildOptions := &dockertest.BuildOptions{
|
||||
Dockerfile: "Dockerfile.debug",
|
||||
@@ -95,7 +139,7 @@ func New(
|
||||
}
|
||||
|
||||
runOptions := &dockertest.RunOptions{
|
||||
Name: hostname,
|
||||
Name: hsic.hostname,
|
||||
ExposedPorts: []string{portProto},
|
||||
Networks: []*dockertest.Network{network},
|
||||
// Cmd: []string{"headscale", "serve"},
|
||||
@@ -108,7 +152,7 @@ func New(
|
||||
// dockertest isnt very good at handling containers that has already
|
||||
// been created, this is an attempt to make sure this container isnt
|
||||
// present.
|
||||
err = pool.RemoveContainerByName(hostname)
|
||||
err = pool.RemoveContainerByName(hsic.hostname)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -123,7 +167,7 @@ func New(
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not start headscale container: %w", err)
|
||||
}
|
||||
log.Printf("Created %s container\n", hostname)
|
||||
log.Printf("Created %s container\n", hsic.hostname)
|
||||
|
||||
hsic.container = container
|
||||
|
||||
@@ -144,9 +188,25 @@ func New(
|
||||
}
|
||||
}
|
||||
|
||||
if hsic.hasTLS() {
|
||||
err = hsic.WriteFile(tlsCertPath, hsic.tlsCert)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err)
|
||||
}
|
||||
|
||||
err = hsic.WriteFile(tlsKeyPath, hsic.tlsKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to write TLS key to container: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return hsic, nil
|
||||
}
|
||||
|
||||
func (t *HeadscaleInContainer) hasTLS() bool {
|
||||
return len(t.tlsCert) != 0 && len(t.tlsKey) != 0
|
||||
}
|
||||
|
||||
func (t *HeadscaleInContainer) Shutdown() error {
|
||||
return t.pool.Purge(t.container)
|
||||
}
|
||||
@@ -154,8 +214,6 @@ func (t *HeadscaleInContainer) Shutdown() error {
|
||||
func (t *HeadscaleInContainer) Execute(
|
||||
command []string,
|
||||
) (string, error) {
|
||||
log.Println("command", command)
|
||||
log.Printf("running command for %s\n", t.hostname)
|
||||
stdout, stderr, err := dockertestutil.ExecuteCommand(
|
||||
t.container,
|
||||
command,
|
||||
@@ -164,11 +222,11 @@ func (t *HeadscaleInContainer) Execute(
|
||||
if err != nil {
|
||||
log.Printf("command stderr: %s\n", stderr)
|
||||
|
||||
return "", err
|
||||
}
|
||||
if stdout != "" {
|
||||
log.Printf("command stdout: %s\n", stdout)
|
||||
}
|
||||
|
||||
if stdout != "" {
|
||||
log.Printf("command stdout: %s\n", stdout)
|
||||
return "", err
|
||||
}
|
||||
|
||||
return stdout, nil
|
||||
@@ -179,17 +237,11 @@ func (t *HeadscaleInContainer) GetIP() string {
|
||||
}
|
||||
|
||||
func (t *HeadscaleInContainer) GetPort() string {
|
||||
portProto := fmt.Sprintf("%d/tcp", t.port)
|
||||
|
||||
return t.container.GetPort(portProto)
|
||||
return fmt.Sprintf("%d", t.port)
|
||||
}
|
||||
|
||||
func (t *HeadscaleInContainer) GetHealthEndpoint() string {
|
||||
hostEndpoint := fmt.Sprintf("%s:%d",
|
||||
t.GetIP(),
|
||||
t.port)
|
||||
|
||||
return fmt.Sprintf("http://%s/health", hostEndpoint)
|
||||
return fmt.Sprintf("%s/health", t.GetEndpoint())
|
||||
}
|
||||
|
||||
func (t *HeadscaleInContainer) GetEndpoint() string {
|
||||
@@ -197,17 +249,39 @@ func (t *HeadscaleInContainer) GetEndpoint() string {
|
||||
t.GetIP(),
|
||||
t.port)
|
||||
|
||||
if t.hasTLS() {
|
||||
return fmt.Sprintf("https://%s", hostEndpoint)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("http://%s", hostEndpoint)
|
||||
}
|
||||
|
||||
func (t *HeadscaleInContainer) GetCert() []byte {
|
||||
return t.tlsCert
|
||||
}
|
||||
|
||||
func (t *HeadscaleInContainer) GetHostname() string {
|
||||
return t.hostname
|
||||
}
|
||||
|
||||
func (t *HeadscaleInContainer) WaitForReady() error {
|
||||
url := t.GetHealthEndpoint()
|
||||
|
||||
log.Printf("waiting for headscale to be ready at %s", url)
|
||||
|
||||
client := &http.Client{}
|
||||
|
||||
if t.hasTLS() {
|
||||
insecureTransport := http.DefaultTransport.(*http.Transport).Clone() //nolint
|
||||
insecureTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} //nolint
|
||||
client = &http.Client{Transport: insecureTransport}
|
||||
}
|
||||
|
||||
return t.pool.Retry(func() error {
|
||||
resp, err := http.Get(url) //nolint
|
||||
resp, err := client.Get(url) //nolint
|
||||
if err != nil {
|
||||
log.Printf("ready err: %s", err)
|
||||
|
||||
return fmt.Errorf("headscale is not ready: %w", err)
|
||||
}
|
||||
|
||||
@@ -294,55 +368,87 @@ func (t *HeadscaleInContainer) ListMachinesInNamespace(
|
||||
}
|
||||
|
||||
func (t *HeadscaleInContainer) WriteFile(path string, data []byte) error {
|
||||
dirPath, fileName := filepath.Split(path)
|
||||
return integrationutil.WriteFileToContainer(t.pool, t.container, path, data)
|
||||
}
|
||||
|
||||
file := bytes.NewReader(data)
|
||||
//nolint
|
||||
func createCertificate() ([]byte, []byte, error) {
|
||||
// From:
|
||||
// https://shaneutt.com/blog/golang-ca-and-signed-cert-go/
|
||||
|
||||
buf := bytes.NewBuffer([]byte{})
|
||||
|
||||
tarWriter := tar.NewWriter(buf)
|
||||
|
||||
header := &tar.Header{
|
||||
Name: fileName,
|
||||
Size: file.Size(),
|
||||
// Mode: int64(stat.Mode()),
|
||||
// ModTime: stat.ModTime(),
|
||||
}
|
||||
|
||||
err := tarWriter.WriteHeader(header)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed write file header to tar: %w", err)
|
||||
}
|
||||
|
||||
_, err = io.Copy(tarWriter, file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to copy file to tar: %w", err)
|
||||
}
|
||||
|
||||
err = tarWriter.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to close tar: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("tar: %s", buf.String())
|
||||
|
||||
// Ensure the directory is present inside the container
|
||||
_, err = t.Execute([]string{"mkdir", "-p", dirPath})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to ensure directory: %w", err)
|
||||
}
|
||||
|
||||
err = t.pool.Client.UploadToContainer(
|
||||
t.container.Container.ID,
|
||||
docker.UploadToContainerOptions{
|
||||
NoOverwriteDirNonDir: false,
|
||||
Path: dirPath,
|
||||
InputStream: bytes.NewReader(buf.Bytes()),
|
||||
ca := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(2019),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Headscale testing INC"},
|
||||
Country: []string{"NL"},
|
||||
Locality: []string{"Leiden"},
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(30 * time.Minute),
|
||||
IsCA: true,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||
x509.ExtKeyUsageClientAuth,
|
||||
x509.ExtKeyUsageServerAuth,
|
||||
},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
caPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
cert := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1658),
|
||||
Subject: pkix.Name{
|
||||
Organization: []string{"Headscale testing INC"},
|
||||
Country: []string{"NL"},
|
||||
Locality: []string{"Leiden"},
|
||||
},
|
||||
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(30 * time.Minute),
|
||||
SubjectKeyId: []byte{1, 2, 3, 4, 6},
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
}
|
||||
|
||||
certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
certBytes, err := x509.CreateCertificate(
|
||||
rand.Reader,
|
||||
cert,
|
||||
ca,
|
||||
&certPrivKey.PublicKey,
|
||||
caPrivKey,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return nil
|
||||
certPEM := new(bytes.Buffer)
|
||||
|
||||
err = pem.Encode(certPEM, &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: certBytes,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
certPrivKeyPEM := new(bytes.Buffer)
|
||||
|
||||
err = pem.Encode(certPrivKeyPEM, &pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return certPEM.Bytes(), certPrivKeyPEM.Bytes(), nil
|
||||
}
|
||||
|
77
integration/integrationutil/util.go
Normal file
77
integration/integrationutil/util.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package integrationutil
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/juanfont/headscale/integration/dockertestutil"
|
||||
"github.com/ory/dockertest/v3"
|
||||
"github.com/ory/dockertest/v3/docker"
|
||||
)
|
||||
|
||||
func WriteFileToContainer(
|
||||
pool *dockertest.Pool,
|
||||
container *dockertest.Resource,
|
||||
path string,
|
||||
data []byte,
|
||||
) error {
|
||||
dirPath, fileName := filepath.Split(path)
|
||||
|
||||
file := bytes.NewReader(data)
|
||||
|
||||
buf := bytes.NewBuffer([]byte{})
|
||||
|
||||
tarWriter := tar.NewWriter(buf)
|
||||
|
||||
header := &tar.Header{
|
||||
Name: fileName,
|
||||
Size: file.Size(),
|
||||
// Mode: int64(stat.Mode()),
|
||||
// ModTime: stat.ModTime(),
|
||||
}
|
||||
|
||||
err := tarWriter.WriteHeader(header)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed write file header to tar: %w", err)
|
||||
}
|
||||
|
||||
_, err = io.Copy(tarWriter, file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to copy file to tar: %w", err)
|
||||
}
|
||||
|
||||
err = tarWriter.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to close tar: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("tar: %s", buf.String())
|
||||
|
||||
// Ensure the directory is present inside the container
|
||||
_, _, err = dockertestutil.ExecuteCommand(
|
||||
container,
|
||||
[]string{"mkdir", "-p", dirPath},
|
||||
[]string{},
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to ensure directory: %w", err)
|
||||
}
|
||||
|
||||
err = pool.Client.UploadToContainer(
|
||||
container.Container.ID,
|
||||
docker.UploadToContainerOptions{
|
||||
NoOverwriteDirNonDir: false,
|
||||
Path: dirPath,
|
||||
InputStream: bytes.NewReader(buf.Bytes()),
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@@ -15,36 +15,46 @@ import (
|
||||
"github.com/juanfont/headscale/integration/hsic"
|
||||
"github.com/juanfont/headscale/integration/tsic"
|
||||
"github.com/ory/dockertest/v3"
|
||||
"github.com/puzpuzpuz/xsync/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
scenarioHashLength = 6
|
||||
maxWait = 60 * time.Second
|
||||
headscalePort = 8080
|
||||
)
|
||||
|
||||
var (
|
||||
errNoHeadscaleAvailable = errors.New("no headscale available")
|
||||
errNoNamespaceAvailable = errors.New("no namespace available")
|
||||
TailscaleVersions = []string{
|
||||
tailscaleVersions2021 = []string{
|
||||
"head",
|
||||
"unstable",
|
||||
"1.32.1",
|
||||
"1.30.2",
|
||||
"1.28.0",
|
||||
"1.26.2",
|
||||
}
|
||||
|
||||
tailscaleVersions2019 = []string{
|
||||
"1.24.2",
|
||||
"1.22.2",
|
||||
"1.20.4",
|
||||
"1.18.2",
|
||||
"1.16.2",
|
||||
|
||||
// These versions seem to fail when fetching from apt.
|
||||
// "1.14.6",
|
||||
// "1.12.4",
|
||||
// "1.10.2",
|
||||
// "1.8.7",
|
||||
}
|
||||
|
||||
// tailscaleVersionsUnavailable = []string{
|
||||
// // These versions seem to fail when fetching from apt.
|
||||
// "1.14.6",
|
||||
// "1.12.4",
|
||||
// "1.10.2",
|
||||
// "1.8.7",
|
||||
// }.
|
||||
|
||||
TailscaleVersions = append(
|
||||
tailscaleVersions2021,
|
||||
tailscaleVersions2019...,
|
||||
)
|
||||
)
|
||||
|
||||
type Namespace struct {
|
||||
@@ -59,12 +69,14 @@ type Namespace struct {
|
||||
type Scenario struct {
|
||||
// TODO(kradalby): support multiple headcales for later, currently only
|
||||
// use one.
|
||||
controlServers map[string]ControlServer
|
||||
controlServers *xsync.MapOf[string, ControlServer]
|
||||
|
||||
namespaces map[string]*Namespace
|
||||
|
||||
pool *dockertest.Pool
|
||||
network *dockertest.Network
|
||||
|
||||
headscaleLock sync.Mutex
|
||||
}
|
||||
|
||||
func NewScenario() (*Scenario, error) {
|
||||
@@ -99,7 +111,7 @@ func NewScenario() (*Scenario, error) {
|
||||
}
|
||||
|
||||
return &Scenario{
|
||||
controlServers: make(map[string]ControlServer),
|
||||
controlServers: xsync.NewMapOf[ControlServer](),
|
||||
namespaces: make(map[string]*Namespace),
|
||||
|
||||
pool: pool,
|
||||
@@ -108,12 +120,17 @@ func NewScenario() (*Scenario, error) {
|
||||
}
|
||||
|
||||
func (s *Scenario) Shutdown() error {
|
||||
for _, control := range s.controlServers {
|
||||
s.controlServers.Range(func(_ string, control ControlServer) bool {
|
||||
err := control.Shutdown()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to tear down control: %w", err)
|
||||
log.Printf(
|
||||
"Failed to shut down control: %s",
|
||||
fmt.Errorf("failed to tear down control: %w", err),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
for namespaceName, namespace := range s.namespaces {
|
||||
for _, client := range namespace.Clients {
|
||||
@@ -150,36 +167,31 @@ func (s *Scenario) Namespaces() []string {
|
||||
// Note: These functions assume that there is a _single_ headscale instance for now
|
||||
|
||||
// TODO(kradalby): make port and headscale configurable, multiple instances support?
|
||||
func (s *Scenario) StartHeadscale() error {
|
||||
headscale, err := hsic.New(s.pool, headscalePort, s.network,
|
||||
hsic.WithACLPolicy(
|
||||
&headscale.ACLPolicy{
|
||||
ACLs: []headscale.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"*"},
|
||||
Destinations: []string{"*:*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create headscale container: %w", err)
|
||||
func (s *Scenario) Headscale(opts ...hsic.Option) (ControlServer, error) {
|
||||
s.headscaleLock.Lock()
|
||||
defer s.headscaleLock.Unlock()
|
||||
|
||||
if headscale, ok := s.controlServers.Load("headscale"); ok {
|
||||
return headscale, nil
|
||||
}
|
||||
|
||||
s.controlServers["headscale"] = headscale
|
||||
headscale, err := hsic.New(s.pool, s.network, opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create headscale container: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
err = headscale.WaitForReady()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed reach headscale container: %w", err)
|
||||
}
|
||||
|
||||
func (s *Scenario) Headscale() *hsic.HeadscaleInContainer {
|
||||
//nolint
|
||||
return s.controlServers["headscale"].(*hsic.HeadscaleInContainer)
|
||||
s.controlServers.Store("headscale", headscale)
|
||||
|
||||
return headscale, nil
|
||||
}
|
||||
|
||||
func (s *Scenario) CreatePreAuthKey(namespace string) (*v1.PreAuthKey, error) {
|
||||
if headscale, ok := s.controlServers["headscale"]; ok {
|
||||
if headscale, err := s.Headscale(); err == nil {
|
||||
key, err := headscale.CreateAuthKey(namespace)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create namespace: %w", err)
|
||||
@@ -192,7 +204,7 @@ func (s *Scenario) CreatePreAuthKey(namespace string) (*v1.PreAuthKey, error) {
|
||||
}
|
||||
|
||||
func (s *Scenario) CreateNamespace(namespace string) error {
|
||||
if headscale, ok := s.controlServers["headscale"]; ok {
|
||||
if headscale, err := s.Headscale(); err == nil {
|
||||
err := headscale.CreateNamespace(namespace)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create namespace: %w", err)
|
||||
@@ -222,16 +234,36 @@ func (s *Scenario) CreateTailscaleNodesInNamespace(
|
||||
version = TailscaleVersions[i%len(TailscaleVersions)]
|
||||
}
|
||||
|
||||
headscale, err := s.Headscale()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create tailscale node: %w", err)
|
||||
}
|
||||
|
||||
cert := headscale.GetCert()
|
||||
hostname := headscale.GetHostname()
|
||||
|
||||
namespace.createWaitGroup.Add(1)
|
||||
|
||||
go func() {
|
||||
defer namespace.createWaitGroup.Done()
|
||||
|
||||
// TODO(kradalby): error handle this
|
||||
tsClient, err := tsic.New(s.pool, version, s.network)
|
||||
tsClient, err := tsic.New(
|
||||
s.pool,
|
||||
version,
|
||||
s.network,
|
||||
tsic.WithHeadscaleTLS(cert),
|
||||
tsic.WithHeadscaleName(hostname),
|
||||
)
|
||||
if err != nil {
|
||||
// return fmt.Errorf("failed to add tailscale node: %w", err)
|
||||
log.Printf("failed to add tailscale node: %s", err)
|
||||
log.Printf("failed to create tailscale node: %s", err)
|
||||
}
|
||||
|
||||
err = tsClient.WaitForReady()
|
||||
if err != nil {
|
||||
// return fmt.Errorf("failed to add tailscale node: %w", err)
|
||||
log.Printf("failed to wait for tailscaled: %s", err)
|
||||
}
|
||||
|
||||
namespace.Clients[tsClient.Hostname()] = tsClient
|
||||
@@ -258,7 +290,13 @@ func (s *Scenario) RunTailscaleUp(
|
||||
// TODO(kradalby): error handle this
|
||||
_ = c.Up(loginServer, authKey)
|
||||
}(client)
|
||||
|
||||
err := client.WaitForReady()
|
||||
if err != nil {
|
||||
log.Printf("error waiting for client %s to be ready: %s", client.Hostname(), err)
|
||||
}
|
||||
}
|
||||
|
||||
namespace.joinWaitGroup.Wait()
|
||||
|
||||
return nil
|
||||
@@ -300,13 +338,8 @@ func (s *Scenario) WaitForTailscaleSync() error {
|
||||
// CreateHeadscaleEnv is a conventient method returning a set up Headcale
|
||||
// test environment with nodes of all versions, joined to the server with X
|
||||
// namespaces.
|
||||
func (s *Scenario) CreateHeadscaleEnv(namespaces map[string]int) error {
|
||||
err := s.StartHeadscale()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.Headscale().WaitForReady()
|
||||
func (s *Scenario) CreateHeadscaleEnv(namespaces map[string]int, opts ...hsic.Option) error {
|
||||
headscale, err := s.Headscale(opts...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -327,7 +360,7 @@ func (s *Scenario) CreateHeadscaleEnv(namespaces map[string]int) error {
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.RunTailscaleUp(namespaceName, s.Headscale().GetEndpoint(), key.GetKey())
|
||||
err = s.RunTailscaleUp(namespaceName, headscale.GetEndpoint(), key.GetKey())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@@ -34,12 +34,12 @@ func TestHeadscale(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("start-headscale", func(t *testing.T) {
|
||||
err = scenario.StartHeadscale()
|
||||
headscale, err := scenario.Headscale()
|
||||
if err != nil {
|
||||
t.Errorf("failed to create start headcale: %s", err)
|
||||
}
|
||||
|
||||
err = scenario.Headscale().WaitForReady()
|
||||
err = headscale.WaitForReady()
|
||||
if err != nil {
|
||||
t.Errorf("headscale failed to become ready: %s", err)
|
||||
}
|
||||
@@ -117,12 +117,11 @@ func TestTailscaleNodesJoiningHeadcale(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("start-headscale", func(t *testing.T) {
|
||||
err = scenario.StartHeadscale()
|
||||
headscale, err := scenario.Headscale()
|
||||
if err != nil {
|
||||
t.Errorf("failed to create start headcale: %s", err)
|
||||
}
|
||||
|
||||
headscale := scenario.Headscale()
|
||||
err = headscale.WaitForReady()
|
||||
if err != nil {
|
||||
t.Errorf("headscale failed to become ready: %s", err)
|
||||
@@ -157,7 +156,16 @@ func TestTailscaleNodesJoiningHeadcale(t *testing.T) {
|
||||
t.Errorf("failed to create preauthkey: %s", err)
|
||||
}
|
||||
|
||||
err = scenario.RunTailscaleUp(namespace, scenario.Headscale().GetEndpoint(), key.GetKey())
|
||||
headscale, err := scenario.Headscale()
|
||||
if err != nil {
|
||||
t.Errorf("failed to create start headcale: %s", err)
|
||||
}
|
||||
|
||||
err = scenario.RunTailscaleUp(
|
||||
namespace,
|
||||
headscale.GetEndpoint(),
|
||||
key.GetKey(),
|
||||
)
|
||||
if err != nil {
|
||||
t.Errorf("failed to login: %s", err)
|
||||
}
|
||||
|
@@ -2,6 +2,7 @@ package integration
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"net/url"
|
||||
|
||||
"tailscale.com/ipn/ipnstate"
|
||||
)
|
||||
@@ -10,11 +11,13 @@ type TailscaleClient interface {
|
||||
Hostname() string
|
||||
Shutdown() error
|
||||
Version() string
|
||||
Execute(command []string) (string, error)
|
||||
Execute(command []string) (string, string, error)
|
||||
Up(loginServer, authKey string) error
|
||||
UpWithLoginURL(loginServer string) (*url.URL, error)
|
||||
IPs() ([]netip.Addr, error)
|
||||
FQDN() (string, error)
|
||||
Status() (*ipnstate.Status, error)
|
||||
WaitForReady() error
|
||||
WaitForPeers(expected int) error
|
||||
Ping(hostnameOrIP string) error
|
||||
}
|
||||
|
@@ -6,11 +6,13 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/juanfont/headscale"
|
||||
"github.com/juanfont/headscale/integration/dockertestutil"
|
||||
"github.com/juanfont/headscale/integration/integrationutil"
|
||||
"github.com/ory/dockertest/v3"
|
||||
"github.com/ory/dockertest/v3/docker"
|
||||
"tailscale.com/ipn/ipnstate"
|
||||
@@ -19,12 +21,15 @@ import (
|
||||
const (
|
||||
tsicHashLength = 6
|
||||
dockerContextPath = "../."
|
||||
headscaleCertPath = "/usr/local/share/ca-certificates/headscale.crt"
|
||||
)
|
||||
|
||||
var (
|
||||
errTailscalePingFailed = errors.New("ping failed")
|
||||
errTailscaleNotLoggedIn = errors.New("tailscale not logged in")
|
||||
errTailscaleWrongPeerCount = errors.New("wrong peer count")
|
||||
errTailscalePingFailed = errors.New("ping failed")
|
||||
errTailscaleNotLoggedIn = errors.New("tailscale not logged in")
|
||||
errTailscaleWrongPeerCount = errors.New("wrong peer count")
|
||||
errTailscaleCannotUpWithoutAuthkey = errors.New("cannot up without authkey")
|
||||
errTailscaleNotConnected = errors.New("tailscale not connected")
|
||||
)
|
||||
|
||||
type TailscaleInContainer struct {
|
||||
@@ -38,12 +43,51 @@ type TailscaleInContainer struct {
|
||||
// "cache"
|
||||
ips []netip.Addr
|
||||
fqdn string
|
||||
|
||||
// optional config
|
||||
headscaleCert []byte
|
||||
headscaleHostname string
|
||||
}
|
||||
|
||||
type Option = func(c *TailscaleInContainer)
|
||||
|
||||
func WithHeadscaleTLS(cert []byte) Option {
|
||||
return func(tsic *TailscaleInContainer) {
|
||||
tsic.headscaleCert = cert
|
||||
}
|
||||
}
|
||||
|
||||
func WithOrCreateNetwork(network *dockertest.Network) Option {
|
||||
return func(tsic *TailscaleInContainer) {
|
||||
if network != nil {
|
||||
tsic.network = network
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
network, err := dockertestutil.GetFirstOrCreateNetwork(
|
||||
tsic.pool,
|
||||
fmt.Sprintf("%s-network", tsic.hostname),
|
||||
)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create network: %s", err)
|
||||
}
|
||||
|
||||
tsic.network = network
|
||||
}
|
||||
}
|
||||
|
||||
func WithHeadscaleName(hsName string) Option {
|
||||
return func(tsic *TailscaleInContainer) {
|
||||
tsic.headscaleHostname = hsName
|
||||
}
|
||||
}
|
||||
|
||||
func New(
|
||||
pool *dockertest.Pool,
|
||||
version string,
|
||||
network *dockertest.Network,
|
||||
opts ...Option,
|
||||
) (*TailscaleInContainer, error) {
|
||||
hash, err := headscale.GenerateRandomStringDNSSafe(tsicHashLength)
|
||||
if err != nil {
|
||||
@@ -52,20 +96,38 @@ func New(
|
||||
|
||||
hostname := fmt.Sprintf("ts-%s-%s", strings.ReplaceAll(version, ".", "-"), hash)
|
||||
|
||||
// TODO(kradalby): figure out why we need to "refresh" the network here.
|
||||
// network, err = dockertestutil.GetFirstOrCreateNetwork(pool, network.Network.Name)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
tsic := &TailscaleInContainer{
|
||||
version: version,
|
||||
hostname: hostname,
|
||||
|
||||
pool: pool,
|
||||
network: network,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(tsic)
|
||||
}
|
||||
|
||||
tailscaleOptions := &dockertest.RunOptions{
|
||||
Name: hostname,
|
||||
Networks: []*dockertest.Network{network},
|
||||
Cmd: []string{
|
||||
"tailscaled", "--tun=tsdev",
|
||||
// Cmd: []string{
|
||||
// "tailscaled", "--tun=tsdev",
|
||||
// },
|
||||
Entrypoint: []string{
|
||||
"/bin/bash",
|
||||
"-c",
|
||||
"/bin/sleep 3 ; update-ca-certificates ; tailscaled --tun=tsdev",
|
||||
},
|
||||
}
|
||||
|
||||
if tsic.headscaleHostname != "" {
|
||||
tailscaleOptions.ExtraHosts = []string{
|
||||
"host.docker.internal:host-gateway",
|
||||
fmt.Sprintf("%s:host-gateway", tsic.headscaleHostname),
|
||||
}
|
||||
}
|
||||
|
||||
// dockertest isnt very good at handling containers that has already
|
||||
// been created, this is an attempt to make sure this container isnt
|
||||
// present.
|
||||
@@ -86,14 +148,20 @@ func New(
|
||||
}
|
||||
log.Printf("Created %s container\n", hostname)
|
||||
|
||||
return &TailscaleInContainer{
|
||||
version: version,
|
||||
hostname: hostname,
|
||||
tsic.container = container
|
||||
|
||||
pool: pool,
|
||||
container: container,
|
||||
network: network,
|
||||
}, nil
|
||||
if tsic.hasTLS() {
|
||||
err = tsic.WriteFile(headscaleCertPath, tsic.headscaleCert)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return tsic, nil
|
||||
}
|
||||
|
||||
func (t *TailscaleInContainer) hasTLS() bool {
|
||||
return len(t.headscaleCert) != 0
|
||||
}
|
||||
|
||||
func (t *TailscaleInContainer) Shutdown() error {
|
||||
@@ -110,9 +178,7 @@ func (t *TailscaleInContainer) Version() string {
|
||||
|
||||
func (t *TailscaleInContainer) Execute(
|
||||
command []string,
|
||||
) (string, error) {
|
||||
log.Println("command", command)
|
||||
log.Printf("running command for %s\n", t.hostname)
|
||||
) (string, string, error) {
|
||||
stdout, stderr, err := dockertestutil.ExecuteCommand(
|
||||
t.container,
|
||||
command,
|
||||
@@ -126,13 +192,13 @@ func (t *TailscaleInContainer) Execute(
|
||||
}
|
||||
|
||||
if strings.Contains(stderr, "NeedsLogin") {
|
||||
return "", errTailscaleNotLoggedIn
|
||||
return stdout, stderr, errTailscaleNotLoggedIn
|
||||
}
|
||||
|
||||
return "", err
|
||||
return stdout, stderr, err
|
||||
}
|
||||
|
||||
return stdout, nil
|
||||
return stdout, stderr, nil
|
||||
}
|
||||
|
||||
func (t *TailscaleInContainer) Up(
|
||||
@@ -149,13 +215,45 @@ func (t *TailscaleInContainer) Up(
|
||||
t.hostname,
|
||||
}
|
||||
|
||||
if _, err := t.Execute(command); err != nil {
|
||||
if _, _, err := t.Execute(command); err != nil {
|
||||
return fmt.Errorf("failed to join tailscale client: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TailscaleInContainer) UpWithLoginURL(
|
||||
loginServer string,
|
||||
) (*url.URL, error) {
|
||||
command := []string{
|
||||
"tailscale",
|
||||
"up",
|
||||
"-login-server",
|
||||
loginServer,
|
||||
"--hostname",
|
||||
t.hostname,
|
||||
}
|
||||
|
||||
_, stderr, err := t.Execute(command)
|
||||
if errors.Is(err, errTailscaleNotLoggedIn) {
|
||||
return nil, errTailscaleCannotUpWithoutAuthkey
|
||||
}
|
||||
|
||||
urlStr := strings.ReplaceAll(stderr, "\nTo authenticate, visit:\n\n\t", "")
|
||||
urlStr = strings.TrimSpace(urlStr)
|
||||
|
||||
// parse URL
|
||||
loginURL, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
log.Printf("Could not parse login URL: %s", err)
|
||||
log.Printf("Original join command result: %s", stderr)
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return loginURL, nil
|
||||
}
|
||||
|
||||
func (t *TailscaleInContainer) IPs() ([]netip.Addr, error) {
|
||||
if t.ips != nil && len(t.ips) != 0 {
|
||||
return t.ips, nil
|
||||
@@ -168,7 +266,7 @@ func (t *TailscaleInContainer) IPs() ([]netip.Addr, error) {
|
||||
"ip",
|
||||
}
|
||||
|
||||
result, err := t.Execute(command)
|
||||
result, _, err := t.Execute(command)
|
||||
if err != nil {
|
||||
return []netip.Addr{}, fmt.Errorf("failed to join tailscale client: %w", err)
|
||||
}
|
||||
@@ -195,7 +293,7 @@ func (t *TailscaleInContainer) Status() (*ipnstate.Status, error) {
|
||||
"--json",
|
||||
}
|
||||
|
||||
result, err := t.Execute(command)
|
||||
result, _, err := t.Execute(command)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute tailscale status command: %w", err)
|
||||
}
|
||||
@@ -222,6 +320,21 @@ func (t *TailscaleInContainer) FQDN() (string, error) {
|
||||
return status.Self.DNSName, nil
|
||||
}
|
||||
|
||||
func (t *TailscaleInContainer) WaitForReady() error {
|
||||
return t.pool.Retry(func() error {
|
||||
status, err := t.Status()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to fetch tailscale status: %w", err)
|
||||
}
|
||||
|
||||
if status.CurrentTailnet != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return errTailscaleNotConnected
|
||||
})
|
||||
}
|
||||
|
||||
func (t *TailscaleInContainer) WaitForPeers(expected int) error {
|
||||
return t.pool.Retry(func() error {
|
||||
status, err := t.Status()
|
||||
@@ -248,7 +361,7 @@ func (t *TailscaleInContainer) Ping(hostnameOrIP string) error {
|
||||
hostnameOrIP,
|
||||
}
|
||||
|
||||
result, err := t.Execute(command)
|
||||
result, _, err := t.Execute(command)
|
||||
if err != nil {
|
||||
log.Printf(
|
||||
"failed to run ping command from %s to %s, err: %s",
|
||||
@@ -268,6 +381,10 @@ func (t *TailscaleInContainer) Ping(hostnameOrIP string) error {
|
||||
})
|
||||
}
|
||||
|
||||
func (t *TailscaleInContainer) WriteFile(path string, data []byte) error {
|
||||
return integrationutil.WriteFileToContainer(t.pool, t.container, path, data)
|
||||
}
|
||||
|
||||
func createTailscaleBuildOptions(version string) *dockertest.BuildOptions {
|
||||
var tailscaleBuildOptions *dockertest.BuildOptions
|
||||
switch version {
|
||||
|
93
oidc.go
93
oidc.go
@@ -76,20 +76,52 @@ func (h *Headscale) RegisterOIDC(
|
||||
) {
|
||||
vars := mux.Vars(req)
|
||||
nodeKeyStr, ok := vars["nkey"]
|
||||
if !ok || nodeKeyStr == "" {
|
||||
log.Error().
|
||||
Caller().
|
||||
Msg("Missing node key in URL")
|
||||
http.Error(writer, "Missing node key in URL", http.StatusBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Trace().
|
||||
Caller().
|
||||
Str("node_key", nodeKeyStr).
|
||||
Msg("Received oidc register call")
|
||||
|
||||
if !NodePublicKeyRegex.Match([]byte(nodeKeyStr)) {
|
||||
log.Warn().Str("node_key", nodeKeyStr).Msg("Invalid node key passed to registration url")
|
||||
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusUnauthorized)
|
||||
_, err := writer.Write([]byte("Unauthorized"))
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Failed to write response")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// We need to make sure we dont open for XSS style injections, if the parameter that
|
||||
// is passed as a key is not parsable/validated as a NodePublic key, then fail to render
|
||||
// the template and log an error.
|
||||
var nodeKey key.NodePublic
|
||||
err := nodeKey.UnmarshalText(
|
||||
[]byte(NodePublicKeyEnsurePrefix(nodeKeyStr)),
|
||||
)
|
||||
|
||||
if !ok || nodeKeyStr == "" || err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to parse incoming nodekey")
|
||||
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, err := writer.Write([]byte("Wrong params"))
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Failed to write response")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
randomBlob := make([]byte, randomByteSize)
|
||||
if _, err := rand.Read(randomBlob); err != nil {
|
||||
log.Error().
|
||||
@@ -103,7 +135,7 @@ func (h *Headscale) RegisterOIDC(
|
||||
stateStr := hex.EncodeToString(randomBlob)[:32]
|
||||
|
||||
// place the node key into the state cache, so it can be retrieved later
|
||||
h.registrationCache.Set(stateStr, nodeKeyStr, registerCacheExpiration)
|
||||
h.registrationCache.Set(stateStr, NodePublicKeyStripPrefix(nodeKey), registerCacheExpiration)
|
||||
|
||||
// Add any extra parameter provided in the configuration to the Authorize Endpoint request
|
||||
extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams))
|
||||
@@ -405,8 +437,8 @@ func (h *Headscale) validateMachineForOIDCCallback(
|
||||
claims *IDTokenClaims,
|
||||
) (*key.NodePublic, bool, error) {
|
||||
// retrieve machinekey from state cache
|
||||
machineKeyIf, machineKeyFound := h.registrationCache.Get(state)
|
||||
if !machineKeyFound {
|
||||
nodeKeyIf, nodeKeyFound := h.registrationCache.Get(state)
|
||||
if !nodeKeyFound {
|
||||
log.Error().
|
||||
Msg("requested machine state key expired before authorisation completed")
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
@@ -419,20 +451,38 @@ func (h *Headscale) validateMachineForOIDCCallback(
|
||||
Msg("Failed to write response")
|
||||
}
|
||||
|
||||
return nil, false, errOIDCInvalidMachineState
|
||||
return nil, false, errOIDCNodeKeyMissing
|
||||
}
|
||||
|
||||
var nodeKey key.NodePublic
|
||||
nodeKeyFromCache, nodeKeyOK := machineKeyIf.(string)
|
||||
nodeKeyFromCache, nodeKeyOK := nodeKeyIf.(string)
|
||||
if !nodeKeyOK {
|
||||
log.Error().
|
||||
Msg("requested machine state key is not a string")
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, err := writer.Write([]byte("state is invalid"))
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Failed to write response")
|
||||
}
|
||||
|
||||
return nil, false, errOIDCInvalidMachineState
|
||||
}
|
||||
|
||||
err := nodeKey.UnmarshalText(
|
||||
[]byte(NodePublicKeyEnsurePrefix(nodeKeyFromCache)),
|
||||
)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Str("nodeKey", nodeKeyFromCache).
|
||||
Bool("nodeKeyOK", nodeKeyOK).
|
||||
Msg("could not parse node public key")
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, werr := writer.Write([]byte("could not parse public key"))
|
||||
_, werr := writer.Write([]byte("could not parse node public key"))
|
||||
if werr != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
@@ -443,21 +493,6 @@ func (h *Headscale) validateMachineForOIDCCallback(
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
if !nodeKeyOK {
|
||||
log.Error().Msg("could not get node key from cache")
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
_, err := writer.Write([]byte("could not get node key from cache"))
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Err(err).
|
||||
Msg("Failed to write response")
|
||||
}
|
||||
|
||||
return nil, false, errOIDCNodeKeyMissing
|
||||
}
|
||||
|
||||
// retrieve machine information if it exist
|
||||
// The error is not important, because if it does not
|
||||
// exist, then this is a new machine and we will move
|
||||
|
Reference in New Issue
Block a user