mirror of
https://github.com/juanfont/headscale.git
synced 2024-11-27 03:55:20 +00:00
wrap policy in policy manager interface (#2255)
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
parent
2c1ad6d11a
commit
f7b0cbbbea
3
.github/workflows/test-integration.yaml
vendored
3
.github/workflows/test-integration.yaml
vendored
@ -31,8 +31,7 @@ jobs:
|
|||||||
- TestPreAuthKeyCorrectUserLoggedInCommand
|
- TestPreAuthKeyCorrectUserLoggedInCommand
|
||||||
- TestApiKeyCommand
|
- TestApiKeyCommand
|
||||||
- TestNodeTagCommand
|
- TestNodeTagCommand
|
||||||
- TestNodeAdvertiseTagNoACLCommand
|
- TestNodeAdvertiseTagCommand
|
||||||
- TestNodeAdvertiseTagWithACLCommand
|
|
||||||
- TestNodeCommand
|
- TestNodeCommand
|
||||||
- TestNodeExpireCommand
|
- TestNodeExpireCommand
|
||||||
- TestNodeRenameCommand
|
- TestNodeRenameCommand
|
||||||
|
163
hscontrol/app.go
163
hscontrol/app.go
@ -88,7 +88,8 @@ type Headscale struct {
|
|||||||
DERPMap *tailcfg.DERPMap
|
DERPMap *tailcfg.DERPMap
|
||||||
DERPServer *derpServer.DERPServer
|
DERPServer *derpServer.DERPServer
|
||||||
|
|
||||||
ACLPolicy *policy.ACLPolicy
|
polManOnce sync.Once
|
||||||
|
polMan policy.PolicyManager
|
||||||
|
|
||||||
mapper *mapper.Mapper
|
mapper *mapper.Mapper
|
||||||
nodeNotifier *notifier.Notifier
|
nodeNotifier *notifier.Notifier
|
||||||
@ -153,6 +154,10 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if err = app.loadPolicyManager(); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load ACL policy: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
var authProvider AuthProvider
|
var authProvider AuthProvider
|
||||||
authProvider = NewAuthProviderWeb(cfg.ServerURL)
|
authProvider = NewAuthProviderWeb(cfg.ServerURL)
|
||||||
if cfg.OIDC.Issuer != "" {
|
if cfg.OIDC.Issuer != "" {
|
||||||
@ -165,6 +170,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|||||||
app.db,
|
app.db,
|
||||||
app.nodeNotifier,
|
app.nodeNotifier,
|
||||||
app.ipAlloc,
|
app.ipAlloc,
|
||||||
|
app.polMan,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if cfg.OIDC.OnlyStartIfOIDCIsAvailable {
|
if cfg.OIDC.OnlyStartIfOIDCIsAvailable {
|
||||||
@ -475,6 +481,52 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
|
|||||||
return router
|
return router
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
|
||||||
|
// Maybe we should attempt a new in memory state and not go via the DB?
|
||||||
|
func usersChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error {
|
||||||
|
users, err := db.ListUsers()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
changed, err := polMan.SetUsers(users)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if changed {
|
||||||
|
ctx := types.NotifyCtx(context.Background(), "acl-users-change", "all")
|
||||||
|
notif.NotifyAll(ctx, types.StateUpdate{
|
||||||
|
Type: types.StateFullUpdate,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(kradalby): Do a variant of this, and polman which only updates the node that has changed.
|
||||||
|
// Maybe we should attempt a new in memory state and not go via the DB?
|
||||||
|
func nodesChangedHook(db *db.HSDatabase, polMan policy.PolicyManager, notif *notifier.Notifier) error {
|
||||||
|
nodes, err := db.ListNodes()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
changed, err := polMan.SetNodes(nodes)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if changed {
|
||||||
|
ctx := types.NotifyCtx(context.Background(), "acl-nodes-change", "all")
|
||||||
|
notif.NotifyAll(ctx, types.StateUpdate{
|
||||||
|
Type: types.StateFullUpdate,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Serve launches the HTTP and gRPC server service Headscale and the API.
|
// Serve launches the HTTP and gRPC server service Headscale and the API.
|
||||||
func (h *Headscale) Serve() error {
|
func (h *Headscale) Serve() error {
|
||||||
if profilingEnabled {
|
if profilingEnabled {
|
||||||
@ -490,19 +542,13 @@ func (h *Headscale) Serve() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
|
||||||
|
|
||||||
if err = h.loadACLPolicy(); err != nil {
|
|
||||||
return fmt.Errorf("failed to load ACL policy: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if dumpConfig {
|
if dumpConfig {
|
||||||
spew.Dump(h.cfg)
|
spew.Dump(h.cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetch an initial DERP Map before we start serving
|
// Fetch an initial DERP Map before we start serving
|
||||||
h.DERPMap = derp.GetDERPMap(h.cfg.DERP)
|
h.DERPMap = derp.GetDERPMap(h.cfg.DERP)
|
||||||
h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier)
|
h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier, h.polMan)
|
||||||
|
|
||||||
if h.cfg.DERP.ServerEnabled {
|
if h.cfg.DERP.ServerEnabled {
|
||||||
// When embedded DERP is enabled we always need a STUN server
|
// When embedded DERP is enabled we always need a STUN server
|
||||||
@ -772,12 +818,21 @@ func (h *Headscale) Serve() error {
|
|||||||
Str("signal", sig.String()).
|
Str("signal", sig.String()).
|
||||||
Msg("Received SIGHUP, reloading ACL and Config")
|
Msg("Received SIGHUP, reloading ACL and Config")
|
||||||
|
|
||||||
// TODO(kradalby): Reload config on SIGHUP
|
if err := h.loadPolicyManager(); err != nil {
|
||||||
if err := h.loadACLPolicy(); err != nil {
|
log.Error().Err(err).Msg("failed to reload Policy")
|
||||||
log.Error().Err(err).Msg("failed to reload ACL policy")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if h.ACLPolicy != nil {
|
pol, err := h.policyBytes()
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("failed to get policy blob")
|
||||||
|
}
|
||||||
|
|
||||||
|
changed, err := h.polMan.SetPolicy(pol)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("failed to set new policy")
|
||||||
|
}
|
||||||
|
|
||||||
|
if changed {
|
||||||
log.Info().
|
log.Info().
|
||||||
Msg("ACL policy successfully reloaded, notifying nodes of change")
|
Msg("ACL policy successfully reloaded, notifying nodes of change")
|
||||||
|
|
||||||
@ -996,27 +1051,46 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
|
|||||||
return &machineKey, nil
|
return &machineKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Headscale) loadACLPolicy() error {
|
// policyBytes returns the appropriate policy for the
|
||||||
var (
|
// current configuration as a []byte array.
|
||||||
pol *policy.ACLPolicy
|
func (h *Headscale) policyBytes() ([]byte, error) {
|
||||||
err error
|
|
||||||
)
|
|
||||||
|
|
||||||
switch h.cfg.Policy.Mode {
|
switch h.cfg.Policy.Mode {
|
||||||
case types.PolicyModeFile:
|
case types.PolicyModeFile:
|
||||||
path := h.cfg.Policy.Path
|
path := h.cfg.Policy.Path
|
||||||
|
|
||||||
// It is fine to start headscale without a policy file.
|
// It is fine to start headscale without a policy file.
|
||||||
if len(path) == 0 {
|
if len(path) == 0 {
|
||||||
return nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
absPath := util.AbsolutePathFromConfigPath(path)
|
absPath := util.AbsolutePathFromConfigPath(path)
|
||||||
pol, err = policy.LoadACLPolicyFromPath(absPath)
|
policyFile, err := os.Open(absPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to load ACL policy from file: %w", err)
|
return nil, err
|
||||||
|
}
|
||||||
|
defer policyFile.Close()
|
||||||
|
|
||||||
|
return io.ReadAll(policyFile)
|
||||||
|
|
||||||
|
case types.PolicyModeDB:
|
||||||
|
p, err := h.db.GetPolicy()
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, types.ErrPolicyNotFound) {
|
||||||
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return []byte(p.Data), err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("unsupported policy mode: %s", h.cfg.Policy.Mode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Headscale) loadPolicyManager() error {
|
||||||
|
var errOut error
|
||||||
|
h.polManOnce.Do(func() {
|
||||||
// Validate and reject configuration that would error when applied
|
// Validate and reject configuration that would error when applied
|
||||||
// when creating a map response. This requires nodes, so there is still
|
// when creating a map response. This requires nodes, so there is still
|
||||||
// a scenario where they might be allowed if the server has no nodes
|
// a scenario where they might be allowed if the server has no nodes
|
||||||
@ -1027,46 +1101,35 @@ func (h *Headscale) loadACLPolicy() error {
|
|||||||
// allowed to be written to the database.
|
// allowed to be written to the database.
|
||||||
nodes, err := h.db.ListNodes()
|
nodes, err := h.db.ListNodes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("loading nodes from database to validate policy: %w", err)
|
errOut = fmt.Errorf("loading nodes from database to validate policy: %w", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
users, err := h.db.ListUsers()
|
users, err := h.db.ListUsers()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("loading users from database to validate policy: %w", err)
|
errOut = fmt.Errorf("loading users from database to validate policy: %w", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = pol.CompileFilterRules(users, nodes)
|
pol, err := h.policyBytes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("verifying policy rules: %w", err)
|
errOut = fmt.Errorf("loading policy bytes: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.polMan, err = policy.NewPolicyManager(pol, users, nodes)
|
||||||
|
if err != nil {
|
||||||
|
errOut = fmt.Errorf("creating policy manager: %w", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(nodes) > 0 {
|
if len(nodes) > 0 {
|
||||||
_, err = pol.CompileSSHPolicy(nodes[0], users, nodes)
|
_, err = h.polMan.SSHPolicy(nodes[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("verifying SSH rules: %w", err)
|
errOut = fmt.Errorf("verifying SSH rules: %w", err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
})
|
||||||
|
|
||||||
case types.PolicyModeDB:
|
return errOut
|
||||||
p, err := h.db.GetPolicy()
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, types.ErrPolicyNotFound) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("failed to get policy from database: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
pol, err = policy.LoadACLPolicyFromBytes([]byte(p.Data))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to parse policy: %w", err)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
log.Fatal().
|
|
||||||
Str("mode", string(h.cfg.Policy.Mode)).
|
|
||||||
Msg("Unknown ACL policy mode")
|
|
||||||
}
|
|
||||||
|
|
||||||
h.ACLPolicy = pol
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
@ -384,6 +384,13 @@ func (h *Headscale) handleAuthKey(
|
|||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = nodesChangedHook(h.db, h.polMan, h.nodeNotifier)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(writer, "Internal server error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = h.db.Write(func(tx *gorm.DB) error {
|
err = h.db.Write(func(tx *gorm.DB) error {
|
||||||
|
@ -563,7 +563,7 @@ func TestAutoApproveRoutes(t *testing.T) {
|
|||||||
pol, err := policy.LoadACLPolicyFromBytes([]byte(tt.acl))
|
pol, err := policy.LoadACLPolicyFromBytes([]byte(tt.acl))
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotNil(t, pol)
|
require.NotNil(t, pol)
|
||||||
|
|
||||||
user, err := adb.CreateUser("test")
|
user, err := adb.CreateUser("test")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@ -600,8 +600,17 @@ func TestAutoApproveRoutes(t *testing.T) {
|
|||||||
node0ByID, err := adb.GetNodeByID(0)
|
node0ByID, err := adb.GetNodeByID(0)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
users, err := adb.ListUsers()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
nodes, err := adb.ListNodes()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
pm, err := policy.NewPolicyManager([]byte(tt.acl), users, nodes)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// TODO(kradalby): Check state update
|
// TODO(kradalby): Check state update
|
||||||
err = adb.EnableAutoApprovedRoutes(pol, node0ByID)
|
err = adb.EnableAutoApprovedRoutes(pm, node0ByID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
enabledRoutes, err := adb.GetEnabledRoutes(node0ByID)
|
enabledRoutes, err := adb.GetEnabledRoutes(node0ByID)
|
||||||
|
@ -598,18 +598,18 @@ func failoverRoute(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
|
func (hsdb *HSDatabase) EnableAutoApprovedRoutes(
|
||||||
aclPolicy *policy.ACLPolicy,
|
polMan policy.PolicyManager,
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
) error {
|
) error {
|
||||||
return hsdb.Write(func(tx *gorm.DB) error {
|
return hsdb.Write(func(tx *gorm.DB) error {
|
||||||
return EnableAutoApprovedRoutes(tx, aclPolicy, node)
|
return EnableAutoApprovedRoutes(tx, polMan, node)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy.
|
// EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy.
|
||||||
func EnableAutoApprovedRoutes(
|
func EnableAutoApprovedRoutes(
|
||||||
tx *gorm.DB,
|
tx *gorm.DB,
|
||||||
aclPolicy *policy.ACLPolicy,
|
polMan policy.PolicyManager,
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
) error {
|
) error {
|
||||||
if node.IPv4 == nil && node.IPv6 == nil {
|
if node.IPv4 == nil && node.IPv6 == nil {
|
||||||
@ -630,12 +630,7 @@ func EnableAutoApprovedRoutes(
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
routeApprovers, err := aclPolicy.AutoApprovers.GetRouteApprovers(
|
routeApprovers := polMan.ApproversForRoute(netip.Prefix(advertisedRoute.Prefix))
|
||||||
netip.Prefix(advertisedRoute.Prefix),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to resolve autoApprovers for route(%d) for node(%s %d): %w", advertisedRoute.ID, node.Hostname, node.ID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Trace().
|
log.Trace().
|
||||||
Str("node", node.Hostname).
|
Str("node", node.Hostname).
|
||||||
@ -648,13 +643,8 @@ func EnableAutoApprovedRoutes(
|
|||||||
if approvedAlias == node.User.Username() {
|
if approvedAlias == node.User.Username() {
|
||||||
approvedRoutes = append(approvedRoutes, advertisedRoute)
|
approvedRoutes = append(approvedRoutes, advertisedRoute)
|
||||||
} else {
|
} else {
|
||||||
users, err := ListUsers(tx)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("looking up users to expand route alias: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(kradalby): figure out how to get this to depend on less stuff
|
// TODO(kradalby): figure out how to get this to depend on less stuff
|
||||||
approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, users, approvedAlias)
|
approvedIps, err := polMan.ExpandAlias(approvedAlias)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err)
|
return fmt.Errorf("expanding alias %q for autoApprovers: %w", approvedAlias, err)
|
||||||
}
|
}
|
||||||
|
@ -21,7 +21,6 @@ import (
|
|||||||
|
|
||||||
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
|
||||||
"github.com/juanfont/headscale/hscontrol/db"
|
"github.com/juanfont/headscale/hscontrol/db"
|
||||||
"github.com/juanfont/headscale/hscontrol/policy"
|
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
)
|
)
|
||||||
@ -58,6 +57,11 @@ func (api headscaleV1APIServer) CreateUser(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = usersChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("updating resources using user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return &v1.CreateUserResponse{User: user.Proto()}, nil
|
return &v1.CreateUserResponse{User: user.Proto()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -97,6 +101,11 @@ func (api headscaleV1APIServer) DeleteUser(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = usersChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("updating resources using user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return &v1.DeleteUserResponse{}, nil
|
return &v1.DeleteUserResponse{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -241,6 +250,11 @@ func (api headscaleV1APIServer) RegisterNode(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = nodesChangedHook(api.h.db, api.h.polMan, api.h.nodeNotifier)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("updating resources using node: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return &v1.RegisterNodeResponse{Node: node.Proto()}, nil
|
return &v1.RegisterNodeResponse{Node: node.Proto()}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -480,10 +494,7 @@ func (api headscaleV1APIServer) ListNodes(
|
|||||||
resp.Online = true
|
resp.Online = true
|
||||||
}
|
}
|
||||||
|
|
||||||
validTags, invalidTags := api.h.ACLPolicy.TagsOfNode(
|
validTags := api.h.polMan.Tags(node)
|
||||||
node,
|
|
||||||
)
|
|
||||||
resp.InvalidTags = invalidTags
|
|
||||||
resp.ValidTags = validTags
|
resp.ValidTags = validTags
|
||||||
response[index] = resp
|
response[index] = resp
|
||||||
}
|
}
|
||||||
@ -759,11 +770,6 @@ func (api headscaleV1APIServer) SetPolicy(
|
|||||||
|
|
||||||
p := request.GetPolicy()
|
p := request.GetPolicy()
|
||||||
|
|
||||||
pol, err := policy.LoadACLPolicyFromBytes([]byte(p))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("loading ACL policy file: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate and reject configuration that would error when applied
|
// Validate and reject configuration that would error when applied
|
||||||
// when creating a map response. This requires nodes, so there is still
|
// when creating a map response. This requires nodes, so there is still
|
||||||
// a scenario where they might be allowed if the server has no nodes
|
// a scenario where they might be allowed if the server has no nodes
|
||||||
@ -773,18 +779,13 @@ func (api headscaleV1APIServer) SetPolicy(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err)
|
return nil, fmt.Errorf("loading nodes from database to validate policy: %w", err)
|
||||||
}
|
}
|
||||||
users, err := api.h.db.ListUsers()
|
changed, err := api.h.polMan.SetPolicy([]byte(p))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("loading users from database to validate policy: %w", err)
|
return nil, fmt.Errorf("setting policy: %w", err)
|
||||||
}
|
|
||||||
|
|
||||||
_, err = pol.CompileFilterRules(users, nodes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("verifying policy rules: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(nodes) > 0 {
|
if len(nodes) > 0 {
|
||||||
_, err = pol.CompileSSHPolicy(nodes[0], users, nodes)
|
_, err = api.h.polMan.SSHPolicy(nodes[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("verifying SSH rules: %w", err)
|
return nil, fmt.Errorf("verifying SSH rules: %w", err)
|
||||||
}
|
}
|
||||||
@ -795,12 +796,13 @@ func (api headscaleV1APIServer) SetPolicy(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
api.h.ACLPolicy = pol
|
// Only send update if the packet filter has changed.
|
||||||
|
if changed {
|
||||||
ctx := types.NotifyCtx(context.Background(), "acl-update", "na")
|
ctx := types.NotifyCtx(context.Background(), "acl-update", "na")
|
||||||
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
|
api.h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
|
||||||
Type: types.StateFullUpdate,
|
Type: types.StateFullUpdate,
|
||||||
})
|
})
|
||||||
|
}
|
||||||
|
|
||||||
response := &v1.SetPolicyResponse{
|
response := &v1.SetPolicyResponse{
|
||||||
Policy: updated.Data,
|
Policy: updated.Data,
|
||||||
|
@ -55,6 +55,7 @@ type Mapper struct {
|
|||||||
cfg *types.Config
|
cfg *types.Config
|
||||||
derpMap *tailcfg.DERPMap
|
derpMap *tailcfg.DERPMap
|
||||||
notif *notifier.Notifier
|
notif *notifier.Notifier
|
||||||
|
polMan policy.PolicyManager
|
||||||
|
|
||||||
uid string
|
uid string
|
||||||
created time.Time
|
created time.Time
|
||||||
@ -71,6 +72,7 @@ func NewMapper(
|
|||||||
cfg *types.Config,
|
cfg *types.Config,
|
||||||
derpMap *tailcfg.DERPMap,
|
derpMap *tailcfg.DERPMap,
|
||||||
notif *notifier.Notifier,
|
notif *notifier.Notifier,
|
||||||
|
polMan policy.PolicyManager,
|
||||||
) *Mapper {
|
) *Mapper {
|
||||||
uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
|
uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
|
||||||
|
|
||||||
@ -79,6 +81,7 @@ func NewMapper(
|
|||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
derpMap: derpMap,
|
derpMap: derpMap,
|
||||||
notif: notif,
|
notif: notif,
|
||||||
|
polMan: polMan,
|
||||||
|
|
||||||
uid: uid,
|
uid: uid,
|
||||||
created: time.Now(),
|
created: time.Now(),
|
||||||
@ -153,11 +156,9 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
|
|||||||
func (m *Mapper) fullMapResponse(
|
func (m *Mapper) fullMapResponse(
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
peers types.Nodes,
|
peers types.Nodes,
|
||||||
users []types.User,
|
|
||||||
pol *policy.ACLPolicy,
|
|
||||||
capVer tailcfg.CapabilityVersion,
|
capVer tailcfg.CapabilityVersion,
|
||||||
) (*tailcfg.MapResponse, error) {
|
) (*tailcfg.MapResponse, error) {
|
||||||
resp, err := m.baseWithConfigMapResponse(node, pol, capVer)
|
resp, err := m.baseWithConfigMapResponse(node, capVer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -165,11 +166,9 @@ func (m *Mapper) fullMapResponse(
|
|||||||
err = appendPeerChanges(
|
err = appendPeerChanges(
|
||||||
resp,
|
resp,
|
||||||
true, // full change
|
true, // full change
|
||||||
pol,
|
m.polMan,
|
||||||
node,
|
node,
|
||||||
capVer,
|
capVer,
|
||||||
users,
|
|
||||||
peers,
|
|
||||||
peers,
|
peers,
|
||||||
m.cfg,
|
m.cfg,
|
||||||
)
|
)
|
||||||
@ -184,19 +183,14 @@ func (m *Mapper) fullMapResponse(
|
|||||||
func (m *Mapper) FullMapResponse(
|
func (m *Mapper) FullMapResponse(
|
||||||
mapRequest tailcfg.MapRequest,
|
mapRequest tailcfg.MapRequest,
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
pol *policy.ACLPolicy,
|
|
||||||
messages ...string,
|
messages ...string,
|
||||||
) ([]byte, error) {
|
) ([]byte, error) {
|
||||||
peers, err := m.ListPeers(node.ID)
|
peers, err := m.ListPeers(node.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
users, err := m.db.ListUsers()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := m.fullMapResponse(node, peers, users, pol, mapRequest.Version)
|
resp, err := m.fullMapResponse(node, peers, mapRequest.Version)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -210,10 +204,9 @@ func (m *Mapper) FullMapResponse(
|
|||||||
func (m *Mapper) ReadOnlyMapResponse(
|
func (m *Mapper) ReadOnlyMapResponse(
|
||||||
mapRequest tailcfg.MapRequest,
|
mapRequest tailcfg.MapRequest,
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
pol *policy.ACLPolicy,
|
|
||||||
messages ...string,
|
messages ...string,
|
||||||
) ([]byte, error) {
|
) ([]byte, error) {
|
||||||
resp, err := m.baseWithConfigMapResponse(node, pol, mapRequest.Version)
|
resp, err := m.baseWithConfigMapResponse(node, mapRequest.Version)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -249,7 +242,6 @@ func (m *Mapper) PeerChangedResponse(
|
|||||||
node *types.Node,
|
node *types.Node,
|
||||||
changed map[types.NodeID]bool,
|
changed map[types.NodeID]bool,
|
||||||
patches []*tailcfg.PeerChange,
|
patches []*tailcfg.PeerChange,
|
||||||
pol *policy.ACLPolicy,
|
|
||||||
messages ...string,
|
messages ...string,
|
||||||
) ([]byte, error) {
|
) ([]byte, error) {
|
||||||
resp := m.baseMapResponse()
|
resp := m.baseMapResponse()
|
||||||
@ -259,11 +251,6 @@ func (m *Mapper) PeerChangedResponse(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
users, err := m.db.ListUsers()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("listing users for map response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var removedIDs []tailcfg.NodeID
|
var removedIDs []tailcfg.NodeID
|
||||||
var changedIDs []types.NodeID
|
var changedIDs []types.NodeID
|
||||||
for nodeID, nodeChanged := range changed {
|
for nodeID, nodeChanged := range changed {
|
||||||
@ -284,11 +271,9 @@ func (m *Mapper) PeerChangedResponse(
|
|||||||
err = appendPeerChanges(
|
err = appendPeerChanges(
|
||||||
&resp,
|
&resp,
|
||||||
false, // partial change
|
false, // partial change
|
||||||
pol,
|
m.polMan,
|
||||||
node,
|
node,
|
||||||
mapRequest.Version,
|
mapRequest.Version,
|
||||||
users,
|
|
||||||
peers,
|
|
||||||
changedNodes,
|
changedNodes,
|
||||||
m.cfg,
|
m.cfg,
|
||||||
)
|
)
|
||||||
@ -315,7 +300,7 @@ func (m *Mapper) PeerChangedResponse(
|
|||||||
|
|
||||||
// Add the node itself, it might have changed, and particularly
|
// Add the node itself, it might have changed, and particularly
|
||||||
// if there are no patches or changes, this is a self update.
|
// if there are no patches or changes, this is a self update.
|
||||||
tailnode, err := tailNode(node, mapRequest.Version, pol, m.cfg)
|
tailnode, err := tailNode(node, mapRequest.Version, m.polMan, m.cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -330,7 +315,6 @@ func (m *Mapper) PeerChangedPatchResponse(
|
|||||||
mapRequest tailcfg.MapRequest,
|
mapRequest tailcfg.MapRequest,
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
changed []*tailcfg.PeerChange,
|
changed []*tailcfg.PeerChange,
|
||||||
pol *policy.ACLPolicy,
|
|
||||||
) ([]byte, error) {
|
) ([]byte, error) {
|
||||||
resp := m.baseMapResponse()
|
resp := m.baseMapResponse()
|
||||||
resp.PeersChangedPatch = changed
|
resp.PeersChangedPatch = changed
|
||||||
@ -459,12 +443,11 @@ func (m *Mapper) baseMapResponse() tailcfg.MapResponse {
|
|||||||
// incremental.
|
// incremental.
|
||||||
func (m *Mapper) baseWithConfigMapResponse(
|
func (m *Mapper) baseWithConfigMapResponse(
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
pol *policy.ACLPolicy,
|
|
||||||
capVer tailcfg.CapabilityVersion,
|
capVer tailcfg.CapabilityVersion,
|
||||||
) (*tailcfg.MapResponse, error) {
|
) (*tailcfg.MapResponse, error) {
|
||||||
resp := m.baseMapResponse()
|
resp := m.baseMapResponse()
|
||||||
|
|
||||||
tailnode, err := tailNode(node, capVer, pol, m.cfg)
|
tailnode, err := tailNode(node, capVer, m.polMan, m.cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -517,35 +500,30 @@ func appendPeerChanges(
|
|||||||
resp *tailcfg.MapResponse,
|
resp *tailcfg.MapResponse,
|
||||||
|
|
||||||
fullChange bool,
|
fullChange bool,
|
||||||
pol *policy.ACLPolicy,
|
polMan policy.PolicyManager,
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
capVer tailcfg.CapabilityVersion,
|
capVer tailcfg.CapabilityVersion,
|
||||||
users []types.User,
|
|
||||||
peers types.Nodes,
|
|
||||||
changed types.Nodes,
|
changed types.Nodes,
|
||||||
cfg *types.Config,
|
cfg *types.Config,
|
||||||
) error {
|
) error {
|
||||||
packetFilter, err := pol.CompileFilterRules(users, append(peers, node))
|
filter := polMan.Filter()
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
sshPolicy, err := pol.CompileSSHPolicy(node, users, peers)
|
sshPolicy, err := polMan.SSHPolicy(node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there are filter rules present, see if there are any nodes that cannot
|
// If there are filter rules present, see if there are any nodes that cannot
|
||||||
// access each-other at all and remove them from the peers.
|
// access each-other at all and remove them from the peers.
|
||||||
if len(packetFilter) > 0 {
|
if len(filter) > 0 {
|
||||||
changed = policy.FilterNodesByACL(node, changed, packetFilter)
|
changed = policy.FilterNodesByACL(node, changed, filter)
|
||||||
}
|
}
|
||||||
|
|
||||||
profiles := generateUserProfiles(node, changed)
|
profiles := generateUserProfiles(node, changed)
|
||||||
|
|
||||||
dnsConfig := generateDNSConfig(cfg, node)
|
dnsConfig := generateDNSConfig(cfg, node)
|
||||||
|
|
||||||
tailPeers, err := tailNodes(changed, capVer, pol, cfg)
|
tailPeers, err := tailNodes(changed, capVer, polMan, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -570,7 +548,7 @@ func appendPeerChanges(
|
|||||||
// new PacketFilters field and "base" allows us to send a full update when we
|
// new PacketFilters field and "base" allows us to send a full update when we
|
||||||
// have to send an empty list, avoiding the hack in the else block.
|
// have to send an empty list, avoiding the hack in the else block.
|
||||||
resp.PacketFilters = map[string][]tailcfg.FilterRule{
|
resp.PacketFilters = map[string][]tailcfg.FilterRule{
|
||||||
"base": policy.ReduceFilterRules(node, packetFilter),
|
"base": policy.ReduceFilterRules(node, filter),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// This is a hack to avoid sending an empty list of packet filters.
|
// This is a hack to avoid sending an empty list of packet filters.
|
||||||
@ -578,11 +556,11 @@ func appendPeerChanges(
|
|||||||
// be omitted, causing the client to consider it unchanged, keeping the
|
// be omitted, causing the client to consider it unchanged, keeping the
|
||||||
// previous packet filter. Worst case, this can cause a node that previously
|
// previous packet filter. Worst case, this can cause a node that previously
|
||||||
// has access to a node to _not_ loose access if an empty (allow none) is sent.
|
// has access to a node to _not_ loose access if an empty (allow none) is sent.
|
||||||
reduced := policy.ReduceFilterRules(node, packetFilter)
|
reduced := policy.ReduceFilterRules(node, filter)
|
||||||
if len(reduced) > 0 {
|
if len(reduced) > 0 {
|
||||||
resp.PacketFilter = reduced
|
resp.PacketFilter = reduced
|
||||||
} else {
|
} else {
|
||||||
resp.PacketFilter = packetFilter
|
resp.PacketFilter = filter
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -461,18 +461,19 @@ func Test_fullMapResponse(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{user1, user2}, append(tt.peers, tt.node))
|
||||||
|
|
||||||
mappy := NewMapper(
|
mappy := NewMapper(
|
||||||
nil,
|
nil,
|
||||||
tt.cfg,
|
tt.cfg,
|
||||||
tt.derpMap,
|
tt.derpMap,
|
||||||
nil,
|
nil,
|
||||||
|
polMan,
|
||||||
)
|
)
|
||||||
|
|
||||||
got, err := mappy.fullMapResponse(
|
got, err := mappy.fullMapResponse(
|
||||||
tt.node,
|
tt.node,
|
||||||
tt.peers,
|
tt.peers,
|
||||||
[]types.User{user1, user2},
|
|
||||||
tt.pol,
|
|
||||||
0,
|
0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ import (
|
|||||||
func tailNodes(
|
func tailNodes(
|
||||||
nodes types.Nodes,
|
nodes types.Nodes,
|
||||||
capVer tailcfg.CapabilityVersion,
|
capVer tailcfg.CapabilityVersion,
|
||||||
pol *policy.ACLPolicy,
|
polMan policy.PolicyManager,
|
||||||
cfg *types.Config,
|
cfg *types.Config,
|
||||||
) ([]*tailcfg.Node, error) {
|
) ([]*tailcfg.Node, error) {
|
||||||
tNodes := make([]*tailcfg.Node, len(nodes))
|
tNodes := make([]*tailcfg.Node, len(nodes))
|
||||||
@ -23,7 +23,7 @@ func tailNodes(
|
|||||||
node, err := tailNode(
|
node, err := tailNode(
|
||||||
node,
|
node,
|
||||||
capVer,
|
capVer,
|
||||||
pol,
|
polMan,
|
||||||
cfg,
|
cfg,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -40,7 +40,7 @@ func tailNodes(
|
|||||||
func tailNode(
|
func tailNode(
|
||||||
node *types.Node,
|
node *types.Node,
|
||||||
capVer tailcfg.CapabilityVersion,
|
capVer tailcfg.CapabilityVersion,
|
||||||
pol *policy.ACLPolicy,
|
polMan policy.PolicyManager,
|
||||||
cfg *types.Config,
|
cfg *types.Config,
|
||||||
) (*tailcfg.Node, error) {
|
) (*tailcfg.Node, error) {
|
||||||
addrs := node.Prefixes()
|
addrs := node.Prefixes()
|
||||||
@ -81,7 +81,7 @@ func tailNode(
|
|||||||
return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err)
|
return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
tags, _ := pol.TagsOfNode(node)
|
tags := polMan.Tags(node)
|
||||||
tags = lo.Uniq(append(tags, node.ForcedTags...))
|
tags = lo.Uniq(append(tags, node.ForcedTags...))
|
||||||
|
|
||||||
tNode := tailcfg.Node{
|
tNode := tailcfg.Node{
|
||||||
|
@ -184,6 +184,7 @@ func TestTailNode(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
polMan, _ := policy.NewPolicyManagerForTest(tt.pol, []types.User{}, types.Nodes{tt.node})
|
||||||
cfg := &types.Config{
|
cfg := &types.Config{
|
||||||
BaseDomain: tt.baseDomain,
|
BaseDomain: tt.baseDomain,
|
||||||
DNSConfig: tt.dnsConfig,
|
DNSConfig: tt.dnsConfig,
|
||||||
@ -192,7 +193,7 @@ func TestTailNode(t *testing.T) {
|
|||||||
got, err := tailNode(
|
got, err := tailNode(
|
||||||
tt.node,
|
tt.node,
|
||||||
0,
|
0,
|
||||||
tt.pol,
|
polMan,
|
||||||
cfg,
|
cfg,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -245,7 +246,7 @@ func TestNodeExpiry(t *testing.T) {
|
|||||||
tn, err := tailNode(
|
tn, err := tailNode(
|
||||||
node,
|
node,
|
||||||
0,
|
0,
|
||||||
&policy.ACLPolicy{},
|
&policy.PolicyManagerV1{},
|
||||||
&types.Config{},
|
&types.Config{},
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -18,6 +18,7 @@ import (
|
|||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/juanfont/headscale/hscontrol/db"
|
"github.com/juanfont/headscale/hscontrol/db"
|
||||||
"github.com/juanfont/headscale/hscontrol/notifier"
|
"github.com/juanfont/headscale/hscontrol/notifier"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/policy"
|
||||||
"github.com/juanfont/headscale/hscontrol/types"
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
"github.com/juanfont/headscale/hscontrol/util"
|
"github.com/juanfont/headscale/hscontrol/util"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
@ -53,6 +54,7 @@ type AuthProviderOIDC struct {
|
|||||||
registrationCache *zcache.Cache[string, key.MachinePublic]
|
registrationCache *zcache.Cache[string, key.MachinePublic]
|
||||||
notifier *notifier.Notifier
|
notifier *notifier.Notifier
|
||||||
ipAlloc *db.IPAllocator
|
ipAlloc *db.IPAllocator
|
||||||
|
polMan policy.PolicyManager
|
||||||
|
|
||||||
oidcProvider *oidc.Provider
|
oidcProvider *oidc.Provider
|
||||||
oauth2Config *oauth2.Config
|
oauth2Config *oauth2.Config
|
||||||
@ -65,6 +67,7 @@ func NewAuthProviderOIDC(
|
|||||||
db *db.HSDatabase,
|
db *db.HSDatabase,
|
||||||
notif *notifier.Notifier,
|
notif *notifier.Notifier,
|
||||||
ipAlloc *db.IPAllocator,
|
ipAlloc *db.IPAllocator,
|
||||||
|
polMan policy.PolicyManager,
|
||||||
) (*AuthProviderOIDC, error) {
|
) (*AuthProviderOIDC, error) {
|
||||||
var err error
|
var err error
|
||||||
// grab oidc config if it hasn't been already
|
// grab oidc config if it hasn't been already
|
||||||
@ -96,6 +99,7 @@ func NewAuthProviderOIDC(
|
|||||||
registrationCache: registrationCache,
|
registrationCache: registrationCache,
|
||||||
notifier: notif,
|
notifier: notif,
|
||||||
ipAlloc: ipAlloc,
|
ipAlloc: ipAlloc,
|
||||||
|
polMan: polMan,
|
||||||
|
|
||||||
oidcProvider: oidcProvider,
|
oidcProvider: oidcProvider,
|
||||||
oauth2Config: oauth2Config,
|
oauth2Config: oauth2Config,
|
||||||
@ -478,6 +482,11 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
|||||||
return nil, fmt.Errorf("creating or updating user: %w", err)
|
return nil, fmt.Errorf("creating or updating user: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = usersChangedHook(a.db, a.polMan, a.notifier)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("updating resources using user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -501,6 +510,11 @@ func (a *AuthProviderOIDC) registerNode(
|
|||||||
return fmt.Errorf("could not register node: %w", err)
|
return fmt.Errorf("could not register node: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = nodesChangedHook(a.db, a.polMan, a.notifier)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("updating resources using node: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
181
hscontrol/policy/pm.go
Normal file
181
hscontrol/policy/pm.go
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
package policy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"go4.org/netipx"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
"tailscale.com/util/deephash"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PolicyManager interface {
|
||||||
|
Filter() []tailcfg.FilterRule
|
||||||
|
SSHPolicy(*types.Node) (*tailcfg.SSHPolicy, error)
|
||||||
|
Tags(*types.Node) []string
|
||||||
|
ApproversForRoute(netip.Prefix) []string
|
||||||
|
ExpandAlias(string) (*netipx.IPSet, error)
|
||||||
|
SetPolicy([]byte) (bool, error)
|
||||||
|
SetUsers(users []types.User) (bool, error)
|
||||||
|
SetNodes(nodes types.Nodes) (bool, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPolicyManagerFromPath(path string, users []types.User, nodes types.Nodes) (PolicyManager, error) {
|
||||||
|
policyFile, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer policyFile.Close()
|
||||||
|
|
||||||
|
policyBytes, err := io.ReadAll(policyFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return NewPolicyManager(policyBytes, users, nodes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPolicyManager(polB []byte, users []types.User, nodes types.Nodes) (PolicyManager, error) {
|
||||||
|
var pol *ACLPolicy
|
||||||
|
var err error
|
||||||
|
if polB != nil && len(polB) > 0 {
|
||||||
|
pol, err = LoadACLPolicyFromBytes(polB)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parsing policy: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pm := PolicyManagerV1{
|
||||||
|
pol: pol,
|
||||||
|
users: users,
|
||||||
|
nodes: nodes,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = pm.updateLocked()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &pm, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPolicyManagerForTest(pol *ACLPolicy, users []types.User, nodes types.Nodes) (PolicyManager, error) {
|
||||||
|
pm := PolicyManagerV1{
|
||||||
|
pol: pol,
|
||||||
|
users: users,
|
||||||
|
nodes: nodes,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := pm.updateLocked()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &pm, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type PolicyManagerV1 struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
pol *ACLPolicy
|
||||||
|
|
||||||
|
users []types.User
|
||||||
|
nodes types.Nodes
|
||||||
|
|
||||||
|
filterHash deephash.Sum
|
||||||
|
filter []tailcfg.FilterRule
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateLocked updates the filter rules based on the current policy and nodes.
|
||||||
|
// It must be called with the lock held.
|
||||||
|
func (pm *PolicyManagerV1) updateLocked() (bool, error) {
|
||||||
|
filter, err := pm.pol.CompileFilterRules(pm.users, pm.nodes)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("compiling filter rules: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
filterHash := deephash.Hash(&filter)
|
||||||
|
if filterHash == pm.filterHash {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
pm.filter = filter
|
||||||
|
pm.filterHash = filterHash
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *PolicyManagerV1) Filter() []tailcfg.FilterRule {
|
||||||
|
pm.mu.Lock()
|
||||||
|
defer pm.mu.Unlock()
|
||||||
|
return pm.filter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *PolicyManagerV1) SSHPolicy(node *types.Node) (*tailcfg.SSHPolicy, error) {
|
||||||
|
pm.mu.Lock()
|
||||||
|
defer pm.mu.Unlock()
|
||||||
|
|
||||||
|
return pm.pol.CompileSSHPolicy(node, pm.users, pm.nodes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *PolicyManagerV1) SetPolicy(polB []byte) (bool, error) {
|
||||||
|
pol, err := LoadACLPolicyFromBytes(polB)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("parsing policy: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pm.mu.Lock()
|
||||||
|
defer pm.mu.Unlock()
|
||||||
|
|
||||||
|
pm.pol = pol
|
||||||
|
|
||||||
|
return pm.updateLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUsers updates the users in the policy manager and updates the filter rules.
|
||||||
|
func (pm *PolicyManagerV1) SetUsers(users []types.User) (bool, error) {
|
||||||
|
pm.mu.Lock()
|
||||||
|
defer pm.mu.Unlock()
|
||||||
|
|
||||||
|
pm.users = users
|
||||||
|
return pm.updateLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNodes updates the nodes in the policy manager and updates the filter rules.
|
||||||
|
func (pm *PolicyManagerV1) SetNodes(nodes types.Nodes) (bool, error) {
|
||||||
|
pm.mu.Lock()
|
||||||
|
defer pm.mu.Unlock()
|
||||||
|
pm.nodes = nodes
|
||||||
|
return pm.updateLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *PolicyManagerV1) Tags(node *types.Node) []string {
|
||||||
|
if pm == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tags, _ := pm.pol.TagsOfNode(node)
|
||||||
|
return tags
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *PolicyManagerV1) ApproversForRoute(route netip.Prefix) []string {
|
||||||
|
// TODO(kradalby): This can be a parse error of the address in the policy,
|
||||||
|
// in the new policy this will be typed and not a problem, in this policy
|
||||||
|
// we will just return empty list
|
||||||
|
if pm.pol == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
approvers, _ := pm.pol.AutoApprovers.GetRouteApprovers(route)
|
||||||
|
return approvers
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *PolicyManagerV1) ExpandAlias(alias string) (*netipx.IPSet, error) {
|
||||||
|
ips, err := pm.pol.ExpandAlias(pm.nodes, pm.users, alias)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ips, nil
|
||||||
|
}
|
158
hscontrol/policy/pm_test.go
Normal file
158
hscontrol/policy/pm_test.go
Normal file
@ -0,0 +1,158 @@
|
|||||||
|
package policy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/juanfont/headscale/hscontrol/types"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"tailscale.com/tailcfg"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPolicySetChange(t *testing.T) {
|
||||||
|
users := []types.User{
|
||||||
|
{
|
||||||
|
Model: gorm.Model{ID: 1},
|
||||||
|
Name: "testuser",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
users []types.User
|
||||||
|
nodes types.Nodes
|
||||||
|
policy []byte
|
||||||
|
wantUsersChange bool
|
||||||
|
wantNodesChange bool
|
||||||
|
wantPolicyChange bool
|
||||||
|
wantFilter []tailcfg.FilterRule
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "set-nodes",
|
||||||
|
nodes: types.Nodes{
|
||||||
|
{
|
||||||
|
IPv4: iap("100.64.0.2"),
|
||||||
|
User: users[0],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantNodesChange: false,
|
||||||
|
wantFilter: []tailcfg.FilterRule{
|
||||||
|
{
|
||||||
|
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "set-users",
|
||||||
|
users: users,
|
||||||
|
wantUsersChange: false,
|
||||||
|
wantFilter: []tailcfg.FilterRule{
|
||||||
|
{
|
||||||
|
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "set-users-and-node",
|
||||||
|
users: users,
|
||||||
|
nodes: types.Nodes{
|
||||||
|
{
|
||||||
|
IPv4: iap("100.64.0.2"),
|
||||||
|
User: users[0],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantUsersChange: false,
|
||||||
|
wantNodesChange: true,
|
||||||
|
wantFilter: []tailcfg.FilterRule{
|
||||||
|
{
|
||||||
|
SrcIPs: []string{"100.64.0.2/32"},
|
||||||
|
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "set-policy",
|
||||||
|
policy: []byte(`
|
||||||
|
{
|
||||||
|
"acls": [
|
||||||
|
{
|
||||||
|
"action": "accept",
|
||||||
|
"src": [
|
||||||
|
"100.64.0.61",
|
||||||
|
],
|
||||||
|
"dst": [
|
||||||
|
"100.64.0.62:*",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
`),
|
||||||
|
wantPolicyChange: true,
|
||||||
|
wantFilter: []tailcfg.FilterRule{
|
||||||
|
{
|
||||||
|
SrcIPs: []string{"100.64.0.61/32"},
|
||||||
|
DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.62/32", Ports: tailcfg.PortRangeAny}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
pol := `
|
||||||
|
{
|
||||||
|
"groups": {
|
||||||
|
"group:example": [
|
||||||
|
"testuser",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
|
||||||
|
"hosts": {
|
||||||
|
"host-1": "100.64.0.1",
|
||||||
|
"subnet-1": "100.100.101.100/24",
|
||||||
|
},
|
||||||
|
|
||||||
|
"acls": [
|
||||||
|
{
|
||||||
|
"action": "accept",
|
||||||
|
"src": [
|
||||||
|
"group:example",
|
||||||
|
],
|
||||||
|
"dst": [
|
||||||
|
"host-1:*",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
`
|
||||||
|
pm, err := NewPolicyManager([]byte(pol), []types.User{}, types.Nodes{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
if tt.policy != nil {
|
||||||
|
change, err := pm.SetPolicy(tt.policy)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.wantPolicyChange, change)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.users != nil {
|
||||||
|
change, err := pm.SetUsers(tt.users)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.wantUsersChange, change)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.nodes != nil {
|
||||||
|
change, err := pm.SetNodes(tt.nodes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.wantNodesChange, change)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.wantFilter, pm.Filter()); diff != "" {
|
||||||
|
t.Errorf("TestPolicySetChange() unexpected result (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -286,7 +286,7 @@ func (m *mapSession) serveLongPoll() {
|
|||||||
switch update.Type {
|
switch update.Type {
|
||||||
case types.StateFullUpdate:
|
case types.StateFullUpdate:
|
||||||
m.tracef("Sending Full MapResponse")
|
m.tracef("Sending Full MapResponse")
|
||||||
data, err = m.mapper.FullMapResponse(m.req, m.node, m.h.ACLPolicy, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming()))
|
data, err = m.mapper.FullMapResponse(m.req, m.node, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming()))
|
||||||
case types.StatePeerChanged:
|
case types.StatePeerChanged:
|
||||||
changed := make(map[types.NodeID]bool, len(update.ChangeNodes))
|
changed := make(map[types.NodeID]bool, len(update.ChangeNodes))
|
||||||
|
|
||||||
@ -296,12 +296,12 @@ func (m *mapSession) serveLongPoll() {
|
|||||||
|
|
||||||
lastMessage = update.Message
|
lastMessage = update.Message
|
||||||
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
|
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
|
||||||
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, m.h.ACLPolicy, lastMessage)
|
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage)
|
||||||
updateType = "change"
|
updateType = "change"
|
||||||
|
|
||||||
case types.StatePeerChangedPatch:
|
case types.StatePeerChangedPatch:
|
||||||
m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage))
|
m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage))
|
||||||
data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches, m.h.ACLPolicy)
|
data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches)
|
||||||
updateType = "patch"
|
updateType = "patch"
|
||||||
case types.StatePeerRemoved:
|
case types.StatePeerRemoved:
|
||||||
changed := make(map[types.NodeID]bool, len(update.Removed))
|
changed := make(map[types.NodeID]bool, len(update.Removed))
|
||||||
@ -310,13 +310,13 @@ func (m *mapSession) serveLongPoll() {
|
|||||||
changed[nodeID] = false
|
changed[nodeID] = false
|
||||||
}
|
}
|
||||||
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
|
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
|
||||||
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, m.h.ACLPolicy, lastMessage)
|
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, lastMessage)
|
||||||
updateType = "remove"
|
updateType = "remove"
|
||||||
case types.StateSelfUpdate:
|
case types.StateSelfUpdate:
|
||||||
lastMessage = update.Message
|
lastMessage = update.Message
|
||||||
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
|
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
|
||||||
// create the map so an empty (self) update is sent
|
// create the map so an empty (self) update is sent
|
||||||
data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, m.h.ACLPolicy, lastMessage)
|
data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, lastMessage)
|
||||||
updateType = "remove"
|
updateType = "remove"
|
||||||
case types.StateDERPUpdated:
|
case types.StateDERPUpdated:
|
||||||
m.tracef("Sending DERPUpdate MapResponse")
|
m.tracef("Sending DERPUpdate MapResponse")
|
||||||
@ -488,9 +488,12 @@ func (m *mapSession) handleEndpointUpdate() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.h.ACLPolicy != nil {
|
// TODO(kradalby): Only update the node that has actually changed
|
||||||
|
nodesChangedHook(m.h.db, m.h.polMan, m.h.nodeNotifier)
|
||||||
|
|
||||||
|
if m.h.polMan != nil {
|
||||||
// update routes with peer information
|
// update routes with peer information
|
||||||
err := m.h.db.EnableAutoApprovedRoutes(m.h.ACLPolicy, m.node)
|
err := m.h.db.EnableAutoApprovedRoutes(m.h.polMan, m.node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.errf(err, "Error running auto approved routes")
|
m.errf(err, "Error running auto approved routes")
|
||||||
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
|
mapResponseEndpointUpdates.WithLabelValues("error").Inc()
|
||||||
@ -544,7 +547,7 @@ func (m *mapSession) handleEndpointUpdate() {
|
|||||||
func (m *mapSession) handleReadOnlyRequest() {
|
func (m *mapSession) handleReadOnlyRequest() {
|
||||||
m.tracef("Client asked for a lite update, responding without peers")
|
m.tracef("Client asked for a lite update, responding without peers")
|
||||||
|
|
||||||
mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node, m.h.ACLPolicy)
|
mapResp, err := m.mapper.ReadOnlyMapResponse(m.req, m.node)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.errf(err, "Failed to create MapResponse")
|
m.errf(err, "Failed to create MapResponse")
|
||||||
http.Error(m.w, "", http.StatusInternalServerError)
|
http.Error(m.w, "", http.StatusInternalServerError)
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -79,6 +79,10 @@ type Option = func(c *HeadscaleInContainer)
|
|||||||
// HeadscaleInContainer instance.
|
// HeadscaleInContainer instance.
|
||||||
func WithACLPolicy(acl *policy.ACLPolicy) Option {
|
func WithACLPolicy(acl *policy.ACLPolicy) Option {
|
||||||
return func(hsic *HeadscaleInContainer) {
|
return func(hsic *HeadscaleInContainer) {
|
||||||
|
if acl == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(kradalby): Move somewhere appropriate
|
// TODO(kradalby): Move somewhere appropriate
|
||||||
hsic.env["HEADSCALE_POLICY_PATH"] = aclPolicyPath
|
hsic.env["HEADSCALE_POLICY_PATH"] = aclPolicyPath
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user