// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package natlab

import (
	"context"
	"fmt"
	"net"
	"sync"
	"time"

	"inet.af/netaddr"
)

// mapping is the state of an allocated NAT session.
type mapping struct {
	lanSrc   netaddr.IPPort
	lanDst   netaddr.IPPort
	wanSrc   netaddr.IPPort
	deadline time.Time

	// pc is a PacketConn that reserves an outbound port on the NAT's
	// WAN interface. We do this because ListenPacket already has
	// random port selection logic built in. Additionally this means
	// that concurrent use of ListenPacket for connections originating
	// from the NAT box won't conflict with NAT mappings, since both
	// use PacketConn to reserve ports on the machine.
	pc net.PacketConn
}

// NATType is the mapping behavior of a NAT device. Values express
// different modes defined by RFC 4787.
type NATType int

const (
	// EndpointIndependentNAT specifies a destination endpoint
	// independent NAT. All traffic from a source ip:port gets mapped
	// to a single WAN ip:port.
	EndpointIndependentNAT NATType = iota
	// AddressDependentNAT specifies a destination address dependent
	// NAT. Every distinct destination IP gets its own WAN ip:port
	// allocation.
	AddressDependentNAT
	// AddressAndPortDependentNAT specifies a destination
	// address-and-port dependent NAT. Every distinct destination
	// ip:port gets its own WAN ip:port allocation.
	AddressAndPortDependentNAT
)

// natKey is the lookup key for a NAT session. While it contains a
// 4-tuple ({src,dst} {ip,port}), some NATTypes will zero out some
// fields, so in practice the key is either a 2-tuple (src only),
// 3-tuple (src ip+port and dst ip) or 4-tuple (src+dst ip+port).
type natKey struct {
	src, dst netaddr.IPPort
}

func (t NATType) key(src, dst netaddr.IPPort) natKey {
	k := natKey{src: src}
	switch t {
	case EndpointIndependentNAT:
	case AddressDependentNAT:
		k.dst = k.dst.WithIP(dst.IP())
	case AddressAndPortDependentNAT:
		k.dst = dst
	default:
		panic(fmt.Sprintf("unknown NAT type %v", t))
	}
	return k
}

// DefaultMappingTimeout is the default timeout for a NAT mapping.
const DefaultMappingTimeout = 30 * time.Second

// SNAT44 implements an IPv4-to-IPv4 source NAT (SNAT) translator, with
// optional builtin firewall.
type SNAT44 struct {
	// Machine is the machine to which this NAT is attached. Altered
	// packets are injected back into this Machine for processing.
	Machine *Machine
	// ExternalInterface is the "WAN" interface of Machine. Packets
	// from other sources get NATed onto this interface.
	ExternalInterface *Interface
	// Type specifies the mapping allocation behavior for this NAT.
	Type NATType
	// MappingTimeout is the lifetime of individual NAT sessions. Once
	// a session expires, the mapped port effectively "closes" to new
	// traffic. If MappingTimeout is 0, DefaultMappingTimeout is used.
	MappingTimeout time.Duration
	// Firewall is an optional packet handler that will be invoked as
	// a firewall during NAT translation. The firewall always sees
	// packets in their "LAN form", i.e. before translation in the
	// outbound direction and after translation in the inbound
	// direction.
	Firewall PacketHandler
	// TimeNow is a function that returns the current time. If
	// nil, time.Now is used.
	TimeNow func() time.Time

	mu    sync.Mutex
	byLAN map[natKey]*mapping         // lookup by outbound packet tuple
	byWAN map[netaddr.IPPort]*mapping // lookup by wan ip:port only
}

func (n *SNAT44) timeNow() time.Time {
	if n.TimeNow != nil {
		return n.TimeNow()
	}
	return time.Now()
}

func (n *SNAT44) mappingTimeout() time.Duration {
	if n.MappingTimeout == 0 {
		return DefaultMappingTimeout
	}
	return n.MappingTimeout
}

