// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package wgcfg

import (
	"bufio"
	"encoding/hex"
	"fmt"
	"io"
	"net"
	"strconv"
	"strings"

	"inet.af/netaddr"
)

type ParseError struct {
	why      string
	offender string
}

func (e *ParseError) Error() string {
	return fmt.Sprintf("%s: %q", e.why, e.offender)
}

func validateEndpoints(s string) error {
	if s == "" {
		// Otherwise strings.Split of the empty string produces [""].
		return nil
	}
	vals := strings.Split(s, ",")
	for _, val := range vals {
		_, _, err := parseEndpoint(val)
		if err != nil {
			return err
		}
	}
	return nil
}

func parseEndpoint(s string) (host string, port uint16, err error) {
	i := strings.LastIndexByte(s, ':')
	if i < 0 {
		return "", 0, &ParseError{"Missing port from endpoint", s}
	}
	host, portStr := s[:i], s[i+1:]
	if len(host) < 1 {
		return "", 0, &ParseError{"Invalid endpoint host", host}
	}
	uport, err := strconv.ParseUint(portStr, 10, 16)
	if err != nil {
		return "", 0, err
	}
	hostColon := strings.IndexByte(host, ':')
	if host[0] == '[' || host[len(host)-1] == ']' || hostColon > 0 {
		err := &ParseError{"Brackets must contain an IPv6 address", host}
		if len(host) > 3 && host[0] == '[' && host[len(host)-1] == ']' && hostColon > 0 {
			maybeV6 := net.ParseIP(host[1 : len(host)-1])
			if maybeV6 == nil || len(maybeV6) != net.IPv6len {
				return "", 0, err
			}
		} else {
			return "", 0, err
		}
		host = host[1 : len(host)-1]
	}
	return host, uint16(uport), nil
}

func parseKeyHex(s string) (*Key, error) {
	k, err := hex.DecodeString(s)
	if err != nil {
		return nil, &ParseError{"Invalid key: " + err.Error(), s}
	}
	if len(k) != KeySize {
		return nil, &ParseError{"Keys must decode to exactly 32 bytes", s}
	}
	var key Key
	copy(key[:], k)
	return &key, nil
}

// FromUAPI generates a Config from r.
// r should be generated by calling device.IpcGetOperation;
// it is not compatible with other uapi streams.
func FromUAPI(r io.Reader) (*Config, error) {
	cfg := new(Config)
	var peer *Peer // current peer being operated on
	deviceConfig := true

	scanner := bufio.NewScanner(r)
	for scanner.Scan() {
		line := scanner.Text()
		if line == "" {
			continue
		}
		parts := strings.Split(line, "=")
		if len(parts) != 2 {
			return nil, fmt.Errorf("failed to parse line %q, found %d =-separated parts, want 2", line, len(parts))
		}
		key := parts[0]
		value := parts[1]

		if key == "public_key" {
			if deviceConfig {
				deviceConfig = false
			}
			// Load/create the peer we are now configuring.
			var err error
			peer, err = cfg.handlePublicKeyLine(value)
			if err != nil {
				return nil, err
			}
			continue
		}

		var err error
		if deviceConfig {
			err = cfg.handleDeviceLine(key, value)
		} else {
			err = cfg.handlePeerLine(peer, key, value)
		}
		if err != nil {
			return nil, err
		}
	}

	if err := scanner.Err(); err != nil {
		return nil, err
	}

	return cfg, nil
}

func (cfg *Config) handleDeviceLine(key, value string) error {
	switch key {
	case "private_key":
		k, err := parseKeyHex(value)
		if err != nil {
			return err
		}
		// wireguard-go guarantees not to send zero value; private keys are already clamped.
		cfg.PrivateKey = PrivateKey(*k)
	case "listen_port":
		// ignore
	case "fwmark":
		// ignore
	default:
		return fmt.Errorf("unexpected IpcGetOperation key: %v", key)
	}
	return nil
}

func (cfg *Config) handlePublicKeyLine(value string) (*Peer, error) {
	k, err := parseKeyHex(value)
	if err != nil {
		return nil, err
	}
	cfg.Peers = append(cfg.Peers, Peer{})
	peer := &cfg.Peers[len(cfg.Peers)-1]
	peer.PublicKey = *k
	return peer, nil
}

func (cfg *Config) handlePeerLine(peer *Peer, key, value string) error {
	switch key {
	case "endpoint":
		err := validateEndpoints(value)
		if err != nil {
			return err
		}
		peer.Endpoints = value
	case "persistent_keepalive_interval":
		n, err := strconv.ParseUint(value, 10, 16)
		if err != nil {
			return err
		}
		peer.PersistentKeepalive = uint16(n)
	case "allowed_ip":
		ipp, err := netaddr.ParseIPPrefix(value)
		if err != nil {
			return err
		}
		peer.AllowedIPs = append(peer.AllowedIPs, ipp)
	case "protocol_version":
		if value != "1" {
			return fmt.Errorf("invalid protocol version: %v", value)
		}
	case "preshared_key", "last_handshake_time_sec", "last_handshake_time_nsec", "tx_bytes", "rx_bytes":
		// ignore
	default:
		return fmt.Errorf("unexpected IpcGetOperation key: %v", key)
	}
	return nil
}