// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause

// Package connstats maintains statistics about connections
// flowing through a TUN device (which operate at the IP layer).
package connstats

import (
	"context"
	"net/netip"
	"sync"
	"time"

	"golang.org/x/sync/errgroup"
	"tailscale.com/net/packet"
	"tailscale.com/net/tsaddr"
	"tailscale.com/types/netlogtype"
)

// Statistics maintains counters for every connection.
// All methods are safe for concurrent use.
// The zero value is ready for use.
type Statistics struct {
	maxConns int // immutable once set

	mu sync.Mutex
	connCnts

	connCntsCh  chan connCnts
	shutdownCtx context.Context
	shutdown    context.CancelFunc
	group       errgroup.Group
}

type connCnts struct {
	start    time.Time
	end      time.Time
	virtual  map[netlogtype.Connection]netlogtype.Counts
	physical map[netlogtype.Connection]netlogtype.Counts
}

// NewStatistics creates a data structure for tracking connection statistics
// that periodically dumps the virtual and physical connection counts
// depending on whether the maxPeriod or maxConns is exceeded.
// The dump function is called from a single goroutine.
// Shutdown must be called to cleanup resources.
func NewStatistics(maxPeriod time.Duration, maxConns int, dump func(start, end time.Time, virtual, physical map[netlogtype.Connection]netlogtype.Counts)) *Statistics {
	s := &Statistics{maxConns: maxConns}
	s.connCntsCh = make(chan connCnts, 256)
	s.shutdownCtx, s.shutdown = context.WithCancel(context.Background())
	s.group.Go(func() error {
		// TODO(joetsai): Using a ticker is problematic on mobile platforms
		// where waking up a process every maxPeriod when there is no activity
		// is a drain on battery life. Switch this instead to instead use
		// a time.Timer that is triggered upon network activity.
		ticker := new(time.Ticker)
		if maxPeriod > 0 {
			ticker = time.NewTicker(maxPeriod)
			defer ticker.Stop()
		}

		for {
			var cc connCnts
			select {
			case cc = <-s.connCntsCh:
			case <-ticker.C:
				cc = s.extract()
			case <-s.shutdownCtx.Done():
				cc = s.extract()
			}
			if len(cc.virtual)+len(cc.physical) > 0 && dump != nil {
				dump(cc.start, cc.end, cc.virtual, cc.physical)
			}
			if s.shutdownCtx.Err() != nil {
				return nil
			}
		}
	})
	return s
}

// UpdateTxVirtual updates the counters for a transmitted IP packet
// The source and destination of the packet directly correspond with
// the source and destination in netlogtype.Connection.
func (s *Statistics) UpdateTxVirtual(b []byte) {
	s.updateVirtual(b, false)
}

// UpdateRxVirtual updates the counters for a received IP packet.
// The source and destination of the packet are inverted with respect to
// the source and destination in netlogtype.Connection.
func (s *Statistics) UpdateRxVirtual(b []byte) {
	s.updateVirtual(b, true)
}

var (
	tailscaleServiceIPv4 = tsaddr.TailscaleServiceIP()
	tailscaleServiceIPv6 = tsaddr.TailscaleServiceIPv6()
)

func (s *Statistics) updateVirtual(b []byte, receive bool) {
	var p packet.Parsed
	p.Decode(b)
	conn := netlogtype.Connection{Proto: p.IPProto, Src: p.Src, Dst: p.Dst}
	if receive {
		conn.Src, conn.Dst = conn.Dst, conn.Src
	}

	// Network logging is defined as traffic between two Tailscale nodes.
	// Traffic with the internal Tailscale service is not with another node
	// and should not be logged. It also happens to be a high volume
	// amount of discrete traffic flows (e.g., DNS lookups).
	switch conn.Dst.Addr() {
	case tailscaleServiceIPv4, tailscaleServiceIPv6:
		return
	}

	s.mu.Lock()
	defer s.mu.Unlock()
	cnts, found := s.virtual[conn]
	if !found && !s.preInsertConn() {
		return
	}
	if receive {
		cnts.RxPackets++
		cnts.RxBytes += uint64(len(b))
	} else {
		cnts.TxPackets++
		cnts.TxBytes += uint64(len(b))
	}
	s.virtual[conn] = cnts
}

// UpdateTxPhysical updates the counters for a transmitted wireguard packet
// The src is always a Tailscale IP address, representing some remote peer.
// The dst is a remote IP address and port that corresponds
// with some physical peer backing the Tailscale IP address.
func (s *Statistics) UpdateTxPhysical(src netip.Addr, dst netip.AddrPort, n int) {
	s.updatePhysical(src, dst, n, false)
}

// UpdateRxPhysical updates the counters for a received wireguard packet.
// The src is always a Tailscale IP address, representing some remote peer.
// The dst is a remote IP address and port that corresponds
// with some physical peer backing the Tailscale IP address.
func (s *Statistics) UpdateRxPhysical(src netip.Addr, dst netip.AddrPort, n int) {
	s.updatePhysical(src, dst, n, true)
}

func (s *Statistics) updatePhysical(src netip.Addr, dst netip.AddrPort, n int, receive bool) {
	conn := netlogtype.Connection{Src: netip.AddrPortFrom(src, 0), Dst: dst}

	s.mu.Lock()
	defer s.mu.Unlock()
	cnts, found := s.physical[conn]
	if !found && !s.preInsertConn() {
		return
	}
	if receive {
		cnts.RxPackets++
		cnts.RxBytes += uint64(n)
	} else {
		cnts.TxPackets++
		cnts.TxBytes += uint64(n)
	}
	s.physical[conn] = cnts
}

// preInsertConn updates the maps to handle insertion of a new connection.
// It reports false if insertion is not allowed (i.e., after shutdown).
func (s *Statistics) preInsertConn() bool {
	// Check whether insertion of a new connection will exceed maxConns.
	if len(s.virtual)+len(s.physical) == s.maxConns && s.maxConns > 0 {
		// Extract the current statistics and send it to the serializer.
		// Avoid blocking the network packet handling path.
		select {
		case s.connCntsCh <- s.extractLocked():
		default:
			// TODO(joetsai): Log that we are dropping an entire connCounts.
		}
	}

	// Initialize the maps if nil.
	if s.virtual == nil && s.physical == nil {
		s.start = time.Now().UTC()
		s.virtual = make(map[netlogtype.Connection]netlogtype.Counts)
		s.physical = make(map[netlogtype.Connection]netlogtype.Counts)
	}

	return s.shutdownCtx.Err() == nil
}

func (s *Statistics) extract() connCnts {
	s.mu.Lock()
	defer s.mu.Unlock()
	return s.extractLocked()
}

func (s *Statistics) extractLocked() connCnts {
	if len(s.virtual)+len(s.physical) == 0 {
		return connCnts{}
	}
	s.end = time.Now().UTC()
	cc := s.connCnts
	s.connCnts = connCnts{}
	return cc
}

// TestExtract synchronously extracts the current network statistics map
// and resets the counters. This should only be used for testing purposes.
func (s *Statistics) TestExtract() (virtual, physical map[netlogtype.Connection]netlogtype.Counts) {
	cc := s.extract()
	return cc.virtual, cc.physical
}

// Shutdown performs a final flush of statistics.
// Statistics for any subsequent calls to Update will be dropped.
// It is safe to call Shutdown concurrently and repeatedly.
func (s *Statistics) Shutdown(context.Context) error {
	s.shutdown()
	return s.group.Wait()
}