WIP, do not review

This commit is contained in:
Josh Bleecher Snyder 2021-03-29 13:47:12 -07:00
parent 0c5c16327d
commit bd4388df36
11 changed files with 76 additions and 98 deletions

View File

@ -39,7 +39,7 @@ func getVal() []interface{} {
Addresses: []netaddr.IPPrefix{{Bits: 5, IP: netaddr.IPFrom16([16]byte{3: 3})}}, Addresses: []netaddr.IPPrefix{{Bits: 5, IP: netaddr.IPFrom16([16]byte{3: 3})}},
Peers: []wgcfg.Peer{ Peers: []wgcfg.Peer{
{ {
Endpoints: "foo:5", Endpoints: wgcfg.Endpoints{HostPorts: []string{"foo:5"}},
}, },
}, },
}, },

View File

@ -35,7 +35,7 @@ var (
errDisabled = errors.New("magicsock: legacy networking disabled") errDisabled = errors.New("magicsock: legacy networking disabled")
) )
func (c *Conn) createLegacyEndpointLocked(pk key.Public, addrs string) (conn.Endpoint, error) { func (c *Conn) createLegacyEndpointLocked(pk key.Public, addrs []string) (conn.Endpoint, error) {
if c.disableLegacy { if c.disableLegacy {
return nil, errDisabled return nil, errDisabled
} }
@ -46,14 +46,12 @@ func (c *Conn) createLegacyEndpointLocked(pk key.Public, addrs string) (conn.End
curAddr: -1, curAddr: -1,
} }
if addrs != "" { for _, ep := range addrs {
for _, ep := range strings.Split(addrs, ",") { ipp, err := netaddr.ParseIPPort(ep)
ipp, err := netaddr.ParseIPPort(ep) if err != nil {
if err != nil { return nil, fmt.Errorf("bogus address %q", ep)
return nil, fmt.Errorf("bogus address %q", ep)
}
a.ipPorts = append(a.ipPorts, ipp)
} }
a.ipPorts = append(a.ipPorts, ipp)
} }
// If this endpoint is being updated, remember its old set of // If this endpoint is being updated, remember its old set of

View File

@ -11,6 +11,7 @@ import (
"context" "context"
crand "crypto/rand" crand "crypto/rand"
"encoding/binary" "encoding/binary"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"hash/fnv" "hash/fnv"
@ -27,7 +28,6 @@ import (
"time" "time"
"github.com/tailscale/wireguard-go/conn" "github.com/tailscale/wireguard-go/conn"
"go4.org/mem"
"golang.org/x/crypto/nacl/box" "golang.org/x/crypto/nacl/box"
"golang.org/x/time/rate" "golang.org/x/time/rate"
"inet.af/netaddr" "inet.af/netaddr"
@ -2755,17 +2755,23 @@ func (c *Conn) ParseEndpoint(keyAddrs string) (conn.Endpoint, error) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
c.logf("magicsock: ParseEndpoint: key=%s: %s", pk.ShortString(), derpStr(addrs)) var endpoints wgcfg.Endpoints
err := json.Unmarshal([]byte(addrs), &endpoints)
if !strings.HasSuffix(addrs, wgcfg.EndpointDiscoSuffix) {
return c.createLegacyEndpointLocked(pk, addrs)
}
discoHex := strings.TrimSuffix(addrs, wgcfg.EndpointDiscoSuffix)
discoKey, err := key.NewPublicFromHexMem(mem.S(discoHex))
if err != nil { if err != nil {
return nil, fmt.Errorf("magicsock: invalid discokey endpoint %q for %v: %w", addrs, pk.ShortString(), err) c.logf("[unexpected] magicsock: failed to parse addrs %q", addrs)
return nil, err
} }
if pk != key.Public(endpoints.PublicKey) {
c.logf("[unexpected] magicsock: incorrect public key in addrs, want %x, addrs is %q", pk, addrs)
return nil, errors.New("bad public key in CreateEndpoint")
}
c.logf("magicsock: CreateEndpoint: key=%s: %s", pk.ShortString(), derpStr(addrs))
discoKey := endpoints.DiscoKey
if discoKey.IsZero() {
return c.createLegacyEndpointLocked(pk, endpoints.HostPorts)
}
de := &discoEndpoint{ de := &discoEndpoint{
c: c, c: c,
publicKey: tailcfg.NodeKey(pk), // peer public key (for WireGuard + DERP) publicKey: tailcfg.NodeKey(pk), // peer public key (for WireGuard + DERP)

View File

@ -470,10 +470,14 @@ func makeConfigs(t *testing.T, addrs []netaddr.IPPort) []wgcfg.Config {
if peerNum == i { if peerNum == i {
continue continue
} }
publicKey := privKeys[peerNum].Public()
peer := wgcfg.Peer{ peer := wgcfg.Peer{
PublicKey: privKeys[peerNum].Public(), PublicKey: publicKey,
AllowedIPs: addresses[peerNum], AllowedIPs: addresses[peerNum],
Endpoints: addr.String(), Endpoints: wgcfg.Endpoints{
PublicKey: publicKey,
HostPorts: []string{addr.String()},
},
PersistentKeepalive: 25, PersistentKeepalive: 25,
} }
cfg.Peers = append(cfg.Peers, peer) cfg.Peers = append(cfg.Peers, peer)
@ -1060,12 +1064,12 @@ func testTwoDevicePing(t *testing.T, d *devices) {
}) })
// Add DERP relay. // Add DERP relay.
derpEp := "127.3.3.40:1" derpEp := []string{"127.3.3.40:1"}
ep0 := cfgs[0].Peers[0].Endpoints ep0 := cfgs[0].Peers[0].Endpoints
ep0 = derpEp + "," + ep0 ep0.HostPorts = append(derpEp[:1:1], ep0.HostPorts...)
cfgs[0].Peers[0].Endpoints = ep0 cfgs[0].Peers[0].Endpoints = ep0
ep1 := cfgs[1].Peers[0].Endpoints ep1 := cfgs[1].Peers[0].Endpoints
ep1 = derpEp + "," + ep1 ep1.HostPorts = append(derpEp[:1:1], ep1.HostPorts...)
cfgs[1].Peers[0].Endpoints = ep1 cfgs[1].Peers[0].Endpoints = ep1
if err := m1.Reconfig(&cfgs[0]); err != nil { if err := m1.Reconfig(&cfgs[0]); err != nil {
t.Fatal(err) t.Fatal(err)
@ -1081,8 +1085,8 @@ func testTwoDevicePing(t *testing.T, d *devices) {
}) })
// Disable real route. // Disable real route.
cfgs[0].Peers[0].Endpoints = derpEp cfgs[0].Peers[0].Endpoints.HostPorts = derpEp
cfgs[1].Peers[0].Endpoints = derpEp cfgs[1].Peers[0].Endpoints.HostPorts = derpEp
if err := m1.Reconfig(&cfgs[0]); err != nil { if err := m1.Reconfig(&cfgs[0]); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -1109,7 +1113,7 @@ func testTwoDevicePing(t *testing.T, d *devices) {
// Give one peer a non-DERP endpoint. We expect the other to // Give one peer a non-DERP endpoint. We expect the other to
// accept it via roamAddr. // accept it via roamAddr.
cfgs[0].Peers[0].Endpoints = ep0 cfgs[0].Peers[0].Endpoints = ep0
if ep2 := cfgs[1].Peers[0].Endpoints; len(ep2) != 1 { if ep2 := cfgs[1].Peers[0].Endpoints.HostPorts; len(ep2) != 1 {
t.Errorf("unexpected peer endpoints in dev2: %v", ep2) t.Errorf("unexpected peer endpoints in dev2: %v", ep2)
} }
if err := m2.Reconfig(&cfgs[1]); err != nil { if err := m2.Reconfig(&cfgs[1]); err != nil {
@ -1134,7 +1138,7 @@ func testTwoDevicePing(t *testing.T, d *devices) {
t.Fatal(err) t.Fatal(err)
} }
ep2 := cfg.Peers[0].Endpoints ep2 := cfg.Peers[0].Endpoints
if len(ep2) != 2 { if len(ep2.HostPorts) != 2 {
t.Error("handshake spray failed to find real route") t.Error("handshake spray failed to find real route")
} }
}) })

View File

@ -675,15 +675,7 @@ func isTrimmablePeer(p *wgcfg.Peer, numPeers int) bool {
if forceFullWireguardConfig(numPeers) { if forceFullWireguardConfig(numPeers) {
return false return false
} }
if !isSingleEndpoint(p.Endpoints) { if p.Endpoints.DiscoKey.IsZero() {
return false
}
host, _, err := net.SplitHostPort(p.Endpoints)
if err != nil {
return false
}
if !strings.HasSuffix(host, ".disco.tailscale") {
return false return false
} }
@ -753,26 +745,6 @@ func (e *userspaceEngine) isActiveSince(dk tailcfg.DiscoKey, ip netaddr.IP, t ti
return unixTime >= t.Unix() return unixTime >= t.Unix()
} }
// discoKeyFromPeer returns the DiscoKey for a wireguard config's Peer.
//
// Invariant: isTrimmablePeer(p) == true, so it should have 1 endpoint with
// Host of form "<64-hex-digits>.disco.tailscale". If invariant is violated,
// we return the zero value.
func discoKeyFromPeer(p *wgcfg.Peer) tailcfg.DiscoKey {
if len(p.Endpoints) < 64 {
return tailcfg.DiscoKey{}
}
host, rest := p.Endpoints[:64], p.Endpoints[64:]
if !strings.HasPrefix(rest, ".disco.tailscale") {
return tailcfg.DiscoKey{}
}
k, err := key.NewPublicFromHexMem(mem.S(host))
if err != nil {
return tailcfg.DiscoKey{}
}
return tailcfg.DiscoKey(k)
}
// discoChanged are the set of peers whose disco keys have changed, implying they've restarted. // discoChanged are the set of peers whose disco keys have changed, implying they've restarted.
// If a peer is in this set and was previously in the live wireguard config, // If a peer is in this set and was previously in the live wireguard config,
// it needs to be first removed and then re-added to flush out its wireguard session key. // it needs to be first removed and then re-added to flush out its wireguard session key.
@ -820,7 +792,7 @@ func (e *userspaceEngine) maybeReconfigWireguardLocked(discoChanged map[key.Publ
} }
continue continue
} }
dk := discoKeyFromPeer(p) dk := p.Endpoints.DiscoKey
trackDisco = append(trackDisco, dk) trackDisco = append(trackDisco, dk)
recentlyActive := false recentlyActive := false
for _, cidr := range p.AllowedIPs { for _, cidr := range p.AllowedIPs {
@ -992,19 +964,19 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config,
// and a second time with it. // and a second time with it.
discoChanged := make(map[key.Public]bool) discoChanged := make(map[key.Public]bool)
{ {
prevEP := make(map[key.Public]string) prevEP := make(map[key.Public]tailcfg.DiscoKey)
for i := range e.lastCfgFull.Peers { for i := range e.lastCfgFull.Peers {
if p := &e.lastCfgFull.Peers[i]; isSingleEndpoint(p.Endpoints) { if p := &e.lastCfgFull.Peers[i]; !p.Endpoints.DiscoKey.IsZero() {
prevEP[key.Public(p.PublicKey)] = p.Endpoints prevEP[key.Public(p.PublicKey)] = p.Endpoints.DiscoKey
} }
} }
for i := range cfg.Peers { for i := range cfg.Peers {
p := &cfg.Peers[i] p := &cfg.Peers[i]
if !isSingleEndpoint(p.Endpoints) { if p.Endpoints.DiscoKey.IsZero() {
continue continue
} }
pub := key.Public(p.PublicKey) pub := key.Public(p.PublicKey)
if old, ok := prevEP[pub]; ok && old != p.Endpoints { if old, ok := prevEP[pub]; ok && old != p.Endpoints.DiscoKey {
discoChanged[pub] = true discoChanged[pub] = true
e.logf("wgengine: Reconfig: %s changed from %q to %q", pub.ShortString(), old, p.Endpoints) e.logf("wgengine: Reconfig: %s changed from %q to %q", pub.ShortString(), old, p.Endpoints)
} }

View File

@ -104,7 +104,9 @@ func TestUserspaceEngineReconfig(t *testing.T) {
AllowedIPs: []netaddr.IPPrefix{ AllowedIPs: []netaddr.IPPrefix{
{IP: netaddr.IPv4(100, 100, 99, 1), Bits: 32}, {IP: netaddr.IPv4(100, 100, 99, 1), Bits: 32},
}, },
Endpoints: discoHex + ".disco.tailscale:12345", Endpoints: wgcfg.Endpoints{
DiscoKey: dkFromHex(discoHex),
},
}, },
}, },
} }

View File

@ -7,6 +7,7 @@ package wgcfg
import ( import (
"inet.af/netaddr" "inet.af/netaddr"
"tailscale.com/tailcfg"
) )
// EndpointDiscoSuffix is appended to the hex representation of a peer's discovery key // EndpointDiscoSuffix is appended to the hex representation of a peer's discovery key
@ -28,10 +29,17 @@ type Config struct {
type Peer struct { type Peer struct {
PublicKey Key PublicKey Key
AllowedIPs []netaddr.IPPrefix AllowedIPs []netaddr.IPPrefix
Endpoints string // comma-separated host/port pairs: "1.2.3.4:56,[::]:80" Endpoints Endpoints // comma-separated host/port pairs: "1.2.3.4:56,[::]:80"
PersistentKeepalive uint16 PersistentKeepalive uint16
} }
// TODO: HostPorts always sorted?
type Endpoints struct {
PublicKey Key `json:"pk"`
DiscoKey tailcfg.DiscoKey `json:"dk,omitempty"`
HostPorts []string `json:"hp,omitempty"`
}
// Copy makes a deep copy of Config. // Copy makes a deep copy of Config.
// The result aliases no memory with the original. // The result aliases no memory with the original.
func (cfg Config) Copy() Config { func (cfg Config) Copy() Config {

View File

@ -9,6 +9,7 @@ import (
"bytes" "bytes"
"io" "io"
"os" "os"
"reflect"
"sort" "sort"
"strings" "strings"
"sync" "sync"
@ -126,7 +127,7 @@ func TestDeviceConfig(t *testing.T) {
}) })
t.Run("device1 modify peer", func(t *testing.T) { t.Run("device1 modify peer", func(t *testing.T) {
cfg1.Peers[0].Endpoints = "1.2.3.4:12345" cfg1.Peers[0].Endpoints.HostPorts = []string{"1.2.3.4:12345"}
if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -134,7 +135,7 @@ func TestDeviceConfig(t *testing.T) {
}) })
t.Run("device1 replace endpoint", func(t *testing.T) { t.Run("device1 replace endpoint", func(t *testing.T) {
cfg1.Peers[0].Endpoints = "1.1.1.1:123" cfg1.Peers[0].Endpoints.HostPorts = []string{"1.1.1.1:123"}
if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -175,7 +176,7 @@ func TestDeviceConfig(t *testing.T) {
} }
peersEqual := func(p, q Peer) bool { peersEqual := func(p, q Peer) bool {
return p.PublicKey == q.PublicKey && p.PersistentKeepalive == q.PersistentKeepalive && return p.PublicKey == q.PublicKey && p.PersistentKeepalive == q.PersistentKeepalive &&
p.Endpoints == q.Endpoints && cidrsEqual(p.AllowedIPs, q.AllowedIPs) reflect.DeepEqual(p.Endpoints, q.Endpoints) && cidrsEqual(p.AllowedIPs, q.AllowedIPs)
} }
if !peersEqual(peer0(origCfg), peer0(newCfg)) { if !peersEqual(peer0(origCfg), peer0(newCfg)) {
t.Error("reconfig modified old peer") t.Error("reconfig modified old peer")

View File

@ -79,7 +79,10 @@ func WGCfg(nm *netmap.NetworkMap, logf logger.Logf, flags netmap.WGConfigFlags,
} }
if !peer.DiscoKey.IsZero() { if !peer.DiscoKey.IsZero() {
cpeer.Endpoints = fmt.Sprintf("%x.disco.tailscale:12345", peer.DiscoKey[:]) cpeer.Endpoints = wgcfg.Endpoints{
PublicKey: wgcfg.Key(peer.Key),
DiscoKey: peer.DiscoKey,
}
} else { } else {
if err := appendEndpoint(cpeer, peer.DERP); err != nil { if err := appendEndpoint(cpeer, peer.DERP); err != nil {
return nil, err return nil, err
@ -147,9 +150,6 @@ func appendEndpoint(peer *wgcfg.Peer, epStr string) error {
if err != nil { if err != nil {
return fmt.Errorf("invalid port in endpoint %q for peer %v", epStr, peer.PublicKey.ShortString()) return fmt.Errorf("invalid port in endpoint %q for peer %v", epStr, peer.PublicKey.ShortString())
} }
if peer.Endpoints != "" { peer.Endpoints.HostPorts = append(peer.Endpoints.HostPorts, epStr)
peer.Endpoints += ","
}
peer.Endpoints += epStr
return nil return nil
} }

View File

@ -7,6 +7,7 @@ package wgcfg
import ( import (
"bufio" "bufio"
"encoding/hex" "encoding/hex"
"encoding/json"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -25,6 +26,7 @@ func (e *ParseError) Error() string {
return fmt.Sprintf("%s: %q", e.why, e.offender) return fmt.Sprintf("%s: %q", e.why, e.offender)
} }
// TODO: delete??
func validateEndpoints(s string) error { func validateEndpoints(s string) error {
if s == "" { if s == "" {
// Otherwise strings.Split of the empty string produces [""]. // Otherwise strings.Split of the empty string produces [""].
@ -167,11 +169,10 @@ func (cfg *Config) handlePublicKeyLine(value string) (*Peer, error) {
func (cfg *Config) handlePeerLine(peer *Peer, key, value string) error { func (cfg *Config) handlePeerLine(peer *Peer, key, value string) error {
switch key { switch key {
case "endpoint": case "endpoint":
err := validateEndpoints(value) err := json.Unmarshal([]byte(value), &peer.Endpoints)
if err != nil { if err != nil {
return err return err
} }
peer.Endpoints = value
case "persistent_keepalive_interval": case "persistent_keepalive_interval":
n, err := strconv.ParseUint(value, 10, 16) n, err := strconv.ParseUint(value, 10, 16)
if err != nil { if err != nil {

View File

@ -5,11 +5,11 @@
package wgcfg package wgcfg
import ( import (
"encoding/json"
"fmt" "fmt"
"io" "io"
"sort" "reflect"
"strconv" "strconv"
"strings"
"inet.af/netaddr" "inet.af/netaddr"
) )
@ -52,8 +52,12 @@ func (cfg *Config) ToUAPI(w io.Writer, prev *Config) error {
setPeer(p) setPeer(p)
set("protocol_version", "1") set("protocol_version", "1")
if !endpointsEqual(oldPeer.Endpoints, p.Endpoints) { if !reflect.DeepEqual(oldPeer.Endpoints, p.Endpoints) {
set("endpoint", p.Endpoints) buf, err := json.Marshal(p.Endpoints)
if err != nil {
return err
}
set("endpoint", string(buf))
} }
// TODO: replace_allowed_ips is expensive. // TODO: replace_allowed_ips is expensive.
@ -89,24 +93,6 @@ func (cfg *Config) ToUAPI(w io.Writer, prev *Config) error {
return stickyErr return stickyErr
} }
func endpointsEqual(x, y string) bool {
// Cheap comparisons.
if x == y {
return true
}
xs := strings.Split(x, ",")
ys := strings.Split(y, ",")
if len(xs) != len(ys) {
return false
}
// Otherwise, see if they're the same, but out of order.
sort.Strings(xs)
sort.Strings(ys)
x = strings.Join(xs, ",")
y = strings.Join(ys, ",")
return x == y
}
func cidrsEqual(x, y []netaddr.IPPrefix) bool { func cidrsEqual(x, y []netaddr.IPPrefix) bool {
// TODO: re-implement using netaddr.IPSet.Equal. // TODO: re-implement using netaddr.IPSet.Equal.
if len(x) != len(y) { if len(x) != len(y) {