types/key: add MachinePrivate and MachinePublic.

Plumb throughout the codebase as a replacement for the mixed use of
tailcfg.MachineKey and wgkey.Private/Public.

Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:
David Anderson
2021-09-01 01:52:27 -07:00
committed by Dave Anderson
parent 4ce091cbd8
commit 4fdb88efe1
24 changed files with 605 additions and 234 deletions

View File

@@ -26,14 +26,13 @@ import (
"time"
"github.com/klauspost/compress/zstd"
"golang.org/x/crypto/nacl/box"
"go4.org/mem"
"inet.af/netaddr"
"tailscale.com/net/tsaddr"
"tailscale.com/smallzstd"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/logger"
"tailscale.com/types/wgkey"
)
const msgLimit = 1 << 20 // encrypted message length limit
@@ -57,8 +56,8 @@ type Server struct {
mu sync.Mutex
inServeMap int
cond *sync.Cond // lazily initialized by condLocked
pubKey wgkey.Key
privKey wgkey.Private
pubKey key.MachinePublic
privKey key.MachinePrivate
nodes map[tailcfg.NodeKey]*tailcfg.Node
users map[tailcfg.NodeKey]*tailcfg.User
logins map[tailcfg.NodeKey]*tailcfg.Login
@@ -199,25 +198,21 @@ func (s *Server) serveUnhandled(w http.ResponseWriter, r *http.Request) {
go panic(fmt.Sprintf("testcontrol.Server received unhandled request: %s", got.Bytes()))
}
func (s *Server) publicKey() wgkey.Key {
func (s *Server) publicKey() key.MachinePublic {
pub, _ := s.keyPair()
return pub
}
func (s *Server) privateKey() wgkey.Private {
func (s *Server) privateKey() key.MachinePrivate {
_, priv := s.keyPair()
return priv
}
func (s *Server) keyPair() (pub wgkey.Key, priv wgkey.Private) {
func (s *Server) keyPair() (pub key.MachinePublic, priv key.MachinePrivate) {
s.mu.Lock()
defer s.mu.Unlock()
if s.pubKey.IsZero() {
var err error
s.privKey, err = wgkey.NewPrivate()
if err != nil {
go panic(err) // bring down test, even if in http.Handler
}
s.privKey = key.NewMachine()
s.pubKey = s.privKey.Public()
}
return s.pubKey, s.privKey
@@ -226,7 +221,7 @@ func (s *Server) keyPair() (pub wgkey.Key, priv wgkey.Private) {
func (s *Server) serveKey(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(200)
io.WriteString(w, s.publicKey().HexString())
io.WriteString(w, s.publicKey().UntypedHexString())
}
func (s *Server) serveMachine(w http.ResponseWriter, r *http.Request) {
@@ -237,12 +232,11 @@ func (s *Server) serveMachine(w http.ResponseWriter, r *http.Request) {
mkeyStr = mkeyStr[:i]
}
key, err := wgkey.ParseHex(mkeyStr)
mkey, err := key.ParseMachinePublicUntyped(mem.S(mkeyStr))
if err != nil {
http.Error(w, "bad machine key hex", 400)
return
}
mkey := tailcfg.MachineKey(key)
if r.Method != "POST" {
http.Error(w, "POST required", 400)
@@ -281,7 +275,7 @@ func (s *Server) AddFakeNode() {
s.nodes = make(map[tailcfg.NodeKey]*tailcfg.Node)
}
nk := tailcfg.NodeKey(key.NewPrivate().Public())
mk := tailcfg.MachineKey(key.NewPrivate().Public())
mk := key.NewMachine().Public()
dk := tailcfg.DiscoKey(key.NewPrivate().Public())
id := int64(binary.LittleEndian.Uint64(nk[:]))
ip := netaddr.IPv4(nk[0], nk[1], nk[2], nk[3])
@@ -398,7 +392,7 @@ func (s *Server) CompleteAuth(authPathOrURL string) bool {
return true
}
func (s *Server) serveRegister(w http.ResponseWriter, r *http.Request, mkey tailcfg.MachineKey) {
func (s *Server) serveRegister(w http.ResponseWriter, r *http.Request, mkey key.MachinePublic) {
msg, err := ioutil.ReadAll(io.LimitReader(r.Body, msgLimit))
if err != nil {
r.Body.Close()
@@ -563,7 +557,7 @@ func (s *Server) InServeMap() int {
return s.inServeMap
}
func (s *Server) serveMap(w http.ResponseWriter, r *http.Request, mkey tailcfg.MachineKey) {
func (s *Server) serveMap(w http.ResponseWriter, r *http.Request, mkey key.MachinePublic) {
s.incrInServeMap(1)
defer s.incrInServeMap(-1)
ctx := r.Context()
@@ -741,7 +735,7 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse,
return res, nil
}
func (s *Server) sendMapMsg(w http.ResponseWriter, mkey tailcfg.MachineKey, compress bool, msg interface{}) error {
func (s *Server) sendMapMsg(w http.ResponseWriter, mkey key.MachinePublic, compress bool, msg interface{}) error {
resBytes, err := s.encode(mkey, compress, msg)
if err != nil {
return err
@@ -765,21 +759,12 @@ func (s *Server) sendMapMsg(w http.ResponseWriter, mkey tailcfg.MachineKey, comp
return nil
}
func (s *Server) decode(mkey tailcfg.MachineKey, msg []byte, v interface{}) error {
func (s *Server) decode(mkey key.MachinePublic, msg []byte, v interface{}) error {
if len(msg) == msgLimit {
return errors.New("encrypted message too long")
}
var nonce [24]byte
if len(msg) < len(nonce)+1 {
return errors.New("missing nonce")
}
copy(nonce[:], msg)
msg = msg[len(nonce):]
priv := s.privateKey()
pub, pri := (*[32]byte)(&mkey), (*[32]byte)(&priv)
decrypted, ok := box.Open(nil, msg, &nonce, pub, pri)
decrypted, ok := s.privateKey().OpenFrom(mkey, msg)
if !ok {
return errors.New("can't decrypt request")
}
@@ -796,7 +781,7 @@ var zstdEncoderPool = &sync.Pool{
},
}
func (s *Server) encode(mkey tailcfg.MachineKey, compress bool, v interface{}) (b []byte, err error) {
func (s *Server) encode(mkey key.MachinePublic, compress bool, v interface{}) (b []byte, err error) {
var isBytes bool
if b, isBytes = v.([]byte); !isBytes {
b, err = json.Marshal(v)
@@ -810,14 +795,7 @@ func (s *Server) encode(mkey tailcfg.MachineKey, compress bool, v interface{}) (
encoder.Close()
zstdEncoderPool.Put(encoder)
}
var nonce [24]byte
if _, err := io.ReadFull(crand.Reader, nonce[:]); err != nil {
panic(err)
}
priv := s.privateKey()
pub, pri := (*[32]byte)(&mkey), (*[32]byte)(&priv)
msgData := box.Seal(nonce[:], b, &nonce, pub, pri)
return msgData, nil
return s.privateKey().SealTo(mkey, b), nil
}
// filterInvalidIPv6Endpoints removes invalid IPv6 endpoints from eps,