headscale/utils.go

167 lines
3.8 KiB
Go
Raw Normal View History

2020-06-21 12:32:08 +02:00
// Codehere is mostly taken from github.com/tailscale/tailscale
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package headscale
import (
"crypto/rand"
"encoding/json"
"fmt"
"io"
2021-08-13 10:33:19 +01:00
"strings"
2020-06-21 12:32:08 +02:00
"golang.org/x/crypto/nacl/box"
"inet.af/netaddr"
2021-08-13 10:33:19 +01:00
"tailscale.com/tailcfg"
2021-06-25 18:57:08 +02:00
"tailscale.com/types/wgkey"
2020-06-21 12:32:08 +02:00
)
2021-05-06 01:01:45 +02:00
// Error is used to compare errors as per https://dave.cheney.net/2016/04/07/constant-errors
2021-05-05 23:00:04 +02:00
type Error string
func (e Error) Error() string { return string(e) }
2021-06-25 18:57:08 +02:00
func decode(msg []byte, v interface{}, pubKey *wgkey.Key, privKey *wgkey.Private) error {
2020-06-21 12:32:08 +02:00
return decodeMsg(msg, v, pubKey, privKey)
}
2021-06-25 18:57:08 +02:00
func decodeMsg(msg []byte, v interface{}, pubKey *wgkey.Key, privKey *wgkey.Private) error {
2020-06-21 12:32:08 +02:00
decrypted, err := decryptMsg(msg, pubKey, privKey)
if err != nil {
return err
}
// fmt.Println(string(decrypted))
if err := json.Unmarshal(decrypted, v); err != nil {
return fmt.Errorf("response: %v", err)
}
return nil
}
2021-06-25 18:57:08 +02:00
func decryptMsg(msg []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, error) {
2020-06-21 12:32:08 +02:00
var nonce [24]byte
if len(msg) < len(nonce)+1 {
return nil, fmt.Errorf("response missing nonce, len=%d", len(msg))
}
copy(nonce[:], msg)
msg = msg[len(nonce):]
pub, pri := (*[32]byte)(pubKey), (*[32]byte)(privKey)
decrypted, ok := box.Open(nil, msg, &nonce, pub, pri)
if !ok {
return nil, fmt.Errorf("cannot decrypt response")
}
return decrypted, nil
}
2021-06-25 18:57:08 +02:00
func encode(v interface{}, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, error) {
2020-06-21 12:32:08 +02:00
b, err := json.Marshal(v)
if err != nil {
return nil, err
}
2021-08-13 10:33:19 +01:00
2020-06-21 12:32:08 +02:00
return encodeMsg(b, pubKey, privKey)
}
2021-06-25 18:57:08 +02:00
func encodeMsg(b []byte, pubKey *wgkey.Key, privKey *wgkey.Private) ([]byte, error) {
2020-06-21 12:32:08 +02:00
var nonce [24]byte
if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
panic(err)
}
pub, pri := (*[32]byte)(pubKey), (*[32]byte)(privKey)
msg := box.Seal(nonce[:], b, &nonce, pub, pri)
return msg, nil
}
func (h *Headscale) getAvailableIP() (*netaddr.IP, error) {
ipPrefix := h.cfg.IPPrefix
usedIps, err := h.getUsedIPs()
if err != nil {
return nil, err
}
// Get the first IP in our prefix
ip := ipPrefix.IP()
2020-06-21 12:32:08 +02:00
for {
if !ipPrefix.Contains(ip) {
return nil, fmt.Errorf("could not find any suitable IP in %s", ipPrefix)
2020-06-21 12:32:08 +02:00
}
// Some OS (including Linux) does not like when IPs ends with 0 or 255, which
// is typically called network or broadcast. Lets avoid them and continue
// to look when we get one of those traditionally reserved IPs.
ipRaw := ip.As4()
if ipRaw[3] == 0 || ipRaw[3] == 255 {
ip = ip.Next()
continue
}
if ip.IsZero() &&
ip.IsLoopback() {
ip = ip.Next()
continue
2020-06-21 12:32:08 +02:00
}
if !containsIPs(usedIps, ip) {
return &ip, nil
2020-06-21 12:32:08 +02:00
}
ip = ip.Next()
2020-06-21 12:32:08 +02:00
}
}
func (h *Headscale) getUsedIPs() ([]netaddr.IP, error) {
var addresses []string
h.db.Model(&Machine{}).Pluck("ip_address", &addresses)
ips := make([]netaddr.IP, len(addresses))
for index, addr := range addresses {
if addr != "" {
ip, err := netaddr.ParseIP(addr)
if err != nil {
return nil, fmt.Errorf("failed to parse ip from database, %w", err)
}
ips[index] = ip
}
}
return ips, nil
}
func containsIPs(ips []netaddr.IP, ip netaddr.IP) bool {
for _, v := range ips {
if v == ip {
return true
2020-06-21 12:32:08 +02:00
}
}
2021-02-21 22:11:27 +01:00
return false
2020-06-21 12:32:08 +02:00
}
2021-08-13 10:33:19 +01:00
func tailNodesToString(nodes []*tailcfg.Node) string {
temp := make([]string, len(nodes))
for index, node := range nodes {
temp[index] = node.Name
}
return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp))
}
func tailMapResponseToString(resp tailcfg.MapResponse) string {
return fmt.Sprintf("{ Node: %s, Peers: %s }", resp.Node.Name, tailNodesToString(resp.Peers))
}
2021-10-29 17:04:58 +00:00
func IsLocalhost(host string) bool {
if strings.Contains(host, LOCALHOST_V4) || strings.Contains(host, LOCALHOST_V6) {
return true
}
return false
}