diff --git a/acls.go b/acls.go index 2073ee84..53bb7023 100644 --- a/acls.go +++ b/acls.go @@ -163,23 +163,20 @@ func (h *Headscale) UpdateACLRules() error { // generateACLPeerCacheMap takes a list of Tailscale filter rules and generates a map // of which Sources ("*" and IPs) can access destinations. This is to speed up the // process of generating MapResponses when deciding which Peers to inform nodes about. -func generateACLPeerCacheMap(rules []tailcfg.FilterRule) map[string]map[string]struct{} { - aclCachePeerMap := make(map[string]map[string]struct{}) +func generateACLPeerCacheMap(rules []tailcfg.FilterRule) map[string][]string { + aclCachePeerMap := make(map[string][]string) for _, rule := range rules { for _, srcIP := range rule.SrcIPs { for _, ip := range expandACLPeerAddr(srcIP) { if data, ok := aclCachePeerMap[ip]; ok { for _, dstPort := range rule.DstPorts { - for _, dstIP := range expandACLPeerAddr(dstPort.IP) { - data[dstIP] = struct{}{} - } + data = append(data, dstPort.IP) } + aclCachePeerMap[ip] = data } else { - dstPortsMap := make(map[string]struct{}, len(rule.DstPorts)) + dstPortsMap := make([]string, 0) for _, dstPort := range rule.DstPorts { - for _, dstIP := range expandACLPeerAddr(dstPort.IP) { - dstPortsMap[dstIP] = struct{}{} - } + dstPortsMap = append(dstPortsMap, dstPort.IP) } aclCachePeerMap[ip] = dstPortsMap } diff --git a/app.go b/app.go index 26a8e23b..480689bc 100644 --- a/app.go +++ b/app.go @@ -87,7 +87,7 @@ type Headscale struct { aclPolicy *ACLPolicy aclRules []tailcfg.FilterRule aclPeerCacheMapRW sync.RWMutex - aclPeerCacheMap map[string]map[string]struct{} + aclPeerCacheMap map[string][]string sshPolicy *tailcfg.SSHPolicy lastStateChange *xsync.MapOf[string, time.Time] diff --git a/machine.go b/machine.go index 6dfa9501..1b70b1e2 100644 --- a/machine.go +++ b/machine.go @@ -4,6 +4,7 @@ import ( "database/sql/driver" "errors" "fmt" + "net" "net/netip" "sort" "strconv" @@ -172,7 +173,7 @@ func filterMachinesByACL( machine *Machine, machines Machines, lock *sync.RWMutex, - aclPeerCacheMap map[string]map[string]struct{}, + aclPeerCacheMap map[string][]string, ) Machines { log.Trace(). Caller(). @@ -197,43 +198,34 @@ func filterMachinesByACL( if dstMap, ok := aclPeerCacheMap["*"]; ok { // match source and all destination - if _, dstOk := dstMap["*"]; dstOk { - peers[peer.ID] = peer - continue + for _, dst := range dstMap { + if dst == "*" { + peers[peer.ID] = peer + + continue + } } // match source and all destination for _, peerIP := range peerIPs { - if _, dstOk := dstMap[peerIP]; dstOk { - peers[peer.ID] = peer + for _, dst := range dstMap { + _, cdr, _ := net.ParseCIDR(dst) + ip := net.ParseIP(peerIP) + if dst == peerIP || (cdr != nil && ip != nil && cdr.Contains(ip)) { + peers[peer.ID] = peer - continue + continue + } } } // match all sources and source for _, machineIP := range machineIPs { - if _, dstOk := dstMap[machineIP]; dstOk { - peers[peer.ID] = peer - - continue - } - } - } - - for _, machineIP := range machineIPs { - if dstMap, ok := aclPeerCacheMap[machineIP]; ok { - // match source and all destination - if _, dstOk := dstMap["*"]; dstOk { - peers[peer.ID] = peer - - continue - } - - // match source and destination - for _, peerIP := range peerIPs { - if _, dstOk := dstMap[peerIP]; dstOk { + for _, dst := range dstMap { + _, cdr, _ := net.ParseCIDR(dst) + ip := net.ParseIP(machineIP) + if dst == machineIP || (cdr != nil && ip != nil && cdr.Contains(ip)) { peers[peer.ID] = peer continue @@ -242,22 +234,55 @@ func filterMachinesByACL( } } - for _, peerIP := range peerIPs { - if dstMap, ok := aclPeerCacheMap[peerIP]; ok { + for _, machineIP := range machineIPs { + if dstMap, ok := aclPeerCacheMap[machineIP]; ok { // match source and all destination - if _, dstOk := dstMap["*"]; dstOk { - peers[peer.ID] = peer - - continue - } - // match return path - for _, machineIP := range machineIPs { - if _, dstOk := dstMap[machineIP]; dstOk { + for _, dst := range dstMap { + if dst == "*" { peers[peer.ID] = peer continue } } + + // match source and destination + for _, peerIP := range peerIPs { + for _, dst := range dstMap { + _, cdr, _ := net.ParseCIDR(dst) + ip := net.ParseIP(peerIP) + if dst == peerIP || (cdr != nil && ip != nil && cdr.Contains(ip)) { + peers[peer.ID] = peer + + continue + } + } + } + } + } + + for _, peerIP := range peerIPs { + if dstMap, ok := aclPeerCacheMap[peerIP]; ok { + // match source and all destination + for _, dst := range dstMap { + if dst == "*" { + peers[peer.ID] = peer + + continue + } + } + + // match return path + for _, machineIP := range machineIPs { + for _, dst := range dstMap { + _, cdr, _ := net.ParseCIDR(dst) + ip := net.ParseIP(machineIP) + if dst == machineIP || (cdr != nil && ip != nil && cdr.Contains(ip)) { + peers[peer.ID] = peer + + continue + } + } + } } } }