wgengine/filter: twiddle bits to optimize

Part of #19.

name            old time/op    new time/op    delta
Filter/icmp4-8    32.2ns ± 3%    32.5ns ± 2%     ~     (p=0.524 n=10+8)
Filter/icmp6-8    49.7ns ± 6%    43.1ns ± 4%  -13.12%  (p=0.000 n=9+10)

Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:
Josh Bleecher Snyder 2020-11-12 17:10:38 -08:00 committed by David Anderson
parent 5062131aad
commit 47380ebcfb
3 changed files with 57 additions and 10 deletions

View File

@ -25,7 +25,7 @@ func newFilter(logf logger.Logf) *Filter {
{Srcs: nets("0.0.0.0/0"), Dsts: netports("100.122.98.50:*")},
{Srcs: nets("0.0.0.0/0"), Dsts: netports("0.0.0.0/0:443")},
{Srcs: nets("153.1.1.1", "153.1.1.2", "153.3.3.3"), Dsts: netports("1.2.3.4:999")},
{Srcs: nets("::1", "::2"), Dsts: netports("2001::1:22")},
{Srcs: nets("::1", "::2"), Dsts: netports("2001::1:22", "2001::2:22")},
{Srcs: nets("::/0"), Dsts: netports("::/0:443")},
}
@ -65,8 +65,9 @@ func TestFilter(t *testing.T) {
{Accept, parsed(packet.TCP, "::1", "2001::1", 0, 22)},
{Accept, parsed(packet.ICMPv6, "::1", "2001::1", 0, 0)},
{Accept, parsed(packet.TCP, "::2", "2001::1", 0, 22)},
{Accept, parsed(packet.TCP, "::2", "2001::2", 0, 22)},
{Drop, parsed(packet.TCP, "::1", "2001::1", 0, 23)},
{Drop, parsed(packet.TCP, "::1", "2001::2", 0, 22)},
{Drop, parsed(packet.TCP, "::1", "2001::3", 0, 22)},
{Drop, parsed(packet.TCP, "::3", "2001::1", 0, 22)},
// allow * => *:443
{Accept, parsed(packet.TCP, "::1", "2001::1", 0, 443)},
@ -83,7 +84,7 @@ func TestFilter(t *testing.T) {
aclFunc = acl.runIn6
}
if got, why := aclFunc(&test.p); test.want != got {
t.Errorf("#%d runIn4 got=%v want=%v why=%q packet:%v", i, got, test.want, why, test.p)
t.Errorf("#%d runIn got=%v want=%v why=%q packet:%v", i, got, test.want, why, test.p)
}
if test.p.IPProto == packet.TCP {
var got Response
@ -98,7 +99,7 @@ func TestFilter(t *testing.T) {
// TCP and UDP are treated equivalently in the filter - verify that.
test.p.IPProto = packet.UDP
if got, why := aclFunc(&test.p); test.want != got {
t.Errorf("#%d runIn4 (UDP) got=%v want=%v why=%q packet:%v", i, got, test.want, why, test.p)
t.Errorf("#%d runIn (UDP) got=%v want=%v why=%q packet:%v", i, got, test.want, why, test.p)
}
}
// Update UDP state

View File

@ -46,8 +46,16 @@ func nets6FromIPPrefixes(pfxs []netaddr.IPPrefix) (ret []net6) {
}
func (n net6) Contains(ip packet.IP6) bool {
return ((n.ip.Hi&n.mask.Hi) == (ip.Hi&n.mask.Hi) &&
(n.ip.Lo&n.mask.Lo) == (ip.Lo&n.mask.Lo))
// This is equivalent to the more straightforward implementation:
// ((n.ip.Hi & n.mask.Hi) == (ip.Hi & n.mask.Hi) &&
// (n.ip.Lo & n.mask.Lo) == (ip.Lo & n.mask.Lo))
//
// This implementation runs significantly faster because it
// eliminates branches and minimizes the required
// bit-twiddling.
a := (n.ip.Hi ^ ip.Hi) & n.mask.Hi
b := (n.ip.Lo ^ ip.Lo) & n.mask.Lo
return (a | b) == 0
}
func (n net6) Bits() int {
@ -128,12 +136,13 @@ func (ms matches6) match(q *packet.Parsed) bool {
}
func (ms matches6) matchIPsOnly(q *packet.Parsed) bool {
for _, m := range ms {
if !ip6InList(q.SrcIP6, m.srcs) {
for i := range ms {
if !ip6InList(q.SrcIP6, ms[i].srcs) {
continue
}
for _, dst := range m.dsts {
if dst.net.Contains(q.DstIP6) {
dsts := ms[i].dsts
for i := range dsts {
if dsts[i].net.Contains(q.DstIP6) {
return true
}
}

View File

@ -0,0 +1,37 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package filter
import "testing"
// Verifies that the fast bit-twiddling implementation of Contains
// works the same as the easy-to-read implementation. Since we can't
// sensibly check it on 128 bits, the test runs over 4-bit
// "IPs". Bit-twiddling is the same at any width, so this adequately
// proves that the implementations are equivalent.
func TestOptimizedContains(t *testing.T) {
for ipHi := 0; ipHi < 0xf; ipHi++ {
for ipLo := 0; ipLo < 0xf; ipLo++ {
for nIPHi := 0; nIPHi < 0xf; nIPHi++ {
for nIPLo := 0; nIPLo < 0xf; nIPLo++ {
for maskHi := 0; maskHi < 0xf; maskHi++ {
for maskLo := 0; maskLo < 0xf; maskLo++ {
a := (nIPHi ^ ipHi) & maskHi
b := (nIPLo ^ ipLo) & maskLo
got := (a | b) == 0
want := ((nIPHi&maskHi) == (ipHi&maskHi) && (nIPLo&maskLo) == (ipLo&maskLo))
if got != want {
t.Errorf("mask %1x%1x/%1x%1x %1x%1x got=%v want=%v", nIPHi, nIPLo, maskHi, maskLo, ipHi, ipLo, got, want)
}
}
}
}
}
}
}
}