remove "stripEmailDomain" argument

This commit makes a wrapper function round the normalisation requiring
"stripEmailDomain" which has to be passed in almost all functions of
headscale by loading it from Viper instead.

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby
2023-06-12 15:29:34 +02:00
committed by Kristoffer Dalby
parent 161243c787
commit 717abe89c1
16 changed files with 127 additions and 220 deletions

View File

@@ -121,14 +121,13 @@ func GenerateFilterRules(
policy *ACLPolicy,
machine *types.Machine,
peers types.Machines,
stripEmailDomain bool,
) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) {
// If there is no policy defined, we default to allow all
if policy == nil {
return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil
}
rules, err := policy.generateFilterRules(machine, peers, stripEmailDomain)
rules, err := policy.generateFilterRules(machine, peers)
if err != nil {
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
}
@@ -136,7 +135,7 @@ func GenerateFilterRules(
log.Trace().Interface("ACL", rules).Msg("ACL rules generated")
var sshPolicy *tailcfg.SSHPolicy
sshRules, err := policy.generateSSHRules(machine, peers, stripEmailDomain)
sshRules, err := policy.generateSSHRules(machine, peers)
if err != nil {
return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err
}
@@ -154,7 +153,6 @@ func GenerateFilterRules(
func (pol *ACLPolicy) generateFilterRules(
machine *types.Machine,
peers types.Machines,
stripEmailDomain bool,
) ([]tailcfg.FilterRule, error) {
rules := []tailcfg.FilterRule{}
machines := append(peers, *machine)
@@ -166,7 +164,7 @@ func (pol *ACLPolicy) generateFilterRules(
srcIPs := []string{}
for srcIndex, src := range acl.Sources {
srcs, err := pol.getIPsFromSource(src, machines, stripEmailDomain)
srcs, err := pol.getIPsFromSource(src, machines)
if err != nil {
log.Error().
Interface("src", src).
@@ -193,7 +191,6 @@ func (pol *ACLPolicy) generateFilterRules(
dest,
machines,
needsWildcard,
stripEmailDomain,
)
if err != nil {
log.Error().
@@ -220,7 +217,6 @@ func (pol *ACLPolicy) generateFilterRules(
func (pol *ACLPolicy) generateSSHRules(
machine *types.Machine,
peers types.Machines,
stripEmailDomain bool,
) ([]*tailcfg.SSHRule, error) {
rules := []*tailcfg.SSHRule{}
@@ -247,7 +243,7 @@ func (pol *ACLPolicy) generateSSHRules(
for index, sshACL := range pol.SSHs {
var dest netipx.IPSetBuilder
for _, src := range sshACL.Destinations {
expanded, err := pol.ExpandAlias(append(peers, *machine), src, stripEmailDomain)
expanded, err := pol.ExpandAlias(append(peers, *machine), src)
if err != nil {
return nil, err
}
@@ -289,7 +285,7 @@ func (pol *ACLPolicy) generateSSHRules(
Any: true,
})
} else if isGroup(rawSrc) {
users, err := pol.getUsersInGroup(rawSrc, stripEmailDomain)
users, err := pol.getUsersInGroup(rawSrc)
if err != nil {
log.Error().
Msgf("Error parsing SSH %d, Source %d", index, innerIndex)
@@ -306,7 +302,6 @@ func (pol *ACLPolicy) generateSSHRules(
expandedSrcs, err := pol.ExpandAlias(
peers,
rawSrc,
stripEmailDomain,
)
if err != nil {
log.Error().
@@ -358,9 +353,8 @@ func sshCheckAction(duration string) (*tailcfg.SSHAction, error) {
func (pol *ACLPolicy) getIPsFromSource(
src string,
machines types.Machines,
stripEmaildomain bool,
) ([]string, error) {
ipSet, err := pol.ExpandAlias(machines, src, stripEmaildomain)
ipSet, err := pol.ExpandAlias(machines, src)
if err != nil {
return []string{}, err
}
@@ -380,7 +374,6 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination(
dest string,
machines types.Machines,
needsWildcard bool,
stripEmaildomain bool,
) ([]tailcfg.NetPortRange, error) {
var tokens []string
@@ -434,7 +427,6 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination(
expanded, err := pol.ExpandAlias(
machines,
alias,
stripEmaildomain,
)
if err != nil {
return nil, err
@@ -519,7 +511,6 @@ func parseProtocol(protocol string) ([]int, bool, error) {
func (pol *ACLPolicy) ExpandAlias(
machines types.Machines,
alias string,
stripEmailDomain bool,
) (*netipx.IPSet, error) {
if isWildcard(alias) {
return util.ParseIPSet("*", nil)
@@ -533,16 +524,16 @@ func (pol *ACLPolicy) ExpandAlias(
// if alias is a group
if isGroup(alias) {
return pol.getIPsFromGroup(alias, machines, stripEmailDomain)
return pol.getIPsFromGroup(alias, machines)
}
// if alias is a tag
if isTag(alias) {
return pol.getIPsFromTag(alias, machines, stripEmailDomain)
return pol.getIPsFromTag(alias, machines)
}
// if alias is a user
if ips, err := pol.getIPsForUser(alias, machines, stripEmailDomain); ips != nil {
if ips, err := pol.getIPsForUser(alias, machines); ips != nil {
return ips, err
}
@@ -551,7 +542,7 @@ func (pol *ACLPolicy) ExpandAlias(
if h, ok := pol.Hosts[alias]; ok {
log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry")
return pol.ExpandAlias(machines, h.String(), stripEmailDomain)
return pol.ExpandAlias(machines, h.String())
}
// if alias is an IP
@@ -576,12 +567,11 @@ func excludeCorrectlyTaggedNodes(
aclPolicy *ACLPolicy,
nodes types.Machines,
user string,
stripEmailDomain bool,
) types.Machines {
out := types.Machines{}
tags := []string{}
for tag := range aclPolicy.TagOwners {
owners, _ := getTagOwners(aclPolicy, user, stripEmailDomain)
owners, _ := getTagOwners(aclPolicy, user)
ns := append(owners, user)
if util.StringOrPrefixListContains(ns, user) {
tags = append(tags, tag)
@@ -674,7 +664,6 @@ func filterMachinesByUser(machines types.Machines, user string) types.Machines {
func getTagOwners(
pol *ACLPolicy,
tag string,
stripEmailDomain bool,
) ([]string, error) {
var owners []string
ows, ok := pol.TagOwners[tag]
@@ -687,7 +676,7 @@ func getTagOwners(
}
for _, owner := range ows {
if isGroup(owner) {
gs, err := pol.getUsersInGroup(owner, stripEmailDomain)
gs, err := pol.getUsersInGroup(owner)
if err != nil {
return []string{}, err
}
@@ -704,7 +693,6 @@ func getTagOwners(
// after some validation.
func (pol *ACLPolicy) getUsersInGroup(
group string,
stripEmailDomain bool,
) ([]string, error) {
users := []string{}
log.Trace().Caller().Interface("pol", pol).Msg("test")
@@ -723,7 +711,7 @@ func (pol *ACLPolicy) getUsersInGroup(
ErrInvalidGroup,
)
}
grp, err := util.NormalizeToFQDNRules(group, stripEmailDomain)
grp, err := util.NormalizeToFQDNRulesConfigFromViper(group)
if err != nil {
return []string{}, fmt.Errorf(
"failed to normalize group %q, err: %w",
@@ -740,11 +728,10 @@ func (pol *ACLPolicy) getUsersInGroup(
func (pol *ACLPolicy) getIPsFromGroup(
group string,
machines types.Machines,
stripEmailDomain bool,
) (*netipx.IPSet, error) {
build := netipx.IPSetBuilder{}
users, err := pol.getUsersInGroup(group, stripEmailDomain)
users, err := pol.getUsersInGroup(group)
if err != nil {
return &netipx.IPSet{}, err
}
@@ -761,7 +748,6 @@ func (pol *ACLPolicy) getIPsFromGroup(
func (pol *ACLPolicy) getIPsFromTag(
alias string,
machines types.Machines,
stripEmailDomain bool,
) (*netipx.IPSet, error) {
build := netipx.IPSetBuilder{}
@@ -773,7 +759,7 @@ func (pol *ACLPolicy) getIPsFromTag(
}
// find tag owners
owners, err := getTagOwners(pol, alias, stripEmailDomain)
owners, err := getTagOwners(pol, alias)
if err != nil {
if errors.Is(err, ErrInvalidTag) {
ipSet, _ := build.IPSet()
@@ -808,12 +794,11 @@ func (pol *ACLPolicy) getIPsFromTag(
func (pol *ACLPolicy) getIPsForUser(
user string,
machines types.Machines,
stripEmailDomain bool,
) (*netipx.IPSet, error) {
build := netipx.IPSetBuilder{}
filteredMachines := filterMachinesByUser(machines, user)
filteredMachines = excludeCorrectlyTaggedNodes(pol, filteredMachines, user, stripEmailDomain)
filteredMachines = excludeCorrectlyTaggedNodes(pol, filteredMachines, user)
// shortcurcuit if we have no machines to get ips from.
if len(filteredMachines) == 0 {
@@ -885,7 +870,6 @@ func isTag(str string) bool {
// Valid tags are tags added by a user that is allowed in the ACL policy to add this tag.
func (pol *ACLPolicy) GetTagsOfMachine(
machine types.Machine,
stripEmailDomain bool,
) ([]string, []string) {
validTags := make([]string, 0)
invalidTags := make([]string, 0)
@@ -893,7 +877,7 @@ func (pol *ACLPolicy) GetTagsOfMachine(
validTagMap := make(map[string]bool)
invalidTagMap := make(map[string]bool)
for _, tag := range machine.HostInfo.RequestTags {
owners, err := getTagOwners(pol, tag, stripEmailDomain)
owners, err := getTagOwners(pol, tag)
if errors.Is(err, ErrInvalidTag) {
invalidTagMap[tag] = true