mirror of
https://github.com/juanfont/headscale.git
synced 2025-08-12 04:47:36 +00:00
remove "stripEmailDomain" argument
This commit makes a wrapper function round the normalisation requiring "stripEmailDomain" which has to be passed in almost all functions of headscale by loading it from Viper instead. Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:

committed by
Kristoffer Dalby

parent
161243c787
commit
717abe89c1
@@ -121,14 +121,13 @@ func GenerateFilterRules(
|
||||
policy *ACLPolicy,
|
||||
machine *types.Machine,
|
||||
peers types.Machines,
|
||||
stripEmailDomain bool,
|
||||
) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) {
|
||||
// If there is no policy defined, we default to allow all
|
||||
if policy == nil {
|
||||
return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil
|
||||
}
|
||||
|
||||
rules, err := policy.generateFilterRules(machine, peers, stripEmailDomain)
|
||||
rules, err := policy.generateFilterRules(machine, peers)
|
||||
if err != nil {
|
||||
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
|
||||
}
|
||||
@@ -136,7 +135,7 @@ func GenerateFilterRules(
|
||||
log.Trace().Interface("ACL", rules).Msg("ACL rules generated")
|
||||
|
||||
var sshPolicy *tailcfg.SSHPolicy
|
||||
sshRules, err := policy.generateSSHRules(machine, peers, stripEmailDomain)
|
||||
sshRules, err := policy.generateSSHRules(machine, peers)
|
||||
if err != nil {
|
||||
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
|
||||
}
|
||||
@@ -154,7 +153,6 @@ func GenerateFilterRules(
|
||||
func (pol *ACLPolicy) generateFilterRules(
|
||||
machine *types.Machine,
|
||||
peers types.Machines,
|
||||
stripEmailDomain bool,
|
||||
) ([]tailcfg.FilterRule, error) {
|
||||
rules := []tailcfg.FilterRule{}
|
||||
machines := append(peers, *machine)
|
||||
@@ -166,7 +164,7 @@ func (pol *ACLPolicy) generateFilterRules(
|
||||
|
||||
srcIPs := []string{}
|
||||
for srcIndex, src := range acl.Sources {
|
||||
srcs, err := pol.getIPsFromSource(src, machines, stripEmailDomain)
|
||||
srcs, err := pol.getIPsFromSource(src, machines)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Interface("src", src).
|
||||
@@ -193,7 +191,6 @@ func (pol *ACLPolicy) generateFilterRules(
|
||||
dest,
|
||||
machines,
|
||||
needsWildcard,
|
||||
stripEmailDomain,
|
||||
)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
@@ -220,7 +217,6 @@ func (pol *ACLPolicy) generateFilterRules(
|
||||
func (pol *ACLPolicy) generateSSHRules(
|
||||
machine *types.Machine,
|
||||
peers types.Machines,
|
||||
stripEmailDomain bool,
|
||||
) ([]*tailcfg.SSHRule, error) {
|
||||
rules := []*tailcfg.SSHRule{}
|
||||
|
||||
@@ -247,7 +243,7 @@ func (pol *ACLPolicy) generateSSHRules(
|
||||
for index, sshACL := range pol.SSHs {
|
||||
var dest netipx.IPSetBuilder
|
||||
for _, src := range sshACL.Destinations {
|
||||
expanded, err := pol.ExpandAlias(append(peers, *machine), src, stripEmailDomain)
|
||||
expanded, err := pol.ExpandAlias(append(peers, *machine), src)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -289,7 +285,7 @@ func (pol *ACLPolicy) generateSSHRules(
|
||||
Any: true,
|
||||
})
|
||||
} else if isGroup(rawSrc) {
|
||||
users, err := pol.getUsersInGroup(rawSrc, stripEmailDomain)
|
||||
users, err := pol.getUsersInGroup(rawSrc)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Msgf("Error parsing SSH %d, Source %d", index, innerIndex)
|
||||
@@ -306,7 +302,6 @@ func (pol *ACLPolicy) generateSSHRules(
|
||||
expandedSrcs, err := pol.ExpandAlias(
|
||||
peers,
|
||||
rawSrc,
|
||||
stripEmailDomain,
|
||||
)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
@@ -358,9 +353,8 @@ func sshCheckAction(duration string) (*tailcfg.SSHAction, error) {
|
||||
func (pol *ACLPolicy) getIPsFromSource(
|
||||
src string,
|
||||
machines types.Machines,
|
||||
stripEmaildomain bool,
|
||||
) ([]string, error) {
|
||||
ipSet, err := pol.ExpandAlias(machines, src, stripEmaildomain)
|
||||
ipSet, err := pol.ExpandAlias(machines, src)
|
||||
if err != nil {
|
||||
return []string{}, err
|
||||
}
|
||||
@@ -380,7 +374,6 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination(
|
||||
dest string,
|
||||
machines types.Machines,
|
||||
needsWildcard bool,
|
||||
stripEmaildomain bool,
|
||||
) ([]tailcfg.NetPortRange, error) {
|
||||
var tokens []string
|
||||
|
||||
@@ -434,7 +427,6 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination(
|
||||
expanded, err := pol.ExpandAlias(
|
||||
machines,
|
||||
alias,
|
||||
stripEmaildomain,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -519,7 +511,6 @@ func parseProtocol(protocol string) ([]int, bool, error) {
|
||||
func (pol *ACLPolicy) ExpandAlias(
|
||||
machines types.Machines,
|
||||
alias string,
|
||||
stripEmailDomain bool,
|
||||
) (*netipx.IPSet, error) {
|
||||
if isWildcard(alias) {
|
||||
return util.ParseIPSet("*", nil)
|
||||
@@ -533,16 +524,16 @@ func (pol *ACLPolicy) ExpandAlias(
|
||||
|
||||
// if alias is a group
|
||||
if isGroup(alias) {
|
||||
return pol.getIPsFromGroup(alias, machines, stripEmailDomain)
|
||||
return pol.getIPsFromGroup(alias, machines)
|
||||
}
|
||||
|
||||
// if alias is a tag
|
||||
if isTag(alias) {
|
||||
return pol.getIPsFromTag(alias, machines, stripEmailDomain)
|
||||
return pol.getIPsFromTag(alias, machines)
|
||||
}
|
||||
|
||||
// if alias is a user
|
||||
if ips, err := pol.getIPsForUser(alias, machines, stripEmailDomain); ips != nil {
|
||||
if ips, err := pol.getIPsForUser(alias, machines); ips != nil {
|
||||
return ips, err
|
||||
}
|
||||
|
||||
@@ -551,7 +542,7 @@ func (pol *ACLPolicy) ExpandAlias(
|
||||
if h, ok := pol.Hosts[alias]; ok {
|
||||
log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry")
|
||||
|
||||
return pol.ExpandAlias(machines, h.String(), stripEmailDomain)
|
||||
return pol.ExpandAlias(machines, h.String())
|
||||
}
|
||||
|
||||
// if alias is an IP
|
||||
@@ -576,12 +567,11 @@ func excludeCorrectlyTaggedNodes(
|
||||
aclPolicy *ACLPolicy,
|
||||
nodes types.Machines,
|
||||
user string,
|
||||
stripEmailDomain bool,
|
||||
) types.Machines {
|
||||
out := types.Machines{}
|
||||
tags := []string{}
|
||||
for tag := range aclPolicy.TagOwners {
|
||||
owners, _ := getTagOwners(aclPolicy, user, stripEmailDomain)
|
||||
owners, _ := getTagOwners(aclPolicy, user)
|
||||
ns := append(owners, user)
|
||||
if util.StringOrPrefixListContains(ns, user) {
|
||||
tags = append(tags, tag)
|
||||
@@ -674,7 +664,6 @@ func filterMachinesByUser(machines types.Machines, user string) types.Machines {
|
||||
func getTagOwners(
|
||||
pol *ACLPolicy,
|
||||
tag string,
|
||||
stripEmailDomain bool,
|
||||
) ([]string, error) {
|
||||
var owners []string
|
||||
ows, ok := pol.TagOwners[tag]
|
||||
@@ -687,7 +676,7 @@ func getTagOwners(
|
||||
}
|
||||
for _, owner := range ows {
|
||||
if isGroup(owner) {
|
||||
gs, err := pol.getUsersInGroup(owner, stripEmailDomain)
|
||||
gs, err := pol.getUsersInGroup(owner)
|
||||
if err != nil {
|
||||
return []string{}, err
|
||||
}
|
||||
@@ -704,7 +693,6 @@ func getTagOwners(
|
||||
// after some validation.
|
||||
func (pol *ACLPolicy) getUsersInGroup(
|
||||
group string,
|
||||
stripEmailDomain bool,
|
||||
) ([]string, error) {
|
||||
users := []string{}
|
||||
log.Trace().Caller().Interface("pol", pol).Msg("test")
|
||||
@@ -723,7 +711,7 @@ func (pol *ACLPolicy) getUsersInGroup(
|
||||
ErrInvalidGroup,
|
||||
)
|
||||
}
|
||||
grp, err := util.NormalizeToFQDNRules(group, stripEmailDomain)
|
||||
grp, err := util.NormalizeToFQDNRulesConfigFromViper(group)
|
||||
if err != nil {
|
||||
return []string{}, fmt.Errorf(
|
||||
"failed to normalize group %q, err: %w",
|
||||
@@ -740,11 +728,10 @@ func (pol *ACLPolicy) getUsersInGroup(
|
||||
func (pol *ACLPolicy) getIPsFromGroup(
|
||||
group string,
|
||||
machines types.Machines,
|
||||
stripEmailDomain bool,
|
||||
) (*netipx.IPSet, error) {
|
||||
build := netipx.IPSetBuilder{}
|
||||
|
||||
users, err := pol.getUsersInGroup(group, stripEmailDomain)
|
||||
users, err := pol.getUsersInGroup(group)
|
||||
if err != nil {
|
||||
return &netipx.IPSet{}, err
|
||||
}
|
||||
@@ -761,7 +748,6 @@ func (pol *ACLPolicy) getIPsFromGroup(
|
||||
func (pol *ACLPolicy) getIPsFromTag(
|
||||
alias string,
|
||||
machines types.Machines,
|
||||
stripEmailDomain bool,
|
||||
) (*netipx.IPSet, error) {
|
||||
build := netipx.IPSetBuilder{}
|
||||
|
||||
@@ -773,7 +759,7 @@ func (pol *ACLPolicy) getIPsFromTag(
|
||||
}
|
||||
|
||||
// find tag owners
|
||||
owners, err := getTagOwners(pol, alias, stripEmailDomain)
|
||||
owners, err := getTagOwners(pol, alias)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrInvalidTag) {
|
||||
ipSet, _ := build.IPSet()
|
||||
@@ -808,12 +794,11 @@ func (pol *ACLPolicy) getIPsFromTag(
|
||||
func (pol *ACLPolicy) getIPsForUser(
|
||||
user string,
|
||||
machines types.Machines,
|
||||
stripEmailDomain bool,
|
||||
) (*netipx.IPSet, error) {
|
||||
build := netipx.IPSetBuilder{}
|
||||
|
||||
filteredMachines := filterMachinesByUser(machines, user)
|
||||
filteredMachines = excludeCorrectlyTaggedNodes(pol, filteredMachines, user, stripEmailDomain)
|
||||
filteredMachines = excludeCorrectlyTaggedNodes(pol, filteredMachines, user)
|
||||
|
||||
// shortcurcuit if we have no machines to get ips from.
|
||||
if len(filteredMachines) == 0 {
|
||||
@@ -885,7 +870,6 @@ func isTag(str string) bool {
|
||||
// Valid tags are tags added by a user that is allowed in the ACL policy to add this tag.
|
||||
func (pol *ACLPolicy) GetTagsOfMachine(
|
||||
machine types.Machine,
|
||||
stripEmailDomain bool,
|
||||
) ([]string, []string) {
|
||||
validTags := make([]string, 0)
|
||||
invalidTags := make([]string, 0)
|
||||
@@ -893,7 +877,7 @@ func (pol *ACLPolicy) GetTagsOfMachine(
|
||||
validTagMap := make(map[string]bool)
|
||||
invalidTagMap := make(map[string]bool)
|
||||
for _, tag := range machine.HostInfo.RequestTags {
|
||||
owners, err := getTagOwners(pol, tag, stripEmailDomain)
|
||||
owners, err := getTagOwners(pol, tag)
|
||||
if errors.Is(err, ErrInvalidTag) {
|
||||
invalidTagMap[tag] = true
|
||||
|
||||
|
Reference in New Issue
Block a user