refact: use generics for contains functions

This commit is contained in:
Adrien Raffin-Caboisse 2022-04-25 21:50:40 +02:00
parent ea9aaa6022
commit 8061abe279
5 changed files with 14 additions and 22 deletions

View File

@ -332,7 +332,7 @@ func excludeCorrectlyTaggedNodes(
out := []Machine{}
tags := []string{}
for tag, ns := range aclPolicy.TagOwners {
if containsString(ns, namespace) {
if contains(ns, namespace) {
tags = append(tags, tag)
}
}
@ -342,7 +342,7 @@ func excludeCorrectlyTaggedNodes(
found := false
for _, t := range hi.RequestTags {
if containsString(tags, t) {
if contains(tags, t) {
found = true
break

View File

@ -372,12 +372,12 @@ func nodesToPtables(
tags += "," + tag
}
for _, tag := range machine.InvalidTags {
if !containsString(machine.ForcedTags, tag) {
if !contains(machine.ForcedTags, tag) {
tags += "," + pterm.LightRed(tag)
}
}
for _, tag := range machine.ValidTags {
if !containsString(machine.ForcedTags, tag) {
if !contains(machine.ForcedTags, tag) {
tags += "," + pterm.LightGreen(tag)
}
}

View File

@ -10,6 +10,7 @@ import (
"net/url"
"os"
"path/filepath"
"reflect"
"strconv"
"strings"
"time"
@ -565,9 +566,9 @@ func GetFileMode(key string) fs.FileMode {
return fs.FileMode(mode)
}
func containsString(ss []string, s string) bool {
for _, v := range ss {
if v == s {
func contains[T string](ts []T, t T) bool {
for _, v := range ts {
if reflect.DeepEqual(v,t) {
return true
}
}

View File

@ -125,7 +125,7 @@ func (machine Machine) isExpired() bool {
func containsAddresses(inputs []string, addrs []string) bool {
for _, addr := range addrs {
if containsString(inputs, addr) {
if contains(inputs, addr) {
return true
}
}
@ -803,7 +803,7 @@ func (h *Headscale) EnableRoutes(machine *Machine, routeStrs ...string) error {
}
for _, newRoute := range newRoutes {
if !containsIPPrefix(machine.GetAdvertisedRoutes(), newRoute) {
if !contains(machine.GetAdvertisedRoutes(), newRoute) {
return fmt.Errorf(
"route (%s) is not available on node %s: %w",
machine.Name,

View File

@ -12,6 +12,7 @@ import (
"encoding/json"
"fmt"
"net"
"reflect"
"strings"
"github.com/rs/zerolog/log"
@ -223,16 +224,6 @@ func (h *Headscale) getUsedIPs() (*netaddr.IPSet, error) {
return ipSet, nil
}
func containsString(ss []string, s string) bool {
for _, v := range ss {
if v == s {
return true
}
}
return false
}
func tailNodesToString(nodes []*tailcfg.Node) string {
temp := make([]string, len(nodes))
@ -282,9 +273,9 @@ func stringToIPPrefix(prefixes []string) ([]netaddr.IPPrefix, error) {
return result, nil
}
func containsIPPrefix(prefixes []netaddr.IPPrefix, prefix netaddr.IPPrefix) bool {
for _, p := range prefixes {
if prefix == p {
func contains[T string | netaddr.IPPrefix](ts []T, t T) bool {
for _, v := range ts {
if reflect.DeepEqual(v, t) {
return true
}
}