func (n *SNAT44) initLocked() {
	if n.byLAN == nil {
		n.byLAN = map[natKey]*mapping{}
		n.byWAN = map[netaddr.IPPort]*mapping{}
	}
	if n.ExternalInterface.Machine() != n.Machine {
		panic(fmt.Sprintf("NAT given interface %s that is not part of given machine %s", n.ExternalInterface, n.Machine.Name))
	}
}

func (n *SNAT44) HandleOut(p *Packet, oif *Interface) *Packet {
	// NATs don't affect locally originated packets.
	if n.Firewall != nil {
		return n.Firewall.HandleOut(p, oif)
	}
	return p
}

func (n *SNAT44) HandleIn(p *Packet, iif *Interface) *Packet {
	if iif != n.ExternalInterface {
		// NAT can't apply, defer to firewall.
		if n.Firewall != nil {
			return n.Firewall.HandleIn(p, iif)
		}
		return p
	}

	n.mu.Lock()
	defer n.mu.Unlock()
	n.initLocked()

	now := n.timeNow()
	mapping := n.byWAN[p.Dst]
	if mapping == nil || now.After(mapping.deadline) {
		// NAT didn't hit, defer to firewall or allow in for local
		// socket handling.
		if n.Firewall != nil {
			return n.Firewall.HandleIn(p, iif)
		}
		return p
	}

	p.Dst = mapping.lanSrc
	p.Trace("dnat to %v", p.Dst)
	// Don't process firewall here. We mutated the packet such that
	// it's no longer destined locally, so we'll get reinvoked as
	// HandleForward and need to process the altered packet there.
	return p
}

func (n *SNAT44) HandleForward(p *Packet, iif, oif *Interface) *Packet {
	switch {
	case oif == n.ExternalInterface:
		if p.Src.IP() == oif.V4() {
			// Packet already NATed and is just retraversing Forward,
			// don't touch it again.
			return p
		}

		if n.Firewall != nil {
			p2 := n.Firewall.HandleForward(p, iif, oif)
			if p2 == nil {
				// firewall dropped, done
				return nil
			}
			if !p.Equivalent(p2) {
				// firewall mutated packet? Weird, but okay.
				return p2
			}
		}

		n.mu.Lock()
		defer n.mu.Unlock()
		n.initLocked()

		k := n.Type.key(p.Src, p.Dst)
		now := n.timeNow()
		m := n.byLAN[k]
		if m == nil || now.After(m.deadline) {
			pc, wanAddr := n.allocateMappedPort()
			m = &mapping{
				lanSrc: p.Src,
				lanDst: p.Dst,
				wanSrc: wanAddr,
				pc:     pc,
			}
			n.byLAN[k] = m
			n.byWAN[wanAddr] = m
		}
		m.deadline = now.Add(n.mappingTimeout())
		p.Src = m.wanSrc
		p.Trace("snat from %v", p.Src)
		return p
	case iif == n.ExternalInterface:
		// Packet was already un-NAT-ed, we just need to either
		// firewall it or let it through.
		if n.Firewall != nil {
			return n.Firewall.HandleForward(p, iif, oif)
		}
		return p
	default:
		// No NAT applies, invoke firewall or drop.
		if n.Firewall != nil {
			return n.Firewall.HandleForward(p, iif, oif)
		}
		return nil
	}
}

func (n *SNAT44) allocateMappedPort() (net.PacketConn, netaddr.IPPort) {
	// Clean up old entries before trying to allocate, to free up any
	// expired ports.
	n.gc()

	ip := n.ExternalInterface.V4()
	pc, err := n.Machine.ListenPacket(context.Background(), "udp", net.JoinHostPort(ip.String(), "0"))
	if err != nil {
		panic(fmt.Sprintf("ran out of NAT ports: %v", err))
	}
	addr := netaddr.IPPortFrom(ip, uint16(pc.LocalAddr().(*net.UDPAddr).Port))
	return pc, addr
}

func (n *SNAT44) gc() {
	now := n.timeNow()
	for _, m := range n.byLAN {
		if !now.After(m.deadline) {
			continue
		}
		m.pc.Close()
		delete(n.byLAN, n.Type.key(m.lanSrc, m.lanDst))
		delete(n.byWAN, m.wanSrc)
	}
}