// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause

//go:build linux

package linuxfw

import (
	"errors"
	"fmt"
	"net/netip"
	"reflect"
	"strings"

	"github.com/google/nftables"
	"github.com/google/nftables/binaryutil"
	"github.com/google/nftables/expr"
	"golang.org/x/sys/unix"
)

// This file contains functionality that is currently (09/2024) used to set up
// routing for the Tailscale Kubernetes operator egress proxies. A tailnet
// service (identified by tailnet IP or FQDN) that gets exposed to cluster
// workloads gets a separate prerouting chain created for it for each IP family
// of the chain's target addresses. Each service's prerouting chain contains one
// or more portmapping rules. A portmapping rule DNATs traffic received on a
// particular port to a port of the tailnet service. Creating a chain per
// service makes it easier to delete a service when no longer needed and helps
// with readability.

// EnsurePortMapRuleForSvc:
// - ensures that nat table exists
// - ensures that there is a prerouting chain for the given service and IP family of the target address in the nat table
// - ensures that there is a portmapping rule mathcing the given portmap (only creates the rule if it does not already exist)
func (n *nftablesRunner) EnsurePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error {
	t, ch, err := n.ensureChainForSvc(svc, targetIP)
	if err != nil {
		return fmt.Errorf("error ensuring chain for %s: %w", svc, err)
	}
	meta := svcPortMapRuleMeta(svc, targetIP, pm)
	rule, err := n.findRuleByMetadata(t, ch, meta)
	if err != nil {
		return fmt.Errorf("error looking up rule: %w", err)
	}
	if rule != nil {
		return nil
	}
	p, err := protoFromString(pm.Protocol)
	if err != nil {
		return fmt.Errorf("error converting protocol %s: %w", pm.Protocol, err)
	}

	rule = portMapRule(t, ch, tun, targetIP, pm.MatchPort, pm.TargetPort, p, meta)
	n.conn.InsertRule(rule)
	return n.conn.Flush()
}

// DeletePortMapRuleForSvc deletes a portmapping rule in the given service/IP family chain.
// It finds the matching rule using metadata attached to the rule.
// The caller is expected to call DeleteSvc if the whole service (the chain)
// needs to be deleted, so we don't deal with the case where this is the only
// rule in the chain here.
func (n *nftablesRunner) DeletePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error {
	table, err := n.getNFTByAddr(targetIP)
	if err != nil {
		return fmt.Errorf("error setting up nftables for IP family of %s: %w", targetIP, err)
	}
	t, err := getTableIfExists(n.conn, table.Proto, "nat")
	if err != nil {
		return fmt.Errorf("error checking if nat table exists: %w", err)
	}
	if t == nil {
		return nil
	}
	ch, err := getChainFromTable(n.conn, t, svc)
	if err != nil && !errors.Is(err, errorChainNotFound{t.Name, svc}) {
		return fmt.Errorf("error checking if chain %s exists: %w", svc, err)
	}
	if errors.Is(err, errorChainNotFound{t.Name, svc}) {
		return nil // service chain does not exist, so neither does the portmapping rule
	}
	meta := svcPortMapRuleMeta(svc, targetIP, pm)
	rule, err := n.findRuleByMetadata(t, ch, meta)
	if err != nil {
		return fmt.Errorf("error checking if rule exists: %w", err)
	}
	if rule == nil {
		return nil
	}
	if err := n.conn.DelRule(rule); err != nil {
		return fmt.Errorf("error deleting rule: %w", err)
	}
	return n.conn.Flush()
}

// DeleteSvc deletes the chains for the given service if any exist.
func (n *nftablesRunner) DeleteSvc(svc, tun string, targetIPs []netip.Addr, pm []PortMap) error {
	for _, tip := range targetIPs {
		table, err := n.getNFTByAddr(tip)
		if err != nil {
			return fmt.Errorf("error setting up nftables for IP family of %s: %w", tip, err)
		}
		t, err := getTableIfExists(n.conn, table.Proto, "nat")
		if err != nil {
			return fmt.Errorf("error checking if nat table exists: %w", err)
		}
		if t == nil {
			return nil
		}
		ch, err := getChainFromTable(n.conn, t, svc)
		if err != nil && !errors.Is(err, errorChainNotFound{t.Name, svc}) {
			return fmt.Errorf("error checking if chain %s exists: %w", svc, err)
		}
		if errors.Is(err, errorChainNotFound{t.Name, svc}) {
			return nil
		}
		n.conn.DelChain(ch)
	}
	return n.conn.Flush()
}

