Migrate IP fields in database to dedicated columns (#1869)

This commit is contained in:
Kristoffer Dalby
2024-04-17 07:03:06 +02:00
committed by GitHub
parent 85cef84e17
commit 2ce23df45a
39 changed files with 1885 additions and 1055 deletions

View File

@@ -31,6 +31,13 @@ var errOidcMutuallyExclusive = errors.New(
"oidc_client_secret and oidc_client_secret_path are mutually exclusive",
)
type IPAllocationStrategy string
const (
IPAllocationStrategySequential IPAllocationStrategy = "sequential"
IPAllocationStrategyRandom IPAllocationStrategy = "random"
)
// Config contains the initial Headscale configuration.
type Config struct {
ServerURL string
@@ -42,6 +49,7 @@ type Config struct {
NodeUpdateCheckInterval time.Duration
PrefixV4 *netip.Prefix
PrefixV6 *netip.Prefix
IPAllocation IPAllocationStrategy
NoisePrivateKeyPath string
BaseDomain string
Log LogConfig
@@ -230,6 +238,8 @@ func LoadConfig(path string, isFile bool) error {
viper.SetDefault("tuning.batch_change_delay", "800ms")
viper.SetDefault("tuning.node_mapsession_buffered_chan_size", 30)
viper.SetDefault("prefixes.allocation", IPAllocationStrategySequential)
if IsCLIConfigured() {
return nil
}
@@ -579,18 +589,16 @@ func GetDNSConfig() (*tailcfg.DNSConfig, string) {
return nil, ""
}
func Prefixes() (*netip.Prefix, *netip.Prefix, error) {
func PrefixV4() (*netip.Prefix, error) {
prefixV4Str := viper.GetString("prefixes.v4")
prefixV6Str := viper.GetString("prefixes.v6")
if prefixV4Str == "" {
return nil, nil
}
prefixV4, err := netip.ParsePrefix(prefixV4Str)
if err != nil {
return nil, nil, err
}
prefixV6, err := netip.ParsePrefix(prefixV6Str)
if err != nil {
return nil, nil, err
return nil, fmt.Errorf("parsing IPv4 prefix from config: %w", err)
}
builder := netipx.IPSetBuilder{}
@@ -603,13 +611,33 @@ func Prefixes() (*netip.Prefix, *netip.Prefix, error) {
prefixV4Str, tsaddr.CGNATRange())
}
return &prefixV4, nil
}
func PrefixV6() (*netip.Prefix, error) {
prefixV6Str := viper.GetString("prefixes.v6")
if prefixV6Str == "" {
return nil, nil
}
prefixV6, err := netip.ParsePrefix(prefixV6Str)
if err != nil {
return nil, fmt.Errorf("parsing IPv6 prefix from config: %w", err)
}
builder := netipx.IPSetBuilder{}
builder.AddPrefix(tsaddr.CGNATRange())
builder.AddPrefix(tsaddr.TailscaleULARange())
ipSet, _ := builder.IPSet()
if !ipSet.ContainsPrefix(prefixV6) {
log.Warn().
Msgf("Prefix %s is not in the %s range. This is an unsupported configuration.",
prefixV6Str, tsaddr.TailscaleULARange())
}
return &prefixV4, &prefixV6, nil
return &prefixV6, nil
}
func GetHeadscaleConfig() (*Config, error) {
@@ -624,11 +652,27 @@ func GetHeadscaleConfig() (*Config, error) {
}, nil
}
prefix4, prefix6, err := Prefixes()
prefix4, err := PrefixV4()
if err != nil {
return nil, err
}
prefix6, err := PrefixV6()
if err != nil {
return nil, err
}
allocStr := viper.GetString("prefixes.allocation")
var alloc IPAllocationStrategy
switch allocStr {
case string(IPAllocationStrategySequential):
alloc = IPAllocationStrategySequential
case string(IPAllocationStrategyRandom):
alloc = IPAllocationStrategyRandom
default:
log.Fatal().Msgf("config error, prefixes.allocation is set to %s, which is not a valid strategy, allowed options: %s, %s", allocStr, IPAllocationStrategySequential, IPAllocationStrategyRandom)
}
dnsConfig, baseDomain := GetDNSConfig()
derpConfig := GetDERPConfig()
logConfig := GetLogTailConfig()
@@ -655,8 +699,9 @@ func GetHeadscaleConfig() (*Config, error) {
GRPCAllowInsecure: viper.GetBool("grpc_allow_insecure"),
DisableUpdateCheck: viper.GetBool("disable_check_updates"),
PrefixV4: prefix4,
PrefixV6: prefix6,
PrefixV4: prefix4,
PrefixV6: prefix6,
IPAllocation: IPAllocationStrategy(alloc),
NoisePrivateKeyPath: util.AbsolutePathFromConfigPath(
viper.GetString("noise.private_key_path"),

View File

@@ -1,12 +1,11 @@
package types
import (
"database/sql/driver"
"database/sql"
"encoding/json"
"errors"
"fmt"
"net/netip"
"sort"
"strconv"
"strings"
"time"
@@ -14,7 +13,6 @@ import (
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"go4.org/netipx"
"google.golang.org/protobuf/types/known/timestamppb"
"gorm.io/gorm"
@@ -83,7 +81,19 @@ type Node struct {
HostinfoDatabaseField string `gorm:"column:host_info"`
Hostinfo *tailcfg.Hostinfo `gorm:"-"`
IPAddresses NodeAddresses
// IPv4DatabaseField is the string representation of v4 address,
// it is _only_ used for reading and writing the key to the
// database and should not be used.
// Use V4 instead.
IPv4DatabaseField sql.NullString `gorm:"column:ipv4"`
IPv4 *netip.Addr `gorm:"-"`
// IPv6DatabaseField is the string representation of v4 address,
// it is _only_ used for reading and writing the key to the
// database and should not be used.
// Use V6 instead.
IPv6DatabaseField sql.NullString `gorm:"column:ipv6"`
IPv6 *netip.Addr `gorm:"-"`
// Hostname represents the name given by the Tailscale
// client during registration
@@ -123,89 +133,6 @@ type (
Nodes []*Node
)
type NodeAddresses []netip.Addr
func (na NodeAddresses) Sort() {
sort.Slice(na, func(index1, index2 int) bool {
if na[index1].Is4() && na[index2].Is6() {
return true
}
if na[index1].Is6() && na[index2].Is4() {
return false
}
return na[index1].Compare(na[index2]) < 0
})
}
func (na NodeAddresses) StringSlice() []string {
na.Sort()
strSlice := make([]string, 0, len(na))
for _, addr := range na {
strSlice = append(strSlice, addr.String())
}
return strSlice
}
func (na NodeAddresses) Prefixes() []netip.Prefix {
addrs := []netip.Prefix{}
for _, nodeAddress := range na {
ip := netip.PrefixFrom(nodeAddress, nodeAddress.BitLen())
addrs = append(addrs, ip)
}
return addrs
}
func (na NodeAddresses) InIPSet(set *netipx.IPSet) bool {
for _, nodeAddr := range na {
if set.Contains(nodeAddr) {
return true
}
}
return false
}
// AppendToIPSet adds the individual ips in NodeAddresses to a
// given netipx.IPSetBuilder.
func (na NodeAddresses) AppendToIPSet(build *netipx.IPSetBuilder) {
for _, ip := range na {
build.Add(ip)
}
}
func (na *NodeAddresses) Scan(destination interface{}) error {
switch value := destination.(type) {
case string:
addresses := strings.Split(value, ",")
*na = (*na)[:0]
for _, addr := range addresses {
if len(addr) < 1 {
continue
}
parsed, err := netip.ParseAddr(addr)
if err != nil {
return err
}
*na = append(*na, parsed)
}
return nil
default:
return fmt.Errorf("%w: unexpected data type %T", ErrNodeAddressesInvalid, destination)
}
}
// Value return json value, implement driver.Valuer interface.
func (na NodeAddresses) Value() (driver.Value, error) {
addresses := strings.Join(na.StringSlice(), ",")
return addresses, nil
}
// IsExpired returns whether the node registration has expired.
func (node Node) IsExpired() bool {
// If Expiry is not set, the client has not indicated that
@@ -224,8 +151,65 @@ func (node *Node) IsEphemeral() bool {
return node.AuthKey != nil && node.AuthKey.Ephemeral
}
func (node *Node) IPs() []netip.Addr {
var ret []netip.Addr
if node.IPv4 != nil {
ret = append(ret, *node.IPv4)
}
if node.IPv6 != nil {
ret = append(ret, *node.IPv6)
}
return ret
}
func (node *Node) Prefixes() []netip.Prefix {
addrs := []netip.Prefix{}
for _, nodeAddress := range node.IPs() {
ip := netip.PrefixFrom(nodeAddress, nodeAddress.BitLen())
addrs = append(addrs, ip)
}
return addrs
}
func (node *Node) IPsAsString() []string {
var ret []string
if node.IPv4 != nil {
ret = append(ret, node.IPv4.String())
}
if node.IPv6 != nil {
ret = append(ret, node.IPv6.String())
}
return ret
}
func (node *Node) InIPSet(set *netipx.IPSet) bool {
for _, nodeAddr := range node.IPs() {
if set.Contains(nodeAddr) {
return true
}
}
return false
}
// AppendToIPSet adds the individual ips in NodeAddresses to a
// given netipx.IPSetBuilder.
func (node *Node) AppendToIPSet(build *netipx.IPSetBuilder) {
for _, ip := range node.IPs() {
build.Add(ip)
}
}
func (node *Node) CanAccess(filter []tailcfg.FilterRule, node2 *Node) bool {
allowedIPs := append([]netip.Addr{}, node2.IPAddresses...)
src := node.IPs()
allowedIPs := node2.IPs()
for _, route := range node2.Routes {
if route.Enabled {
@@ -237,7 +221,7 @@ func (node *Node) CanAccess(filter []tailcfg.FilterRule, node2 *Node) bool {
// TODO(kradalby): Cache or pregen this
matcher := matcher.MatchFromFilterRule(rule)
if !matcher.SrcsContainsIPs([]netip.Addr(node.IPAddresses)) {
if !matcher.SrcsContainsIPs(src) {
continue
}
@@ -250,13 +234,16 @@ func (node *Node) CanAccess(filter []tailcfg.FilterRule, node2 *Node) bool {
}
func (nodes Nodes) FilterByIP(ip netip.Addr) Nodes {
found := make(Nodes, 0)
var found Nodes
for _, node := range nodes {
for _, mIP := range node.IPAddresses {
if ip == mIP {
found = append(found, node)
}
if node.IPv4 != nil && ip == *node.IPv4 {
found = append(found, node)
continue
}
if node.IPv6 != nil && ip == *node.IPv6 {
found = append(found, node)
}
}
@@ -281,10 +268,22 @@ func (node *Node) BeforeSave(tx *gorm.DB) error {
hi, err := json.Marshal(node.Hostinfo)
if err != nil {
return fmt.Errorf("failed to marshal Hostinfo to store in db: %w", err)
return fmt.Errorf("marshalling Hostinfo to store in db: %w", err)
}
node.HostinfoDatabaseField = string(hi)
if node.IPv4 != nil {
node.IPv4DatabaseField.String, node.IPv4DatabaseField.Valid = node.IPv4.String(), true
} else {
node.IPv4DatabaseField.String, node.IPv4DatabaseField.Valid = "", false
}
if node.IPv6 != nil {
node.IPv6DatabaseField.String, node.IPv6DatabaseField.Valid = node.IPv6.String(), true
} else {
node.IPv6DatabaseField.String, node.IPv6DatabaseField.Valid = "", false
}
return nil
}
@@ -296,19 +295,19 @@ func (node *Node) BeforeSave(tx *gorm.DB) error {
func (node *Node) AfterFind(tx *gorm.DB) error {
var machineKey key.MachinePublic
if err := machineKey.UnmarshalText([]byte(node.MachineKeyDatabaseField)); err != nil {
return fmt.Errorf("failed to unmarshal machine key from db: %w", err)
return fmt.Errorf("unmarshalling machine key from db: %w", err)
}
node.MachineKey = machineKey
var nodeKey key.NodePublic
if err := nodeKey.UnmarshalText([]byte(node.NodeKeyDatabaseField)); err != nil {
return fmt.Errorf("failed to unmarshal node key from db: %w", err)
return fmt.Errorf("unmarshalling node key from db: %w", err)
}
node.NodeKey = nodeKey
var discoKey key.DiscoPublic
if err := discoKey.UnmarshalText([]byte(node.DiscoKeyDatabaseField)); err != nil {
return fmt.Errorf("failed to unmarshal disco key from db: %w", err)
return fmt.Errorf("unmarshalling disco key from db: %w", err)
}
node.DiscoKey = discoKey
@@ -316,7 +315,7 @@ func (node *Node) AfterFind(tx *gorm.DB) error {
for idx, ep := range node.EndpointsDatabaseField {
addrPort, err := netip.ParseAddrPort(ep)
if err != nil {
return fmt.Errorf("failed to parse endpoint from db: %w", err)
return fmt.Errorf("parsing endpoint from db: %w", err)
}
endpoints[idx] = addrPort
@@ -325,12 +324,28 @@ func (node *Node) AfterFind(tx *gorm.DB) error {
var hi tailcfg.Hostinfo
if err := json.Unmarshal([]byte(node.HostinfoDatabaseField), &hi); err != nil {
log.Trace().Err(err).Msgf("Hostinfo content: %s", node.HostinfoDatabaseField)
return fmt.Errorf("failed to unmarshal Hostinfo from db: %w", err)
return fmt.Errorf("unmarshalling hostinfo from database: %w", err)
}
node.Hostinfo = &hi
if node.IPv4DatabaseField.Valid {
ip, err := netip.ParseAddr(node.IPv4DatabaseField.String)
if err != nil {
return fmt.Errorf("parsing IPv4 from database: %w", err)
}
node.IPv4 = &ip
}
if node.IPv6DatabaseField.Valid {
ip, err := netip.ParseAddr(node.IPv6DatabaseField.String)
if err != nil {
return fmt.Errorf("parsing IPv6 from database: %w", err)
}
node.IPv6 = &ip
}
return nil
}
@@ -339,9 +354,11 @@ func (node *Node) Proto() *v1.Node {
Id: uint64(node.ID),
MachineKey: node.MachineKey.String(),
NodeKey: node.NodeKey.String(),
DiscoKey: node.DiscoKey.String(),
IpAddresses: node.IPAddresses.StringSlice(),
NodeKey: node.NodeKey.String(),
DiscoKey: node.DiscoKey.String(),
// TODO(kradalby): replace list with v4, v6 field?
IpAddresses: node.IPsAsString(),
Name: node.Hostname,
GivenName: node.GivenName,
User: node.User.Proto(),

View File

@@ -12,6 +12,10 @@ import (
)
func Test_NodeCanAccess(t *testing.T) {
iap := func(ipStr string) *netip.Addr {
ip := netip.MustParseAddr(ipStr)
return &ip
}
tests := []struct {
name string
node1 Node
@@ -22,10 +26,10 @@ func Test_NodeCanAccess(t *testing.T) {
{
name: "no-rules",
node1: Node{
IPAddresses: []netip.Addr{netip.MustParseAddr("10.0.0.1")},
IPv4: iap("10.0.0.1"),
},
node2: Node{
IPAddresses: []netip.Addr{netip.MustParseAddr("10.0.0.2")},
IPv4: iap("10.0.0.2"),
},
rules: []tailcfg.FilterRule{},
want: false,
@@ -33,10 +37,10 @@ func Test_NodeCanAccess(t *testing.T) {
{
name: "wildcard",
node1: Node{
IPAddresses: []netip.Addr{netip.MustParseAddr("10.0.0.1")},
IPv4: iap("10.0.0.1"),
},
node2: Node{
IPAddresses: []netip.Addr{netip.MustParseAddr("10.0.0.2")},
IPv4: iap("10.0.0.2"),
},
rules: []tailcfg.FilterRule{
{
@@ -54,10 +58,10 @@ func Test_NodeCanAccess(t *testing.T) {
{
name: "other-cant-access-src",
node1: Node{
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")},
IPv4: iap("100.64.0.1"),
},
node2: Node{
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
IPv4: iap("100.64.0.3"),
},
rules: []tailcfg.FilterRule{
{
@@ -72,10 +76,10 @@ func Test_NodeCanAccess(t *testing.T) {
{
name: "dest-cant-access-src",
node1: Node{
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
IPv4: iap("100.64.0.3"),
},
node2: Node{
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
IPv4: iap("100.64.0.2"),
},
rules: []tailcfg.FilterRule{
{
@@ -90,10 +94,10 @@ func Test_NodeCanAccess(t *testing.T) {
{
name: "src-can-access-dest",
node1: Node{
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")},
IPv4: iap("100.64.0.2"),
},
node2: Node{
IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")},
IPv4: iap("100.64.0.3"),
},
rules: []tailcfg.FilterRule{
{
@@ -118,32 +122,6 @@ func Test_NodeCanAccess(t *testing.T) {
}
}
func TestNodeAddressesOrder(t *testing.T) {
machineAddresses := NodeAddresses{
netip.MustParseAddr("2001:db8::2"),
netip.MustParseAddr("100.64.0.2"),
netip.MustParseAddr("2001:db8::1"),
netip.MustParseAddr("100.64.0.1"),
}
strSlice := machineAddresses.StringSlice()
expected := []string{
"100.64.0.1",
"100.64.0.2",
"2001:db8::1",
"2001:db8::2",
}
if len(strSlice) != len(expected) {
t.Fatalf("unexpected slice length: got %v, want %v", len(strSlice), len(expected))
}
for i, addr := range strSlice {
if addr != expected[i] {
t.Errorf("unexpected address at index %v: got %v, want %v", i, addr, expected[i])
}
}
}
func TestNodeFQDN(t *testing.T) {
tests := []struct {
name string