mirror of
https://github.com/tailscale/tailscale.git
synced 2025-08-13 22:47:30 +00:00
util/linuxfw: initial implementation of package
This package is an initial implementation of something that can read netfilter and iptables rules from the Linux kernel without needing to shell out to an external utility; it speaks directly to the kernel using syscalls and parses the data returned. Currently this is read-only since it only knows how to parse a subset of the available data. Signed-off-by: Andrew Dunham <andrew@tailscale.com> Change-Id: Iccadf5dcc081b73268d8ccf8884c24eb6a6f1ff5
This commit is contained in:

committed by
Andrew Dunham

parent
3c107ff301
commit
ba48ec5e39
269
util/linuxfw/nftables.go
Normal file
269
util/linuxfw/nftables.go
Normal file
@@ -0,0 +1,269 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build linux && !(386 || loong64)
|
||||
|
||||
package linuxfw
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/google/nftables/xt"
|
||||
"github.com/josharian/native"
|
||||
"golang.org/x/sys/unix"
|
||||
"tailscale.com/types/logger"
|
||||
)
|
||||
|
||||
// DebugNetfilter prints debug information about netfilter rules to the
|
||||
// provided log function.
|
||||
func DebugNetfilter(logf logger.Logf) error {
|
||||
conn, err := nftables.New()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
chains, err := conn.ListChains()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot list chains: %w", err)
|
||||
}
|
||||
|
||||
if len(chains) == 0 {
|
||||
logf("netfilter: no chains")
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, chain := range chains {
|
||||
logf("netfilter: table=%s chain=%s", chain.Table.Name, chain.Name)
|
||||
|
||||
rules, err := conn.GetRules(chain.Table, chain)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
sort.Slice(rules, func(i, j int) bool {
|
||||
return rules[i].Position < rules[j].Position
|
||||
})
|
||||
|
||||
for i, rule := range rules {
|
||||
logf("netfilter: rule[%d]: pos=%d flags=%d", i, rule.Position, rule.Flags)
|
||||
for _, ex := range rule.Exprs {
|
||||
switch v := ex.(type) {
|
||||
case *expr.Meta:
|
||||
key := metaKeyNames[v.Key]
|
||||
if key == "" {
|
||||
key = "UNKNOWN"
|
||||
}
|
||||
logf("netfilter: Meta: key=%s source_register=%v register=%d", key, v.SourceRegister, v.Register)
|
||||
|
||||
case *expr.Cmp:
|
||||
op := cmpOpNames[v.Op]
|
||||
if op == "" {
|
||||
op = "UNKNOWN"
|
||||
}
|
||||
logf("netfilter: Cmp: op=%s register=%d data=%s", op, v.Register, formatMaybePrintable(v.Data))
|
||||
|
||||
case *expr.Counter:
|
||||
// don't print
|
||||
|
||||
case *expr.Verdict:
|
||||
kind := verdictNames[v.Kind]
|
||||
if kind == "" {
|
||||
kind = "UNKNOWN"
|
||||
}
|
||||
logf("netfilter: Verdict: kind=%s data=%s", kind, v.Chain)
|
||||
|
||||
case *expr.Target:
|
||||
logf("netfilter: Target: name=%s info=%s", v.Name, printTargetInfo(v.Name, v.Info))
|
||||
|
||||
case *expr.Match:
|
||||
logf("netfilter: Match: name=%s info=%+v", v.Name, printMatchInfo(v.Name, v.Info))
|
||||
|
||||
case *expr.Payload:
|
||||
logf("netfilter: Payload: op=%s src=%d dst=%d base=%s offset=%d len=%d",
|
||||
payloadOperationTypeNames[v.OperationType],
|
||||
v.SourceRegister, v.DestRegister,
|
||||
payloadBaseNames[v.Base],
|
||||
v.Offset, v.Len)
|
||||
// TODO(andrew): csum
|
||||
|
||||
case *expr.Bitwise:
|
||||
var xor string
|
||||
for _, b := range v.Xor {
|
||||
if b != 0 {
|
||||
xor = fmt.Sprintf(" xor=%v", v.Xor)
|
||||
break
|
||||
}
|
||||
}
|
||||
logf("netfilter: Bitwise: src=%d dst=%d len=%d mask=%v%s",
|
||||
v.SourceRegister, v.DestRegister, v.Len, v.Mask, xor)
|
||||
|
||||
default:
|
||||
logf("netfilter: unknown %T: %+v", v, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DetectNetfilter returns the number of nftables rules present in the system.
|
||||
func DetectNetfilter() (int, error) {
|
||||
conn, err := nftables.New()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
chains, err := conn.ListChains()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("cannot list chains: %w", err)
|
||||
}
|
||||
|
||||
var validRules int
|
||||
for _, chain := range chains {
|
||||
rules, err := conn.GetRules(chain.Table, chain)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
validRules += len(rules)
|
||||
}
|
||||
return validRules, nil
|
||||
}
|
||||
|
||||
func printMatchInfo(name string, info xt.InfoAny) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString(`{`)
|
||||
|
||||
var handled bool = true
|
||||
switch v := info.(type) {
|
||||
// TODO(andrew): we should support these common types
|
||||
//case *xt.ConntrackMtinfo3:
|
||||
//case *xt.ConntrackMtinfo2:
|
||||
case *xt.Tcp:
|
||||
fmt.Fprintf(&sb, "Src:%s Dst:%s", formatPortRange(v.SrcPorts), formatPortRange(v.DstPorts))
|
||||
if v.Option != 0 {
|
||||
fmt.Fprintf(&sb, " Option:%d", v.Option)
|
||||
}
|
||||
if v.FlagsMask != 0 {
|
||||
fmt.Fprintf(&sb, " FlagsMask:%d", v.FlagsMask)
|
||||
}
|
||||
if v.FlagsCmp != 0 {
|
||||
fmt.Fprintf(&sb, " FlagsCmp:%d", v.FlagsCmp)
|
||||
}
|
||||
if v.InvFlags != 0 {
|
||||
fmt.Fprintf(&sb, " InvFlags:%d", v.InvFlags)
|
||||
}
|
||||
|
||||
case *xt.Udp:
|
||||
fmt.Fprintf(&sb, "Src:%s Dst:%s", formatPortRange(v.SrcPorts), formatPortRange(v.DstPorts))
|
||||
if v.InvFlags != 0 {
|
||||
fmt.Fprintf(&sb, " InvFlags:%d", v.InvFlags)
|
||||
}
|
||||
|
||||
case *xt.AddrType:
|
||||
var sprefix, dprefix string
|
||||
if v.InvertSource {
|
||||
sprefix = "!"
|
||||
}
|
||||
if v.InvertDest {
|
||||
dprefix = "!"
|
||||
}
|
||||
// TODO(andrew): translate source/dest
|
||||
fmt.Fprintf(&sb, "Source:%s%d Dest:%s%d", sprefix, v.Source, dprefix, v.Dest)
|
||||
|
||||
case *xt.AddrTypeV1:
|
||||
// TODO(andrew): translate source/dest
|
||||
fmt.Fprintf(&sb, "Source:%d Dest:%d", v.Source, v.Dest)
|
||||
|
||||
var flags []string
|
||||
for flag, name := range addrTypeFlagNames {
|
||||
if v.Flags&flag != 0 {
|
||||
flags = append(flags, name)
|
||||
}
|
||||
}
|
||||
if len(flags) > 0 {
|
||||
sort.Strings(flags)
|
||||
fmt.Fprintf(&sb, "Flags:%s", strings.Join(flags, ","))
|
||||
}
|
||||
|
||||
default:
|
||||
handled = false
|
||||
}
|
||||
if handled {
|
||||
sb.WriteString(`}`)
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
unknown, ok := info.(*xt.Unknown)
|
||||
if !ok {
|
||||
return fmt.Sprintf("(%T)%+v", info, info)
|
||||
}
|
||||
data := []byte(*unknown)
|
||||
|
||||
// Things where upstream has no type
|
||||
handled = true
|
||||
switch name {
|
||||
case "pkttype":
|
||||
if len(data) != 8 {
|
||||
handled = false
|
||||
break
|
||||
}
|
||||
|
||||
pkttype := int(native.Endian.Uint32(data[0:4]))
|
||||
invert := int(native.Endian.Uint32(data[4:8]))
|
||||
var invertPrefix string
|
||||
if invert != 0 {
|
||||
invertPrefix = "!"
|
||||
}
|
||||
|
||||
pkttypeName := packetTypeNames[pkttype]
|
||||
if pkttypeName != "" {
|
||||
fmt.Fprintf(&sb, "PktType:%s%s", invertPrefix, pkttypeName)
|
||||
} else {
|
||||
fmt.Fprintf(&sb, "PktType:%s%d", invertPrefix, pkttype)
|
||||
}
|
||||
|
||||
default:
|
||||
handled = true
|
||||
}
|
||||
|
||||
if !handled {
|
||||
return fmt.Sprintf("(%T)%+v", info, info)
|
||||
}
|
||||
|
||||
sb.WriteString(`}`)
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func printTargetInfo(name string, info xt.InfoAny) string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString(`{`)
|
||||
|
||||
unknown, ok := info.(*xt.Unknown)
|
||||
if !ok {
|
||||
return fmt.Sprintf("(%T)%+v", info, info)
|
||||
}
|
||||
data := []byte(*unknown)
|
||||
|
||||
// Things where upstream has no type
|
||||
switch name {
|
||||
case "LOG":
|
||||
if len(data) != 32 {
|
||||
fmt.Fprintf(&sb, `Error:"bad size; want 32, got %d"`, len(data))
|
||||
break
|
||||
}
|
||||
|
||||
level := data[0]
|
||||
logflags := data[1]
|
||||
prefix := unix.ByteSliceToString(data[2:])
|
||||
fmt.Fprintf(&sb, "Level:%d LogFlags:%d Prefix:%q", level, logflags, prefix)
|
||||
default:
|
||||
return fmt.Sprintf("(%T)%+v", info, info)
|
||||
}
|
||||
|
||||
sb.WriteString(`}`)
|
||||
return sb.String()
|
||||
}
|
Reference in New Issue
Block a user