diff --git a/src/core/tcp.go b/src/core/tcp.go index 5c3e506b..4dc36a08 100644 --- a/src/core/tcp.go +++ b/src/core/tcp.go @@ -54,7 +54,7 @@ type TcpListener struct { } type TcpUpgrade struct { - upgrade func(c net.Conn) (net.Conn, error) + upgrade func(c net.Conn, o *tcpOptions) (net.Conn, error) name string } @@ -361,7 +361,7 @@ func (t *tcp) handler(sock net.Conn, incoming bool, options tcpOptions) chan str var upgraded bool if options.upgrade != nil { var err error - if sock, err = options.upgrade.upgrade(sock); err != nil { + if sock, err = options.upgrade.upgrade(sock, &options); err != nil { t.links.core.log.Errorln("TCP handler upgrade failed:", err) return nil } diff --git a/src/core/tls.go b/src/core/tls.go index 4fdcf997..4c25225b 100644 --- a/src/core/tls.go +++ b/src/core/tls.go @@ -9,6 +9,7 @@ import ( "crypto/x509/pkix" "encoding/hex" "encoding/pem" + "errors" "log" "math/big" "net" @@ -76,16 +77,47 @@ func (t *tcptls) init(tcp *tcp) { } } -func (t *tcptls) upgradeListener(c net.Conn) (net.Conn, error) { - conn := tls.Server(c, t.config) +func (t *tcptls) configForOptions(options *tcpOptions) *tls.Config { + config := *t.config + config.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error { + if len(rawCerts) != 1 { + return errors.New("tls not exactly 1 cert") + } + cert, err := x509.ParseCertificate(rawCerts[0]) + if err != nil { + return errors.New("tls failed to parse cert") + } + if cert.PublicKeyAlgorithm != x509.Ed25519 { + return errors.New("tls wrong cert algorithm") + } + pk := cert.PublicKey.(ed25519.PublicKey) + var key keyArray + copy(key[:], pk) + // If options does not have a pinned key, then pin one now + if options.pinnedEd25519Keys == nil { + options.pinnedEd25519Keys = make(map[keyArray]struct{}) + options.pinnedEd25519Keys[key] = struct{}{} + } + if _, isIn := options.pinnedEd25519Keys[key]; !isIn { + return errors.New("tls key does not match pinned key") + } + return nil + } + return &config +} + +func (t *tcptls) upgradeListener(c net.Conn, options *tcpOptions) (net.Conn, error) { + config := t.configForOptions(options) + conn := tls.Server(c, config) if err := conn.Handshake(); err != nil { return c, err } return conn, nil } -func (t *tcptls) upgradeDialer(c net.Conn) (net.Conn, error) { - conn := tls.Client(c, t.config) +func (t *tcptls) upgradeDialer(c net.Conn, options *tcpOptions) (net.Conn, error) { + config := t.configForOptions(options) + conn := tls.Client(c, config) if err := conn.Handshake(); err != nil { return c, err }