mirror of
https://github.com/tailscale/tailscale.git
synced 2025-01-08 09:07:44 +00:00
Introduce a state store to LocalBackend.
The store is passed-in by callers of NewLocalBackend and ipnserver.Run, but currently all callers are hardcoded to an in-memory store. The store is unused. Signed-Off-By: David Anderson <dave@natulte.net>
This commit is contained in:
parent
21280ca2d1
commit
5bc632271b
@ -44,14 +44,15 @@ func main() {
|
|||||||
log.Printf("fixConsoleOutput: %v\n", err)
|
log.Printf("fixConsoleOutput: %v\n", err)
|
||||||
}
|
}
|
||||||
config := getopt.StringLong("config", 'f', "", "path to config file")
|
config := getopt.StringLong("config", 'f', "", "path to config file")
|
||||||
|
statekey := getopt.StringLong("statekey", 0, "", "state key for daemon-side config")
|
||||||
server := getopt.StringLong("server", 's', "https://login.tailscale.com", "URL to tailcontrol server")
|
server := getopt.StringLong("server", 's', "https://login.tailscale.com", "URL to tailcontrol server")
|
||||||
nuroutes := getopt.BoolLong("no-single-routes", 'N', "disallow (non-subnet) routes to single nodes")
|
nuroutes := getopt.BoolLong("no-single-routes", 'N', "disallow (non-subnet) routes to single nodes")
|
||||||
rroutes := getopt.BoolLong("remote-routes", 'R', "allow routing subnets to remote nodes")
|
rroutes := getopt.BoolLong("remote-routes", 'R', "allow routing subnets to remote nodes")
|
||||||
droutes := getopt.BoolLong("default-routes", 'D', "allow default route on remote node")
|
droutes := getopt.BoolLong("default-routes", 'D', "allow default route on remote node")
|
||||||
getopt.Parse()
|
getopt.Parse()
|
||||||
if *config == "" {
|
if *config == "" && *statekey == "" {
|
||||||
logpolicy.New("tailnode.log.tailscale.io", "tailscale")
|
logpolicy.New("tailnode.log.tailscale.io", "tailscale")
|
||||||
log.Fatal("no --config file specified")
|
log.Fatal("no --config or --statekey provided")
|
||||||
}
|
}
|
||||||
if len(getopt.Args()) > 0 {
|
if len(getopt.Args()) > 0 {
|
||||||
log.Fatalf("too many non-flag arguments: %#v", getopt.Args()[0])
|
log.Fatalf("too many non-flag arguments: %#v", getopt.Args()[0])
|
||||||
@ -60,16 +61,20 @@ func main() {
|
|||||||
pol := logpolicy.New("tailnode.log.tailscale.io", *config)
|
pol := logpolicy.New("tailnode.log.tailscale.io", *config)
|
||||||
defer pol.Close()
|
defer pol.Close()
|
||||||
|
|
||||||
prefs, err := loadConfig(*config)
|
var prefs *ipn.Prefs
|
||||||
|
if *config != "" {
|
||||||
|
localCfg, err := loadConfig(*config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(apenwarr): fix different semantics between prefs and uflags
|
// TODO(apenwarr): fix different semantics between prefs and uflags
|
||||||
// TODO(apenwarr): allow setting/using CorpDNS
|
// TODO(apenwarr): allow setting/using CorpDNS
|
||||||
|
prefs = &localCfg
|
||||||
prefs.WantRunning = true
|
prefs.WantRunning = true
|
||||||
prefs.RouteAll = *rroutes || *droutes
|
prefs.RouteAll = *rroutes || *droutes
|
||||||
prefs.AllowSingleHosts = !*nuroutes
|
prefs.AllowSingleHosts = !*nuroutes
|
||||||
|
}
|
||||||
|
|
||||||
c, err := safesocket.Connect("", "Tailscale", "tailscaled", 41112)
|
c, err := safesocket.Connect("", "Tailscale", "tailscaled", 41112)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -90,7 +95,8 @@ func main() {
|
|||||||
|
|
||||||
bc := ipn.NewBackendClient(log.Printf, clientToServer)
|
bc := ipn.NewBackendClient(log.Printf, clientToServer)
|
||||||
opts := ipn.Options{
|
opts := ipn.Options{
|
||||||
Prefs: &prefs,
|
StateKey: ipn.StateKey(*statekey),
|
||||||
|
Prefs: prefs,
|
||||||
ServerURL: *server,
|
ServerURL: *server,
|
||||||
Notify: func(n ipn.Notify) {
|
Notify: func(n ipn.Notify) {
|
||||||
log.Printf("Notify: %v\n", n)
|
log.Printf("Notify: %v\n", n)
|
||||||
@ -112,7 +118,7 @@ func main() {
|
|||||||
fmt.Fprintf(os.Stderr, "\nTo authenticate, visit:\n\n\t%s\n\n", *url)
|
fmt.Fprintf(os.Stderr, "\nTo authenticate, visit:\n\n\t%s\n\n", *url)
|
||||||
}
|
}
|
||||||
if p := n.Prefs; p != nil {
|
if p := n.Prefs; p != nil {
|
||||||
prefs = *p
|
prefs = p
|
||||||
saveConfig(*config, *p)
|
saveConfig(*config, *p)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -131,6 +137,9 @@ func loadConfig(path string) (ipn.Prefs, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func saveConfig(path string, prefs ipn.Prefs) error {
|
func saveConfig(path string, prefs ipn.Prefs) error {
|
||||||
|
if path == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
b, err := json.MarshalIndent(prefs, "", "\t")
|
b, err := json.MarshalIndent(prefs, "", "\t")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("save config: %v", err)
|
return fmt.Errorf("save config: %v", err)
|
||||||
|
@ -28,6 +28,7 @@ func main() {
|
|||||||
debug := getopt.StringLong("debug", 0, "", "Address of debug server")
|
debug := getopt.StringLong("debug", 0, "", "Address of debug server")
|
||||||
tunname := getopt.StringLong("tun", 0, "ts0", "tunnel interface name")
|
tunname := getopt.StringLong("tun", 0, "ts0", "tunnel interface name")
|
||||||
listenport := getopt.Uint16Long("port", 'p', magicsock.DefaultPort, "WireGuard port (0=autoselect)")
|
listenport := getopt.Uint16Long("port", 'p', magicsock.DefaultPort, "WireGuard port (0=autoselect)")
|
||||||
|
statepath := getopt.StringLong("state", 0, "", "Path of state file")
|
||||||
|
|
||||||
logf := wgengine.RusagePrefixLog(log.Printf)
|
logf := wgengine.RusagePrefixLog(log.Printf)
|
||||||
|
|
||||||
@ -58,6 +59,7 @@ func main() {
|
|||||||
e = wgengine.NewWatchdog(e)
|
e = wgengine.NewWatchdog(e)
|
||||||
|
|
||||||
opts := ipnserver.Options{
|
opts := ipnserver.Options{
|
||||||
|
StatePath: *statepath,
|
||||||
SurviveDisconnects: true,
|
SurviveDisconnects: true,
|
||||||
AllowQuit: false,
|
AllowQuit: false,
|
||||||
}
|
}
|
||||||
|
@ -50,11 +50,44 @@ type Notify struct {
|
|||||||
BackendLogID *string // public logtail id used by backend
|
BackendLogID *string // public logtail id used by backend
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// StateKey is an opaque identifier for a set of LocalBackend state
|
||||||
|
// (preferences, private keys, etc.).
|
||||||
|
//
|
||||||
|
// The reason we need this is that the Tailscale agent may be running
|
||||||
|
// on a multi-user machine, in a context where a single daemon is
|
||||||
|
// shared by several consecutive users. Ideally we would just use the
|
||||||
|
// username of the connected frontend as the StateKey.
|
||||||
|
//
|
||||||
|
// However, on Windows, there seems to be no safe way to figure out
|
||||||
|
// the owning user of a process connected over IPC mechanisms
|
||||||
|
// (sockets, named pipes). So instead, on Windows, we use a
|
||||||
|
// capability-oriented system where the frontend generates a random
|
||||||
|
// identifier for itself, and uses that as the StateKey when talking
|
||||||
|
// to the backend. That way, while we can't identify an OS user by
|
||||||
|
// name, we can tell two different users apart, because they'll have
|
||||||
|
// different opaque state keys (and no access to each others's keys).
|
||||||
|
//
|
||||||
|
// It would be much nicer if we could just figure out the OS user that
|
||||||
|
// owns the connected frontend, but here we are.
|
||||||
|
type StateKey string
|
||||||
|
|
||||||
type Options struct {
|
type Options struct {
|
||||||
FrontendLogID string // public logtail id used by frontend
|
// Public logtail id used by frontend.
|
||||||
|
FrontendLogID string
|
||||||
|
// Base URL for the tailcontrol server to talk to.
|
||||||
ServerURL string
|
ServerURL string
|
||||||
|
// StateKey and Prefs together define the state the backend should
|
||||||
|
// use:
|
||||||
|
// - StateKey=="" && Prefs!=nil: use Prefs for internal state,
|
||||||
|
// don't persist changes in the backend.
|
||||||
|
// - StateKey!="" && Prefs==nil: load the given backend-side
|
||||||
|
// state and use/update that.
|
||||||
|
// - StateKey!="" && Prefs!=nil: like the previous case, but do
|
||||||
|
// an initial overwrite of backend state with Prefs.
|
||||||
|
StateKey StateKey
|
||||||
Prefs *Prefs
|
Prefs *Prefs
|
||||||
Notify func(n Notify) `json:"-"`
|
// Callback for backend events.
|
||||||
|
Notify func(Notify) `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Backend interface {
|
type Backend interface {
|
||||||
|
@ -158,7 +158,7 @@ func newNode(t *testing.T, prefix string, https *httptest.Server) testNode {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewFakeEngine: %v\n", err)
|
t.Fatalf("NewFakeEngine: %v\n", err)
|
||||||
}
|
}
|
||||||
n, err := NewLocalBackend(logf, prefix, e1)
|
n, err := NewLocalBackend(logf, prefix, &MemoryStore{}, e1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("NewLocalBackend: %v\n", err)
|
t.Fatalf("NewLocalBackend: %v\n", err)
|
||||||
}
|
}
|
||||||
|
@ -29,6 +29,7 @@
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Options struct {
|
type Options struct {
|
||||||
|
StatePath string
|
||||||
SurviveDisconnects bool
|
SurviveDisconnects bool
|
||||||
AllowQuit bool
|
AllowQuit bool
|
||||||
}
|
}
|
||||||
@ -58,7 +59,17 @@ func Run(rctx context.Context, logf logger.Logf, logid string, opts Options, e w
|
|||||||
return fmt.Errorf("safesocket.Listen: %v", err)
|
return fmt.Errorf("safesocket.Listen: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
b, err := ipn.NewLocalBackend(logf, logid, e)
|
var store ipn.StateStore
|
||||||
|
if opts.StatePath != "" {
|
||||||
|
store, err = ipn.NewFileStore(opts.StatePath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("ipn.NewFileStore(%q): %v", opts.StatePath, err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
store = &ipn.MemoryStore{}
|
||||||
|
}
|
||||||
|
|
||||||
|
b, err := ipn.NewLocalBackend(logf, logid, store, e)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("NewLocalBackend: %v", err)
|
return fmt.Errorf("NewLocalBackend: %v", err)
|
||||||
}
|
}
|
||||||
|
71
ipn/local.go
71
ipn/local.go
@ -5,6 +5,7 @@
|
|||||||
package ipn
|
package ipn
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"strings"
|
"strings"
|
||||||
@ -28,6 +29,7 @@ type LocalBackend struct {
|
|||||||
notify func(n Notify)
|
notify func(n Notify)
|
||||||
c *controlclient.Client
|
c *controlclient.Client
|
||||||
e wgengine.Engine
|
e wgengine.Engine
|
||||||
|
store StateStore
|
||||||
serverURL string
|
serverURL string
|
||||||
backendLogID string
|
backendLogID string
|
||||||
portpoll *portlist.Poller // may be nil
|
portpoll *portlist.Poller // may be nil
|
||||||
@ -36,6 +38,7 @@ type LocalBackend struct {
|
|||||||
|
|
||||||
// The mutex protects the following elements.
|
// The mutex protects the following elements.
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
|
stateKey StateKey
|
||||||
prefs Prefs
|
prefs Prefs
|
||||||
state State
|
state State
|
||||||
hiCache tailcfg.Hostinfo
|
hiCache tailcfg.Hostinfo
|
||||||
@ -52,7 +55,9 @@ type LocalBackend struct {
|
|||||||
statusChanged *sync.Cond
|
statusChanged *sync.Cond
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLocalBackend(logf logger.Logf, logid string, e wgengine.Engine) (*LocalBackend, error) {
|
// NewLocalBackend returns a new LocalBackend that is ready to run,
|
||||||
|
// but is not actually running.
|
||||||
|
func NewLocalBackend(logf logger.Logf, logid string, store StateStore, e wgengine.Engine) (*LocalBackend, error) {
|
||||||
if e == nil {
|
if e == nil {
|
||||||
panic("ipn.NewLocalBackend: wgengine must not be nil")
|
panic("ipn.NewLocalBackend: wgengine must not be nil")
|
||||||
}
|
}
|
||||||
@ -68,6 +73,7 @@ func NewLocalBackend(logf logger.Logf, logid string, e wgengine.Engine) (*LocalB
|
|||||||
b := LocalBackend{
|
b := LocalBackend{
|
||||||
logf: logf,
|
logf: logf,
|
||||||
e: e,
|
e: e,
|
||||||
|
store: store,
|
||||||
backendLogID: logid,
|
backendLogID: logid,
|
||||||
state: NoState,
|
state: NoState,
|
||||||
portpoll: portpoll,
|
portpoll: portpoll,
|
||||||
@ -113,8 +119,8 @@ func (b *LocalBackend) SetCmpDiff(cmpDiff func(x, y interface{}) string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *LocalBackend) Start(opts Options) error {
|
func (b *LocalBackend) Start(opts Options) error {
|
||||||
if opts.Prefs == nil {
|
if opts.Prefs == nil && opts.StateKey == "" {
|
||||||
panic("Prefs can't be nil yet")
|
return errors.New("no state key or prefs provided")
|
||||||
}
|
}
|
||||||
|
|
||||||
if b.c != nil {
|
if b.c != nil {
|
||||||
@ -128,7 +134,11 @@ func (b *LocalBackend) Start(opts Options) error {
|
|||||||
b.c.Shutdown()
|
b.c.Shutdown()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if opts.Prefs != nil {
|
||||||
b.logf("Start: %v\n", opts.Prefs.Pretty())
|
b.logf("Start: %v\n", opts.Prefs.Pretty())
|
||||||
|
} else {
|
||||||
|
b.logf("Start\n")
|
||||||
|
}
|
||||||
|
|
||||||
hi := controlclient.NewHostinfo()
|
hi := controlclient.NewHostinfo()
|
||||||
hi.BackendLogID = b.backendLogID
|
hi.BackendLogID = b.backendLogID
|
||||||
@ -139,9 +149,12 @@ func (b *LocalBackend) Start(opts Options) error {
|
|||||||
b.hiCache = hi
|
b.hiCache = hi
|
||||||
b.state = NoState
|
b.state = NoState
|
||||||
b.serverURL = opts.ServerURL
|
b.serverURL = opts.ServerURL
|
||||||
if opts.Prefs != nil {
|
|
||||||
b.prefs = *opts.Prefs
|
if err := b.loadStateWithLock(opts.StateKey, opts.Prefs); err != nil {
|
||||||
|
b.mu.Unlock()
|
||||||
|
return fmt.Errorf("loading requested state: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
b.notify = opts.Notify
|
b.notify = opts.Notify
|
||||||
b.netMapCache = nil
|
b.netMapCache = nil
|
||||||
b.mu.Unlock()
|
b.mu.Unlock()
|
||||||
@ -187,6 +200,11 @@ func (b *LocalBackend) Start(opts Options) error {
|
|||||||
if new.Persist != nil {
|
if new.Persist != nil {
|
||||||
persist := *new.Persist // copy
|
persist := *new.Persist // copy
|
||||||
b.prefs.Persist = &persist
|
b.prefs.Persist = &persist
|
||||||
|
if b.stateKey != "" {
|
||||||
|
if err := b.store.WriteState(b.stateKey, b.prefs.ToBytes()); err != nil {
|
||||||
|
b.logf("Failed to save new controlclient state: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
np := b.prefs
|
np := b.prefs
|
||||||
b.send(Notify{Prefs: &np})
|
b.send(Notify{Prefs: &np})
|
||||||
}
|
}
|
||||||
@ -257,6 +275,8 @@ func (b *LocalBackend) Start(opts Options) error {
|
|||||||
blid := b.backendLogID
|
blid := b.backendLogID
|
||||||
b.logf("Backend: logs: be:%v fe:%v\n", blid, opts.FrontendLogID)
|
b.logf("Backend: logs: be:%v fe:%v\n", blid, opts.FrontendLogID)
|
||||||
b.send(Notify{BackendLogID: &blid})
|
b.send(Notify{BackendLogID: &blid})
|
||||||
|
nprefs := b.prefs // make a copy
|
||||||
|
b.send(Notify{Prefs: &nprefs})
|
||||||
|
|
||||||
cli.Login(nil, controlclient.LoginDefault)
|
cli.Login(nil, controlclient.LoginDefault)
|
||||||
return nil
|
return nil
|
||||||
@ -338,6 +358,42 @@ func (b *LocalBackend) popBrowserAuthNow() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *LocalBackend) loadStateWithLock(key StateKey, prefs *Prefs) error {
|
||||||
|
switch {
|
||||||
|
case key != "" && prefs != nil:
|
||||||
|
b.logf("Importing frontend prefs into backend store")
|
||||||
|
if err := b.store.WriteState(key, prefs.ToBytes()); err != nil {
|
||||||
|
return fmt.Errorf("store.WriteState: %v", err)
|
||||||
|
}
|
||||||
|
fallthrough
|
||||||
|
case key != "":
|
||||||
|
b.logf("Using backend prefs")
|
||||||
|
bs, err := b.store.ReadState(key)
|
||||||
|
if err != nil {
|
||||||
|
if err == ErrStateNotExist {
|
||||||
|
b.prefs = NewPrefs()
|
||||||
|
b.stateKey = key
|
||||||
|
b.logf("Created empty state for %q", key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("store.ReadState(%q): %v", key, err)
|
||||||
|
}
|
||||||
|
b.prefs, err = PrefsFromBytes(bs, false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("PrefsFromBytes: %v", err)
|
||||||
|
}
|
||||||
|
b.stateKey = key
|
||||||
|
case prefs != nil:
|
||||||
|
b.logf("Using frontend prefs")
|
||||||
|
b.prefs = *prefs
|
||||||
|
b.stateKey = ""
|
||||||
|
default:
|
||||||
|
panic("state key and prefs are unset")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (b *LocalBackend) State() State {
|
func (b *LocalBackend) State() State {
|
||||||
b.mu.Lock()
|
b.mu.Lock()
|
||||||
defer b.mu.Unlock()
|
defer b.mu.Unlock()
|
||||||
@ -436,6 +492,11 @@ func (b *LocalBackend) SetPrefs(new Prefs) {
|
|||||||
old := b.prefs
|
old := b.prefs
|
||||||
new.Persist = old.Persist // caller isn't allowed to override this
|
new.Persist = old.Persist // caller isn't allowed to override this
|
||||||
b.prefs = new
|
b.prefs = new
|
||||||
|
if b.stateKey != "" {
|
||||||
|
if err := b.store.WriteState(b.stateKey, b.prefs.ToBytes()); err != nil {
|
||||||
|
b.logf("Failed to save new controlclient state: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
b.mu.Unlock()
|
b.mu.Unlock()
|
||||||
|
|
||||||
if old.WantRunning != new.WantRunning {
|
if old.WantRunning != new.WantRunning {
|
||||||
|
@ -33,6 +33,9 @@ type Prefs struct {
|
|||||||
Persist *controlclient.Persist `json:"Config"`
|
Persist *controlclient.Persist `json:"Config"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsEmpty reports whether p is nil or pointing to a Prefs zero value.
|
||||||
|
func (uc *Prefs) IsEmpty() bool { return uc == nil || *uc == Prefs{} }
|
||||||
|
|
||||||
func (uc *Prefs) Pretty() string {
|
func (uc *Prefs) Pretty() string {
|
||||||
var ucp string
|
var ucp string
|
||||||
if uc.Persist != nil {
|
if uc.Persist != nil {
|
||||||
|
117
ipn/store.go
Normal file
117
ipn/store.go
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ipn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"tailscale.com/atomicfile"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrStateNotExist is returned by StateStore.ReadState when the
|
||||||
|
// requested state id doesn't exist.
|
||||||
|
var ErrStateNotExist = errors.New("no state with given id")
|
||||||
|
|
||||||
|
// StateStore persists state, and produces it back on request.
|
||||||
|
type StateStore interface {
|
||||||
|
// ReadState returns the bytes associated with id. Returns (nil,
|
||||||
|
// ErrStateNotExist) if the id doesn't have associated state.
|
||||||
|
ReadState(id StateKey) ([]byte, error)
|
||||||
|
// WriteState saves bs as the state associated with id.
|
||||||
|
WriteState(id StateKey, bs []byte) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// MemoryStore is a store that keeps state in memory only.
|
||||||
|
type MemoryStore struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
cache map[StateKey][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *MemoryStore) ReadState(id StateKey) ([]byte, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
if s.cache == nil {
|
||||||
|
s.cache = map[StateKey][]byte{}
|
||||||
|
}
|
||||||
|
bs, ok := s.cache[id]
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrStateNotExist
|
||||||
|
}
|
||||||
|
return bs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *MemoryStore) WriteState(id StateKey, bs []byte) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
if s.cache == nil {
|
||||||
|
s.cache = map[StateKey][]byte{}
|
||||||
|
}
|
||||||
|
s.cache[id] = append([]byte(nil), bs...)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FileStore is a StateStore that uses a JSON file for persistence.
|
||||||
|
type FileStore struct {
|
||||||
|
path string
|
||||||
|
|
||||||
|
mu sync.RWMutex
|
||||||
|
cache map[StateKey][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFileStore returns a new file store that persists to path.
|
||||||
|
func NewFileStore(path string) (*FileStore, error) {
|
||||||
|
bs, err := ioutil.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
// Write out an initial file, to verify that we can write
|
||||||
|
// to the path.
|
||||||
|
if err = atomicfile.WriteFile(path, []byte("{}"), 0600); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &FileStore{
|
||||||
|
path: path,
|
||||||
|
cache: map[StateKey][]byte{},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ret := &FileStore{
|
||||||
|
path: path,
|
||||||
|
cache: map[StateKey][]byte{},
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(bs, &ret.cache); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadState returns the bytes persisted for id, if any.
|
||||||
|
func (s *FileStore) ReadState(id StateKey) ([]byte, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
bs, ok := s.cache[id]
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrStateNotExist
|
||||||
|
}
|
||||||
|
return bs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteState persists bs under the key id.
|
||||||
|
func (s *FileStore) WriteState(id StateKey, bs []byte) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.cache[id] = append([]byte(nil), bs...)
|
||||||
|
bs, err := json.MarshalIndent(s.cache, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return atomicfile.WriteFile(s.path, bs, 0600)
|
||||||
|
}
|
121
ipn/store_test.go
Normal file
121
ipn/store_test.go
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package ipn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testStoreSemantics(t *testing.T, store StateStore) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
// if true, data is data to write. If false, data is expected
|
||||||
|
// output of read.
|
||||||
|
write bool
|
||||||
|
id StateKey
|
||||||
|
data string
|
||||||
|
// If write=false, true if we expect a not-exist error.
|
||||||
|
notExists bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
id: "foo",
|
||||||
|
notExists: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
write: true,
|
||||||
|
id: "foo",
|
||||||
|
data: "bar",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: "foo",
|
||||||
|
data: "bar",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: "baz",
|
||||||
|
notExists: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
write: true,
|
||||||
|
id: "baz",
|
||||||
|
data: "quux",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: "foo",
|
||||||
|
data: "bar",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: "baz",
|
||||||
|
data: "quux",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
if test.write {
|
||||||
|
if err := store.WriteState(test.id, []byte(test.data)); err != nil {
|
||||||
|
t.Errorf("writing %q to %q: %v", test.data, test.id, err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
bs, err := store.ReadState(test.id)
|
||||||
|
if err != nil {
|
||||||
|
if test.notExists && err == ErrStateNotExist {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
t.Errorf("reading %q: %v", test.id, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if string(bs) != test.data {
|
||||||
|
t.Errorf("reading %q: got %q, want %q", test.id, string(bs), test.data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMemoryStore(t *testing.T) {
|
||||||
|
store := &MemoryStore{}
|
||||||
|
testStoreSemantics(t, store)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileStore(t *testing.T) {
|
||||||
|
f, err := ioutil.TempFile("", "test_ipn_store")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
path := f.Name()
|
||||||
|
f.Close()
|
||||||
|
if err := os.Remove(path); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
store, err := NewFileStore(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("creating file store failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
testStoreSemantics(t, store)
|
||||||
|
|
||||||
|
// Build a brand new file store and check that both IDs written
|
||||||
|
// above are still there.
|
||||||
|
store, err = NewFileStore(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("creating second file store failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := map[StateKey]string{
|
||||||
|
"foo": "bar",
|
||||||
|
"baz": "quux",
|
||||||
|
}
|
||||||
|
for id, want := range expected {
|
||||||
|
bs, err := store.ReadState(id)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("reading %q (2nd store): %v", id, err)
|
||||||
|
}
|
||||||
|
if string(bs) != want {
|
||||||
|
t.Errorf("reading %q (2nd store): got %q, want %q", id, string(bs), want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user