func portMapRule(t *nftables.Table, ch *nftables.Chain, tun string, targetIP netip.Addr, matchPort, targetPort uint16, proto uint8, meta []byte) *nftables.Rule {
	var fam uint32
	if targetIP.Is4() {
		fam = unix.NFPROTO_IPV4
	} else {
		fam = unix.NFPROTO_IPV6
	}
	rule := &nftables.Rule{
		Table:    t,
		Chain:    ch,
		UserData: meta,
		Exprs: []expr.Any{
			&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
			&expr.Cmp{
				Op:       expr.CmpOpNeq,
				Register: 1,
				Data:     []byte(tun),
			},
			&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
			&expr.Cmp{
				Op:       expr.CmpOpEq,
				Register: 1,
				Data:     []byte{proto},
			},
			&expr.Payload{
				DestRegister: 1,
				Base:         expr.PayloadBaseTransportHeader,
				Offset:       2,
				Len:          2,
			},
			&expr.Cmp{
				Op:       expr.CmpOpEq,
				Register: 1,
				Data:     binaryutil.BigEndian.PutUint16(matchPort),
			},
			&expr.Immediate{
				Register: 1,
				Data:     targetIP.AsSlice(),
			},
			&expr.Immediate{
				Register: 2,
				Data:     binaryutil.BigEndian.PutUint16(targetPort),
			},
			&expr.NAT{
				Type:        expr.NATTypeDestNAT,
				Family:      fam,
				RegAddrMin:  1,
				RegAddrMax:  1,
				RegProtoMin: 2,
				RegProtoMax: 2,
			},
		},
	}
	return rule
}

// svcPortMapRuleMeta generates metadata for a rule.
// This metadata can then be used to find the rule.
// https://github.com/google/nftables/issues/48
func svcPortMapRuleMeta(svcName string, targetIP netip.Addr, pm PortMap) []byte {
	return []byte(fmt.Sprintf("svc:%s,targetIP:%s:matchPort:%v,targetPort:%v,proto:%v", svcName, targetIP.String(), pm.MatchPort, pm.TargetPort, pm.Protocol))
}

func (n *nftablesRunner) findRuleByMetadata(t *nftables.Table, ch *nftables.Chain, meta []byte) (*nftables.Rule, error) {
	if n.conn == nil || t == nil || ch == nil || len(meta) == 0 {
		return nil, nil
	}
	rules, err := n.conn.GetRules(t, ch)
	if err != nil {
		return nil, fmt.Errorf("error listing rules: %w", err)
	}
	for _, rule := range rules {
		if reflect.DeepEqual(rule.UserData, meta) {
			return rule, nil
		}
	}
	return nil, nil
}

func (n *nftablesRunner) ensureChainForSvc(svc string, targetIP netip.Addr) (*nftables.Table, *nftables.Chain, error) {
	polAccept := nftables.ChainPolicyAccept
	table, err := n.getNFTByAddr(targetIP)
	if err != nil {
		return nil, nil, fmt.Errorf("error setting up nftables for IP family of %v: %w", targetIP, err)
	}
	nat, err := createTableIfNotExist(n.conn, table.Proto, "nat")
	if err != nil {
		return nil, nil, fmt.Errorf("error ensuring nat table: %w", err)
	}
	svcCh, err := getOrCreateChain(n.conn, chainInfo{
		table:         nat,
		name:          svc,
		chainType:     nftables.ChainTypeNAT,
		chainHook:     nftables.ChainHookPrerouting,
		chainPriority: nftables.ChainPriorityNATDest,
		chainPolicy:   &polAccept,
	})
	if err != nil {
		return nil, nil, fmt.Errorf("error ensuring prerouting chain: %w", err)
	}
	return nat, svcCh, nil
}

// // PortMap is the port mapping for a service rule.
type PortMap struct {
	// MatchPort is the local port to which the rule should apply.
	MatchPort uint16
	// TargetPort is the port to which the traffic should be forwarded.
	TargetPort uint16
	// Protocol is the protocol to match packets on. Only TCP and UDP are
	// supported.
	Protocol string
}

func protoFromString(s string) (uint8, error) {
	switch strings.ToLower(s) {
	case "tcp":
		return unix.IPPROTO_TCP, nil
	case "udp":
		return unix.IPPROTO_UDP, nil
	default:
		return 0, fmt.Errorf("unrecognized protocol: %q", s)
	}
}