mirror of
https://github.com/tailscale/tailscale.git
synced 2025-10-08 23:49:56 +00:00
Move Linux client & common packages into a public repo.
This commit is contained in:
594
control/controlclient/auto.go
Normal file
594
control/controlclient/auto.go
Normal file
@@ -0,0 +1,594 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package controlclient implements the client for the IPN control plane.
|
||||
//
|
||||
// It handles authentication, port picking, and collects the local
|
||||
// network configuration.
|
||||
package controlclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"tailscale.com/logger"
|
||||
"tailscale.com/logtail/backoff"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// TODO(apenwarr): eliminate the 'state' variable, as it's now obsolete.
|
||||
// It's used only by the unit tests.
|
||||
type state int
|
||||
|
||||
const (
|
||||
stateNew = state(iota)
|
||||
stateNotAuthenticated
|
||||
stateAuthenticating
|
||||
stateURLVisitRequired
|
||||
stateAuthenticated
|
||||
stateSynchronized // connected and received map update
|
||||
)
|
||||
|
||||
func (s state) MarshalText() ([]byte, error) {
|
||||
return []byte(s.String()), nil
|
||||
}
|
||||
|
||||
func (s state) String() string {
|
||||
switch s {
|
||||
case stateNew:
|
||||
return "state:new"
|
||||
case stateNotAuthenticated:
|
||||
return "state:not-authenticated"
|
||||
case stateAuthenticating:
|
||||
return "state:authenticating"
|
||||
case stateURLVisitRequired:
|
||||
return "state:url-visit-required"
|
||||
case stateAuthenticated:
|
||||
return "state:authenticated"
|
||||
case stateSynchronized:
|
||||
return "state:synchronized"
|
||||
default:
|
||||
return fmt.Sprintf("state:unknown:%d", int(s))
|
||||
}
|
||||
}
|
||||
|
||||
type Status struct {
|
||||
LoginFinished *struct{}
|
||||
Err string
|
||||
URL string
|
||||
Persist *Persist // locally persisted configuration
|
||||
NetMap *NetworkMap // server-pushed configuration
|
||||
Hostinfo tailcfg.Hostinfo // current Hostinfo data
|
||||
state state
|
||||
}
|
||||
|
||||
// Equal reports whether s and s2 are equal.
|
||||
func (s *Status) Equal(s2 *Status) bool {
|
||||
if s == nil && s2 == nil {
|
||||
return true
|
||||
}
|
||||
return s != nil && s2 != nil &&
|
||||
(s.LoginFinished == nil) == (s2.LoginFinished == nil) &&
|
||||
s.Err == s2.Err &&
|
||||
s.URL == s2.URL &&
|
||||
reflect.DeepEqual(s.Persist, s2.Persist) &&
|
||||
reflect.DeepEqual(s.NetMap, s2.NetMap) &&
|
||||
reflect.DeepEqual(s.Hostinfo, s2.Hostinfo) &&
|
||||
s.state == s2.state
|
||||
}
|
||||
|
||||
func (s Status) String() string {
|
||||
b, err := json.MarshalIndent(s, "", "\t")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return s.state.String() + " " + string(b)
|
||||
}
|
||||
|
||||
type LoginGoal struct {
|
||||
wantLoggedIn bool // true if we *want* to be logged in
|
||||
token *oauth2.Token // oauth token to use when logging in
|
||||
flags LoginFlags // flags to use when logging in
|
||||
url string // auth url that needs to be visited
|
||||
}
|
||||
|
||||
// Client connects to a tailcontrol server for a node.
|
||||
type Client struct {
|
||||
direct *Direct // our interface to the server APIs
|
||||
timeNow func() time.Time
|
||||
logf logger.Logf
|
||||
expiry *time.Time
|
||||
closed bool
|
||||
newMapCh chan struct{} // readable when we must restart a map request
|
||||
|
||||
mu sync.Mutex // mutex guards the following fields
|
||||
statusFunc func(Status) // called to update Client status
|
||||
|
||||
loggedIn bool // true if currently logged in
|
||||
loginGoal *LoginGoal // non-nil if some login activity is desired
|
||||
synced bool // true if our netmap is up-to-date
|
||||
hostinfo tailcfg.Hostinfo
|
||||
inPollNetMap bool // true if currently running a PollNetMap
|
||||
inSendStatus int // number of sendStatus calls currently in progress
|
||||
state state
|
||||
|
||||
authCtx context.Context // context used for auth requests
|
||||
mapCtx context.Context // context used for netmap requests
|
||||
authCancel func() // cancel the auth context
|
||||
mapCancel func() // cancel the netmap context
|
||||
quit chan struct{} // when closed, goroutines should all exit
|
||||
authDone chan struct{} // when closed, auth goroutine is done
|
||||
mapDone chan struct{} // when closed, map goroutine is done
|
||||
}
|
||||
|
||||
// New creates and starts a new Client.
|
||||
func New(opts Options) (*Client, error) {
|
||||
c, err := NewNoStart(opts)
|
||||
if c != nil {
|
||||
c.Start()
|
||||
}
|
||||
return c, err
|
||||
}
|
||||
|
||||
// NewNoStart creates a new Client, but without calling Start on it.
|
||||
func NewNoStart(opts Options) (*Client, error) {
|
||||
direct, err := NewDirect(opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c := &Client{
|
||||
direct: direct,
|
||||
timeNow: opts.TimeNow,
|
||||
logf: opts.Logf,
|
||||
newMapCh: make(chan struct{}, 1),
|
||||
quit: make(chan struct{}),
|
||||
authDone: make(chan struct{}),
|
||||
mapDone: make(chan struct{}),
|
||||
}
|
||||
c.authCtx, c.authCancel = context.WithCancel(context.Background())
|
||||
c.mapCtx, c.mapCancel = context.WithCancel(context.Background())
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Start starts the client's goroutines.
|
||||
//
|
||||
// It should only be called for clients created by NewNoStart.
|
||||
func (c *Client) Start() {
|
||||
go c.authRoutine()
|
||||
go c.mapRoutine()
|
||||
}
|
||||
|
||||
func (c *Client) cancelAuth() {
|
||||
c.mu.Lock()
|
||||
if c.authCancel != nil {
|
||||
c.authCancel()
|
||||
}
|
||||
if !c.closed {
|
||||
c.authCtx, c.authCancel = context.WithCancel(context.Background())
|
||||
}
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *Client) cancelMapLocked() {
|
||||
if c.mapCancel != nil {
|
||||
c.mapCancel()
|
||||
}
|
||||
if !c.closed {
|
||||
c.mapCtx, c.mapCancel = context.WithCancel(context.Background())
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) cancelMapUnsafely() {
|
||||
c.mu.Lock()
|
||||
c.cancelMapLocked()
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *Client) cancelMapSafely() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.logf("cancelMapSafely: synced=%v\n", c.synced)
|
||||
|
||||
if c.inPollNetMap == true {
|
||||
// received at least one netmap since the last
|
||||
// interruption. That means the server has already
|
||||
// fully processed our last request, which might
|
||||
// include UpdateEndpoints(). Interrupt it and try
|
||||
// again.
|
||||
c.cancelMapLocked()
|
||||
} else {
|
||||
// !synced means we either haven't done a netmap
|
||||
// request yet, or it hasn't answered yet. So the
|
||||
// server is in an undefined state. If we send
|
||||
// another netmap request too soon, it might race
|
||||
// with the last one, and if we're very unlucky,
|
||||
// the new request will be applied before the old one,
|
||||
// and the wrong endpoints will get registered. We
|
||||
// have to tell the client to abort politely, only
|
||||
// after it receives a response to its existing netmap
|
||||
// request.
|
||||
select {
|
||||
case c.newMapCh <- struct{}{}:
|
||||
c.logf("cancelMapSafely: wrote to channel\n")
|
||||
default:
|
||||
// if channel write failed, then there was already
|
||||
// an outstanding newMapCh request. One is enough,
|
||||
// since it'll always use the latest endpoints.
|
||||
c.logf("cancelMapSafely: channel was full\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) authRoutine() {
|
||||
defer close(c.authDone)
|
||||
bo := backoff.Backoff{Name: "authRoutine"}
|
||||
|
||||
for {
|
||||
c.mu.Lock()
|
||||
c.logf("authRoutine: %s\n", c.state)
|
||||
expiry := c.expiry
|
||||
goal := c.loginGoal
|
||||
ctx := c.authCtx
|
||||
synced := c.synced
|
||||
c.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-c.quit:
|
||||
c.logf("authRoutine: quit\n")
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
report := func(err error, msg string) {
|
||||
c.logf("%s: %v\n", msg, err)
|
||||
err = fmt.Errorf("%s: %v", msg, err)
|
||||
// don't send status updates for context errors,
|
||||
// since context cancelation is always on purpose.
|
||||
if ctx.Err() == nil {
|
||||
c.sendStatus("authRoutine1", err, "", nil)
|
||||
}
|
||||
}
|
||||
|
||||
if goal == nil {
|
||||
// Wait for something interesting to happen
|
||||
var exp <-chan time.Time
|
||||
if expiry != nil && !expiry.IsZero() {
|
||||
// if expiry is in the future, don't delay
|
||||
// past that time.
|
||||
// If it's in the past, then it's already
|
||||
// being handled by someone, so no need to
|
||||
// wake ourselves up again.
|
||||
now := c.timeNow()
|
||||
if expiry.Before(now) {
|
||||
delay := expiry.Sub(now)
|
||||
if delay > 5*time.Second {
|
||||
delay = time.Second
|
||||
}
|
||||
exp = time.After(delay)
|
||||
}
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.logf("authRoutine: context done.\n")
|
||||
case <-exp:
|
||||
// Unfortunately the key expiry isn't provided
|
||||
// by the control server until mapRequest.
|
||||
// So we have to do some hackery with c.expiry
|
||||
// in here.
|
||||
// TODO(apenwarr): add a key expiry field in RegisterResponse.
|
||||
c.logf("authRoutine: key expiration check.\n")
|
||||
if synced && expiry != nil && !expiry.IsZero() && expiry.Before(c.timeNow()) {
|
||||
c.logf("Key expired; setting loggedIn=false.")
|
||||
|
||||
c.mu.Lock()
|
||||
c.loginGoal = &LoginGoal{
|
||||
wantLoggedIn: c.loggedIn,
|
||||
}
|
||||
c.loggedIn = false
|
||||
c.expiry = nil
|
||||
c.mu.Unlock()
|
||||
}
|
||||
}
|
||||
} else if !goal.wantLoggedIn {
|
||||
err := c.direct.TryLogout(c.authCtx)
|
||||
if err != nil {
|
||||
report(err, "TryLogout")
|
||||
bo.BackOff(ctx, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// success
|
||||
c.mu.Lock()
|
||||
c.loggedIn = false
|
||||
c.loginGoal = nil
|
||||
c.state = stateNotAuthenticated
|
||||
c.synced = false
|
||||
c.mu.Unlock()
|
||||
|
||||
c.sendStatus("authRoutine2", nil, "", nil)
|
||||
bo.BackOff(ctx, nil)
|
||||
} else { // ie. goal.wantLoggedIn
|
||||
c.mu.Lock()
|
||||
if goal.url != "" {
|
||||
c.state = stateURLVisitRequired
|
||||
} else {
|
||||
c.state = stateAuthenticating
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
var url string
|
||||
var err error
|
||||
var f string
|
||||
if goal.url != "" {
|
||||
url, err = c.direct.WaitLoginURL(ctx, goal.url)
|
||||
f = "WaitLoginURL"
|
||||
} else {
|
||||
url, err = c.direct.TryLogin(ctx, goal.token, goal.flags)
|
||||
f = "TryLogin"
|
||||
}
|
||||
if err != nil {
|
||||
report(err, f)
|
||||
bo.BackOff(ctx, err)
|
||||
continue
|
||||
} else if url != "" {
|
||||
if goal.url != "" {
|
||||
err = fmt.Errorf("weird: server required a new url?")
|
||||
report(err, "WaitLoginURL")
|
||||
}
|
||||
goal.url = url
|
||||
goal.token = nil
|
||||
goal.flags = LoginDefault
|
||||
|
||||
c.mu.Lock()
|
||||
c.loginGoal = goal
|
||||
c.state = stateURLVisitRequired
|
||||
c.synced = false
|
||||
c.mu.Unlock()
|
||||
|
||||
c.sendStatus("authRoutine3", err, url, nil)
|
||||
bo.BackOff(ctx, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// success
|
||||
c.mu.Lock()
|
||||
c.loggedIn = true
|
||||
c.loginGoal = nil
|
||||
c.state = stateAuthenticated
|
||||
c.mu.Unlock()
|
||||
|
||||
c.sendStatus("authRoutine4", nil, "", nil)
|
||||
c.cancelMapSafely()
|
||||
bo.BackOff(ctx, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) mapRoutine() {
|
||||
defer close(c.mapDone)
|
||||
bo := backoff.Backoff{Name: "mapRoutine"}
|
||||
|
||||
for {
|
||||
c.mu.Lock()
|
||||
c.logf("mapRoutine: %s\n", c.state)
|
||||
loggedIn := c.loggedIn
|
||||
ctx := c.mapCtx
|
||||
c.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-c.quit:
|
||||
c.logf("mapRoutine: quit\n")
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
report := func(err error, msg string) {
|
||||
c.logf("%s: %v\n", msg, err)
|
||||
err = fmt.Errorf("%s: %v", msg, err)
|
||||
// don't send status updates for context errors,
|
||||
// since context cancelation is always on purpose.
|
||||
if ctx.Err() == nil {
|
||||
c.sendStatus("mapRoutine1", err, "", nil)
|
||||
}
|
||||
}
|
||||
|
||||
if !loggedIn {
|
||||
// Wait for something interesting to happen
|
||||
c.mu.Lock()
|
||||
c.synced = false
|
||||
// c.state is set by authRoutine()
|
||||
c.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.logf("mapRoutine: context done.\n")
|
||||
case <-c.newMapCh:
|
||||
c.logf("mapRoutine: new map needed while idle.\n")
|
||||
}
|
||||
} else {
|
||||
// Be sure this is false when we're not inside
|
||||
// PollNetMap, so that cancelMapSafely() can notify
|
||||
// us correctly.
|
||||
c.mu.Lock()
|
||||
c.inPollNetMap = false
|
||||
c.mu.Unlock()
|
||||
|
||||
err := c.direct.PollNetMap(ctx, -1, func(nm *NetworkMap) {
|
||||
c.mu.Lock()
|
||||
|
||||
select {
|
||||
case <-c.newMapCh:
|
||||
c.logf("mapRoutine: new map request during PollNetMap. canceling.\n")
|
||||
c.cancelMapLocked()
|
||||
|
||||
// Don't emit this netmap; we're
|
||||
// about to request a fresh one.
|
||||
c.mu.Unlock()
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
c.synced = true
|
||||
c.inPollNetMap = true
|
||||
if c.loggedIn {
|
||||
c.state = stateSynchronized
|
||||
}
|
||||
exp := nm.Expiry
|
||||
c.expiry = &exp
|
||||
stillAuthed := c.loggedIn
|
||||
state := c.state
|
||||
|
||||
c.mu.Unlock()
|
||||
|
||||
c.logf("mapRoutine: netmap received: %s\n", state)
|
||||
if stillAuthed {
|
||||
c.sendStatus("mapRoutine2", nil, "", nm)
|
||||
}
|
||||
})
|
||||
|
||||
c.mu.Lock()
|
||||
c.synced = false
|
||||
c.inPollNetMap = false
|
||||
if c.state == stateSynchronized {
|
||||
c.state = stateAuthenticated
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
report(err, "PollNetMap")
|
||||
bo.BackOff(ctx, err)
|
||||
continue
|
||||
}
|
||||
bo.BackOff(ctx, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) AuthCantContinue() bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
return !c.loggedIn && (c.loginGoal == nil || c.loginGoal.url != "")
|
||||
}
|
||||
|
||||
func (c *Client) SetStatusFunc(fn func(Status)) {
|
||||
c.mu.Lock()
|
||||
c.statusFunc = fn
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *Client) SetHostinfo(hi tailcfg.Hostinfo) {
|
||||
c.direct.SetHostinfo(hi)
|
||||
// Send new Hostinfo to server
|
||||
c.cancelMapSafely()
|
||||
}
|
||||
|
||||
func (c *Client) sendStatus(who string, err error, url string, nm *NetworkMap) {
|
||||
c.mu.Lock()
|
||||
state := c.state
|
||||
loggedIn := c.loggedIn
|
||||
synced := c.synced
|
||||
statusFunc := c.statusFunc
|
||||
hi := c.hostinfo
|
||||
c.inSendStatus++
|
||||
c.mu.Unlock()
|
||||
|
||||
c.logf("sendStatus: %s: %v\n", who, state)
|
||||
|
||||
var p *Persist
|
||||
var fin *struct{}
|
||||
if state == stateAuthenticated {
|
||||
fin = &struct{}{}
|
||||
}
|
||||
if nm != nil && loggedIn && synced {
|
||||
pp := c.direct.GetPersist()
|
||||
p = &pp
|
||||
} else {
|
||||
// don't send netmap status, as it's misleading when we're
|
||||
// not logged in.
|
||||
nm = nil
|
||||
}
|
||||
new := Status{
|
||||
LoginFinished: fin,
|
||||
URL: url,
|
||||
Persist: p,
|
||||
NetMap: nm,
|
||||
Hostinfo: hi,
|
||||
state: state,
|
||||
}
|
||||
if err != nil {
|
||||
new.Err = err.Error()
|
||||
}
|
||||
if statusFunc != nil {
|
||||
statusFunc(new)
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.inSendStatus--
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *Client) Login(t *oauth2.Token, flags LoginFlags) {
|
||||
c.logf("client.Login(%v, %v)\n", t != nil, flags)
|
||||
|
||||
c.mu.Lock()
|
||||
c.loginGoal = &LoginGoal{
|
||||
wantLoggedIn: true,
|
||||
token: t,
|
||||
flags: flags,
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
c.cancelAuth()
|
||||
}
|
||||
|
||||
func (c *Client) Logout() {
|
||||
c.logf("client.Logout()\n")
|
||||
|
||||
c.mu.Lock()
|
||||
c.loginGoal = &LoginGoal{
|
||||
wantLoggedIn: false,
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
c.cancelAuth()
|
||||
}
|
||||
|
||||
func (c *Client) UpdateEndpoints(localPort uint16, endpoints []string) {
|
||||
changed, err := c.direct.SetEndpoints(localPort, endpoints)
|
||||
if err != nil {
|
||||
c.sendStatus("updateEndpoints", err, "", nil)
|
||||
} else if changed {
|
||||
c.cancelMapSafely()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) Shutdown() {
|
||||
c.logf("client.Shutdown()\n")
|
||||
|
||||
c.mu.Lock()
|
||||
inSendStatus := c.inSendStatus
|
||||
closed := c.closed
|
||||
if !closed {
|
||||
c.closed = true
|
||||
c.statusFunc = nil
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
c.logf("client.Shutdown: inSendStatus=%v\n", inSendStatus)
|
||||
if !closed {
|
||||
close(c.quit)
|
||||
c.cancelAuth()
|
||||
<-c.authDone
|
||||
c.cancelMapUnsafely()
|
||||
<-c.mapDone
|
||||
c.logf("Client.Shutdown done.\n")
|
||||
}
|
||||
}
|
1107
control/controlclient/auto_test.go
Normal file
1107
control/controlclient/auto_test.go
Normal file
File diff suppressed because it is too large
Load Diff
68
control/controlclient/controlclient_test.go
Normal file
68
control/controlclient/controlclient_test.go
Normal file
@@ -0,0 +1,68 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package controlclient
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func fieldsOf(t reflect.Type) (fields []string) {
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
fields = append(fields, t.Field(i).Name)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func TestStatusEqual(t *testing.T) {
|
||||
// Verify that the Equal method stays in sync with reality
|
||||
equalHandles := []string{"LoginFinished", "Err", "URL", "Persist", "NetMap", "Hostinfo", "state"}
|
||||
if have := fieldsOf(reflect.TypeOf(Status{})); !reflect.DeepEqual(have, equalHandles) {
|
||||
t.Errorf("Status.Equal check might be out of sync\nfields: %q\nhandled: %q\n",
|
||||
have, equalHandles)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
a, b *Status
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
&Status{},
|
||||
nil,
|
||||
false,
|
||||
},
|
||||
{
|
||||
nil,
|
||||
&Status{},
|
||||
false,
|
||||
},
|
||||
{
|
||||
&Status{},
|
||||
&Status{},
|
||||
true,
|
||||
},
|
||||
{
|
||||
&Status{state: stateNew},
|
||||
&Status{state: stateNew},
|
||||
true,
|
||||
},
|
||||
{
|
||||
&Status{state: stateNew},
|
||||
&Status{state: stateAuthenticated},
|
||||
false,
|
||||
},
|
||||
{
|
||||
&Status{LoginFinished: nil},
|
||||
&Status{LoginFinished: new(struct{})},
|
||||
false,
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
got := tt.a.Equal(tt.b)
|
||||
if got != tt.want {
|
||||
t.Errorf("%d. Equal = %v; want %v", i, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
656
control/controlclient/direct.go
Normal file
656
control/controlclient/direct.go
Normal file
@@ -0,0 +1,656 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package controlclient
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/tailscale/wireguard-go/wgcfg"
|
||||
"golang.org/x/crypto/nacl/box"
|
||||
"golang.org/x/oauth2"
|
||||
"tailscale.com/logger"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/version"
|
||||
"tailscale.com/wgengine/filter"
|
||||
)
|
||||
|
||||
type Persist struct {
|
||||
PrivateMachineKey wgcfg.PrivateKey
|
||||
PrivateNodeKey wgcfg.PrivateKey
|
||||
OldPrivateNodeKey wgcfg.PrivateKey // needed to request key rotation
|
||||
Provider string
|
||||
LoginName string
|
||||
}
|
||||
|
||||
func (p *Persist) Pretty() string {
|
||||
var mk, ok, nk wgcfg.Key
|
||||
if !p.PrivateMachineKey.IsZero() {
|
||||
mk = *p.PrivateMachineKey.Public()
|
||||
}
|
||||
if !p.OldPrivateNodeKey.IsZero() {
|
||||
ok = *p.OldPrivateNodeKey.Public()
|
||||
}
|
||||
if !p.PrivateNodeKey.IsZero() {
|
||||
nk = *p.PrivateNodeKey.Public()
|
||||
}
|
||||
return fmt.Sprintf("Persist{m=%v, o=%v, n=%v u=%#v}",
|
||||
mk.ShortString(), ok.ShortString(), nk.ShortString(),
|
||||
p.LoginName)
|
||||
}
|
||||
|
||||
// Direct is the client that connects to a tailcontrol server for a node.
|
||||
type Direct struct {
|
||||
httpc *http.Client // HTTP client used to talk to tailcontrol
|
||||
serverURL string // URL of the tailcontrol server
|
||||
timeNow func() time.Time
|
||||
newDecompressor func() (Decompressor, error)
|
||||
keepAlive bool
|
||||
logf logger.Logf
|
||||
|
||||
mu sync.Mutex // mutex guards the following fields
|
||||
serverKey wgcfg.Key
|
||||
persist Persist
|
||||
tryingNewKey wgcfg.PrivateKey
|
||||
expiry *time.Time
|
||||
hostinfo tailcfg.Hostinfo
|
||||
endpoints []string
|
||||
localPort uint16
|
||||
cmdCh chan interface{}
|
||||
doneCh chan struct{}
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
Persist Persist // initial persistent data
|
||||
HTTPC *http.Client // HTTP client used to talk to tailcontrol
|
||||
ServerURL string // URL of the tailcontrol server
|
||||
TimeNow func() time.Time // time.Now implementation used by Client
|
||||
Hostinfo *tailcfg.Hostinfo
|
||||
NewDecompressor func() (Decompressor, error)
|
||||
KeepAlive bool
|
||||
Logf logger.Logf
|
||||
}
|
||||
|
||||
type Decompressor interface {
|
||||
DecodeAll(input, dst []byte) ([]byte, error)
|
||||
Close()
|
||||
}
|
||||
|
||||
// NewDirect returns a new Direct client.
|
||||
func NewDirect(opts Options) (*Direct, error) {
|
||||
if opts.ServerURL == "" {
|
||||
return nil, errors.New("controlclient.New: no server URL specified")
|
||||
}
|
||||
opts.ServerURL = strings.TrimRight(opts.ServerURL, "/")
|
||||
if opts.HTTPC == nil {
|
||||
opts.HTTPC = http.DefaultClient
|
||||
}
|
||||
if opts.TimeNow == nil {
|
||||
opts.TimeNow = time.Now
|
||||
}
|
||||
if opts.Logf == nil {
|
||||
// TODO(apenwarr): remove this default and fail instead.
|
||||
opts.Logf = log.Printf
|
||||
}
|
||||
|
||||
c := &Direct{
|
||||
httpc: opts.HTTPC,
|
||||
serverURL: opts.ServerURL,
|
||||
timeNow: opts.TimeNow,
|
||||
logf: opts.Logf,
|
||||
newDecompressor: opts.NewDecompressor,
|
||||
keepAlive: opts.KeepAlive,
|
||||
persist: opts.Persist,
|
||||
}
|
||||
if opts.Hostinfo == nil {
|
||||
c.SetHostinfo(NewHostinfo())
|
||||
} else {
|
||||
c.SetHostinfo(*opts.Hostinfo)
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func NewHostinfo() tailcfg.Hostinfo {
|
||||
hostname, _ := os.Hostname()
|
||||
os := runtime.GOOS
|
||||
switch os {
|
||||
case "darwin":
|
||||
switch runtime.GOARCH {
|
||||
case "arm", "arm64":
|
||||
os = "iOS"
|
||||
default:
|
||||
os = "macOS"
|
||||
}
|
||||
}
|
||||
|
||||
return tailcfg.Hostinfo{
|
||||
IPNVersion: version.LONG,
|
||||
Hostname: hostname,
|
||||
OS: os,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Direct) SetHostinfo(hi tailcfg.Hostinfo) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.logf("Hostinfo: %v\n", hi)
|
||||
c.hostinfo = hi
|
||||
}
|
||||
|
||||
func (c *Direct) GetPersist() Persist {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.persist
|
||||
}
|
||||
|
||||
type LoginFlags int
|
||||
|
||||
const (
|
||||
LoginDefault = LoginFlags(0)
|
||||
LoginInteractive = LoginFlags(1 << iota) // force user login and key refresh
|
||||
)
|
||||
|
||||
func (c *Direct) TryLogout(ctx context.Context) error {
|
||||
c.logf("direct.TryLogout()\n")
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.persist.PrivateNodeKey != (wgcfg.PrivateKey{}) {
|
||||
// TODO(crawshaw): Tell the server. This node key should be immediately invalidated.
|
||||
}
|
||||
c.persist = Persist{
|
||||
PrivateMachineKey: c.persist.PrivateMachineKey,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Direct) TryLogin(ctx context.Context, t *oauth2.Token, flags LoginFlags) (url string, err error) {
|
||||
c.logf("direct.TryLogin(%v, %v)\n", t != nil, flags)
|
||||
return c.doLoginOrRegen(ctx, t, flags, false, "")
|
||||
}
|
||||
|
||||
func (c *Direct) WaitLoginURL(ctx context.Context, url string) (newUrl string, err error) {
|
||||
c.logf("direct.WaitLoginURL\n")
|
||||
return c.doLoginOrRegen(ctx, nil, LoginDefault, false, url)
|
||||
}
|
||||
|
||||
func (c *Direct) doLoginOrRegen(ctx context.Context, t *oauth2.Token, flags LoginFlags, regen bool, url string) (newUrl string, err error) {
|
||||
mustregen, url, err := c.doLogin(ctx, t, flags, regen, url)
|
||||
if err != nil {
|
||||
return url, err
|
||||
}
|
||||
if mustregen {
|
||||
_, url, err = c.doLogin(ctx, t, flags, true, url)
|
||||
}
|
||||
return url, err
|
||||
}
|
||||
|
||||
func (c *Direct) doLogin(ctx context.Context, t *oauth2.Token, flags LoginFlags, regen bool, url string) (mustregen bool, newurl string, err error) {
|
||||
c.mu.Lock()
|
||||
persist := c.persist
|
||||
tryingNewKey := c.tryingNewKey
|
||||
serverKey := c.serverKey
|
||||
expired := c.expiry != nil && !c.expiry.IsZero() && c.expiry.Before(c.timeNow())
|
||||
c.mu.Unlock()
|
||||
|
||||
if persist.PrivateMachineKey == (wgcfg.PrivateKey{}) {
|
||||
c.logf("Generating a new machinekey.\n")
|
||||
mkey, err := wgcfg.NewPrivateKey()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
persist.PrivateMachineKey = *mkey
|
||||
}
|
||||
|
||||
if expired {
|
||||
c.logf("Old key expired -> regen=true\n")
|
||||
regen = true
|
||||
}
|
||||
if (flags & LoginInteractive) != 0 {
|
||||
c.logf("LoginInteractive -> regen=true\n")
|
||||
regen = true
|
||||
}
|
||||
|
||||
c.logf("doLogin(regen=%v, hasUrl=%v)\n", regen, url != "")
|
||||
if serverKey == (wgcfg.Key{}) {
|
||||
var err error
|
||||
serverKey, err = loadServerKey(ctx, c.httpc, c.serverURL)
|
||||
if err != nil {
|
||||
return regen, url, err
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.serverKey = serverKey
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
var oldNodeKey wgcfg.Key
|
||||
if url != "" {
|
||||
} else if regen || persist.PrivateNodeKey == (wgcfg.PrivateKey{}) {
|
||||
c.logf("Generating a new nodekey.\n")
|
||||
persist.OldPrivateNodeKey = persist.PrivateNodeKey
|
||||
key, err := wgcfg.NewPrivateKey()
|
||||
if err != nil {
|
||||
c.logf("login keygen: %v", err)
|
||||
return regen, url, err
|
||||
}
|
||||
tryingNewKey = *key
|
||||
} else {
|
||||
// Try refreshing the current key first
|
||||
tryingNewKey = persist.PrivateNodeKey
|
||||
}
|
||||
if persist.OldPrivateNodeKey != (wgcfg.PrivateKey{}) {
|
||||
oldNodeKey = *persist.OldPrivateNodeKey.Public()
|
||||
}
|
||||
|
||||
if tryingNewKey == (wgcfg.PrivateKey{}) {
|
||||
log.Fatalf("tryingNewKey is empty, give up\n")
|
||||
}
|
||||
if c.hostinfo.BackendLogID == "" {
|
||||
err = errors.New("hostinfo: BackendLogID missing")
|
||||
return regen, url, err
|
||||
}
|
||||
request := tailcfg.RegisterRequest{
|
||||
Version: 1,
|
||||
OldNodeKey: tailcfg.NodeKey(oldNodeKey),
|
||||
NodeKey: tailcfg.NodeKey(*tryingNewKey.Public()),
|
||||
Hostinfo: c.hostinfo,
|
||||
Followup: url,
|
||||
}
|
||||
c.logf("RegisterReq: onode=%v node=%v fup=%v\n",
|
||||
request.OldNodeKey.AbbrevString(),
|
||||
request.NodeKey.AbbrevString(), url != "")
|
||||
request.Auth.Oauth2Token = t
|
||||
request.Auth.Provider = persist.Provider
|
||||
request.Auth.LoginName = persist.LoginName
|
||||
bodyData, err := encode(request, &serverKey, &persist.PrivateMachineKey)
|
||||
if err != nil {
|
||||
return regen, url, err
|
||||
}
|
||||
body := bytes.NewReader(bodyData)
|
||||
|
||||
u := fmt.Sprintf("%s/machine/%s", c.serverURL, persist.PrivateMachineKey.Public().HexString())
|
||||
req, err := http.NewRequest("POST", u, body)
|
||||
if err != nil {
|
||||
return regen, url, err
|
||||
}
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
res, err := c.httpc.Do(req)
|
||||
if err != nil {
|
||||
return regen, url, fmt.Errorf("register request: %v", err)
|
||||
}
|
||||
c.logf("RegisterReq: returned.\n")
|
||||
resp := tailcfg.RegisterResponse{}
|
||||
if err := decode(res, &resp, &serverKey, &persist.PrivateMachineKey); err != nil {
|
||||
return regen, url, fmt.Errorf("register request: %v", err)
|
||||
}
|
||||
|
||||
if resp.NodeKeyExpired {
|
||||
if regen {
|
||||
return true, "", fmt.Errorf("weird: regen=true but server says NodeKeyExpired: %v", request.NodeKey)
|
||||
}
|
||||
c.logf("server reports new node key %v has expired",
|
||||
request.NodeKey.AbbrevString())
|
||||
return true, "", nil
|
||||
}
|
||||
if persist.Provider == "" {
|
||||
persist.Provider = resp.Login.Provider
|
||||
}
|
||||
if persist.LoginName == "" {
|
||||
persist.LoginName = resp.Login.LoginName
|
||||
}
|
||||
|
||||
// TODO(crawshaw): RegisterResponse should be able to mechanically
|
||||
// communicate some extra instructions from the server:
|
||||
// - new node key required
|
||||
// - machine key no longer supported
|
||||
// - user is disabled
|
||||
|
||||
if resp.AuthURL != "" {
|
||||
c.logf("AuthURL is %.20v...\n", resp.AuthURL)
|
||||
} else {
|
||||
c.logf("No AuthURL\n")
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
if resp.AuthURL == "" {
|
||||
// key rotation is complete
|
||||
persist.PrivateNodeKey = tryingNewKey
|
||||
} else {
|
||||
// save it for the retry-with-URL
|
||||
c.tryingNewKey = tryingNewKey
|
||||
}
|
||||
c.persist = persist
|
||||
c.mu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
return regen, "", err
|
||||
}
|
||||
if ctx.Err() != nil {
|
||||
return regen, "", ctx.Err()
|
||||
}
|
||||
return false, resp.AuthURL, nil
|
||||
}
|
||||
|
||||
func sameStrings(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Direct) newEndpoints(localPort uint16, endpoints []string) bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
// Nothing new?
|
||||
if c.localPort == localPort && sameStrings(c.endpoints, endpoints) {
|
||||
return false // unchanged
|
||||
}
|
||||
c.logf("client.newEndpoints(%v, %v)\n", localPort, endpoints)
|
||||
if len(c.endpoints) > 0 {
|
||||
// empty the old list without deallocating it
|
||||
c.endpoints = c.endpoints[:0]
|
||||
}
|
||||
c.localPort = localPort
|
||||
c.endpoints = append(c.endpoints, endpoints...)
|
||||
return true // changed
|
||||
}
|
||||
|
||||
// SetEndpoints updates the list of locally advertised endpoints.
|
||||
// It won't be replicated to the server until a *fresh* call to PollNetMap().
|
||||
// You don't need to restart PollNetMap if we return changed==false.
|
||||
func (c *Direct) SetEndpoints(localPort uint16, endpoints []string) (changed bool, err error) {
|
||||
// (no log message on function entry, because it clutters the logs
|
||||
// if endpoints haven't changed. newEndpoints() will log it.)
|
||||
changed = c.newEndpoints(localPort, endpoints)
|
||||
return changed, nil
|
||||
}
|
||||
|
||||
func (c *Direct) PollNetMap(ctx context.Context, maxPolls int, cb func(*NetworkMap)) error {
|
||||
c.mu.Lock()
|
||||
persist := c.persist
|
||||
serverURL := c.serverURL
|
||||
serverKey := c.serverKey
|
||||
hostinfo := c.hostinfo
|
||||
localPort := c.localPort
|
||||
ep := append([]string(nil), c.endpoints...)
|
||||
c.mu.Unlock()
|
||||
|
||||
if hostinfo.BackendLogID == "" {
|
||||
return errors.New("hostinfo: BackendLogID missing")
|
||||
}
|
||||
|
||||
allowStream := maxPolls != 1
|
||||
c.logf("PollNetMap: stream=%v :%v %v\n", maxPolls, localPort, ep)
|
||||
|
||||
request := tailcfg.MapRequest{
|
||||
Version: 4,
|
||||
KeepAlive: c.keepAlive,
|
||||
NodeKey: tailcfg.NodeKey(*persist.PrivateNodeKey.Public()),
|
||||
Endpoints: ep,
|
||||
Stream: allowStream,
|
||||
Hostinfo: hostinfo,
|
||||
}
|
||||
if c.newDecompressor != nil {
|
||||
request.Compress = "zstd"
|
||||
}
|
||||
|
||||
bodyData, err := encode(request, &serverKey, &persist.PrivateMachineKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
u := fmt.Sprintf("%s/machine/%s/map", serverURL, persist.PrivateMachineKey.Public().HexString())
|
||||
req, err := http.NewRequest("POST", u, bytes.NewReader(bodyData))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
res, err := c.httpc.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if res.StatusCode != 200 {
|
||||
msg, _ := ioutil.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
return fmt.Errorf("initial fetch failed %d: %s",
|
||||
res.StatusCode, strings.TrimSpace(string(msg)))
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
// If we go more than pollTimeout without hearing from the server,
|
||||
// end the long poll. We should be receiving a keep alive ping
|
||||
// every minute.
|
||||
const pollTimeout = 120 * time.Second
|
||||
timeout := time.NewTimer(pollTimeout)
|
||||
timeoutReset := make(chan struct{})
|
||||
defer close(timeoutReset)
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-timeout.C:
|
||||
c.logf("map response long-poll timed out!")
|
||||
cancel()
|
||||
return
|
||||
case _, ok := <-timeoutReset:
|
||||
if !ok {
|
||||
return // channel closed, shut down goroutine
|
||||
}
|
||||
if !timeout.Stop() {
|
||||
<-timeout.C
|
||||
}
|
||||
timeout.Reset(pollTimeout)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// If allowStream, then the server will use an HTTP long poll to
|
||||
// return incremental results. There is always one response right
|
||||
// away, followed by a delay, and eventually others.
|
||||
// If !allowStream, it'll still send the first result in exactly
|
||||
// the same format before just closing the connection.
|
||||
// We can use this same read loop either way.
|
||||
var msg []byte
|
||||
for i := 0; i < maxPolls || maxPolls < 0; i++ {
|
||||
var siz [4]byte
|
||||
if _, err := io.ReadFull(res.Body, siz[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
size := binary.LittleEndian.Uint32(siz[:])
|
||||
msg = append(msg[:0], make([]byte, size)...)
|
||||
if _, err := io.ReadFull(res.Body, msg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var resp tailcfg.MapResponse
|
||||
|
||||
// Default filter if the key is missing from the incoming
|
||||
// json (ie. old tailcontrol server without PacketFilter
|
||||
// support). If even an empty PacketFilter is provided, this
|
||||
// will be overwritten.
|
||||
// TODO(apenwarr 2020-02-01): remove after tailcontrol is fully deployed.
|
||||
resp.PacketFilter = filter.MatchAllowAll
|
||||
|
||||
if err := c.decodeMsg(msg, &resp); err != nil {
|
||||
return err
|
||||
}
|
||||
if resp.KeepAlive {
|
||||
c.logf("map response keep alive received")
|
||||
timeoutReset <- struct{}{}
|
||||
continue
|
||||
}
|
||||
|
||||
nm := &NetworkMap{
|
||||
NodeKey: tailcfg.NodeKey(*persist.PrivateNodeKey.Public()),
|
||||
PrivateKey: persist.PrivateNodeKey,
|
||||
Expiry: resp.Node.KeyExpiry,
|
||||
Addresses: resp.Node.Addresses,
|
||||
Peers: resp.Peers,
|
||||
LocalPort: localPort,
|
||||
User: resp.Node.User,
|
||||
UserProfiles: make(map[tailcfg.UserID]tailcfg.UserProfile),
|
||||
Domain: resp.Domain,
|
||||
Roles: resp.Roles,
|
||||
DNS: resp.DNS,
|
||||
DNSDomains: resp.SearchPaths,
|
||||
Hostinfo: resp.Node.Hostinfo,
|
||||
PacketFilter: resp.PacketFilter,
|
||||
}
|
||||
for _, profile := range resp.UserProfiles {
|
||||
nm.UserProfiles[profile.ID] = profile
|
||||
}
|
||||
if resp.Node.MachineAuthorized {
|
||||
nm.MachineStatus = tailcfg.MachineAuthorized
|
||||
} else {
|
||||
nm.MachineStatus = tailcfg.MachineUnauthorized
|
||||
}
|
||||
//c.logf("new network map[%d]:\n%s", i, nm.Concise())
|
||||
|
||||
c.mu.Lock()
|
||||
c.expiry = &nm.Expiry
|
||||
c.mu.Unlock()
|
||||
|
||||
cb(nm)
|
||||
}
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func decode(res *http.Response, v interface{}, serverKey *wgcfg.Key, mkey *wgcfg.PrivateKey) error {
|
||||
defer res.Body.Close()
|
||||
msg, err := ioutil.ReadAll(io.LimitReader(res.Body, 1<<20))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if res.StatusCode != 200 {
|
||||
return fmt.Errorf("%d: %v", res.StatusCode, string(msg))
|
||||
}
|
||||
return decodeMsg(msg, v, serverKey, mkey)
|
||||
}
|
||||
|
||||
func (c *Direct) decodeMsg(msg []byte, v interface{}) error {
|
||||
mkey := c.persist.PrivateMachineKey
|
||||
serverKey := c.serverKey
|
||||
|
||||
decrypted, err := decryptMsg(msg, &serverKey, &mkey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var b []byte
|
||||
if c.newDecompressor == nil {
|
||||
b = decrypted
|
||||
} else {
|
||||
//decoder, err := zstd.NewReader(nil)
|
||||
decoder, err := c.newDecompressor()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer decoder.Close()
|
||||
b, err = decoder.DecodeAll(decrypted, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := json.Unmarshal(b, v); err != nil {
|
||||
return fmt.Errorf("response: %v", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func decodeMsg(msg []byte, v interface{}, serverKey *wgcfg.Key, mkey *wgcfg.PrivateKey) error {
|
||||
decrypted, err := decryptMsg(msg, serverKey, mkey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := json.Unmarshal(decrypted, v); err != nil {
|
||||
return fmt.Errorf("response: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func decryptMsg(msg []byte, serverKey *wgcfg.Key, mkey *wgcfg.PrivateKey) ([]byte, error) {
|
||||
var nonce [24]byte
|
||||
if len(msg) < len(nonce)+1 {
|
||||
return nil, fmt.Errorf("response missing nonce, len=%d", len(msg))
|
||||
}
|
||||
copy(nonce[:], msg)
|
||||
msg = msg[len(nonce):]
|
||||
|
||||
pub, pri := (*[32]byte)(serverKey), (*[32]byte)(mkey)
|
||||
decrypted, ok := box.Open(nil, msg, &nonce, pub, pri)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("cannot decrypt response")
|
||||
}
|
||||
return decrypted, nil
|
||||
}
|
||||
|
||||
func encode(v interface{}, serverKey *wgcfg.Key, mkey *wgcfg.PrivateKey) ([]byte, error) {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var nonce [24]byte
|
||||
if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
pub, pri := (*[32]byte)(serverKey), (*[32]byte)(mkey)
|
||||
msg := box.Seal(nonce[:], b, &nonce, pub, pri)
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func loadServerKey(ctx context.Context, httpc *http.Client, serverURL string) (wgcfg.Key, error) {
|
||||
req, err := http.NewRequest("GET", serverURL+"/key", nil)
|
||||
if err != nil {
|
||||
return wgcfg.Key{}, fmt.Errorf("create control key request: %v", err)
|
||||
}
|
||||
req = req.WithContext(ctx)
|
||||
res, err := httpc.Do(req)
|
||||
if err != nil {
|
||||
return wgcfg.Key{}, fmt.Errorf("fetch control key: %v", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
b, err := ioutil.ReadAll(io.LimitReader(res.Body, 1<<16))
|
||||
if err != nil {
|
||||
return wgcfg.Key{}, fmt.Errorf("fetch control key response: %v", err)
|
||||
}
|
||||
if res.StatusCode != 200 {
|
||||
return wgcfg.Key{}, fmt.Errorf("fetch control key: %d: %s", res.StatusCode, string(b))
|
||||
}
|
||||
key, err := wgcfg.ParseHexKey(string(b))
|
||||
if err != nil {
|
||||
return wgcfg.Key{}, fmt.Errorf("fetch control key: %v", err)
|
||||
}
|
||||
return *key, nil
|
||||
}
|
305
control/controlclient/direct_test.go
Normal file
305
control/controlclient/direct_test.go
Normal file
@@ -0,0 +1,305 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build depends_on_currently_unreleased
|
||||
|
||||
package controlclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/tailscale/wireguard-go/wgcfg"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.io/control" // not yet released
|
||||
)
|
||||
|
||||
func TestClientsReusingKeys(t *testing.T) {
|
||||
tmpdir, err := ioutil.TempDir("", "control-test-")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var server *control.Server
|
||||
httpsrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
server.ServeHTTP(w, r)
|
||||
}))
|
||||
server, err = control.New(tmpdir, httpsrv.URL, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
server.QuietLogging = true
|
||||
defer func() {
|
||||
httpsrv.CloseClientConnections()
|
||||
httpsrv.Close()
|
||||
os.RemoveAll(tmpdir)
|
||||
}()
|
||||
|
||||
hi := NewHostinfo()
|
||||
hi.FrontendLogID = "go-test-only"
|
||||
hi.BackendLogID = "go-test-only"
|
||||
c1, err := NewDirect(Options{
|
||||
ServerURL: httpsrv.URL,
|
||||
HTTPC: httpsrv.Client(),
|
||||
//TimeNow: s.control.TimeNow,
|
||||
Logf: func(fmt string, args ...interface{}) {
|
||||
t.Helper()
|
||||
t.Logf("c1: "+fmt, args...)
|
||||
},
|
||||
Hostinfo: &hi,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
authURL, err := c1.TryLogin(ctx, nil, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
const user = "testuser1@tailscale.onmicrosoft.com"
|
||||
postAuthURL(t, ctx, httpsrv.Client(), user, authURL)
|
||||
newURL, err := c1.WaitLoginURL(ctx, authURL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if newURL != "" {
|
||||
t.Fatalf("unexpected newURL: %s", newURL)
|
||||
}
|
||||
|
||||
pollErrCh := make(chan error)
|
||||
go func() {
|
||||
err := c1.PollNetMap(ctx, -1, func(netMap *NetworkMap) {})
|
||||
pollErrCh <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-pollErrCh:
|
||||
t.Fatal(err)
|
||||
default:
|
||||
}
|
||||
|
||||
c2, err := NewDirect(Options{
|
||||
ServerURL: httpsrv.URL,
|
||||
HTTPC: httpsrv.Client(),
|
||||
Logf: func(fmt string, args ...interface{}) {
|
||||
t.Helper()
|
||||
t.Logf("c2: "+fmt, args...)
|
||||
},
|
||||
Persist: c1.GetPersist(),
|
||||
Hostinfo: &hi,
|
||||
NewDecompressor: func() (Decompressor, error) {
|
||||
return zstd.NewReader(nil)
|
||||
},
|
||||
KeepAlive: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
authURL, err = c2.TryLogin(ctx, nil, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if authURL != "" {
|
||||
t.Errorf("unexpected authURL %s", authURL)
|
||||
}
|
||||
|
||||
err = c2.PollNetMap(ctx, 1, func(netMap *NetworkMap) {})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-pollErrCh:
|
||||
t.Logf("expected poll error: %v", err)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("first client poll failed to close")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientsReusingOldKey(t *testing.T) {
|
||||
tmpdir, err := ioutil.TempDir("", "control-test-")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var server *control.Server
|
||||
httpsrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
server.ServeHTTP(w, r)
|
||||
}))
|
||||
server, err = control.New(tmpdir, httpsrv.URL, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
server.QuietLogging = true
|
||||
defer func() {
|
||||
httpsrv.CloseClientConnections()
|
||||
httpsrv.Close()
|
||||
os.RemoveAll(tmpdir)
|
||||
}()
|
||||
|
||||
hi := NewHostinfo()
|
||||
hi.FrontendLogID = "go-test-only"
|
||||
hi.BackendLogID = "go-test-only"
|
||||
genOpts := func() Options {
|
||||
return Options{
|
||||
ServerURL: httpsrv.URL,
|
||||
HTTPC: httpsrv.Client(),
|
||||
//TimeNow: s.control.TimeNow,
|
||||
Logf: func(fmt string, args ...interface{}) {
|
||||
t.Helper()
|
||||
t.Logf("c1: "+fmt, args...)
|
||||
},
|
||||
Hostinfo: &hi,
|
||||
}
|
||||
}
|
||||
|
||||
// Login with a new node key. This requires authorization.
|
||||
c1, err := NewDirect(genOpts())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
authURL, err := c1.TryLogin(ctx, nil, 0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
const user = "testuser1@tailscale.onmicrosoft.com"
|
||||
postAuthURL(t, ctx, httpsrv.Client(), user, authURL)
|
||||
newURL, err := c1.WaitLoginURL(ctx, authURL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if newURL != "" {
|
||||
t.Fatalf("unexpected newURL: %s", newURL)
|
||||
}
|
||||
|
||||
if err := c1.PollNetMap(ctx, 1, func(netMap *NetworkMap) {}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
newPrivKey := func(t *testing.T) wgcfg.PrivateKey {
|
||||
t.Helper()
|
||||
k, err := wgcfg.NewPrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return *k
|
||||
}
|
||||
|
||||
// Replace the previous key with a new key.
|
||||
persist1 := c1.GetPersist()
|
||||
persist2 := Persist{
|
||||
PrivateMachineKey: persist1.PrivateMachineKey,
|
||||
OldPrivateNodeKey: persist1.PrivateNodeKey,
|
||||
PrivateNodeKey: newPrivKey(t),
|
||||
}
|
||||
opts := genOpts()
|
||||
opts.Persist = persist2
|
||||
|
||||
c1, err = NewDirect(opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if authURL, err := c1.TryLogin(ctx, nil, 0); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if authURL == "" {
|
||||
t.Fatal("expected authURL for reused oldNodeKey, got none")
|
||||
} else {
|
||||
postAuthURL(t, ctx, httpsrv.Client(), user, authURL)
|
||||
if newURL, err := c1.WaitLoginURL(ctx, authURL); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if newURL != "" {
|
||||
t.Fatalf("unexpected newURL: %s", newURL)
|
||||
}
|
||||
}
|
||||
if p := c1.GetPersist(); p.PrivateNodeKey != opts.Persist.PrivateNodeKey {
|
||||
t.Error("unexpected node key change")
|
||||
} else {
|
||||
persist2 = p
|
||||
}
|
||||
|
||||
// Here we simulate a client using using old persistant data.
|
||||
// We use the key we have already replaced as the old node key.
|
||||
// This requires the user to authenticate.
|
||||
persist3 := Persist{
|
||||
PrivateMachineKey: persist1.PrivateMachineKey,
|
||||
OldPrivateNodeKey: persist1.PrivateNodeKey,
|
||||
PrivateNodeKey: newPrivKey(t),
|
||||
}
|
||||
opts = genOpts()
|
||||
opts.Persist = persist3
|
||||
|
||||
c1, err = NewDirect(opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if authURL, err := c1.TryLogin(ctx, nil, 0); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if authURL == "" {
|
||||
t.Fatal("expected authURL for reused oldNodeKey, got none")
|
||||
} else {
|
||||
postAuthURL(t, ctx, httpsrv.Client(), user, authURL)
|
||||
if newURL, err := c1.WaitLoginURL(ctx, authURL); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if newURL != "" {
|
||||
t.Fatalf("unexpected newURL: %s", newURL)
|
||||
}
|
||||
}
|
||||
if err := c1.PollNetMap(ctx, 1, func(netMap *NetworkMap) {}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// At this point, there should only be one node for the machine key
|
||||
// registered as active in the server.
|
||||
mkey := tailcfg.MachineKey(*persist1.PrivateMachineKey.Public())
|
||||
nodeIDs, err := server.DB().MachineNodes(mkey)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(nodeIDs) != 1 {
|
||||
t.Logf("active nodes for machine key %v:", mkey)
|
||||
for i, nodeID := range nodeIDs {
|
||||
nodeKey := server.DB().NodeKey(nodeID)
|
||||
t.Logf("\tnode %d: id=%v, key=%v", i, nodeID, nodeKey)
|
||||
}
|
||||
t.Fatalf("want 1 active node for the client machine, got %d", len(nodeIDs))
|
||||
}
|
||||
|
||||
// Now try the previous node key. It should fail.
|
||||
opts = genOpts()
|
||||
opts.Persist = persist2
|
||||
c1, err = NewDirect(opts)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// TODO(crawshaw): make this return an actual error.
|
||||
// Have cfgdb track expired keys, and when an expired key is reused
|
||||
// produce an error.
|
||||
if authURL, err := c1.TryLogin(ctx, nil, 0); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if authURL == "" {
|
||||
t.Fatal("expected authURL for reused nodeKey, got none")
|
||||
} else {
|
||||
postAuthURL(t, ctx, httpsrv.Client(), user, authURL)
|
||||
if newURL, err := c1.WaitLoginURL(ctx, authURL); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if newURL != "" {
|
||||
t.Fatalf("unexpected newURL: %s", newURL)
|
||||
}
|
||||
}
|
||||
if err := c1.PollNetMap(ctx, 1, func(netMap *NetworkMap) {}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if nodeIDs, err := server.DB().MachineNodes(mkey); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if len(nodeIDs) != 1 {
|
||||
t.Fatalf("want 1 active node for the client machine, got %d", len(nodeIDs))
|
||||
}
|
||||
}
|
294
control/controlclient/netmap.go
Normal file
294
control/controlclient/netmap.go
Normal file
@@ -0,0 +1,294 @@
|
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package controlclient
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/tailscale/wireguard-go/wgcfg"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/wgengine/filter"
|
||||
)
|
||||
|
||||
type NetworkMap struct {
|
||||
// Core networking
|
||||
|
||||
NodeKey tailcfg.NodeKey
|
||||
PrivateKey wgcfg.PrivateKey
|
||||
Expiry time.Time
|
||||
Addresses []wgcfg.CIDR
|
||||
LocalPort uint16 // used for debugging
|
||||
MachineStatus tailcfg.MachineStatus
|
||||
Peers []tailcfg.Node
|
||||
DNS []wgcfg.IP
|
||||
DNSDomains []string
|
||||
Hostinfo tailcfg.Hostinfo
|
||||
PacketFilter filter.Matches
|
||||
|
||||
// ACLs
|
||||
|
||||
User tailcfg.UserID
|
||||
Domain string
|
||||
// TODO(crawshaw): reduce UserProfiles to []tailcfg.UserProfile?
|
||||
// There are lots of ways to slice this data, leave it up to users.
|
||||
UserProfiles map[tailcfg.UserID]tailcfg.UserProfile
|
||||
Roles []tailcfg.Role
|
||||
// TODO(crawshaw): Groups []tailcfg.Group
|
||||
// TODO(crawshaw): Capabilities []tailcfg.Capability
|
||||
}
|
||||
|
||||
func (n *NetworkMap) Equal(n2 *NetworkMap) bool {
|
||||
// TODO(crawshaw): this is crude, but is an easy way to avoid bugs.
|
||||
b, err := json.Marshal(n)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
b2, err := json.Marshal(n2)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return bytes.Equal(b, b2)
|
||||
}
|
||||
|
||||
func (n *NetworkMap) isEmpty() bool {
|
||||
if n == nil {
|
||||
return true
|
||||
}
|
||||
return n.Equal(&NetworkMap{})
|
||||
}
|
||||
|
||||
func (nm NetworkMap) String() string {
|
||||
return nm.Concise()
|
||||
}
|
||||
|
||||
func keyString(key [32]byte) string {
|
||||
b64 := base64.StdEncoding.EncodeToString(key[:])
|
||||
abbrev := "invalid"
|
||||
if len(b64) == 44 {
|
||||
abbrev = b64[0:4] + "…" + b64[39:43]
|
||||
}
|
||||
return fmt.Sprintf("[%s]", abbrev)
|
||||
}
|
||||
|
||||
func (nm *NetworkMap) Concise() string {
|
||||
buf := new(strings.Builder)
|
||||
fmt.Fprintf(buf, "NetworkMap: self: %v auth=%v :%v %v\n",
|
||||
keyString(nm.NodeKey), nm.MachineStatus,
|
||||
nm.LocalPort, nm.Addresses)
|
||||
for _, p := range nm.Peers {
|
||||
aip := make([]string, len(p.AllowedIPs))
|
||||
for i, a := range p.AllowedIPs {
|
||||
aip[i] = fmt.Sprint(a)
|
||||
}
|
||||
u := fmt.Sprint(p.User)
|
||||
if strings.HasPrefix(u, "userid:") {
|
||||
u = "u:" + u[7:]
|
||||
}
|
||||
f1 := fmt.Sprintf(" %v %-6v %v",
|
||||
keyString(p.Key), u, p.Endpoints)
|
||||
f2 := fmt.Sprintf(" %*v\n", 70-len(f1),
|
||||
strings.Join(aip, " "))
|
||||
fmt.Fprintf(buf, "%s%s", f1, f2)
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
func (nm *NetworkMap) JSON() string {
|
||||
b, err := json.MarshalIndent(*nm, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Sprintf("[json error: %v]", err)
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// TODO(apenwarr): delete me once relaynode doesn't need this anymore.
|
||||
// control.go:userMap() supercedes it. This does not belong in the client.
|
||||
func (nm *NetworkMap) UserMap() map[string][]filter.IP {
|
||||
// Make a lookup table of roles
|
||||
log.Printf("roles list is: %v\n", nm.Roles)
|
||||
roles := make(map[tailcfg.RoleID]tailcfg.Role)
|
||||
for _, role := range nm.Roles {
|
||||
roles[role.ID] = role
|
||||
}
|
||||
|
||||
// First, go through each node's addresses and make a lookup table
|
||||
// of IP->User.
|
||||
fwd := make(map[wgcfg.IP]string)
|
||||
for _, node := range nm.Peers {
|
||||
for _, addr := range node.Addresses {
|
||||
if addr.Mask == 32 && addr.IP.Is4() {
|
||||
user, ok := nm.UserProfiles[node.User]
|
||||
if ok {
|
||||
fwd[addr.IP] = user.LoginName
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Next, reverse the mapping into User->IP.
|
||||
rev := make(map[string][]filter.IP)
|
||||
for ip, username := range fwd {
|
||||
ip4 := ip.To4()
|
||||
if ip4 != nil {
|
||||
fip := filter.NewIP(net.IP(ip4))
|
||||
rev[username] = append(rev[username], fip)
|
||||
}
|
||||
}
|
||||
|
||||
// Now add roles, which are lists of users, and therefore lists
|
||||
// of those users' IP addresses.
|
||||
for _, user := range nm.UserProfiles {
|
||||
for _, roleid := range user.Roles {
|
||||
role, ok := roles[roleid]
|
||||
if ok {
|
||||
rolename := "role:" + role.Name
|
||||
rev[rolename] = append(rev[rolename], rev[user.LoginName]...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//log.Printf("Usermap is: %v\n", rev)
|
||||
return rev
|
||||
}
|
||||
|
||||
var iOS = runtime.GOOS == "darwin" && (runtime.GOARCH == "arm" || runtime.GOARCH == "arm64")
|
||||
var keepalive = !iOS
|
||||
|
||||
const (
|
||||
UAllowSingleHosts = 1 << iota
|
||||
UAllowSubnetRoutes
|
||||
UAllowDefaultRoute
|
||||
UHackDefaultRoute
|
||||
|
||||
UDefault = 0
|
||||
)
|
||||
|
||||
// Several programs need to parse these arguments into uflags, so let's
|
||||
// centralize it here.
|
||||
func UFlagsHelper(uroutes, rroutes, droutes bool) int {
|
||||
uflags := 0
|
||||
if uroutes {
|
||||
uflags |= UAllowSingleHosts
|
||||
}
|
||||
if rroutes {
|
||||
uflags |= UAllowSubnetRoutes
|
||||
}
|
||||
if droutes {
|
||||
uflags |= UAllowDefaultRoute
|
||||
}
|
||||
return uflags
|
||||
}
|
||||
|
||||
func (nm *NetworkMap) UAPI(uflags int, dnsOverride []wgcfg.IP) string {
|
||||
wgcfg, err := nm.WGCfg(uflags, dnsOverride)
|
||||
if err != nil {
|
||||
log.Fatalf("WGCfg() failed unexpectedly: %v\n", err)
|
||||
}
|
||||
s, err := wgcfg.ToUAPI()
|
||||
if err != nil {
|
||||
log.Fatalf("ToUAPI() failed unexpectedly: %v\n", err)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (nm *NetworkMap) WGCfg(uflags int, dnsOverride []wgcfg.IP) (*wgcfg.Config, error) {
|
||||
s := nm._WireGuardConfig(uflags, dnsOverride, true)
|
||||
return wgcfg.FromWgQuick(s, "tailscale")
|
||||
}
|
||||
|
||||
// TODO(apenwarr): This mode is dangerous.
|
||||
// Discarding the extra endpoints is almost universally the wrong choice.
|
||||
// Except that plain wireguard can't handle a peer with multiple endpoints.
|
||||
// (Yet?)
|
||||
func (nm *NetworkMap) WireGuardConfigOneEndpoint(uflags int, dnsOverride []wgcfg.IP) string {
|
||||
return nm._WireGuardConfig(uflags, dnsOverride, false)
|
||||
}
|
||||
|
||||
func (nm *NetworkMap) _WireGuardConfig(uflags int, dnsOverride []wgcfg.IP, allEndpoints bool) string {
|
||||
buf := new(strings.Builder)
|
||||
fmt.Fprintf(buf, "[Interface]\n")
|
||||
fmt.Fprintf(buf, "PrivateKey = %s\n", base64.StdEncoding.EncodeToString(nm.PrivateKey[:]))
|
||||
if len(nm.Addresses) > 0 {
|
||||
fmt.Fprintf(buf, "Address = ")
|
||||
for i, cidr := range nm.Addresses {
|
||||
if i > 0 {
|
||||
fmt.Fprintf(buf, ", ")
|
||||
}
|
||||
fmt.Fprintf(buf, "%s", cidr)
|
||||
}
|
||||
fmt.Fprintf(buf, "\n")
|
||||
}
|
||||
fmt.Fprintf(buf, "ListenPort = %d\n", nm.LocalPort)
|
||||
if len(dnsOverride) > 0 {
|
||||
dnss := []string{}
|
||||
for _, ip := range dnsOverride {
|
||||
dnss = append(dnss, ip.String())
|
||||
}
|
||||
fmt.Fprintf(buf, "DNS = %s\n", strings.Join(dnss, ","))
|
||||
}
|
||||
fmt.Fprintf(buf, "\n")
|
||||
|
||||
for i, peer := range nm.Peers {
|
||||
if (uflags&UAllowSingleHosts) == 0 && len(peer.AllowedIPs) < 2 {
|
||||
log.Printf("wgcfg: %v skipping a single-host peer.\n", peer.Key.AbbrevString())
|
||||
continue
|
||||
}
|
||||
if i > 0 {
|
||||
fmt.Fprintf(buf, "\n")
|
||||
}
|
||||
fmt.Fprintf(buf, "[Peer]\n")
|
||||
fmt.Fprintf(buf, "PublicKey = %s\n", base64.StdEncoding.EncodeToString(peer.Key[:]))
|
||||
if len(peer.Endpoints) > 0 {
|
||||
if len(peer.Endpoints) == 1 {
|
||||
fmt.Fprintf(buf, "Endpoint = %s", peer.Endpoints[0])
|
||||
} else if allEndpoints {
|
||||
// TODO(apenwarr): This mode is incompatible.
|
||||
// Normal wireguard clients don't know how to
|
||||
// parse it (yet?)
|
||||
fmt.Fprintf(buf, "Endpoint = %s",
|
||||
strings.Join(peer.Endpoints, ","))
|
||||
} else {
|
||||
fmt.Fprintf(buf, "Endpoint = %s # other endpoints: %s",
|
||||
peer.Endpoints[0],
|
||||
strings.Join(peer.Endpoints[1:], ", "))
|
||||
}
|
||||
buf.WriteByte('\n')
|
||||
}
|
||||
var aips []string
|
||||
for _, allowedIP := range peer.AllowedIPs {
|
||||
aip := allowedIP.String()
|
||||
if allowedIP.Mask == 0 {
|
||||
if (uflags & UAllowDefaultRoute) == 0 {
|
||||
log.Printf("wgcfg: %v skipping default route\n", peer.Key.AbbrevString())
|
||||
continue
|
||||
}
|
||||
if (uflags & UHackDefaultRoute) != 0 {
|
||||
aip = "10.0.0.0/8"
|
||||
log.Printf("wgcfg: %v converting default route => %v\n", peer.Key.AbbrevString(), aip)
|
||||
}
|
||||
} else if allowedIP.Mask < 32 {
|
||||
if (uflags & UAllowSubnetRoutes) == 0 {
|
||||
log.Printf("wgcfg: %v skipping subnet route\n", peer.Key.AbbrevString())
|
||||
continue
|
||||
}
|
||||
}
|
||||
aips = append(aips, aip)
|
||||
}
|
||||
fmt.Fprintf(buf, "AllowedIPs = %s\n", strings.Join(aips, ", "))
|
||||
if keepalive {
|
||||
fmt.Fprintf(buf, "PersistentKeepalive = 25\n")
|
||||
}
|
||||
}
|
||||
|
||||
return buf.String()
|
||||
}
|
Reference in New Issue
Block a user