mirror of
https://github.com/juanfont/headscale.git
synced 2025-03-12 08:20:52 +00:00
Experimental implementation of Policy v2 (#2214)
* utility iterator for ipset Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * split policy -> policy and v1 This commit split out the common policy logic and policy implementation into separate packages. policy contains functions that are independent of the policy implementation, this typically means logic that works on tailcfg types and generic formats. In addition, it defines the PolicyManager interface which the v1 implements. v1 is a subpackage which implements the PolicyManager using the "original" policy implementation. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * use polivyv1 definitions in integration tests These can be marshalled back into JSON, which the new format might not be able to. Also, just dont change it all to JSON strings for now. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * formatter: breaks lines Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * remove compareprefix, use tsaddr version Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * remove getacl test, add back autoapprover Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * use policy manager tag handling Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * rename display helper for user Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * introduce policy v2 package policy v2 is built from the ground up to be stricter and follow the same pattern for all types of resolvers. TODO introduce aliass resolver Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * wire up policyv2 in integration testing Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * split policy v2 tests into seperate workflow to work around github limit Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * add policy manager output to /debug Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> * update changelog Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com> --------- Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
b6fbd37539
commit
87326f5c4f
@ -38,12 +38,13 @@ func findTests() []string {
|
||||
return tests
|
||||
}
|
||||
|
||||
func updateYAML(tests []string) {
|
||||
func updateYAML(tests []string, testPath string) {
|
||||
testsForYq := fmt.Sprintf("[%s]", strings.Join(tests, ", "))
|
||||
|
||||
yqCommand := fmt.Sprintf(
|
||||
"yq eval '.jobs.integration-test.strategy.matrix.test = %s' ./test-integration.yaml -i",
|
||||
"yq eval '.jobs.integration-test.strategy.matrix.test = %s' %s -i",
|
||||
testsForYq,
|
||||
testPath,
|
||||
)
|
||||
cmd := exec.Command("bash", "-c", yqCommand)
|
||||
|
||||
@ -58,7 +59,7 @@ func updateYAML(tests []string) {
|
||||
log.Fatalf("failed to run yq command: %s", err)
|
||||
}
|
||||
|
||||
fmt.Println("YAML file updated successfully")
|
||||
fmt.Printf("YAML file (%s) updated successfully\n", testPath)
|
||||
}
|
||||
|
||||
func main() {
|
||||
@ -69,5 +70,6 @@ func main() {
|
||||
quotedTests[i] = fmt.Sprintf("\"%s\"", test)
|
||||
}
|
||||
|
||||
updateYAML(quotedTests)
|
||||
updateYAML(quotedTests, "./test-integration.yaml")
|
||||
updateYAML(quotedTests, "./test-integration-policyv2.yaml")
|
||||
}
|
||||
|
159
.github/workflows/test-integration-policyv2.yaml
vendored
Normal file
159
.github/workflows/test-integration-policyv2.yaml
vendored
Normal file
@ -0,0 +1,159 @@
|
||||
name: Integration Tests (policy v2)
|
||||
# To debug locally on a branch, and when needing secrets
|
||||
# change this to include `push` so the build is ran on
|
||||
# the main repository.
|
||||
on: [pull_request]
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
jobs:
|
||||
integration-test:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
test:
|
||||
- TestACLHostsInNetMapTable
|
||||
- TestACLAllowUser80Dst
|
||||
- TestACLDenyAllPort80
|
||||
- TestACLAllowUserDst
|
||||
- TestACLAllowStarDst
|
||||
- TestACLNamedHostsCanReachBySubnet
|
||||
- TestACLNamedHostsCanReach
|
||||
- TestACLDevice1CanAccessDevice2
|
||||
- TestPolicyUpdateWhileRunningWithCLIInDatabase
|
||||
- TestAuthKeyLogoutAndReloginSameUser
|
||||
- TestAuthKeyLogoutAndReloginNewUser
|
||||
- TestAuthKeyLogoutAndReloginSameUserExpiredKey
|
||||
- TestOIDCAuthenticationPingAll
|
||||
- TestOIDCExpireNodesBasedOnTokenExpiry
|
||||
- TestOIDC024UserCreation
|
||||
- TestOIDCAuthenticationWithPKCE
|
||||
- TestOIDCReloginSameNodeNewUser
|
||||
- TestAuthWebFlowAuthenticationPingAll
|
||||
- TestAuthWebFlowLogoutAndRelogin
|
||||
- TestUserCommand
|
||||
- TestPreAuthKeyCommand
|
||||
- TestPreAuthKeyCommandWithoutExpiry
|
||||
- TestPreAuthKeyCommandReusableEphemeral
|
||||
- TestPreAuthKeyCorrectUserLoggedInCommand
|
||||
- TestApiKeyCommand
|
||||
- TestNodeTagCommand
|
||||
- TestNodeAdvertiseTagCommand
|
||||
- TestNodeCommand
|
||||
- TestNodeExpireCommand
|
||||
- TestNodeRenameCommand
|
||||
- TestNodeMoveCommand
|
||||
- TestPolicyCommand
|
||||
- TestPolicyBrokenConfigCommand
|
||||
- TestDERPVerifyEndpoint
|
||||
- TestResolveMagicDNS
|
||||
- TestResolveMagicDNSExtraRecordsPath
|
||||
- TestValidateResolvConf
|
||||
- TestDERPServerScenario
|
||||
- TestDERPServerWebsocketScenario
|
||||
- TestPingAllByIP
|
||||
- TestPingAllByIPPublicDERP
|
||||
- TestEphemeral
|
||||
- TestEphemeralInAlternateTimezone
|
||||
- TestEphemeral2006DeletedTooQuickly
|
||||
- TestPingAllByHostname
|
||||
- TestTaildrop
|
||||
- TestUpdateHostnameFromClient
|
||||
- TestExpireNode
|
||||
- TestNodeOnlineStatus
|
||||
- TestPingAllByIPManyUpDown
|
||||
- Test2118DeletingOnlineNodePanics
|
||||
- TestEnablingRoutes
|
||||
- TestHASubnetRouterFailover
|
||||
- TestEnableDisableAutoApprovedRoute
|
||||
- TestAutoApprovedSubRoute2068
|
||||
- TestSubnetRouteACL
|
||||
- TestEnablingExitRoutes
|
||||
- TestHeadscale
|
||||
- TestCreateTailscale
|
||||
- TestTailscaleNodesJoiningHeadcale
|
||||
- TestSSHOneUserToAll
|
||||
- TestSSHMultipleUsersAllToAll
|
||||
- TestSSHNoSSHConfigured
|
||||
- TestSSHIsBlockedInACL
|
||||
- TestSSHUserOnlyIsolation
|
||||
database: [postgres, sqlite]
|
||||
env:
|
||||
# Github does not allow us to access secrets in pull requests,
|
||||
# so this env var is used to check if we have the secret or not.
|
||||
# If we have the secrets, meaning we are running on push in a fork,
|
||||
# there might be secrets available for more debugging.
|
||||
# If TS_OAUTH_CLIENT_ID and TS_OAUTH_SECRET is set, then the job
|
||||
# will join a debug tailscale network, set up SSH and a tmux session.
|
||||
# The SSH will be configured to use the SSH key of the Github user
|
||||
# that triggered the build.
|
||||
HAS_TAILSCALE_SECRET: ${{ secrets.TS_OAUTH_CLIENT_ID }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: Get changed files
|
||||
id: changed-files
|
||||
uses: dorny/paths-filter@v3
|
||||
with:
|
||||
filters: |
|
||||
files:
|
||||
- '*.nix'
|
||||
- 'go.*'
|
||||
- '**/*.go'
|
||||
- 'integration_test/'
|
||||
- 'config-example.yaml'
|
||||
- name: Tailscale
|
||||
if: ${{ env.HAS_TAILSCALE_SECRET }}
|
||||
uses: tailscale/github-action@v2
|
||||
with:
|
||||
oauth-client-id: ${{ secrets.TS_OAUTH_CLIENT_ID }}
|
||||
oauth-secret: ${{ secrets.TS_OAUTH_SECRET }}
|
||||
tags: tag:gh
|
||||
- name: Setup SSH server for Actor
|
||||
if: ${{ env.HAS_TAILSCALE_SECRET }}
|
||||
uses: alexellis/setup-sshd-actor@master
|
||||
- uses: DeterminateSystems/nix-installer-action@main
|
||||
if: steps.changed-files.outputs.files == 'true'
|
||||
- uses: DeterminateSystems/magic-nix-cache-action@main
|
||||
if: steps.changed-files.outputs.files == 'true'
|
||||
- uses: satackey/action-docker-layer-caching@main
|
||||
if: steps.changed-files.outputs.files == 'true'
|
||||
continue-on-error: true
|
||||
- name: Run Integration Test
|
||||
uses: Wandalen/wretry.action@master
|
||||
if: steps.changed-files.outputs.files == 'true'
|
||||
env:
|
||||
USE_POSTGRES: ${{ matrix.database == 'postgres' && '1' || '0' }}
|
||||
with:
|
||||
attempt_limit: 5
|
||||
command: |
|
||||
nix develop --command -- docker run \
|
||||
--tty --rm \
|
||||
--volume ~/.cache/hs-integration-go:/go \
|
||||
--name headscale-test-suite \
|
||||
--volume $PWD:$PWD -w $PWD/integration \
|
||||
--volume /var/run/docker.sock:/var/run/docker.sock \
|
||||
--volume $PWD/control_logs:/tmp/control \
|
||||
--env HEADSCALE_INTEGRATION_POSTGRES=${{env.USE_POSTGRES}} \
|
||||
--env HEADSCALE_EXPERIMENTAL_POLICY_V2=1 \
|
||||
golang:1 \
|
||||
go run gotest.tools/gotestsum@latest -- ./... \
|
||||
-failfast \
|
||||
-timeout 120m \
|
||||
-parallel 1 \
|
||||
-run "^${{ matrix.test }}$"
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: always() && steps.changed-files.outputs.files == 'true'
|
||||
with:
|
||||
name: ${{ matrix.test }}-${{matrix.database}}-${{matrix.policy}}-logs
|
||||
path: "control_logs/*.log"
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: always() && steps.changed-files.outputs.files == 'true'
|
||||
with:
|
||||
name: ${{ matrix.test }}-${{matrix.database}}-${{matrix.policy}}-pprof
|
||||
path: "control_logs/*.pprof.tar"
|
||||
- name: Setup a blocking tmux session
|
||||
if: ${{ env.HAS_TAILSCALE_SECRET }}
|
||||
uses: alexellis/block-with-tmux-action@master
|
5
.github/workflows/test-integration.yaml
vendored
5
.github/workflows/test-integration.yaml
vendored
@ -137,6 +137,7 @@ jobs:
|
||||
--volume /var/run/docker.sock:/var/run/docker.sock \
|
||||
--volume $PWD/control_logs:/tmp/control \
|
||||
--env HEADSCALE_INTEGRATION_POSTGRES=${{env.USE_POSTGRES}} \
|
||||
--env HEADSCALE_EXPERIMENTAL_POLICY_V2=0 \
|
||||
golang:1 \
|
||||
go run gotest.tools/gotestsum@latest -- ./... \
|
||||
-failfast \
|
||||
@ -146,12 +147,12 @@ jobs:
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: always() && steps.changed-files.outputs.files == 'true'
|
||||
with:
|
||||
name: ${{ matrix.test }}-${{matrix.database}}-logs
|
||||
name: ${{ matrix.test }}-${{matrix.database}}-${{matrix.policy}}-logs
|
||||
path: "control_logs/*.log"
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: always() && steps.changed-files.outputs.files == 'true'
|
||||
with:
|
||||
name: ${{ matrix.test }}-${{matrix.database}}-pprof
|
||||
name: ${{ matrix.test }}-${{matrix.database}}-${{matrix.policy}}-pprof
|
||||
path: "control_logs/*.pprof.tar"
|
||||
- name: Setup a blocking tmux session
|
||||
if: ${{ env.HAS_TAILSCALE_SECRET }}
|
||||
|
64
CHANGELOG.md
64
CHANGELOG.md
@ -4,13 +4,13 @@
|
||||
|
||||
### BREAKING
|
||||
|
||||
Route internals have been rewritten, removing the dedicated route table in the database.
|
||||
This was done to simplify the codebase, which had grown unnecessarily complex after
|
||||
the routes were split into separate tables. The overhead of having to go via the database
|
||||
and keeping the state in sync made the code very hard to reason about and prone to errors.
|
||||
The majority of the route state is only relevant when headscale is running, and is now only
|
||||
kept in memory.
|
||||
As part of this, the CLI and API has been simplified to reflect the changes;
|
||||
Route internals have been rewritten, removing the dedicated route table in the
|
||||
database. This was done to simplify the codebase, which had grown unnecessarily
|
||||
complex after the routes were split into separate tables. The overhead of having
|
||||
to go via the database and keeping the state in sync made the code very hard to
|
||||
reason about and prone to errors. The majority of the route state is only
|
||||
relevant when headscale is running, and is now only kept in memory. As part of
|
||||
this, the CLI and API has been simplified to reflect the changes;
|
||||
|
||||
```console
|
||||
$ headscale nodes list-routes
|
||||
@ -27,15 +27,55 @@ ID | Hostname | Approved | Available | Serving
|
||||
2 | ts-unstable-fq7ob4 | | 0.0.0.0/0, ::/0 |
|
||||
```
|
||||
|
||||
Note that if an exit route is approved (0.0.0.0/0 or ::/0), both IPv4 and IPv6 will be approved.
|
||||
Note that if an exit route is approved (0.0.0.0/0 or ::/0), both IPv4 and IPv6
|
||||
will be approved.
|
||||
|
||||
- Route API and CLI has been removed [#2422](https://github.com/juanfont/headscale/pull/2422)
|
||||
- Routes are now managed via the Node API [#2422](https://github.com/juanfont/headscale/pull/2422)
|
||||
- Route API and CLI has been removed
|
||||
[#2422](https://github.com/juanfont/headscale/pull/2422)
|
||||
- Routes are now managed via the Node API
|
||||
[#2422](https://github.com/juanfont/headscale/pull/2422)
|
||||
|
||||
### Experimental Policy v2
|
||||
|
||||
This release introduces a new experimental version of Headscales policy
|
||||
implementation. In this context, experimental means that the feature is not yet
|
||||
fully tested and may contain bugs or unexpected behavior and that we are still
|
||||
experimenting with how the final interface/behavior will be.
|
||||
|
||||
#### Breaking changes
|
||||
|
||||
- The policy is validated and "resolved" when loading, providing errors for
|
||||
invalid rules and conditions.
|
||||
- Previously this was done as a mix between load and runtime (when it was
|
||||
applied to a node).
|
||||
- This means that when you convert the first time, what was previously a
|
||||
policy that loaded, but failed at runtime, will now fail at load time.
|
||||
- Error messages should be more descriptive and informative.
|
||||
- There is still work to be here, but it is already improved with "typing"
|
||||
(e.g. only Users can be put in Groups)
|
||||
- All users must contain an `@` character.
|
||||
- If your user naturally contains and `@`, like an email, this will just work.
|
||||
- If its based on usernames, or other identifiers not containing an `@`, an
|
||||
`@` should be appended at the end. For example, if your user is `john`, it
|
||||
must be written as `john@` in the policy.
|
||||
|
||||
#### Current state
|
||||
|
||||
The new policy is passing all tests, both integration and unit tests. This does
|
||||
not mean it is perfect, but it is a good start. Corner cases that is currently
|
||||
working in v1 and not tested might be broken in v2 (and vice versa).
|
||||
|
||||
**We do need help testing this code**, and we think that most of the user facing
|
||||
API will not really change. We are not sure yet when this code will replace v1,
|
||||
but we are confident that it will, and all new changes and fixes will be made
|
||||
towards this code.
|
||||
|
||||
The new policy can be used by setting the environment variable
|
||||
`HEADSCALE_EXPERIMENTAL_POLICY_V2` to `1`.
|
||||
|
||||
### Changes
|
||||
|
||||
- Use Go 1.24
|
||||
[#2427](https://github.com/juanfont/headscale/pull/2427)
|
||||
- Use Go 1.24 [#2427](https://github.com/juanfont/headscale/pull/2427)
|
||||
- `oidc.map_legacy_users` and `oidc.strip_email_domain` has been removed
|
||||
[#2411](https://github.com/juanfont/headscale/pull/2411)
|
||||
- Add more information to `/debug` endpoint
|
||||
|
@ -194,10 +194,14 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
||||
|
||||
var magicDNSDomains []dnsname.FQDN
|
||||
if cfg.PrefixV4 != nil {
|
||||
magicDNSDomains = append(magicDNSDomains, util.GenerateIPv4DNSRootDomain(*cfg.PrefixV4)...)
|
||||
magicDNSDomains = append(
|
||||
magicDNSDomains,
|
||||
util.GenerateIPv4DNSRootDomain(*cfg.PrefixV4)...)
|
||||
}
|
||||
if cfg.PrefixV6 != nil {
|
||||
magicDNSDomains = append(magicDNSDomains, util.GenerateIPv6DNSRootDomain(*cfg.PrefixV6)...)
|
||||
magicDNSDomains = append(
|
||||
magicDNSDomains,
|
||||
util.GenerateIPv6DNSRootDomain(*cfg.PrefixV6)...)
|
||||
}
|
||||
|
||||
// we might have routes already from Split DNS
|
||||
@ -459,11 +463,13 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
|
||||
router := mux.NewRouter()
|
||||
router.Use(prometheusMiddleware)
|
||||
|
||||
router.HandleFunc(ts2021UpgradePath, h.NoiseUpgradeHandler).Methods(http.MethodPost, http.MethodGet)
|
||||
router.HandleFunc(ts2021UpgradePath, h.NoiseUpgradeHandler).
|
||||
Methods(http.MethodPost, http.MethodGet)
|
||||
|
||||
router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet)
|
||||
router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet)
|
||||
router.HandleFunc("/register/{registration_id}", h.authProvider.RegisterHandler).Methods(http.MethodGet)
|
||||
router.HandleFunc("/register/{registration_id}", h.authProvider.RegisterHandler).
|
||||
Methods(http.MethodGet)
|
||||
|
||||
if provider, ok := h.authProvider.(*AuthProviderOIDC); ok {
|
||||
router.HandleFunc("/oidc/callback", provider.OIDCCallbackHandler).Methods(http.MethodGet)
|
||||
@ -523,7 +529,11 @@ func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *not
|
||||
// Maybe we should attempt a new in memory state and not go via the DB?
|
||||
// Maybe this should be implemented as an event bus?
|
||||
// A bool is returned indicating if a full update was sent to all nodes
|
||||
func nodesChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) (bool, error) {
|
||||
func nodesChangedHook(
|
||||
db *db.HSDatabase,
|
||||
polMan policy.PolicyManager,
|
||||
notif *notifier.Notifier,
|
||||
) (bool, error) {
|
||||
nodes, err := db.ListNodes()
|
||||
if err != nil {
|
||||
return false, err
|
||||
@ -1143,6 +1153,7 @@ func (h *Headscale) loadPolicyManager() error {
|
||||
errOut = fmt.Errorf("creating policy manager: %w", err)
|
||||
return
|
||||
}
|
||||
log.Info().Msgf("Using policy manager version: %d", h.polMan.Version())
|
||||
|
||||
if len(nodes) > 0 {
|
||||
_, err = h.polMan.SSHPolicy(nodes[0])
|
||||
|
@ -22,6 +22,7 @@ import (
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/schema"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/util/set"
|
||||
"zgo.at/zcache/v2"
|
||||
)
|
||||
@ -655,7 +656,7 @@ AND auth_key_id NOT IN (
|
||||
}
|
||||
|
||||
for nodeID, routes := range nodeRoutes {
|
||||
slices.SortFunc(routes, util.ComparePrefix)
|
||||
tsaddr.SortPrefixes(routes)
|
||||
slices.Compact(routes)
|
||||
|
||||
data, err := json.Marshal(routes)
|
||||
|
@ -19,6 +19,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/check.v1"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/ptr"
|
||||
@ -146,105 +147,6 @@ func (s *Suite) TestListPeers(c *check.C) {
|
||||
c.Assert(peersOfNode0[8].Hostname, check.Equals, "testnode10")
|
||||
}
|
||||
|
||||
func (s *Suite) TestGetACLFilteredPeers(c *check.C) {
|
||||
type base struct {
|
||||
user *types.User
|
||||
key *types.PreAuthKey
|
||||
}
|
||||
|
||||
stor := make([]base, 0)
|
||||
|
||||
for _, name := range []string{"test", "admin"} {
|
||||
user, err := db.CreateUser(types.User{Name: name})
|
||||
c.Assert(err, check.IsNil)
|
||||
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
|
||||
c.Assert(err, check.IsNil)
|
||||
stor = append(stor, base{user, pak})
|
||||
}
|
||||
|
||||
_, err := db.GetNodeByID(0)
|
||||
c.Assert(err, check.NotNil)
|
||||
|
||||
for index := 0; index <= 10; index++ {
|
||||
nodeKey := key.NewNode()
|
||||
machineKey := key.NewMachine()
|
||||
|
||||
v4 := netip.MustParseAddr(fmt.Sprintf("100.64.0.%d", index+1))
|
||||
node := types.Node{
|
||||
ID: types.NodeID(index),
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
IPv4: &v4,
|
||||
Hostname: "testnode" + strconv.Itoa(index),
|
||||
UserID: stor[index%2].user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: ptr.To(stor[index%2].key.ID),
|
||||
}
|
||||
trx := db.DB.Save(&node)
|
||||
c.Assert(trx.Error, check.IsNil)
|
||||
}
|
||||
|
||||
aclPolicy := &policy.ACLPolicy{
|
||||
Groups: map[string][]string{
|
||||
"group:test": {"admin"},
|
||||
},
|
||||
Hosts: map[string]netip.Prefix{},
|
||||
TagOwners: map[string][]string{},
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"admin"},
|
||||
Destinations: []string{"*:*"},
|
||||
},
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"test"},
|
||||
Destinations: []string{"test:*"},
|
||||
},
|
||||
},
|
||||
Tests: []policy.ACLTest{},
|
||||
}
|
||||
|
||||
adminNode, err := db.GetNodeByID(1)
|
||||
c.Logf("Node(%v), user: %v", adminNode.Hostname, adminNode.User)
|
||||
c.Assert(adminNode.IPv4, check.NotNil)
|
||||
c.Assert(adminNode.IPv6, check.IsNil)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
testNode, err := db.GetNodeByID(2)
|
||||
c.Logf("Node(%v), user: %v", testNode.Hostname, testNode.User)
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
adminPeers, err := db.ListPeers(adminNode.ID)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(adminPeers), check.Equals, 9)
|
||||
|
||||
testPeers, err := db.ListPeers(testNode.ID)
|
||||
c.Assert(err, check.IsNil)
|
||||
c.Assert(len(testPeers), check.Equals, 9)
|
||||
|
||||
adminRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, adminNode, adminPeers, []types.User{*stor[0].user, *stor[1].user})
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
testRules, _, err := policy.GenerateFilterAndSSHRulesForTests(aclPolicy, testNode, testPeers, []types.User{*stor[0].user, *stor[1].user})
|
||||
c.Assert(err, check.IsNil)
|
||||
|
||||
peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules)
|
||||
peersOfTestNode := policy.FilterNodesByACL(testNode, testPeers, testRules)
|
||||
c.Log(peersOfAdminNode)
|
||||
c.Log(peersOfTestNode)
|
||||
|
||||
c.Assert(len(peersOfTestNode), check.Equals, 9)
|
||||
c.Assert(peersOfTestNode[0].Hostname, check.Equals, "testnode1")
|
||||
c.Assert(peersOfTestNode[1].Hostname, check.Equals, "testnode3")
|
||||
c.Assert(peersOfTestNode[3].Hostname, check.Equals, "testnode5")
|
||||
|
||||
c.Assert(len(peersOfAdminNode), check.Equals, 9)
|
||||
c.Assert(peersOfAdminNode[0].Hostname, check.Equals, "testnode2")
|
||||
c.Assert(peersOfAdminNode[2].Hostname, check.Equals, "testnode4")
|
||||
c.Assert(peersOfAdminNode[5].Hostname, check.Equals, "testnode7")
|
||||
}
|
||||
|
||||
func (s *Suite) TestExpireNode(c *check.C) {
|
||||
user, err := db.CreateUser(types.User{Name: "test"})
|
||||
c.Assert(err, check.IsNil)
|
||||
@ -456,143 +358,171 @@ func TestHeadscale_generateGivenName(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(kradalby): replace this test
|
||||
// func TestAutoApproveRoutes(t *testing.T) {
|
||||
// tests := []struct {
|
||||
// name string
|
||||
// acl string
|
||||
// routes []netip.Prefix
|
||||
// want []netip.Prefix
|
||||
// }{
|
||||
// {
|
||||
// name: "2068-approve-issue-sub",
|
||||
// acl: `
|
||||
// {
|
||||
// "groups": {
|
||||
// "group:k8s": ["test"]
|
||||
// },
|
||||
func TestAutoApproveRoutes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
acl string
|
||||
routes []netip.Prefix
|
||||
want []netip.Prefix
|
||||
want2 []netip.Prefix
|
||||
}{
|
||||
{
|
||||
name: "2068-approve-issue-sub-kube",
|
||||
acl: `
|
||||
{
|
||||
"groups": {
|
||||
"group:k8s": ["test@"]
|
||||
},
|
||||
|
||||
// "acls": [
|
||||
// {"action": "accept", "users": ["*"], "ports": ["*:*"]},
|
||||
// ],
|
||||
|
||||
// "autoApprovers": {
|
||||
// "routes": {
|
||||
// "10.42.0.0/16": ["test"],
|
||||
// }
|
||||
// }
|
||||
// }`,
|
||||
// routes: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
|
||||
// want: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
|
||||
// },
|
||||
// {
|
||||
// name: "2068-approve-issue-sub",
|
||||
// acl: `
|
||||
// {
|
||||
// "tagOwners": {
|
||||
// "tag:exit": ["test"],
|
||||
// },
|
||||
"autoApprovers": {
|
||||
"routes": {
|
||||
"10.42.0.0/16": ["test@"],
|
||||
}
|
||||
}
|
||||
}`,
|
||||
routes: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
|
||||
want: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")},
|
||||
},
|
||||
{
|
||||
name: "2068-approve-issue-sub-exit-tag",
|
||||
acl: `
|
||||
{
|
||||
"tagOwners": {
|
||||
"tag:exit": ["test@"],
|
||||
},
|
||||
|
||||
// "groups": {
|
||||
// "group:test": ["test"]
|
||||
// },
|
||||
"groups": {
|
||||
"group:test": ["test@"]
|
||||
},
|
||||
|
||||
// "acls": [
|
||||
// {"action": "accept", "users": ["*"], "ports": ["*:*"]},
|
||||
// ],
|
||||
|
||||
// "autoApprovers": {
|
||||
// "exitNode": ["tag:exit"],
|
||||
// "routes": {
|
||||
// "10.10.0.0/16": ["group:test"],
|
||||
// "10.11.0.0/16": ["test"],
|
||||
// }
|
||||
// }
|
||||
// }`,
|
||||
// routes: []netip.Prefix{
|
||||
// tsaddr.AllIPv4(),
|
||||
// tsaddr.AllIPv6(),
|
||||
// netip.MustParsePrefix("10.10.0.0/16"),
|
||||
// netip.MustParsePrefix("10.11.0.0/24"),
|
||||
// },
|
||||
// want: []netip.Prefix{
|
||||
// tsaddr.AllIPv4(),
|
||||
// netip.MustParsePrefix("10.10.0.0/16"),
|
||||
// netip.MustParsePrefix("10.11.0.0/24"),
|
||||
// tsaddr.AllIPv6(),
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
"autoApprovers": {
|
||||
"exitNode": ["tag:exit"],
|
||||
"routes": {
|
||||
"10.10.0.0/16": ["group:test"],
|
||||
"10.11.0.0/16": ["test@"],
|
||||
"8.11.0.0/24": ["test2@"], // No nodes
|
||||
}
|
||||
}
|
||||
}`,
|
||||
routes: []netip.Prefix{
|
||||
tsaddr.AllIPv4(),
|
||||
tsaddr.AllIPv6(),
|
||||
netip.MustParsePrefix("10.10.0.0/16"),
|
||||
netip.MustParsePrefix("10.11.0.0/24"),
|
||||
|
||||
// for _, tt := range tests {
|
||||
// t.Run(tt.name, func(t *testing.T) {
|
||||
// adb, err := newSQLiteTestDB()
|
||||
// require.NoError(t, err)
|
||||
// pol, err := policy.LoadACLPolicyFromBytes([]byte(tt.acl))
|
||||
// Not approved
|
||||
netip.MustParsePrefix("8.11.0.0/24"),
|
||||
},
|
||||
want: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.10.0.0/16"),
|
||||
netip.MustParsePrefix("10.11.0.0/24"),
|
||||
},
|
||||
want2: []netip.Prefix{
|
||||
tsaddr.AllIPv4(),
|
||||
tsaddr.AllIPv6(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// require.NoError(t, err)
|
||||
// require.NotNil(t, pol)
|
||||
for _, tt := range tests {
|
||||
pmfs := policy.PolicyManagerFuncsForTest([]byte(tt.acl))
|
||||
for i, pmf := range pmfs {
|
||||
version := i + 1
|
||||
t.Run(fmt.Sprintf("%s-policyv%d", tt.name, version), func(t *testing.T) {
|
||||
adb, err := newSQLiteTestDB()
|
||||
require.NoError(t, err)
|
||||
|
||||
// user, err := adb.CreateUser(types.User{Name: "test"})
|
||||
// require.NoError(t, err)
|
||||
suffix := ""
|
||||
if version == 1 {
|
||||
suffix = "@"
|
||||
}
|
||||
|
||||
// pak, err := adb.CreatePreAuthKey(types.UserID(user.ID), false, nil, nil)
|
||||
// require.NoError(t, err)
|
||||
user, err := adb.CreateUser(types.User{Name: "test" + suffix})
|
||||
require.NoError(t, err)
|
||||
_, err = adb.CreateUser(types.User{Name: "test2" + suffix})
|
||||
require.NoError(t, err)
|
||||
taggedUser, err := adb.CreateUser(types.User{Name: "tagged" + suffix})
|
||||
require.NoError(t, err)
|
||||
|
||||
// nodeKey := key.NewNode()
|
||||
// machineKey := key.NewMachine()
|
||||
node := types.Node{
|
||||
ID: 1,
|
||||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
Hostname: "testnode",
|
||||
UserID: user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: tt.routes,
|
||||
},
|
||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
||||
}
|
||||
|
||||
// v4 := netip.MustParseAddr("100.64.0.1")
|
||||
// node := types.Node{
|
||||
// ID: 0,
|
||||
// MachineKey: machineKey.Public(),
|
||||
// NodeKey: nodeKey.Public(),
|
||||
// Hostname: "test",
|
||||
// UserID: user.ID,
|
||||
// RegisterMethod: util.RegisterMethodAuthKey,
|
||||
// AuthKeyID: ptr.To(pak.ID),
|
||||
// Hostinfo: &tailcfg.Hostinfo{
|
||||
// RequestTags: []string{"tag:exit"},
|
||||
// RoutableIPs: tt.routes,
|
||||
// },
|
||||
// IPv4: &v4,
|
||||
// }
|
||||
err = adb.DB.Save(&node).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// trx := adb.DB.Save(&node)
|
||||
// require.NoError(t, trx.Error)
|
||||
nodeTagged := types.Node{
|
||||
ID: 2,
|
||||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
Hostname: "taggednode",
|
||||
UserID: taggedUser.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: tt.routes,
|
||||
},
|
||||
ForcedTags: []string{"tag:exit"},
|
||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")),
|
||||
}
|
||||
|
||||
// sendUpdate, err := adb.SaveNodeRoutes(&node)
|
||||
// require.NoError(t, err)
|
||||
// assert.False(t, sendUpdate)
|
||||
err = adb.DB.Save(&nodeTagged).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// node0ByID, err := adb.GetNodeByID(0)
|
||||
// require.NoError(t, err)
|
||||
users, err := adb.ListUsers()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// users, err := adb.ListUsers()
|
||||
// assert.NoError(t, err)
|
||||
nodes, err := adb.ListNodes()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// nodes, err := adb.ListNodes()
|
||||
// assert.NoError(t, err)
|
||||
pm, err := pmf(users, nodes)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pm)
|
||||
|
||||
// pm, err := policy.NewPolicyManager([]byte(tt.acl), users, nodes)
|
||||
// assert.NoError(t, err)
|
||||
changed1 := policy.AutoApproveRoutes(pm, &node)
|
||||
assert.True(t, changed1)
|
||||
|
||||
// // TODO(kradalby): Check state update
|
||||
// err = adb.EnableAutoApprovedRoutes(pm, node0ByID)
|
||||
// require.NoError(t, err)
|
||||
err = adb.DB.Save(&node).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// enabledRoutes, err := adb.GetEnabledRoutes(node0ByID)
|
||||
// require.NoError(t, err)
|
||||
// assert.Len(t, enabledRoutes, len(tt.want))
|
||||
_ = policy.AutoApproveRoutes(pm, &nodeTagged)
|
||||
|
||||
// tsaddr.SortPrefixes(enabledRoutes)
|
||||
err = adb.DB.Save(&nodeTagged).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
// if diff := cmp.Diff(tt.want, enabledRoutes, util.Comparers...); diff != "" {
|
||||
// t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
node1ByID, err := adb.GetNodeByID(1)
|
||||
require.NoError(t, err)
|
||||
|
||||
if diff := cmp.Diff(tt.want, node1ByID.SubnetRoutes(), util.Comparers...); diff != "" {
|
||||
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
node2ByID, err := adb.GetNodeByID(2)
|
||||
require.NoError(t, err)
|
||||
|
||||
if diff := cmp.Diff(tt.want2, node2ByID.SubnetRoutes(), util.Comparers...); diff != "" {
|
||||
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEphemeralGarbageCollectorOrder(t *testing.T) {
|
||||
want := []types.NodeID{1, 3}
|
||||
|
@ -105,6 +105,11 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(h.primaryRoutes.String()))
|
||||
}))
|
||||
debug.Handle("policy-manager", "Policy Manager", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(h.polMan.DebugString()))
|
||||
}))
|
||||
|
||||
err := statsviz.Register(debugMux)
|
||||
if err == nil {
|
||||
|
@ -348,7 +348,7 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
|
||||
routes = append(routes, prefix)
|
||||
}
|
||||
}
|
||||
slices.SortFunc(routes, util.ComparePrefix)
|
||||
tsaddr.SortPrefixes(routes)
|
||||
slices.Compact(routes)
|
||||
|
||||
node, err := db.Write(api.h.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
@ -525,7 +525,12 @@ func nodesToProto(polMan policy.PolicyManager, isLikelyConnected *xsync.MapOf[ty
|
||||
resp.Online = true
|
||||
}
|
||||
|
||||
tags := polMan.Tags(node)
|
||||
var tags []string
|
||||
for _, tag := range node.RequestTags() {
|
||||
if polMan.NodeCanHaveTag(node, tag) {
|
||||
tags = append(tags, tag)
|
||||
}
|
||||
}
|
||||
resp.ValidTags = lo.Uniq(append(tags, node.ForcedTags...))
|
||||
response[index] = resp
|
||||
}
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/routes"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
@ -246,7 +247,7 @@ func Test_fullMapResponse(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
pol *policy.ACLPolicy
|
||||
pol []byte
|
||||
node *types.Node
|
||||
peers types.Nodes
|
||||
|
||||
@ -258,7 +259,7 @@ func Test_fullMapResponse(t *testing.T) {
|
||||
// {
|
||||
// name: "empty-node",
|
||||
// node: types.Node{},
|
||||
// pol: &policy.ACLPolicy{},
|
||||
// pol: &policyv1.ACLPolicy{},
|
||||
// dnsConfig: &tailcfg.DNSConfig{},
|
||||
// baseDomain: "",
|
||||
// want: nil,
|
||||
@ -266,7 +267,6 @@ func Test_fullMapResponse(t *testing.T) {
|
||||
// },
|
||||
{
|
||||
name: "no-pol-no-peers-map-response",
|
||||
pol: &policy.ACLPolicy{},
|
||||
node: mini,
|
||||
peers: types.Nodes{},
|
||||
derpMap: &tailcfg.DERPMap{},
|
||||
@ -284,10 +284,15 @@ func Test_fullMapResponse(t *testing.T) {
|
||||
DNSConfig: &tailcfg.DNSConfig{},
|
||||
Domain: "",
|
||||
CollectServices: "false",
|
||||
PacketFilter: []tailcfg.FilterRule{},
|
||||
UserProfiles: []tailcfg.UserProfile{{ID: tailcfg.UserID(user1.ID), LoginName: "user1", DisplayName: "user1"}},
|
||||
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
|
||||
ControlTime: &time.Time{},
|
||||
UserProfiles: []tailcfg.UserProfile{
|
||||
{
|
||||
ID: tailcfg.UserID(user1.ID),
|
||||
LoginName: "user1",
|
||||
DisplayName: "user1",
|
||||
},
|
||||
},
|
||||
PacketFilter: tailcfg.FilterAllowAll,
|
||||
ControlTime: &time.Time{},
|
||||
Debug: &tailcfg.Debug{
|
||||
DisableLogTail: true,
|
||||
},
|
||||
@ -296,7 +301,6 @@ func Test_fullMapResponse(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "no-pol-with-peer-map-response",
|
||||
pol: &policy.ACLPolicy{},
|
||||
node: mini,
|
||||
peers: types.Nodes{
|
||||
peer1,
|
||||
@ -318,13 +322,12 @@ func Test_fullMapResponse(t *testing.T) {
|
||||
DNSConfig: &tailcfg.DNSConfig{},
|
||||
Domain: "",
|
||||
CollectServices: "false",
|
||||
PacketFilter: []tailcfg.FilterRule{},
|
||||
UserProfiles: []tailcfg.UserProfile{
|
||||
{ID: tailcfg.UserID(user1.ID), LoginName: "user1", DisplayName: "user1"},
|
||||
{ID: tailcfg.UserID(user2.ID), LoginName: "user2", DisplayName: "user2"},
|
||||
},
|
||||
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
|
||||
ControlTime: &time.Time{},
|
||||
PacketFilter: tailcfg.FilterAllowAll,
|
||||
ControlTime: &time.Time{},
|
||||
Debug: &tailcfg.Debug{
|
||||
DisableLogTail: true,
|
||||
},
|
||||
@ -333,18 +336,17 @@ func Test_fullMapResponse(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "with-pol-map-response",
|
||||
pol: &policy.ACLPolicy{
|
||||
Hosts: policy.Hosts{
|
||||
"mini": netip.MustParsePrefix("100.64.0.1/32"),
|
||||
},
|
||||
ACLs: []policy.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"100.64.0.2"},
|
||||
Destinations: []string{"mini:*"},
|
||||
},
|
||||
},
|
||||
},
|
||||
pol: []byte(`
|
||||
{
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": ["100.64.0.2"],
|
||||
"dst": ["user1:*"],
|
||||
},
|
||||
],
|
||||
}
|
||||
`),
|
||||
node: mini,
|
||||
peers: types.Nodes{
|
||||
peer1,
|
||||
@ -374,11 +376,11 @@ func Test_fullMapResponse(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
SSHPolicy: &tailcfg.SSHPolicy{},
|
||||
UserProfiles: []tailcfg.UserProfile{
|
||||
{ID: tailcfg.UserID(user1.ID), LoginName: "user1", DisplayName: "user1"},
|
||||
{ID: tailcfg.UserID(user2.ID), LoginName: "user2", DisplayName: "user2"},
|
||||
},
|
||||
SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}},
|
||||
ControlTime: &time.Time{},
|
||||
Debug: &tailcfg.Debug{
|
||||
DisableLogTail: true,
|
||||
@ -390,7 +392,8 @@ func Test_fullMapResponse(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{user1, user2}, append(tt.peers, tt.node))
|
||||
polMan, err := policy.NewPolicyManager(tt.pol, []types.User{user1, user2}, append(tt.peers, tt.node))
|
||||
require.NoError(t, err)
|
||||
primary := routes.New()
|
||||
|
||||
primary.SetRoutes(tt.node.ID, tt.node.SubnetRoutes()...)
|
||||
|
@ -81,7 +81,12 @@ func tailNode(
|
||||
return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err)
|
||||
}
|
||||
|
||||
tags := polMan.Tags(node)
|
||||
var tags []string
|
||||
for _, tag := range node.RequestTags() {
|
||||
if polMan.NodeCanHaveTag(node, tag) {
|
||||
tags = append(tags, tag)
|
||||
}
|
||||
}
|
||||
tags = lo.Uniq(append(tags, node.ForcedTags...))
|
||||
|
||||
tNode := tailcfg.Node{
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/routes"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
@ -49,7 +50,7 @@ func TestTailNode(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
node *types.Node
|
||||
pol *policy.ACLPolicy
|
||||
pol []byte
|
||||
dnsConfig *tailcfg.DNSConfig
|
||||
baseDomain string
|
||||
want *tailcfg.Node
|
||||
@ -61,7 +62,6 @@ func TestTailNode(t *testing.T) {
|
||||
GivenName: "empty",
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
},
|
||||
pol: &policy.ACLPolicy{},
|
||||
dnsConfig: &tailcfg.DNSConfig{},
|
||||
baseDomain: "",
|
||||
want: &tailcfg.Node{
|
||||
@ -117,7 +117,6 @@ func TestTailNode(t *testing.T) {
|
||||
ApprovedRoutes: []netip.Prefix{tsaddr.AllIPv4(), netip.MustParsePrefix("192.168.0.0/24")},
|
||||
CreatedAt: created,
|
||||
},
|
||||
pol: &policy.ACLPolicy{},
|
||||
dnsConfig: &tailcfg.DNSConfig{},
|
||||
baseDomain: "",
|
||||
want: &tailcfg.Node{
|
||||
@ -179,7 +178,8 @@ func TestTailNode(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{}, types.Nodes{tt.node})
|
||||
polMan, err := policy.NewPolicyManager(tt.pol, []types.User{}, types.Nodes{tt.node})
|
||||
require.NoError(t, err)
|
||||
primary := routes.New()
|
||||
cfg := &types.Config{
|
||||
BaseDomain: tt.baseDomain,
|
||||
@ -248,7 +248,7 @@ func TestNodeExpiry(t *testing.T) {
|
||||
tn, err := tailNode(
|
||||
node,
|
||||
0,
|
||||
&policy.PolicyManagerV1{},
|
||||
nil, // TODO(kradalby): removed in merge but error?
|
||||
nil,
|
||||
&types.Config{},
|
||||
)
|
||||
|
@ -513,7 +513,7 @@ func renderOIDCCallbackTemplate(
|
||||
) (*bytes.Buffer, error) {
|
||||
var content bytes.Buffer
|
||||
if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{
|
||||
User: user.DisplayNameOrUsername(),
|
||||
User: user.Display(),
|
||||
Verb: verb,
|
||||
}); err != nil {
|
||||
return nil, fmt.Errorf("rendering OIDC callback template: %w", err)
|
||||
|
@ -1,219 +1,81 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
policyv1 "github.com/juanfont/headscale/hscontrol/policy/v1"
|
||||
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/rs/zerolog/log"
|
||||
"go4.org/netipx"
|
||||
"tailscale.com/envknob"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/deephash"
|
||||
)
|
||||
|
||||
var (
|
||||
polv2 = envknob.Bool("HEADSCALE_EXPERIMENTAL_POLICY_V2")
|
||||
)
|
||||
|
||||
type PolicyManager interface {
|
||||
Filter() []tailcfg.FilterRule
|
||||
SSHPolicy(*types.Node) (*tailcfg.SSHPolicy, error)
|
||||
Tags(*types.Node) []string
|
||||
ApproversForRoute(netip.Prefix) []string
|
||||
ExpandAlias(string) (*netipx.IPSet, error)
|
||||
SetPolicy([]byte) (bool, error)
|
||||
SetUsers(users []types.User) (bool, error)
|
||||
SetNodes(nodes types.Nodes) (bool, error)
|
||||
// NodeCanHaveTag reports whether the given node can have the given tag.
|
||||
NodeCanHaveTag(*types.Node, string) bool
|
||||
|
||||
// NodeCanApproveRoute reports whether the given node can approve the given route.
|
||||
NodeCanApproveRoute(*types.Node, netip.Prefix) bool
|
||||
|
||||
Version() int
|
||||
DebugString() string
|
||||
}
|
||||
|
||||
func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (PolicyManager, error) {
|
||||
policyFile, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer policyFile.Close()
|
||||
|
||||
policyBytes, err := io.ReadAll(policyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewPolicyManager(policyBytes, users, nodes)
|
||||
}
|
||||
|
||||
func NewPolicyManager(polB []byte, users []types.User, nodes types.Nodes) (PolicyManager, error) {
|
||||
var pol *ACLPolicy
|
||||
// NewPolicyManager returns a new policy manager, the version is determined by
|
||||
// the environment flag "HEADSCALE_EXPERIMENTAL_POLICY_V2".
|
||||
func NewPolicyManager(pol []byte, users []types.User, nodes types.Nodes) (PolicyManager, error) {
|
||||
var polMan PolicyManager
|
||||
var err error
|
||||
if polB != nil && len(polB) > 0 {
|
||||
pol, err = LoadACLPolicyFromBytes(polB)
|
||||
if polv2 {
|
||||
polMan, err = policyv2.NewPolicyManager(pol, users, nodes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing policy: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
polMan, err = policyv1.NewPolicyManager(pol, users, nodes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
pm := PolicyManagerV1{
|
||||
pol: pol,
|
||||
users: users,
|
||||
nodes: nodes,
|
||||
}
|
||||
|
||||
_, err = pm.updateLocked()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &pm, nil
|
||||
return polMan, err
|
||||
}
|
||||
|
||||
func NewPolicyManagerForTest(pol *ACLPolicy, users []types.User, nodes types.Nodes) (PolicyManager, error) {
|
||||
pm := PolicyManagerV1{
|
||||
pol: pol,
|
||||
users: users,
|
||||
nodes: nodes,
|
||||
}
|
||||
// PolicyManagersForTest returns all available PostureManagers to be used
|
||||
// in tests to validate them in tests that try to determine that they
|
||||
// behave the same.
|
||||
func PolicyManagersForTest(pol []byte, users []types.User, nodes types.Nodes) ([]PolicyManager, error) {
|
||||
var polMans []PolicyManager
|
||||
|
||||
_, err := pm.updateLocked()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &pm, nil
|
||||
}
|
||||
|
||||
type PolicyManagerV1 struct {
|
||||
mu sync.Mutex
|
||||
pol *ACLPolicy
|
||||
|
||||
users []types.User
|
||||
nodes types.Nodes
|
||||
|
||||
filterHash deephash.Sum
|
||||
filter []tailcfg.FilterRule
|
||||
}
|
||||
|
||||
// updateLocked updates the filter rules based on the current policy and nodes.
|
||||
// It must be called with the lock held.
|
||||
func (pm *PolicyManagerV1) updateLocked() (bool, error) {
|
||||
filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("compiling filter rules: %w", err)
|
||||
}
|
||||
|
||||
filterHash := deephash.Hash(&filter)
|
||||
if filterHash == pm.filterHash {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
pm.filter = filter
|
||||
pm.filterHash = filterHash
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) Filter() []tailcfg.FilterRule {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
return pm.filter
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
return pm.pol.CompileSSHPolicy(node, pm.users, pm.nodes)
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) SetPolicy(polB []byte) (bool, error) {
|
||||
if len(polB) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
pol, err := LoadACLPolicyFromBytes(polB)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("parsing policy: %w", err)
|
||||
}
|
||||
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
pm.pol = pol
|
||||
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
// SetUsers updates the users in the policy manager and updates the filter rules.
|
||||
func (pm *PolicyManagerV1) SetUsers(users []types.User) (bool, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
pm.users = users
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
// SetNodes updates the nodes in the policy manager and updates the filter rules.
|
||||
func (pm *PolicyManagerV1) SetNodes(nodes types.Nodes) (bool, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
pm.nodes = nodes
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) Tags(node *types.Node) []string {
|
||||
if pm == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
tags, invalid := pm.pol.TagsOfNode(pm.users, node)
|
||||
log.Debug().Strs("authorised_tags", tags).Strs("unauthorised_tags", invalid).Uint64("node.id", node.ID.Uint64()).Msg("tags provided by policy")
|
||||
return tags
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) ApproversForRoute(route netip.Prefix) []string {
|
||||
// TODO(kradalby): This can be a parse error of the address in the policy,
|
||||
// in the new policy this will be typed and not a problem, in this policy
|
||||
// we will just return empty list
|
||||
if pm.pol == nil {
|
||||
return nil
|
||||
}
|
||||
approvers, _ := pm.pol.AutoApprovers.GetRouteApprovers(route)
|
||||
return approvers
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) ExpandAlias(alias string) (*netipx.IPSet, error) {
|
||||
ips, err := pm.pol.ExpandAlias(pm.nodes, pm.users, alias)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
func (pm *PolicyManagerV1) NodeCanApproveRoute(node *types.Node, route netip.Prefix) bool {
|
||||
if pm.pol == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
approvers, _ := pm.pol.AutoApprovers.GetRouteApprovers(route)
|
||||
|
||||
for _, approvedAlias := range approvers {
|
||||
if approvedAlias == node.User.Username() {
|
||||
return true
|
||||
} else {
|
||||
ips, err := pm.pol.ExpandAlias(pm.nodes, pm.users, approvedAlias)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// approvedIPs should contain all of node's IPs if it matches the rule, so check for first
|
||||
if ips.Contains(*node.IPv4) {
|
||||
return true
|
||||
}
|
||||
for _, pmf := range PolicyManagerFuncsForTest(pol) {
|
||||
pm, err := pmf(users, nodes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
polMans = append(polMans, pm)
|
||||
}
|
||||
|
||||
return false
|
||||
return polMans, nil
|
||||
}
|
||||
|
||||
func PolicyManagerFuncsForTest(pol []byte) []func([]types.User, types.Nodes) (PolicyManager, error) {
|
||||
var polmanFuncs []func([]types.User, types.Nodes) (PolicyManager, error)
|
||||
|
||||
polmanFuncs = append(polmanFuncs, func(u []types.User, n types.Nodes) (PolicyManager, error) {
|
||||
return policyv1.NewPolicyManager(pol, u, n)
|
||||
})
|
||||
polmanFuncs = append(polmanFuncs, func(u []types.User, n types.Nodes) (PolicyManager, error) {
|
||||
return policyv2.NewPolicyManager(pol, u, n)
|
||||
})
|
||||
|
||||
return polmanFuncs
|
||||
}
|
||||
|
109
hscontrol/policy/policy.go
Normal file
109
hscontrol/policy/policy.go
Normal file
@ -0,0 +1,109 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"slices"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/samber/lo"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// FilterNodesByACL returns the list of peers authorized to be accessed from a given node.
|
||||
func FilterNodesByACL(
|
||||
node *types.Node,
|
||||
nodes types.Nodes,
|
||||
filter []tailcfg.FilterRule,
|
||||
) types.Nodes {
|
||||
var result types.Nodes
|
||||
|
||||
for index, peer := range nodes {
|
||||
if peer.ID == node.ID {
|
||||
continue
|
||||
}
|
||||
|
||||
if node.CanAccess(filter, nodes[index]) || peer.CanAccess(filter, node) {
|
||||
result = append(result, peer)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// ReduceFilterRules takes a node and a set of rules and removes all rules and destinations
|
||||
// that are not relevant to that particular node.
|
||||
func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.FilterRule {
|
||||
ret := []tailcfg.FilterRule{}
|
||||
|
||||
for _, rule := range rules {
|
||||
// record if the rule is actually relevant for the given node.
|
||||
var dests []tailcfg.NetPortRange
|
||||
DEST_LOOP:
|
||||
for _, dest := range rule.DstPorts {
|
||||
expanded, err := util.ParseIPSet(dest.IP, nil)
|
||||
// Fail closed, if we can't parse it, then we should not allow
|
||||
// access.
|
||||
if err != nil {
|
||||
continue DEST_LOOP
|
||||
}
|
||||
|
||||
if node.InIPSet(expanded) {
|
||||
dests = append(dests, dest)
|
||||
continue DEST_LOOP
|
||||
}
|
||||
|
||||
// If the node exposes routes, ensure they are note removed
|
||||
// when the filters are reduced.
|
||||
if node.Hostinfo != nil {
|
||||
if len(node.Hostinfo.RoutableIPs) > 0 {
|
||||
for _, routableIP := range node.Hostinfo.RoutableIPs {
|
||||
if expanded.OverlapsPrefix(routableIP) {
|
||||
dests = append(dests, dest)
|
||||
continue DEST_LOOP
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(dests) > 0 {
|
||||
ret = append(ret, tailcfg.FilterRule{
|
||||
SrcIPs: rule.SrcIPs,
|
||||
DstPorts: dests,
|
||||
IPProto: rule.IPProto,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// AutoApproveRoutes approves any route that can be autoapproved from
|
||||
// the nodes perspective according to the given policy.
|
||||
// It reports true if any routes were approved.
|
||||
func AutoApproveRoutes(pm PolicyManager, node *types.Node) bool {
|
||||
if pm == nil {
|
||||
return false
|
||||
}
|
||||
var newApproved []netip.Prefix
|
||||
for _, route := range node.AnnouncedRoutes() {
|
||||
if pm.NodeCanApproveRoute(node, route) {
|
||||
newApproved = append(newApproved, route)
|
||||
}
|
||||
}
|
||||
if newApproved != nil {
|
||||
newApproved = append(newApproved, node.ApprovedRoutes...)
|
||||
tsaddr.SortPrefixes(newApproved)
|
||||
newApproved = slices.Compact(newApproved)
|
||||
newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool {
|
||||
return route.IsValid()
|
||||
})
|
||||
node.ApprovedRoutes = newApproved
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
1455
hscontrol/policy/policy_test.go
Normal file
1455
hscontrol/policy/policy_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,11 +1,10 @@
|
||||
package policy
|
||||
package v1
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
"net/netip"
|
||||
"os"
|
||||
"slices"
|
||||
@ -18,7 +17,6 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/tailscale/hujson"
|
||||
"go4.org/netipx"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
@ -37,38 +35,6 @@ const (
|
||||
expectedTokenItems = 2
|
||||
)
|
||||
|
||||
var theInternetSet *netipx.IPSet
|
||||
|
||||
// theInternet returns the IPSet for the Internet.
|
||||
// https://www.youtube.com/watch?v=iDbyYGrswtg
|
||||
func theInternet() *netipx.IPSet {
|
||||
if theInternetSet != nil {
|
||||
return theInternetSet
|
||||
}
|
||||
|
||||
var internetBuilder netipx.IPSetBuilder
|
||||
internetBuilder.AddPrefix(netip.MustParsePrefix("2000::/3"))
|
||||
internetBuilder.AddPrefix(tsaddr.AllIPv4())
|
||||
|
||||
// Delete Private network addresses
|
||||
// https://datatracker.ietf.org/doc/html/rfc1918
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("fc00::/7"))
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("10.0.0.0/8"))
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("172.16.0.0/12"))
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("192.168.0.0/16"))
|
||||
|
||||
// Delete Tailscale networks
|
||||
internetBuilder.RemovePrefix(tsaddr.TailscaleULARange())
|
||||
internetBuilder.RemovePrefix(tsaddr.CGNATRange())
|
||||
|
||||
// Delete "can't find DHCP networks"
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("fe80::/10")) // link-local
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("169.254.0.0/16"))
|
||||
|
||||
theInternetSet, _ := internetBuilder.IPSet()
|
||||
return theInternetSet
|
||||
}
|
||||
|
||||
// For some reason golang.org/x/net/internal/iana is an internal package.
|
||||
const (
|
||||
protocolICMP = 1 // Internet Control Message
|
||||
@ -240,53 +206,6 @@ func (pol *ACLPolicy) CompileFilterRules(
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
// ReduceFilterRules takes a node and a set of rules and removes all rules and destinations
|
||||
// that are not relevant to that particular node.
|
||||
func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.FilterRule {
|
||||
// TODO(kradalby): Make this nil and not alloc unless needed
|
||||
ret := []tailcfg.FilterRule{}
|
||||
|
||||
for _, rule := range rules {
|
||||
// record if the rule is actually relevant for the given node.
|
||||
var dests []tailcfg.NetPortRange
|
||||
DEST_LOOP:
|
||||
for _, dest := range rule.DstPorts {
|
||||
expanded, err := util.ParseIPSet(dest.IP, nil)
|
||||
// Fail closed, if we can't parse it, then we should not allow
|
||||
// access.
|
||||
if err != nil {
|
||||
continue DEST_LOOP
|
||||
}
|
||||
|
||||
if node.InIPSet(expanded) {
|
||||
dests = append(dests, dest)
|
||||
continue DEST_LOOP
|
||||
}
|
||||
|
||||
// If the node exposes routes, ensure they are note removed
|
||||
// when the filters are reduced.
|
||||
if len(node.SubnetRoutes()) > 0 {
|
||||
for _, routableIP := range node.SubnetRoutes() {
|
||||
if expanded.OverlapsPrefix(routableIP) {
|
||||
dests = append(dests, dest)
|
||||
continue DEST_LOOP
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(dests) > 0 {
|
||||
ret = append(ret, tailcfg.FilterRule{
|
||||
SrcIPs: rule.SrcIPs,
|
||||
DstPorts: dests,
|
||||
IPProto: rule.IPProto,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func (pol *ACLPolicy) CompileSSHPolicy(
|
||||
node *types.Node,
|
||||
users []types.User,
|
||||
@ -418,7 +337,7 @@ func (pol *ACLPolicy) CompileSSHPolicy(
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing SSH policy, expanding alias, index: %d->%d: %w", index, innerIndex, err)
|
||||
}
|
||||
for addr := range ipSetAll(ips) {
|
||||
for addr := range util.IPSetAddrIter(ips) {
|
||||
principals = append(principals, &tailcfg.SSHPrincipal{
|
||||
NodeIP: addr.String(),
|
||||
})
|
||||
@ -441,19 +360,6 @@ func (pol *ACLPolicy) CompileSSHPolicy(
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ipSetAll returns a function that iterates over all the IPs in the IPSet.
|
||||
func ipSetAll(ipSet *netipx.IPSet) iter.Seq[netip.Addr] {
|
||||
return func(yield func(netip.Addr) bool) {
|
||||
for _, rng := range ipSet.Ranges() {
|
||||
for ip := rng.From(); ip.Compare(rng.To()) <= 0; ip = ip.Next() {
|
||||
if !yield(ip) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func sshCheckAction(duration string) (*tailcfg.SSHAction, error) {
|
||||
sessionLength, err := time.ParseDuration(duration)
|
||||
if err != nil {
|
||||
@ -950,7 +856,7 @@ func (pol *ACLPolicy) expandIPsFromIPPrefix(
|
||||
func expandAutoGroup(alias string) (*netipx.IPSet, error) {
|
||||
switch {
|
||||
case strings.HasPrefix(alias, "autogroup:internet"):
|
||||
return theInternet(), nil
|
||||
return util.TheInternet(), nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown autogroup %q", alias)
|
||||
@ -1084,24 +990,3 @@ func findUserFromToken(users []types.User, token string) (types.User, error) {
|
||||
|
||||
return potentialUsers[0], nil
|
||||
}
|
||||
|
||||
// FilterNodesByACL returns the list of peers authorized to be accessed from a given node.
|
||||
func FilterNodesByACL(
|
||||
node *types.Node,
|
||||
nodes types.Nodes,
|
||||
filter []tailcfg.FilterRule,
|
||||
) types.Nodes {
|
||||
var result types.Nodes
|
||||
|
||||
for index, peer := range nodes {
|
||||
if peer.ID == node.ID {
|
||||
continue
|
||||
}
|
||||
|
||||
if node.CanAccess(filter, nodes[index]) || peer.CanAccess(filter, node) {
|
||||
result = append(result, peer)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -1,4 +1,4 @@
|
||||
package policy
|
||||
package v1
|
||||
|
||||
import (
|
||||
"encoding/json"
|
187
hscontrol/policy/v1/policy.go
Normal file
187
hscontrol/policy/v1/policy.go
Normal file
@ -0,0 +1,187 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/rs/zerolog/log"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/deephash"
|
||||
)
|
||||
|
||||
func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (*PolicyManager, error) {
|
||||
policyFile, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer policyFile.Close()
|
||||
|
||||
policyBytes, err := io.ReadAll(policyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewPolicyManager(policyBytes, users, nodes)
|
||||
}
|
||||
|
||||
func NewPolicyManager(polB []byte, users []types.User, nodes types.Nodes) (*PolicyManager, error) {
|
||||
var pol *ACLPolicy
|
||||
var err error
|
||||
if polB != nil && len(polB) > 0 {
|
||||
pol, err = LoadACLPolicyFromBytes(polB)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing policy: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
pm := PolicyManager{
|
||||
pol: pol,
|
||||
users: users,
|
||||
nodes: nodes,
|
||||
}
|
||||
|
||||
_, err = pm.updateLocked()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &pm, nil
|
||||
}
|
||||
|
||||
type PolicyManager struct {
|
||||
mu sync.Mutex
|
||||
pol *ACLPolicy
|
||||
|
||||
users []types.User
|
||||
nodes types.Nodes
|
||||
|
||||
filterHash deephash.Sum
|
||||
filter []tailcfg.FilterRule
|
||||
}
|
||||
|
||||
// updateLocked updates the filter rules based on the current policy and nodes.
|
||||
// It must be called with the lock held.
|
||||
func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||
filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("compiling filter rules: %w", err)
|
||||
}
|
||||
|
||||
filterHash := deephash.Hash(&filter)
|
||||
if filterHash == pm.filterHash {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
pm.filter = filter
|
||||
pm.filterHash = filterHash
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) Filter() []tailcfg.FilterRule {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
return pm.filter
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
return pm.pol.CompileSSHPolicy(node, pm.users, pm.nodes)
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) {
|
||||
if len(polB) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
pol, err := LoadACLPolicyFromBytes(polB)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("parsing policy: %w", err)
|
||||
}
|
||||
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
pm.pol = pol
|
||||
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
// SetUsers updates the users in the policy manager and updates the filter rules.
|
||||
func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
pm.users = users
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
// SetNodes updates the nodes in the policy manager and updates the filter rules.
|
||||
func (pm *PolicyManager) SetNodes(nodes types.Nodes) (bool, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
pm.nodes = nodes
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) NodeCanHaveTag(node *types.Node, tag string) bool {
|
||||
if pm == nil || pm.pol == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
tags, invalid := pm.pol.TagsOfNode(pm.users, node)
|
||||
log.Debug().Strs("authorised_tags", tags).Strs("unauthorised_tags", invalid).Uint64("node.id", node.ID.Uint64()).Msg("tags provided by policy")
|
||||
|
||||
for _, t := range tags {
|
||||
if t == tag {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) NodeCanApproveRoute(node *types.Node, route netip.Prefix) bool {
|
||||
if pm == nil || pm.pol == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
approvers, _ := pm.pol.AutoApprovers.GetRouteApprovers(route)
|
||||
|
||||
for _, approvedAlias := range approvers {
|
||||
if approvedAlias == node.User.Username() {
|
||||
return true
|
||||
} else {
|
||||
ips, err := pm.pol.ExpandAlias(pm.nodes, pm.users, approvedAlias)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// approvedIPs should contain all of node's IPs if it matches the rule, so check for first
|
||||
if ips.Contains(*node.IPv4) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) Version() int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) DebugString() string {
|
||||
return "not implemented for v1"
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package policy
|
||||
package v1
|
||||
|
||||
import (
|
||||
"testing"
|
169
hscontrol/policy/v2/filter.go
Normal file
169
hscontrol/policy/v2/filter.go
Normal file
@ -0,0 +1,169 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"go4.org/netipx"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidAction = errors.New("invalid action")
|
||||
)
|
||||
|
||||
// compileFilterRules takes a set of nodes and an ACLPolicy and generates a
|
||||
// set of Tailscale compatible FilterRules used to allow traffic on clients.
|
||||
func (pol *Policy) compileFilterRules(
|
||||
users types.Users,
|
||||
nodes types.Nodes,
|
||||
) ([]tailcfg.FilterRule, error) {
|
||||
if pol == nil {
|
||||
return tailcfg.FilterAllowAll, nil
|
||||
}
|
||||
|
||||
var rules []tailcfg.FilterRule
|
||||
|
||||
for _, acl := range pol.ACLs {
|
||||
if acl.Action != "accept" {
|
||||
return nil, ErrInvalidAction
|
||||
}
|
||||
|
||||
srcIPs, err := acl.Sources.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
log.Trace().Err(err).Msgf("resolving source ips")
|
||||
}
|
||||
|
||||
if len(srcIPs.Prefixes()) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO(kradalby): integrate type into schema
|
||||
// TODO(kradalby): figure out the _ is wildcard stuff
|
||||
protocols, _, err := parseProtocol(acl.Protocol)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing policy, protocol err: %w ", err)
|
||||
}
|
||||
|
||||
var destPorts []tailcfg.NetPortRange
|
||||
for _, dest := range acl.Destinations {
|
||||
ips, err := dest.Alias.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
log.Trace().Err(err).Msgf("resolving destination ips")
|
||||
}
|
||||
|
||||
for _, pref := range ips.Prefixes() {
|
||||
for _, port := range dest.Ports {
|
||||
pr := tailcfg.NetPortRange{
|
||||
IP: pref.String(),
|
||||
Ports: port,
|
||||
}
|
||||
destPorts = append(destPorts, pr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(destPorts) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
rules = append(rules, tailcfg.FilterRule{
|
||||
SrcIPs: ipSetToPrefixStringList(srcIPs),
|
||||
DstPorts: destPorts,
|
||||
IPProto: protocols,
|
||||
})
|
||||
}
|
||||
|
||||
return rules, nil
|
||||
}
|
||||
|
||||
func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction {
|
||||
return tailcfg.SSHAction{
|
||||
Reject: !accept,
|
||||
Accept: accept,
|
||||
SessionDuration: duration,
|
||||
AllowAgentForwarding: true,
|
||||
AllowLocalPortForwarding: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (pol *Policy) compileSSHPolicy(
|
||||
users types.Users,
|
||||
node *types.Node,
|
||||
nodes types.Nodes,
|
||||
) (*tailcfg.SSHPolicy, error) {
|
||||
if pol == nil || pol.SSHs == nil || len(pol.SSHs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var rules []*tailcfg.SSHRule
|
||||
|
||||
for index, rule := range pol.SSHs {
|
||||
var dest netipx.IPSetBuilder
|
||||
for _, src := range rule.Destinations {
|
||||
ips, err := src.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
log.Trace().Err(err).Msgf("resolving destination ips")
|
||||
}
|
||||
dest.AddSet(ips)
|
||||
}
|
||||
|
||||
destSet, err := dest.IPSet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !node.InIPSet(destSet) {
|
||||
continue
|
||||
}
|
||||
|
||||
var action tailcfg.SSHAction
|
||||
switch rule.Action {
|
||||
case "accept":
|
||||
action = sshAction(true, 0)
|
||||
case "check":
|
||||
action = sshAction(true, rule.CheckPeriod)
|
||||
default:
|
||||
return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err)
|
||||
}
|
||||
|
||||
var principals []*tailcfg.SSHPrincipal
|
||||
srcIPs, err := rule.Sources.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
log.Trace().Err(err).Msgf("resolving source ips")
|
||||
}
|
||||
|
||||
for addr := range util.IPSetAddrIter(srcIPs) {
|
||||
principals = append(principals, &tailcfg.SSHPrincipal{
|
||||
NodeIP: addr.String(),
|
||||
})
|
||||
}
|
||||
|
||||
userMap := make(map[string]string, len(rule.Users))
|
||||
for _, user := range rule.Users {
|
||||
userMap[user.String()] = "="
|
||||
}
|
||||
rules = append(rules, &tailcfg.SSHRule{
|
||||
Principals: principals,
|
||||
SSHUsers: userMap,
|
||||
Action: &action,
|
||||
})
|
||||
}
|
||||
|
||||
return &tailcfg.SSHPolicy{
|
||||
Rules: rules,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ipSetToPrefixStringList(ips *netipx.IPSet) []string {
|
||||
var out []string
|
||||
|
||||
for _, pref := range ips.Prefixes() {
|
||||
out = append(out, pref.String())
|
||||
}
|
||||
return out
|
||||
}
|
378
hscontrol/policy/v2/filter_test.go
Normal file
378
hscontrol/policy/v2/filter_test.go
Normal file
@ -0,0 +1,378 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
func TestParsing(t *testing.T) {
|
||||
users := types.Users{
|
||||
{Model: gorm.Model{ID: 1}, Name: "testuser"},
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
format string
|
||||
acl string
|
||||
want []tailcfg.FilterRule
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "invalid-hujson",
|
||||
format: "hujson",
|
||||
acl: `
|
||||
{
|
||||
`,
|
||||
want: []tailcfg.FilterRule{},
|
||||
wantErr: true,
|
||||
},
|
||||
// The new parser will ignore all that is irrelevant
|
||||
// {
|
||||
// name: "valid-hujson-invalid-content",
|
||||
// format: "hujson",
|
||||
// acl: `
|
||||
// {
|
||||
// "valid_json": true,
|
||||
// "but_a_policy_though": false
|
||||
// }
|
||||
// `,
|
||||
// want: []tailcfg.FilterRule{},
|
||||
// wantErr: true,
|
||||
// },
|
||||
// {
|
||||
// name: "invalid-cidr",
|
||||
// format: "hujson",
|
||||
// acl: `
|
||||
// {"example-host-1": "100.100.100.100/42"}
|
||||
// `,
|
||||
// want: []tailcfg.FilterRule{},
|
||||
// wantErr: true,
|
||||
// },
|
||||
{
|
||||
name: "basic-rule",
|
||||
format: "hujson",
|
||||
acl: `
|
||||
{
|
||||
"hosts": {
|
||||
"host-1": "100.100.100.100",
|
||||
"subnet-1": "100.100.101.100/24",
|
||||
},
|
||||
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": [
|
||||
"subnet-1",
|
||||
"192.168.1.0/24"
|
||||
],
|
||||
"dst": [
|
||||
"*:22,3389",
|
||||
"host-1:*",
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
`,
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.100.101.0/24", "192.168.1.0/24"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "0.0.0.0/0", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
{IP: "0.0.0.0/0", Ports: tailcfg.PortRange{First: 3389, Last: 3389}},
|
||||
{IP: "::/0", Ports: tailcfg.PortRange{First: 22, Last: 22}},
|
||||
{IP: "::/0", Ports: tailcfg.PortRange{First: 3389, Last: 3389}},
|
||||
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "parse-protocol",
|
||||
format: "hujson",
|
||||
acl: `
|
||||
{
|
||||
"hosts": {
|
||||
"host-1": "100.100.100.100",
|
||||
"subnet-1": "100.100.101.100/24",
|
||||
},
|
||||
|
||||
"acls": [
|
||||
{
|
||||
"Action": "accept",
|
||||
"src": [
|
||||
"*",
|
||||
],
|
||||
"proto": "tcp",
|
||||
"dst": [
|
||||
"host-1:*",
|
||||
],
|
||||
},
|
||||
{
|
||||
"Action": "accept",
|
||||
"src": [
|
||||
"*",
|
||||
],
|
||||
"proto": "udp",
|
||||
"dst": [
|
||||
"host-1:53",
|
||||
],
|
||||
},
|
||||
{
|
||||
"Action": "accept",
|
||||
"src": [
|
||||
"*",
|
||||
],
|
||||
"proto": "icmp",
|
||||
"dst": [
|
||||
"host-1:*",
|
||||
],
|
||||
},
|
||||
],
|
||||
}`,
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"0.0.0.0/0", "::/0"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||
},
|
||||
IPProto: []int{protocolTCP},
|
||||
},
|
||||
{
|
||||
SrcIPs: []string{"0.0.0.0/0", "::/0"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.100.100.100/32", Ports: tailcfg.PortRange{First: 53, Last: 53}},
|
||||
},
|
||||
IPProto: []int{protocolUDP},
|
||||
},
|
||||
{
|
||||
SrcIPs: []string{"0.0.0.0/0", "::/0"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||
},
|
||||
IPProto: []int{protocolICMP, protocolIPv6ICMP},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "port-wildcard",
|
||||
format: "hujson",
|
||||
acl: `
|
||||
{
|
||||
"hosts": {
|
||||
"host-1": "100.100.100.100",
|
||||
"subnet-1": "100.100.101.100/24",
|
||||
},
|
||||
|
||||
"acls": [
|
||||
{
|
||||
"Action": "accept",
|
||||
"src": [
|
||||
"*",
|
||||
],
|
||||
"dst": [
|
||||
"host-1:*",
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
`,
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"0.0.0.0/0", "::/0"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "port-range",
|
||||
format: "hujson",
|
||||
acl: `
|
||||
{
|
||||
"hosts": {
|
||||
"host-1": "100.100.100.100",
|
||||
"subnet-1": "100.100.101.100/24",
|
||||
},
|
||||
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": [
|
||||
"subnet-1",
|
||||
],
|
||||
"dst": [
|
||||
"host-1:5400-5500",
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
`,
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"100.100.101.0/24"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{
|
||||
IP: "100.100.100.100/32",
|
||||
Ports: tailcfg.PortRange{First: 5400, Last: 5500},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "port-group",
|
||||
format: "hujson",
|
||||
acl: `
|
||||
{
|
||||
"groups": {
|
||||
"group:example": [
|
||||
"testuser@",
|
||||
],
|
||||
},
|
||||
|
||||
"hosts": {
|
||||
"host-1": "100.100.100.100",
|
||||
"subnet-1": "100.100.101.100/24",
|
||||
},
|
||||
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": [
|
||||
"group:example",
|
||||
],
|
||||
"dst": [
|
||||
"host-1:*",
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
`,
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"200.200.200.200/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "port-user",
|
||||
format: "hujson",
|
||||
acl: `
|
||||
{
|
||||
"hosts": {
|
||||
"host-1": "100.100.100.100",
|
||||
"subnet-1": "100.100.101.100/24",
|
||||
},
|
||||
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": [
|
||||
"testuser@",
|
||||
],
|
||||
"dst": [
|
||||
"host-1:*",
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
`,
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"200.200.200.200/32"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "ipv6",
|
||||
format: "hujson",
|
||||
acl: `
|
||||
{
|
||||
"hosts": {
|
||||
"host-1": "100.100.100.100/32",
|
||||
"subnet-1": "100.100.101.100/24",
|
||||
},
|
||||
|
||||
"acls": [
|
||||
{
|
||||
"action": "accept",
|
||||
"src": [
|
||||
"*",
|
||||
],
|
||||
"dst": [
|
||||
"host-1:*",
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
`,
|
||||
want: []tailcfg.FilterRule{
|
||||
{
|
||||
SrcIPs: []string{"0.0.0.0/0", "::/0"},
|
||||
DstPorts: []tailcfg.NetPortRange{
|
||||
{IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pol, err := policyFromBytes([]byte(tt.acl))
|
||||
if tt.wantErr && err == nil {
|
||||
t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
||||
return
|
||||
} else if !tt.wantErr && err != nil {
|
||||
t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
rules, err := pol.compileFilterRules(
|
||||
users,
|
||||
types.Nodes{
|
||||
&types.Node{
|
||||
IPv4: ap("100.100.100.100"),
|
||||
},
|
||||
&types.Node{
|
||||
IPv4: ap("200.200.200.200"),
|
||||
User: users[0],
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
},
|
||||
})
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.want, rules); diff != "" {
|
||||
t.Errorf("parsing() unexpected result (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
283
hscontrol/policy/v2/policy.go
Normal file
283
hscontrol/policy/v2/policy.go
Normal file
@ -0,0 +1,283 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"go4.org/netipx"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/util/deephash"
|
||||
)
|
||||
|
||||
type PolicyManager struct {
|
||||
mu sync.Mutex
|
||||
pol *Policy
|
||||
users []types.User
|
||||
nodes types.Nodes
|
||||
|
||||
filterHash deephash.Sum
|
||||
filter []tailcfg.FilterRule
|
||||
|
||||
tagOwnerMapHash deephash.Sum
|
||||
tagOwnerMap map[Tag]*netipx.IPSet
|
||||
|
||||
autoApproveMapHash deephash.Sum
|
||||
autoApproveMap map[netip.Prefix]*netipx.IPSet
|
||||
|
||||
// Lazy map of SSH policies
|
||||
sshPolicyMap map[types.NodeID]*tailcfg.SSHPolicy
|
||||
}
|
||||
|
||||
// NewPolicyManager creates a new PolicyManager from a policy file and a list of users and nodes.
|
||||
// It returns an error if the policy file is invalid.
|
||||
// The policy manager will update the filter rules based on the users and nodes.
|
||||
func NewPolicyManager(b []byte, users []types.User, nodes types.Nodes) (*PolicyManager, error) {
|
||||
policy, err := policyFromBytes(b)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing policy: %w", err)
|
||||
}
|
||||
|
||||
pm := PolicyManager{
|
||||
pol: policy,
|
||||
users: users,
|
||||
nodes: nodes,
|
||||
sshPolicyMap: make(map[types.NodeID]*tailcfg.SSHPolicy, len(nodes)),
|
||||
}
|
||||
|
||||
_, err = pm.updateLocked()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &pm, nil
|
||||
}
|
||||
|
||||
// updateLocked updates the filter rules based on the current policy and nodes.
|
||||
// It must be called with the lock held.
|
||||
func (pm *PolicyManager) updateLocked() (bool, error) {
|
||||
filter, err := pm.pol.compileFilterRules(pm.users, pm.nodes)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("compiling filter rules: %w", err)
|
||||
}
|
||||
|
||||
filterHash := deephash.Hash(&filter)
|
||||
filterChanged := filterHash == pm.filterHash
|
||||
pm.filter = filter
|
||||
pm.filterHash = filterHash
|
||||
|
||||
// Order matters, tags might be used in autoapprovers, so we need to ensure
|
||||
// that the map for tag owners is resolved before resolving autoapprovers.
|
||||
// TODO(kradalby): Order might not matter after #2417
|
||||
tagMap, err := resolveTagOwners(pm.pol, pm.users, pm.nodes)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("resolving tag owners map: %w", err)
|
||||
}
|
||||
|
||||
tagOwnerMapHash := deephash.Hash(&tagMap)
|
||||
tagOwnerChanged := tagOwnerMapHash != pm.tagOwnerMapHash
|
||||
pm.tagOwnerMap = tagMap
|
||||
pm.tagOwnerMapHash = tagOwnerMapHash
|
||||
|
||||
autoMap, err := resolveAutoApprovers(pm.pol, pm.users, pm.nodes)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("resolving auto approvers map: %w", err)
|
||||
}
|
||||
|
||||
autoApproveMapHash := deephash.Hash(&autoMap)
|
||||
autoApproveChanged := autoApproveMapHash != pm.autoApproveMapHash
|
||||
pm.autoApproveMap = autoMap
|
||||
pm.autoApproveMapHash = autoApproveMapHash
|
||||
|
||||
// If neither of the calculated values changed, no need to update nodes
|
||||
if !filterChanged && !tagOwnerChanged && !autoApproveChanged {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Clear the SSH policy map to ensure it's recalculated with the new policy.
|
||||
// TODO(kradalby): This could potentially be optimized by only clearing the
|
||||
// policies for nodes that have changed. Particularly if the only difference is
|
||||
// that nodes has been added or removed.
|
||||
clear(pm.sshPolicyMap)
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
if sshPol, ok := pm.sshPolicyMap[node.ID]; ok {
|
||||
return sshPol, nil
|
||||
}
|
||||
|
||||
sshPol, err := pm.pol.compileSSHPolicy(pm.users, node, pm.nodes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("compiling SSH policy: %w", err)
|
||||
}
|
||||
pm.sshPolicyMap[node.ID] = sshPol
|
||||
|
||||
return sshPol, nil
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) {
|
||||
if len(polB) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
pol, err := policyFromBytes(polB)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("parsing policy: %w", err)
|
||||
}
|
||||
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
pm.pol = pol
|
||||
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
// Filter returns the current filter rules for the entire tailnet.
|
||||
func (pm *PolicyManager) Filter() []tailcfg.FilterRule {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
return pm.filter
|
||||
}
|
||||
|
||||
// SetUsers updates the users in the policy manager and updates the filter rules.
|
||||
func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
pm.users = users
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
// SetNodes updates the nodes in the policy manager and updates the filter rules.
|
||||
func (pm *PolicyManager) SetNodes(nodes types.Nodes) (bool, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
pm.nodes = nodes
|
||||
return pm.updateLocked()
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) NodeCanHaveTag(node *types.Node, tag string) bool {
|
||||
if pm == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
if ips, ok := pm.tagOwnerMap[Tag(tag)]; ok {
|
||||
for _, nodeAddr := range node.IPs() {
|
||||
if ips.Contains(nodeAddr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) NodeCanApproveRoute(node *types.Node, route netip.Prefix) bool {
|
||||
if pm == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
// The fast path is that a node requests to approve a prefix
|
||||
// where there is an exact entry, e.g. 10.0.0.0/8, then
|
||||
// check and return quickly
|
||||
if _, ok := pm.autoApproveMap[route]; ok {
|
||||
for _, nodeAddr := range node.IPs() {
|
||||
if pm.autoApproveMap[route].Contains(nodeAddr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The slow path is that the node tries to approve
|
||||
// 10.0.10.0/24, which is a part of 10.0.0.0/8, then we
|
||||
// cannot just lookup in the prefix map and have to check
|
||||
// if there is a "parent" prefix available.
|
||||
for prefix, approveAddrs := range pm.autoApproveMap {
|
||||
// We do not want the exit node entry to approve all
|
||||
// sorts of routes. The logic here is that it would be
|
||||
// unexpected behaviour to have specific routes approved
|
||||
// just because the node is allowed to designate itself as
|
||||
// an exit.
|
||||
if tsaddr.IsExitRoute(prefix) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if prefix is larger (so containing) and then overlaps
|
||||
// the route to see if the node can approve a subset of an autoapprover
|
||||
if prefix.Bits() <= route.Bits() && prefix.Overlaps(route) {
|
||||
for _, nodeAddr := range node.IPs() {
|
||||
if approveAddrs.Contains(nodeAddr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) Version() int {
|
||||
return 2
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) DebugString() string {
|
||||
var sb strings.Builder
|
||||
|
||||
fmt.Fprintf(&sb, "PolicyManager (v%d):\n\n", pm.Version())
|
||||
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
if pm.pol != nil {
|
||||
pol, err := json.MarshalIndent(pm.pol, "", " ")
|
||||
if err == nil {
|
||||
sb.WriteString("Policy:\n")
|
||||
sb.Write(pol)
|
||||
sb.WriteString("\n\n")
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(&sb, "AutoApprover (%d):\n", len(pm.autoApproveMap))
|
||||
for prefix, approveAddrs := range pm.autoApproveMap {
|
||||
fmt.Fprintf(&sb, "\t%s:\n", prefix)
|
||||
for _, iprange := range approveAddrs.Ranges() {
|
||||
fmt.Fprintf(&sb, "\t\t%s\n", iprange)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
fmt.Fprintf(&sb, "TagOwner (%d):\n", len(pm.tagOwnerMap))
|
||||
for prefix, tagOwners := range pm.tagOwnerMap {
|
||||
fmt.Fprintf(&sb, "\t%s:\n", prefix)
|
||||
for _, iprange := range tagOwners.Ranges() {
|
||||
fmt.Fprintf(&sb, "\t\t%s\n", iprange)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("\n\n")
|
||||
if pm.filter != nil {
|
||||
filter, err := json.MarshalIndent(pm.filter, "", " ")
|
||||
if err == nil {
|
||||
sb.WriteString("Compiled filter:\n")
|
||||
sb.Write(filter)
|
||||
sb.WriteString("\n\n")
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
58
hscontrol/policy/v2/policy_test.go
Normal file
58
hscontrol/policy/v2/policy_test.go
Normal file
@ -0,0 +1,58 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) *types.Node {
|
||||
return &types.Node{
|
||||
ID: 0,
|
||||
Hostname: name,
|
||||
IPv4: ap(ipv4),
|
||||
IPv6: ap(ipv6),
|
||||
User: user,
|
||||
UserID: user.ID,
|
||||
Hostinfo: hostinfo,
|
||||
}
|
||||
}
|
||||
|
||||
func TestPolicyManager(t *testing.T) {
|
||||
users := types.Users{
|
||||
{Model: gorm.Model{ID: 1}, Name: "testuser", Email: "testuser@headscale.net"},
|
||||
{Model: gorm.Model{ID: 2}, Name: "otheruser", Email: "otheruser@headscale.net"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
pol string
|
||||
nodes types.Nodes
|
||||
wantFilter []tailcfg.FilterRule
|
||||
}{
|
||||
{
|
||||
name: "empty-policy",
|
||||
pol: "{}",
|
||||
nodes: types.Nodes{},
|
||||
wantFilter: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes)
|
||||
require.NoError(t, err)
|
||||
|
||||
filter := pm.Filter()
|
||||
if diff := cmp.Diff(filter, tt.wantFilter); diff != "" {
|
||||
t.Errorf("Filter() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
// TODO(kradalby): Test SSH Policy
|
||||
})
|
||||
}
|
||||
}
|
1005
hscontrol/policy/v2/types.go
Normal file
1005
hscontrol/policy/v2/types.go
Normal file
File diff suppressed because it is too large
Load Diff
1162
hscontrol/policy/v2/types_test.go
Normal file
1162
hscontrol/policy/v2/types_test.go
Normal file
File diff suppressed because it is too large
Load Diff
164
hscontrol/policy/v2/utils.go
Normal file
164
hscontrol/policy/v2/utils.go
Normal file
@ -0,0 +1,164 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// splitDestinationAndPort takes an input string and returns the destination and port as a tuple, or an error if the input is invalid.
|
||||
func splitDestinationAndPort(input string) (string, string, error) {
|
||||
// Find the last occurrence of the colon character
|
||||
lastColonIndex := strings.LastIndex(input, ":")
|
||||
|
||||
// Check if the colon character is present and not at the beginning or end of the string
|
||||
if lastColonIndex == -1 {
|
||||
return "", "", errors.New("input must contain a colon character separating destination and port")
|
||||
}
|
||||
if lastColonIndex == 0 {
|
||||
return "", "", errors.New("input cannot start with a colon character")
|
||||
}
|
||||
if lastColonIndex == len(input)-1 {
|
||||
return "", "", errors.New("input cannot end with a colon character")
|
||||
}
|
||||
|
||||
// Split the string into destination and port based on the last colon
|
||||
destination := input[:lastColonIndex]
|
||||
port := input[lastColonIndex+1:]
|
||||
|
||||
return destination, port, nil
|
||||
}
|
||||
|
||||
// parsePortRange parses a port definition string and returns a slice of PortRange structs.
|
||||
func parsePortRange(portDef string) ([]tailcfg.PortRange, error) {
|
||||
if portDef == "*" {
|
||||
return []tailcfg.PortRange{tailcfg.PortRangeAny}, nil
|
||||
}
|
||||
|
||||
var portRanges []tailcfg.PortRange
|
||||
parts := strings.Split(portDef, ",")
|
||||
|
||||
for _, part := range parts {
|
||||
if strings.Contains(part, "-") {
|
||||
rangeParts := strings.Split(part, "-")
|
||||
rangeParts = slices.DeleteFunc(rangeParts, func(e string) bool {
|
||||
return e == ""
|
||||
})
|
||||
if len(rangeParts) != 2 {
|
||||
return nil, errors.New("invalid port range format")
|
||||
}
|
||||
|
||||
first, err := parsePort(rangeParts[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
last, err := parsePort(rangeParts[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if first > last {
|
||||
return nil, errors.New("invalid port range: first port is greater than last port")
|
||||
}
|
||||
|
||||
portRanges = append(portRanges, tailcfg.PortRange{First: first, Last: last})
|
||||
} else {
|
||||
port, err := parsePort(part)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
portRanges = append(portRanges, tailcfg.PortRange{First: port, Last: port})
|
||||
}
|
||||
}
|
||||
|
||||
return portRanges, nil
|
||||
}
|
||||
|
||||
// parsePort parses a single port number from a string.
|
||||
func parsePort(portStr string) (uint16, error) {
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return 0, errors.New("invalid port number")
|
||||
}
|
||||
|
||||
if port < 0 || port > 65535 {
|
||||
return 0, errors.New("port number out of range")
|
||||
}
|
||||
|
||||
return uint16(port), nil
|
||||
}
|
||||
|
||||
// For some reason golang.org/x/net/internal/iana is an internal package.
|
||||
const (
|
||||
protocolICMP = 1 // Internet Control Message
|
||||
protocolIGMP = 2 // Internet Group Management
|
||||
protocolIPv4 = 4 // IPv4 encapsulation
|
||||
protocolTCP = 6 // Transmission Control
|
||||
protocolEGP = 8 // Exterior Gateway Protocol
|
||||
protocolIGP = 9 // any private interior gateway (used by Cisco for their IGRP)
|
||||
protocolUDP = 17 // User Datagram
|
||||
protocolGRE = 47 // Generic Routing Encapsulation
|
||||
protocolESP = 50 // Encap Security Payload
|
||||
protocolAH = 51 // Authentication Header
|
||||
protocolIPv6ICMP = 58 // ICMP for IPv6
|
||||
protocolSCTP = 132 // Stream Control Transmission Protocol
|
||||
ProtocolFC = 133 // Fibre Channel
|
||||
)
|
||||
|
||||
// parseProtocol reads the proto field of the ACL and generates a list of
|
||||
// protocols that will be allowed, following the IANA IP protocol number
|
||||
// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml
|
||||
//
|
||||
// If the ACL proto field is empty, it allows ICMPv4, ICMPv6, TCP, and UDP,
|
||||
// as per Tailscale behaviour (see tailcfg.FilterRule).
|
||||
//
|
||||
// Also returns a boolean indicating if the protocol
|
||||
// requires all the destinations to use wildcard as port number (only TCP,
|
||||
// UDP and SCTP support specifying ports).
|
||||
func parseProtocol(protocol string) ([]int, bool, error) {
|
||||
switch protocol {
|
||||
case "":
|
||||
return nil, false, nil
|
||||
case "igmp":
|
||||
return []int{protocolIGMP}, true, nil
|
||||
case "ipv4", "ip-in-ip":
|
||||
return []int{protocolIPv4}, true, nil
|
||||
case "tcp":
|
||||
return []int{protocolTCP}, false, nil
|
||||
case "egp":
|
||||
return []int{protocolEGP}, true, nil
|
||||
case "igp":
|
||||
return []int{protocolIGP}, true, nil
|
||||
case "udp":
|
||||
return []int{protocolUDP}, false, nil
|
||||
case "gre":
|
||||
return []int{protocolGRE}, true, nil
|
||||
case "esp":
|
||||
return []int{protocolESP}, true, nil
|
||||
case "ah":
|
||||
return []int{protocolAH}, true, nil
|
||||
case "sctp":
|
||||
return []int{protocolSCTP}, false, nil
|
||||
case "icmp":
|
||||
return []int{protocolICMP, protocolIPv6ICMP}, true, nil
|
||||
|
||||
default:
|
||||
protocolNumber, err := strconv.Atoi(protocol)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("parsing protocol number: %w", err)
|
||||
}
|
||||
|
||||
// TODO(kradalby): What is this?
|
||||
needsWildcard := protocolNumber != protocolTCP &&
|
||||
protocolNumber != protocolUDP &&
|
||||
protocolNumber != protocolSCTP
|
||||
|
||||
return []int{protocolNumber}, needsWildcard, nil
|
||||
}
|
||||
}
|
102
hscontrol/policy/v2/utils_test.go
Normal file
102
hscontrol/policy/v2/utils_test.go
Normal file
@ -0,0 +1,102 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// TestParseDestinationAndPort tests the parseDestinationAndPort function using table-driven tests.
|
||||
func TestParseDestinationAndPort(t *testing.T) {
|
||||
testCases := []struct {
|
||||
input string
|
||||
expectedDst string
|
||||
expectedPort string
|
||||
expectedErr error
|
||||
}{
|
||||
{"git-server:*", "git-server", "*", nil},
|
||||
{"192.168.1.0/24:22", "192.168.1.0/24", "22", nil},
|
||||
{"fd7a:115c:a1e0::2:22", "fd7a:115c:a1e0::2", "22", nil},
|
||||
{"fd7a:115c:a1e0::2/128:22", "fd7a:115c:a1e0::2/128", "22", nil},
|
||||
{"tag:montreal-webserver:80,443", "tag:montreal-webserver", "80,443", nil},
|
||||
{"tag:api-server:443", "tag:api-server", "443", nil},
|
||||
{"example-host-1:*", "example-host-1", "*", nil},
|
||||
{"hostname:80-90", "hostname", "80-90", nil},
|
||||
{"invalidinput", "", "", errors.New("input must contain a colon character separating destination and port")},
|
||||
{":invalid", "", "", errors.New("input cannot start with a colon character")},
|
||||
{"invalid:", "", "", errors.New("input cannot end with a colon character")},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
dst, port, err := splitDestinationAndPort(testCase.input)
|
||||
if dst != testCase.expectedDst || port != testCase.expectedPort || (err != nil && err.Error() != testCase.expectedErr.Error()) {
|
||||
t.Errorf("parseDestinationAndPort(%q) = (%q, %q, %v), want (%q, %q, %v)",
|
||||
testCase.input, dst, port, err, testCase.expectedDst, testCase.expectedPort, testCase.expectedErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePort(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected uint16
|
||||
err string
|
||||
}{
|
||||
{"80", 80, ""},
|
||||
{"0", 0, ""},
|
||||
{"65535", 65535, ""},
|
||||
{"-1", 0, "port number out of range"},
|
||||
{"65536", 0, "port number out of range"},
|
||||
{"abc", 0, "invalid port number"},
|
||||
{"", 0, "invalid port number"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
result, err := parsePort(test.input)
|
||||
if err != nil && err.Error() != test.err {
|
||||
t.Errorf("parsePort(%q) error = %v, expected error = %v", test.input, err, test.err)
|
||||
}
|
||||
if err == nil && test.err != "" {
|
||||
t.Errorf("parsePort(%q) expected error = %v, got nil", test.input, test.err)
|
||||
}
|
||||
if result != test.expected {
|
||||
t.Errorf("parsePort(%q) = %v, expected %v", test.input, result, test.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePortRange(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
expected []tailcfg.PortRange
|
||||
err string
|
||||
}{
|
||||
{"80", []tailcfg.PortRange{{80, 80}}, ""},
|
||||
{"80-90", []tailcfg.PortRange{{80, 90}}, ""},
|
||||
{"80,90", []tailcfg.PortRange{{80, 80}, {90, 90}}, ""},
|
||||
{"80-91,92,93-95", []tailcfg.PortRange{{80, 91}, {92, 92}, {93, 95}}, ""},
|
||||
{"*", []tailcfg.PortRange{tailcfg.PortRangeAny}, ""},
|
||||
{"80-", nil, "invalid port range format"},
|
||||
{"-90", nil, "invalid port range format"},
|
||||
{"80-90,", nil, "invalid port number"},
|
||||
{"80,90-", nil, "invalid port range format"},
|
||||
{"80-90,abc", nil, "invalid port number"},
|
||||
{"80-90,65536", nil, "port number out of range"},
|
||||
{"80-90,90-80", nil, "invalid port range: first port is greater than last port"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
result, err := parsePortRange(test.input)
|
||||
if err != nil && err.Error() != test.err {
|
||||
t.Errorf("parsePortRange(%q) error = %v, expected error = %v", test.input, err, test.err)
|
||||
}
|
||||
if err == nil && test.err != "" {
|
||||
t.Errorf("parsePortRange(%q) expected error = %v, got nil", test.input, test.err)
|
||||
}
|
||||
if diff := cmp.Diff(result, test.expected); diff != "" {
|
||||
t.Errorf("parsePortRange(%q) mismatch (-want +got):\n%s", test.input, diff)
|
||||
}
|
||||
}
|
||||
}
|
@ -10,10 +10,9 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/mapper"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/samber/lo"
|
||||
"github.com/sasha-s/go-deadlock"
|
||||
xslices "golang.org/x/exp/slices"
|
||||
"tailscale.com/net/tsaddr"
|
||||
@ -459,25 +458,10 @@ func (m *mapSession) handleEndpointUpdate() {
|
||||
// TODO(kradalby): I am not sure if we need this?
|
||||
nodesChangedHook(m.h.db, m.h.polMan, m.h.nodeNotifier)
|
||||
|
||||
// Take all the routes presented to us by the node and check
|
||||
// if any of them should be auto approved by the policy.
|
||||
// If any of them are, add them to the approved routes of the node.
|
||||
// Keep all the old entries and compact the list to remove duplicates.
|
||||
var newApproved []netip.Prefix
|
||||
for _, route := range m.node.Hostinfo.RoutableIPs {
|
||||
if m.h.polMan.NodeCanApproveRoute(m.node, route) {
|
||||
newApproved = append(newApproved, route)
|
||||
}
|
||||
}
|
||||
if newApproved != nil {
|
||||
newApproved = append(newApproved, m.node.ApprovedRoutes...)
|
||||
slices.SortFunc(newApproved, util.ComparePrefix)
|
||||
slices.Compact(newApproved)
|
||||
newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool {
|
||||
return route.IsValid()
|
||||
})
|
||||
m.node.ApprovedRoutes = newApproved
|
||||
|
||||
// Approve routes if they are auto-approved by the policy.
|
||||
// If any of them are approved, report them to the primary route tracker
|
||||
// and send updates accordingly.
|
||||
if policy.AutoApproveRoutes(m.h.polMan, m.node) {
|
||||
if m.h.primaryRoutes.SetRoutes(m.node.ID, m.node.SubnetRoutes()...) {
|
||||
ctx := types.NotifyCtx(m.ctx, "poll-primary-change", m.node.Hostname)
|
||||
m.h.nodeNotifier.NotifyAll(ctx, types.UpdateFull())
|
||||
|
@ -150,6 +150,68 @@ func (node *Node) IPs() []netip.Addr {
|
||||
return ret
|
||||
}
|
||||
|
||||
// HasIP reports if a node has a given IP address.
|
||||
func (node *Node) HasIP(i netip.Addr) bool {
|
||||
for _, ip := range node.IPs() {
|
||||
if ip.Compare(i) == 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsTagged reports if a device is tagged
|
||||
// and therefore should not be treated as a
|
||||
// user owned device.
|
||||
// Currently, this function only handles tags set
|
||||
// via CLI ("forced tags" and preauthkeys)
|
||||
func (node *Node) IsTagged() bool {
|
||||
if len(node.ForcedTags) > 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
if node.AuthKey != nil && len(node.AuthKey.Tags) > 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
if node.Hostinfo == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// TODO(kradalby): Figure out how tagging should work
|
||||
// and hostinfo.requestedtags.
|
||||
// Do this in other work.
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// HasTag reports if a node has a given tag.
|
||||
// Currently, this function only handles tags set
|
||||
// via CLI ("forced tags" and preauthkeys)
|
||||
func (node *Node) HasTag(tag string) bool {
|
||||
if slices.Contains(node.ForcedTags, tag) {
|
||||
return true
|
||||
}
|
||||
|
||||
if node.AuthKey != nil && slices.Contains(node.AuthKey.Tags, tag) {
|
||||
return true
|
||||
}
|
||||
|
||||
// TODO(kradalby): Figure out how tagging should work
|
||||
// and hostinfo.requestedtags.
|
||||
// Do this in other work.
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (node *Node) RequestTags() []string {
|
||||
if node.Hostinfo == nil {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
return node.Hostinfo.RequestTags
|
||||
}
|
||||
|
||||
func (node *Node) Prefixes() []netip.Prefix {
|
||||
addrs := []netip.Prefix{}
|
||||
for _, nodeAddress := range node.IPs() {
|
||||
@ -163,12 +225,8 @@ func (node *Node) Prefixes() []netip.Prefix {
|
||||
func (node *Node) IPsAsString() []string {
|
||||
var ret []string
|
||||
|
||||
if node.IPv4 != nil {
|
||||
ret = append(ret, node.IPv4.String())
|
||||
}
|
||||
|
||||
if node.IPv6 != nil {
|
||||
ret = append(ret, node.IPv6.String())
|
||||
for _, ip := range node.IPs() {
|
||||
ret = append(ret, ip.String())
|
||||
}
|
||||
|
||||
return ret
|
||||
@ -335,9 +393,9 @@ func (node *Node) SubnetRoutes() []netip.Prefix {
|
||||
return routes
|
||||
}
|
||||
|
||||
// func (node *Node) String() string {
|
||||
// return node.Hostname
|
||||
// }
|
||||
func (node *Node) String() string {
|
||||
return node.Hostname
|
||||
}
|
||||
|
||||
// PeerChangeFromMapRequest takes a MapRequest and compares it to the node
|
||||
// to produce a PeerChange struct that can be used to updated the node and
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"net/mail"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
@ -18,6 +19,19 @@ import (
|
||||
|
||||
type UserID uint64
|
||||
|
||||
type Users []User
|
||||
|
||||
func (u Users) String() string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("[ ")
|
||||
for _, user := range u {
|
||||
fmt.Fprintf(&sb, "%d: %s, ", user.ID, user.Name)
|
||||
}
|
||||
sb.WriteString(" ]")
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
// User is the way Headscale implements the concept of users in Tailscale
|
||||
//
|
||||
// At the end of the day, users in Tailscale are some kind of 'bubbles' or users
|
||||
@ -74,12 +88,13 @@ func (u *User) Username() string {
|
||||
u.Email,
|
||||
u.Name,
|
||||
u.ProviderIdentifier.String,
|
||||
u.StringID())
|
||||
u.StringID(),
|
||||
)
|
||||
}
|
||||
|
||||
// DisplayNameOrUsername returns the DisplayName if it exists, otherwise
|
||||
// Display returns the DisplayName if it exists, otherwise
|
||||
// it will return the Username.
|
||||
func (u *User) DisplayNameOrUsername() string {
|
||||
func (u *User) Display() string {
|
||||
return cmp.Or(u.DisplayName, u.Username())
|
||||
}
|
||||
|
||||
@ -91,7 +106,7 @@ func (u *User) profilePicURL() string {
|
||||
func (u *User) TailscaleUser() *tailcfg.User {
|
||||
user := tailcfg.User{
|
||||
ID: tailcfg.UserID(u.ID),
|
||||
DisplayName: u.DisplayNameOrUsername(),
|
||||
DisplayName: u.Display(),
|
||||
ProfilePicURL: u.profilePicURL(),
|
||||
Created: u.CreatedAt,
|
||||
}
|
||||
@ -101,11 +116,10 @@ func (u *User) TailscaleUser() *tailcfg.User {
|
||||
|
||||
func (u *User) TailscaleLogin() *tailcfg.Login {
|
||||
login := tailcfg.Login{
|
||||
ID: tailcfg.LoginID(u.ID),
|
||||
// TODO(kradalby): this should reflect registration method.
|
||||
ID: tailcfg.LoginID(u.ID),
|
||||
Provider: u.Provider,
|
||||
LoginName: u.Username(),
|
||||
DisplayName: u.DisplayNameOrUsername(),
|
||||
DisplayName: u.Display(),
|
||||
ProfilePicURL: u.profilePicURL(),
|
||||
}
|
||||
|
||||
@ -116,7 +130,7 @@ func (u *User) TailscaleUserProfile() tailcfg.UserProfile {
|
||||
return tailcfg.UserProfile{
|
||||
ID: tailcfg.UserID(u.ID),
|
||||
LoginName: u.Username(),
|
||||
DisplayName: u.DisplayNameOrUsername(),
|
||||
DisplayName: u.Display(),
|
||||
ProfilePicURL: u.profilePicURL(),
|
||||
}
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"iter"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
@ -111,3 +112,16 @@ func StringToIPPrefix(prefixes []string) ([]netip.Prefix, error) {
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// IPSetAddrIter returns a function that iterates over all the IPs in the IPSet.
|
||||
func IPSetAddrIter(ipSet *netipx.IPSet) iter.Seq[netip.Addr] {
|
||||
return func(yield func(netip.Addr) bool) {
|
||||
for _, rng := range ipSet.Ranges() {
|
||||
for ip := rng.From(); ip.Compare(rng.To()) <= 0; ip = ip.Next() {
|
||||
if !yield(ip) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,10 +1,13 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"go4.org/netipx"
|
||||
"tailscale.com/net/tsaddr"
|
||||
)
|
||||
|
||||
func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) {
|
||||
@ -13,24 +16,6 @@ func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) {
|
||||
return d.DialContext(ctx, "unix", addr)
|
||||
}
|
||||
|
||||
// TODO(kradalby): Remove when in stdlib;
|
||||
// https://github.com/golang/go/issues/61642
|
||||
// Compare returns an integer comparing two prefixes.
|
||||
// The result will be 0 if p == p2, -1 if p < p2, and +1 if p > p2.
|
||||
// Prefixes sort first by validity (invalid before valid), then
|
||||
// address family (IPv4 before IPv6), then prefix length, then
|
||||
// address.
|
||||
func ComparePrefix(p, p2 netip.Prefix) int {
|
||||
if c := cmp.Compare(p.Addr().BitLen(), p2.Addr().BitLen()); c != 0 {
|
||||
return c
|
||||
}
|
||||
if c := cmp.Compare(p.Bits(), p2.Bits()); c != 0 {
|
||||
return c
|
||||
}
|
||||
|
||||
return p.Addr().Compare(p2.Addr())
|
||||
}
|
||||
|
||||
func PrefixesToString(prefixes []netip.Prefix) []string {
|
||||
ret := make([]string, 0, len(prefixes))
|
||||
for _, prefix := range prefixes {
|
||||
@ -49,3 +34,29 @@ func MustStringsToPrefixes(strings []string) []netip.Prefix {
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
// TheInternet returns the IPSet for the Internet.
|
||||
// https://www.youtube.com/watch?v=iDbyYGrswtg
|
||||
var TheInternet = sync.OnceValue(func() *netipx.IPSet {
|
||||
var internetBuilder netipx.IPSetBuilder
|
||||
internetBuilder.AddPrefix(netip.MustParsePrefix("2000::/3"))
|
||||
internetBuilder.AddPrefix(tsaddr.AllIPv4())
|
||||
|
||||
// Delete Private network addresses
|
||||
// https://datatracker.ietf.org/doc/html/rfc1918
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("fc00::/7"))
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("10.0.0.0/8"))
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("172.16.0.0/12"))
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("192.168.0.0/16"))
|
||||
|
||||
// Delete Tailscale networks
|
||||
internetBuilder.RemovePrefix(tsaddr.TailscaleULARange())
|
||||
internetBuilder.RemovePrefix(tsaddr.CGNATRange())
|
||||
|
||||
// Delete "can't find DHCP networks"
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("fe80::/10")) // link-local
|
||||
internetBuilder.RemovePrefix(netip.MustParsePrefix("169.254.0.0/16"))
|
||||
|
||||
theInternetSet, _ := internetBuilder.IPSet()
|
||||
return theInternetSet
|
||||
})
|
||||
|
@ -8,7 +8,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
policyv1 "github.com/juanfont/headscale/hscontrol/policy/v1"
|
||||
"github.com/juanfont/headscale/integration/hsic"
|
||||
"github.com/juanfont/headscale/integration/tsic"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -50,7 +50,7 @@ var veryLargeDestination = []string{
|
||||
|
||||
func aclScenario(
|
||||
t *testing.T,
|
||||
policy *policy.ACLPolicy,
|
||||
policy *policyv1.ACLPolicy,
|
||||
clientsPerUser int,
|
||||
) *Scenario {
|
||||
t.Helper()
|
||||
@ -77,6 +77,8 @@ func aclScenario(
|
||||
},
|
||||
hsic.WithACLPolicy(policy),
|
||||
hsic.WithTestName("acl"),
|
||||
hsic.WithEmbeddedDERPServerOnly(),
|
||||
hsic.WithTLS(),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -100,7 +102,7 @@ func TestACLHostsInNetMapTable(t *testing.T) {
|
||||
// they can access minus one (them self).
|
||||
tests := map[string]struct {
|
||||
users map[string]int
|
||||
policy policy.ACLPolicy
|
||||
policy policyv1.ACLPolicy
|
||||
want map[string]int
|
||||
}{
|
||||
// Test that when we have no ACL, each client netmap has
|
||||
@ -110,8 +112,8 @@ func TestACLHostsInNetMapTable(t *testing.T) {
|
||||
"user1": 2,
|
||||
"user2": 2,
|
||||
},
|
||||
policy: policy.ACLPolicy{
|
||||
ACLs: []policy.ACL{
|
||||
policy: policyv1.ACLPolicy{
|
||||
ACLs: []policyv1.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"*"},
|
||||
@ -131,8 +133,8 @@ func TestACLHostsInNetMapTable(t *testing.T) {
|
||||
"user1": 2,
|
||||
"user2": 2,
|
||||
},
|
||||
policy: policy.ACLPolicy{
|
||||
ACLs: []policy.ACL{
|
||||
policy: policyv1.ACLPolicy{
|
||||
ACLs: []policyv1.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"user1"},
|
||||
@ -157,8 +159,8 @@ func TestACLHostsInNetMapTable(t *testing.T) {
|
||||
"user1": 2,
|
||||
"user2": 2,
|
||||
},
|
||||
policy: policy.ACLPolicy{
|
||||
ACLs: []policy.ACL{
|
||||
policy: policyv1.ACLPolicy{
|
||||
ACLs: []policyv1.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"user1"},
|
||||
@ -194,8 +196,8 @@ func TestACLHostsInNetMapTable(t *testing.T) {
|
||||
"user1": 2,
|
||||
"user2": 2,
|
||||
},
|
||||
policy: policy.ACLPolicy{
|
||||
ACLs: []policy.ACL{
|
||||
policy: policyv1.ACLPolicy{
|
||||
ACLs: []policyv1.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"user1"},
|
||||
@ -222,8 +224,8 @@ func TestACLHostsInNetMapTable(t *testing.T) {
|
||||
"user1": 2,
|
||||
"user2": 2,
|
||||
},
|
||||
policy: policy.ACLPolicy{
|
||||
ACLs: []policy.ACL{
|
||||
policy: policyv1.ACLPolicy{
|
||||
ACLs: []policyv1.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"user1"},
|
||||
@ -250,8 +252,8 @@ func TestACLHostsInNetMapTable(t *testing.T) {
|
||||
"user1": 2,
|
||||
"user2": 2,
|
||||
},
|
||||
policy: policy.ACLPolicy{
|
||||
ACLs: []policy.ACL{
|
||||
policy: policyv1.ACLPolicy{
|
||||
ACLs: []policyv1.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"*"},
|
||||
@ -306,8 +308,8 @@ func TestACLAllowUser80Dst(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
|
||||
scenario := aclScenario(t,
|
||||
&policy.ACLPolicy{
|
||||
ACLs: []policy.ACL{
|
||||
&policyv1.ACLPolicy{
|
||||
ACLs: []policyv1.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"user1"},
|
||||
@ -360,11 +362,11 @@ func TestACLDenyAllPort80(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
|
||||
scenario := aclScenario(t,
|
||||
&policy.ACLPolicy{
|
||||
&policyv1.ACLPolicy{
|
||||
Groups: map[string][]string{
|
||||
"group:integration-acl-test": {"user1", "user2"},
|
||||
},
|
||||
ACLs: []policy.ACL{
|
||||
ACLs: []policyv1.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"group:integration-acl-test"},
|
||||
@ -407,8 +409,8 @@ func TestACLAllowUserDst(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
|
||||
scenario := aclScenario(t,
|
||||
&policy.ACLPolicy{
|
||||
ACLs: []policy.ACL{
|
||||
&policyv1.ACLPolicy{
|
||||
ACLs: []policyv1.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"user1"},
|
||||
@ -463,8 +465,8 @@ func TestACLAllowStarDst(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
|
||||
scenario := aclScenario(t,
|
||||
&policy.ACLPolicy{
|
||||
ACLs: []policy.ACL{
|
||||
&policyv1.ACLPolicy{
|
||||
ACLs: []policyv1.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"user1"},
|
||||
@ -520,11 +522,11 @@ func TestACLNamedHostsCanReachBySubnet(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
|
||||
scenario := aclScenario(t,
|
||||
&policy.ACLPolicy{
|
||||
Hosts: policy.Hosts{
|
||||
&policyv1.ACLPolicy{
|
||||
Hosts: policyv1.Hosts{
|
||||
"all": netip.MustParsePrefix("100.64.0.0/24"),
|
||||
},
|
||||
ACLs: []policy.ACL{
|
||||
ACLs: []policyv1.ACL{
|
||||
// Everyone can curl test3
|
||||
{
|
||||
Action: "accept",
|
||||
@ -617,16 +619,16 @@ func TestACLNamedHostsCanReach(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
|
||||
tests := map[string]struct {
|
||||
policy policy.ACLPolicy
|
||||
policy policyv1.ACLPolicy
|
||||
}{
|
||||
"ipv4": {
|
||||
policy: policy.ACLPolicy{
|
||||
Hosts: policy.Hosts{
|
||||
policy: policyv1.ACLPolicy{
|
||||
Hosts: policyv1.Hosts{
|
||||
"test1": netip.MustParsePrefix("100.64.0.1/32"),
|
||||
"test2": netip.MustParsePrefix("100.64.0.2/32"),
|
||||
"test3": netip.MustParsePrefix("100.64.0.3/32"),
|
||||
},
|
||||
ACLs: []policy.ACL{
|
||||
ACLs: []policyv1.ACL{
|
||||
// Everyone can curl test3
|
||||
{
|
||||
Action: "accept",
|
||||
@ -643,13 +645,13 @@ func TestACLNamedHostsCanReach(t *testing.T) {
|
||||
},
|
||||
},
|
||||
"ipv6": {
|
||||
policy: policy.ACLPolicy{
|
||||
Hosts: policy.Hosts{
|
||||
policy: policyv1.ACLPolicy{
|
||||
Hosts: policyv1.Hosts{
|
||||
"test1": netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
|
||||
"test2": netip.MustParsePrefix("fd7a:115c:a1e0::2/128"),
|
||||
"test3": netip.MustParsePrefix("fd7a:115c:a1e0::3/128"),
|
||||
},
|
||||
ACLs: []policy.ACL{
|
||||
ACLs: []policyv1.ACL{
|
||||
// Everyone can curl test3
|
||||
{
|
||||
Action: "accept",
|
||||
@ -866,11 +868,11 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
|
||||
tests := map[string]struct {
|
||||
policy policy.ACLPolicy
|
||||
policy policyv1.ACLPolicy
|
||||
}{
|
||||
"ipv4": {
|
||||
policy: policy.ACLPolicy{
|
||||
ACLs: []policy.ACL{
|
||||
policy: policyv1.ACLPolicy{
|
||||
ACLs: []policyv1.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"100.64.0.1"},
|
||||
@ -880,8 +882,8 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
|
||||
},
|
||||
},
|
||||
"ipv6": {
|
||||
policy: policy.ACLPolicy{
|
||||
ACLs: []policy.ACL{
|
||||
policy: policyv1.ACLPolicy{
|
||||
ACLs: []policyv1.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"fd7a:115c:a1e0::1"},
|
||||
@ -891,12 +893,12 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
|
||||
},
|
||||
},
|
||||
"hostv4cidr": {
|
||||
policy: policy.ACLPolicy{
|
||||
Hosts: policy.Hosts{
|
||||
policy: policyv1.ACLPolicy{
|
||||
Hosts: policyv1.Hosts{
|
||||
"test1": netip.MustParsePrefix("100.64.0.1/32"),
|
||||
"test2": netip.MustParsePrefix("100.64.0.2/32"),
|
||||
},
|
||||
ACLs: []policy.ACL{
|
||||
ACLs: []policyv1.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"test1"},
|
||||
@ -906,12 +908,12 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
|
||||
},
|
||||
},
|
||||
"hostv6cidr": {
|
||||
policy: policy.ACLPolicy{
|
||||
Hosts: policy.Hosts{
|
||||
policy: policyv1.ACLPolicy{
|
||||
Hosts: policyv1.Hosts{
|
||||
"test1": netip.MustParsePrefix("fd7a:115c:a1e0::1/128"),
|
||||
"test2": netip.MustParsePrefix("fd7a:115c:a1e0::2/128"),
|
||||
},
|
||||
ACLs: []policy.ACL{
|
||||
ACLs: []policyv1.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"test1"},
|
||||
@ -921,12 +923,12 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) {
|
||||
},
|
||||
},
|
||||
"group": {
|
||||
policy: policy.ACLPolicy{
|
||||
policy: policyv1.ACLPolicy{
|
||||
Groups: map[string][]string{
|
||||
"group:one": {"user1"},
|
||||
"group:two": {"user2"},
|
||||
},
|
||||
ACLs: []policy.ACL{
|
||||
ACLs: []policyv1.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"group:one"},
|
||||
@ -1085,15 +1087,18 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
|
||||
headscale, err := scenario.Headscale()
|
||||
require.NoError(t, err)
|
||||
|
||||
p := policy.ACLPolicy{
|
||||
ACLs: []policy.ACL{
|
||||
p := policyv1.ACLPolicy{
|
||||
ACLs: []policyv1.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"user1"},
|
||||
Destinations: []string{"user2:*"},
|
||||
},
|
||||
},
|
||||
Hosts: policy.Hosts{},
|
||||
Hosts: policyv1.Hosts{},
|
||||
}
|
||||
if usePolicyV2ForTest {
|
||||
hsic.RewritePolicyToV2(&p)
|
||||
}
|
||||
|
||||
pBytes, _ := json.Marshal(p)
|
||||
@ -1118,7 +1123,7 @@ func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) {
|
||||
|
||||
// Get the current policy and check
|
||||
// if it is the same as the one we set.
|
||||
var output *policy.ACLPolicy
|
||||
var output *policyv1.ACLPolicy
|
||||
err = executeAndUnmarshal(
|
||||
headscale,
|
||||
[]string{
|
||||
|
@ -11,7 +11,7 @@ import (
|
||||
tcmp "github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
policyv1 "github.com/juanfont/headscale/hscontrol/policy/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/integration/hsic"
|
||||
"github.com/juanfont/headscale/integration/tsic"
|
||||
@ -915,7 +915,7 @@ func TestNodeAdvertiseTagCommand(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
policy *policy.ACLPolicy
|
||||
policy *policyv1.ACLPolicy
|
||||
wantTag bool
|
||||
}{
|
||||
{
|
||||
@ -924,8 +924,8 @@ func TestNodeAdvertiseTagCommand(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "with-policy-email",
|
||||
policy: &policy.ACLPolicy{
|
||||
ACLs: []policy.ACL{
|
||||
policy: &policyv1.ACLPolicy{
|
||||
ACLs: []policyv1.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"*"},
|
||||
@ -940,8 +940,8 @@ func TestNodeAdvertiseTagCommand(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "with-policy-username",
|
||||
policy: &policy.ACLPolicy{
|
||||
ACLs: []policy.ACL{
|
||||
policy: &policyv1.ACLPolicy{
|
||||
ACLs: []policyv1.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"*"},
|
||||
@ -956,11 +956,11 @@ func TestNodeAdvertiseTagCommand(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "with-policy-groups",
|
||||
policy: &policy.ACLPolicy{
|
||||
Groups: policy.Groups{
|
||||
policy: &policyv1.ACLPolicy{
|
||||
Groups: policyv1.Groups{
|
||||
"group:admins": []string{"user1"},
|
||||
},
|
||||
ACLs: []policy.ACL{
|
||||
ACLs: []policyv1.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"*"},
|
||||
@ -1726,7 +1726,7 @@ func TestPolicyCommand(t *testing.T) {
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
"policy-user": 0,
|
||||
"user1": 0,
|
||||
}
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(
|
||||
@ -1742,8 +1742,8 @@ func TestPolicyCommand(t *testing.T) {
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
|
||||
p := policy.ACLPolicy{
|
||||
ACLs: []policy.ACL{
|
||||
p := policyv1.ACLPolicy{
|
||||
ACLs: []policyv1.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"*"},
|
||||
@ -1751,9 +1751,12 @@ func TestPolicyCommand(t *testing.T) {
|
||||
},
|
||||
},
|
||||
TagOwners: map[string][]string{
|
||||
"tag:exists": {"policy-user"},
|
||||
"tag:exists": {"user1"},
|
||||
},
|
||||
}
|
||||
if usePolicyV2ForTest {
|
||||
hsic.RewritePolicyToV2(&p)
|
||||
}
|
||||
|
||||
pBytes, _ := json.Marshal(p)
|
||||
|
||||
@ -1778,7 +1781,7 @@ func TestPolicyCommand(t *testing.T) {
|
||||
|
||||
// Get the current policy and check
|
||||
// if it is the same as the one we set.
|
||||
var output *policy.ACLPolicy
|
||||
var output *policyv1.ACLPolicy
|
||||
err = executeAndUnmarshal(
|
||||
headscale,
|
||||
[]string{
|
||||
@ -1794,7 +1797,11 @@ func TestPolicyCommand(t *testing.T) {
|
||||
|
||||
assert.Len(t, output.TagOwners, 1)
|
||||
assert.Len(t, output.ACLs, 1)
|
||||
assert.Equal(t, output.TagOwners["tag:exists"], []string{"policy-user"})
|
||||
if usePolicyV2ForTest {
|
||||
assert.Equal(t, output.TagOwners["tag:exists"], []string{"user1@"})
|
||||
} else {
|
||||
assert.Equal(t, output.TagOwners["tag:exists"], []string{"user1"})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPolicyBrokenConfigCommand(t *testing.T) {
|
||||
@ -1806,7 +1813,7 @@ func TestPolicyBrokenConfigCommand(t *testing.T) {
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
spec := map[string]int{
|
||||
"policy-user": 1,
|
||||
"user1": 1,
|
||||
}
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(
|
||||
@ -1822,8 +1829,8 @@ func TestPolicyBrokenConfigCommand(t *testing.T) {
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
|
||||
p := policy.ACLPolicy{
|
||||
ACLs: []policy.ACL{
|
||||
p := policyv1.ACLPolicy{
|
||||
ACLs: []policyv1.ACL{
|
||||
{
|
||||
// This is an unknown action, so it will return an error
|
||||
// and the config will not be applied.
|
||||
@ -1833,9 +1840,12 @@ func TestPolicyBrokenConfigCommand(t *testing.T) {
|
||||
},
|
||||
},
|
||||
TagOwners: map[string][]string{
|
||||
"tag:exists": {"policy-user"},
|
||||
"tag:exists": {"user1"},
|
||||
},
|
||||
}
|
||||
if usePolicyV2ForTest {
|
||||
hsic.RewritePolicyToV2(&p)
|
||||
}
|
||||
|
||||
pBytes, _ := json.Marshal(p)
|
||||
|
||||
|
@ -365,7 +365,11 @@ func TestTaildrop(t *testing.T) {
|
||||
"taildrop": len(MustTestVersions),
|
||||
}
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("taildrop"))
|
||||
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{},
|
||||
hsic.WithTestName("taildrop"),
|
||||
hsic.WithEmbeddedDERPServerOnly(),
|
||||
hsic.WithTLS(),
|
||||
)
|
||||
assertNoErrHeadscaleEnv(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"path"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -19,7 +20,7 @@ import (
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
policyv1 "github.com/juanfont/headscale/hscontrol/policy/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/juanfont/headscale/integration/dockertestutil"
|
||||
@ -64,12 +65,13 @@ type HeadscaleInContainer struct {
|
||||
extraPorts []string
|
||||
caCerts [][]byte
|
||||
hostPortBindings map[string][]string
|
||||
aclPolicy *policy.ACLPolicy
|
||||
aclPolicy *policyv1.ACLPolicy
|
||||
env map[string]string
|
||||
tlsCert []byte
|
||||
tlsKey []byte
|
||||
filesInContainer []fileInContainer
|
||||
postgres bool
|
||||
policyV2 bool
|
||||
}
|
||||
|
||||
// Option represent optional settings that can be given to a
|
||||
@ -78,7 +80,7 @@ type Option = func(c *HeadscaleInContainer)
|
||||
|
||||
// WithACLPolicy adds a hscontrol.ACLPolicy policy to the
|
||||
// HeadscaleInContainer instance.
|
||||
func WithACLPolicy(acl *policy.ACLPolicy) Option {
|
||||
func WithACLPolicy(acl *policyv1.ACLPolicy) Option {
|
||||
return func(hsic *HeadscaleInContainer) {
|
||||
if acl == nil {
|
||||
return
|
||||
@ -186,6 +188,14 @@ func WithPostgres() Option {
|
||||
}
|
||||
}
|
||||
|
||||
// WithPolicyV2 tells the integration test to use the new v2 filter.
|
||||
func WithPolicyV2() Option {
|
||||
return func(hsic *HeadscaleInContainer) {
|
||||
hsic.policyV2 = true
|
||||
hsic.env["HEADSCALE_EXPERIMENTAL_POLICY_V2"] = "1"
|
||||
}
|
||||
}
|
||||
|
||||
// WithIPAllocationStrategy sets the tests IP Allocation strategy.
|
||||
func WithIPAllocationStrategy(strategy types.IPAllocationStrategy) Option {
|
||||
return func(hsic *HeadscaleInContainer) {
|
||||
@ -403,6 +413,10 @@ func New(
|
||||
}
|
||||
|
||||
if hsic.aclPolicy != nil {
|
||||
// Rewrite all user entries in the policy to have an @ at the end.
|
||||
if hsic.policyV2 {
|
||||
RewritePolicyToV2(hsic.aclPolicy)
|
||||
}
|
||||
data, err := json.Marshal(hsic.aclPolicy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal ACL Policy to JSON: %w", err)
|
||||
@ -869,3 +883,50 @@ func (t *HeadscaleInContainer) SendInterrupt() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(kradalby): Remove this function when v1 is deprecated
|
||||
func rewriteUsersToV2(strs []string) []string {
|
||||
var result []string
|
||||
userPattern := regexp.MustCompile(`^user\d+$`)
|
||||
|
||||
for _, username := range strs {
|
||||
parts := strings.Split(username, ":")
|
||||
if len(parts) == 0 {
|
||||
result = append(result, username)
|
||||
continue
|
||||
}
|
||||
firstPart := parts[0]
|
||||
if userPattern.MatchString(firstPart) {
|
||||
modifiedFirst := firstPart + "@"
|
||||
if len(parts) > 1 {
|
||||
rest := strings.Join(parts[1:], ":")
|
||||
username = modifiedFirst + ":" + rest
|
||||
} else {
|
||||
username = modifiedFirst
|
||||
}
|
||||
}
|
||||
result = append(result, username)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// rewritePolicyToV2 rewrites the policy to v2 format.
|
||||
// This mostly means adding the @ prefix to user names.
|
||||
// replaces are done inplace
|
||||
func RewritePolicyToV2(pol *policyv1.ACLPolicy) {
|
||||
for idx := range pol.ACLs {
|
||||
pol.ACLs[idx].Sources = rewriteUsersToV2(pol.ACLs[idx].Sources)
|
||||
pol.ACLs[idx].Destinations = rewriteUsersToV2(pol.ACLs[idx].Destinations)
|
||||
}
|
||||
for idx := range pol.Groups {
|
||||
pol.Groups[idx] = rewriteUsersToV2(pol.Groups[idx])
|
||||
}
|
||||
for idx := range pol.TagOwners {
|
||||
pol.TagOwners[idx] = rewriteUsersToV2(pol.TagOwners[idx])
|
||||
}
|
||||
for idx := range pol.SSHs {
|
||||
pol.SSHs[idx].Sources = rewriteUsersToV2(pol.SSHs[idx].Sources)
|
||||
pol.SSHs[idx].Destinations = rewriteUsersToV2(pol.SSHs[idx].Destinations)
|
||||
}
|
||||
}
|
||||
|
@ -8,7 +8,7 @@ import (
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
policyv1 "github.com/juanfont/headscale/hscontrol/policy/v1"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/juanfont/headscale/integration/hsic"
|
||||
"github.com/juanfont/headscale/integration/tsic"
|
||||
@ -29,7 +29,7 @@ func TestEnablingRoutes(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
user := "enable-routing"
|
||||
user := "user6"
|
||||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
require.NoErrorf(t, err, "failed to create scenario: %s", err)
|
||||
@ -203,7 +203,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
user := "enable-routing"
|
||||
user := "user9"
|
||||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
require.NoErrorf(t, err, "failed to create scenario: %s", err)
|
||||
@ -528,7 +528,7 @@ func TestEnableDisableAutoApprovedRoute(t *testing.T) {
|
||||
|
||||
expectedRoutes := "172.0.0.0/24"
|
||||
|
||||
user := "enable-disable-routing"
|
||||
user := "user2"
|
||||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
require.NoErrorf(t, err, "failed to create scenario: %s", err)
|
||||
@ -539,8 +539,8 @@ func TestEnableDisableAutoApprovedRoute(t *testing.T) {
|
||||
}
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{tsic.WithTags([]string{"tag:approve"})}, hsic.WithTestName("clienableroute"), hsic.WithACLPolicy(
|
||||
&policy.ACLPolicy{
|
||||
ACLs: []policy.ACL{
|
||||
&policyv1.ACLPolicy{
|
||||
ACLs: []policyv1.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"*"},
|
||||
@ -550,7 +550,7 @@ func TestEnableDisableAutoApprovedRoute(t *testing.T) {
|
||||
TagOwners: map[string][]string{
|
||||
"tag:approve": {user},
|
||||
},
|
||||
AutoApprovers: policy.AutoApprovers{
|
||||
AutoApprovers: policyv1.AutoApprovers{
|
||||
Routes: map[string][]string{
|
||||
expectedRoutes: {"tag:approve"},
|
||||
},
|
||||
@ -640,8 +640,8 @@ func TestAutoApprovedSubRoute2068(t *testing.T) {
|
||||
hsic.WithEmbeddedDERPServerOnly(),
|
||||
hsic.WithTLS(),
|
||||
hsic.WithACLPolicy(
|
||||
&policy.ACLPolicy{
|
||||
ACLs: []policy.ACL{
|
||||
&policyv1.ACLPolicy{
|
||||
ACLs: []policyv1.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"*"},
|
||||
@ -651,7 +651,7 @@ func TestAutoApprovedSubRoute2068(t *testing.T) {
|
||||
TagOwners: map[string][]string{
|
||||
"tag:approve": {user},
|
||||
},
|
||||
AutoApprovers: policy.AutoApprovers{
|
||||
AutoApprovers: policyv1.AutoApprovers{
|
||||
Routes: map[string][]string{
|
||||
"10.42.0.0/16": {"tag:approve"},
|
||||
},
|
||||
@ -696,7 +696,7 @@ func TestSubnetRouteACL(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
user := "subnet-route-acl"
|
||||
user := "user4"
|
||||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
require.NoErrorf(t, err, "failed to create scenario: %s", err)
|
||||
@ -707,11 +707,11 @@ func TestSubnetRouteACL(t *testing.T) {
|
||||
}
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clienableroute"), hsic.WithACLPolicy(
|
||||
&policy.ACLPolicy{
|
||||
Groups: policy.Groups{
|
||||
&policyv1.ACLPolicy{
|
||||
Groups: policyv1.Groups{
|
||||
"group:admins": {user},
|
||||
},
|
||||
ACLs: []policy.ACL{
|
||||
ACLs: []policyv1.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"group:admins"},
|
||||
|
@ -33,6 +33,7 @@ const (
|
||||
)
|
||||
|
||||
var usePostgresForTest = envknob.Bool("HEADSCALE_INTEGRATION_POSTGRES")
|
||||
var usePolicyV2ForTest = envknob.Bool("HEADSCALE_EXPERIMENTAL_POLICY_V2")
|
||||
|
||||
var (
|
||||
errNoHeadscaleAvailable = errors.New("no headscale available")
|
||||
@ -230,6 +231,10 @@ func (s *Scenario) Headscale(opts ...hsic.Option) (ControlServer, error) {
|
||||
opts = append(opts, hsic.WithPostgres())
|
||||
}
|
||||
|
||||
if usePolicyV2ForTest {
|
||||
opts = append(opts, hsic.WithPolicyV2())
|
||||
}
|
||||
|
||||
headscale, err := hsic.New(s.pool, s.network, opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create headscale container: %w", err)
|
||||
|
@ -7,7 +7,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
policyv1 "github.com/juanfont/headscale/hscontrol/policy/v1"
|
||||
"github.com/juanfont/headscale/integration/hsic"
|
||||
"github.com/juanfont/headscale/integration/tsic"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -48,7 +48,7 @@ var retry = func(times int, sleepInterval time.Duration,
|
||||
return result, stderr, err
|
||||
}
|
||||
|
||||
func sshScenario(t *testing.T, policy *policy.ACLPolicy, clientsPerUser int) *Scenario {
|
||||
func sshScenario(t *testing.T, policy *policyv1.ACLPolicy, clientsPerUser int) *Scenario {
|
||||
t.Helper()
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
@ -92,18 +92,18 @@ func TestSSHOneUserToAll(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenario := sshScenario(t,
|
||||
&policy.ACLPolicy{
|
||||
&policyv1.ACLPolicy{
|
||||
Groups: map[string][]string{
|
||||
"group:integration-test": {"user1"},
|
||||
},
|
||||
ACLs: []policy.ACL{
|
||||
ACLs: []policyv1.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"*"},
|
||||
Destinations: []string{"*:*"},
|
||||
},
|
||||
},
|
||||
SSHs: []policy.SSH{
|
||||
SSHs: []policyv1.SSH{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"group:integration-test"},
|
||||
@ -157,18 +157,18 @@ func TestSSHMultipleUsersAllToAll(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenario := sshScenario(t,
|
||||
&policy.ACLPolicy{
|
||||
&policyv1.ACLPolicy{
|
||||
Groups: map[string][]string{
|
||||
"group:integration-test": {"user1", "user2"},
|
||||
},
|
||||
ACLs: []policy.ACL{
|
||||
ACLs: []policyv1.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"*"},
|
||||
Destinations: []string{"*:*"},
|
||||
},
|
||||
},
|
||||
SSHs: []policy.SSH{
|
||||
SSHs: []policyv1.SSH{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"group:integration-test"},
|
||||
@ -210,18 +210,18 @@ func TestSSHNoSSHConfigured(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenario := sshScenario(t,
|
||||
&policy.ACLPolicy{
|
||||
&policyv1.ACLPolicy{
|
||||
Groups: map[string][]string{
|
||||
"group:integration-test": {"user1"},
|
||||
},
|
||||
ACLs: []policy.ACL{
|
||||
ACLs: []policyv1.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"*"},
|
||||
Destinations: []string{"*:*"},
|
||||
},
|
||||
},
|
||||
SSHs: []policy.SSH{},
|
||||
SSHs: []policyv1.SSH{},
|
||||
},
|
||||
len(MustTestVersions),
|
||||
)
|
||||
@ -252,18 +252,18 @@ func TestSSHIsBlockedInACL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenario := sshScenario(t,
|
||||
&policy.ACLPolicy{
|
||||
&policyv1.ACLPolicy{
|
||||
Groups: map[string][]string{
|
||||
"group:integration-test": {"user1"},
|
||||
},
|
||||
ACLs: []policy.ACL{
|
||||
ACLs: []policyv1.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"*"},
|
||||
Destinations: []string{"*:80"},
|
||||
},
|
||||
},
|
||||
SSHs: []policy.SSH{
|
||||
SSHs: []policyv1.SSH{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"group:integration-test"},
|
||||
@ -301,19 +301,19 @@ func TestSSHUserOnlyIsolation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
scenario := sshScenario(t,
|
||||
&policy.ACLPolicy{
|
||||
&policyv1.ACLPolicy{
|
||||
Groups: map[string][]string{
|
||||
"group:ssh1": {"user1"},
|
||||
"group:ssh2": {"user2"},
|
||||
},
|
||||
ACLs: []policy.ACL{
|
||||
ACLs: []policyv1.ACL{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"*"},
|
||||
Destinations: []string{"*:*"},
|
||||
},
|
||||
},
|
||||
SSHs: []policy.SSH{
|
||||
SSHs: []policyv1.SSH{
|
||||
{
|
||||
Action: "accept",
|
||||
Sources: []string{"group:ssh1"},
|
||||
|
Loading…
x
Reference in New Issue
Block a user