net/tlsdial: add package for TLS dials, and make DERP & controlclient use it

This will do the iOS-optimized cert checking in a following change.
This commit is contained in:
Brad Fitzpatrick 2020-04-25 13:24:53 -07:00
parent d427fc023e
commit b6fa5a69be
3 changed files with 36 additions and 13 deletions

View File

@ -26,6 +26,7 @@
"github.com/tailscale/wireguard-go/wgcfg" "github.com/tailscale/wireguard-go/wgcfg"
"golang.org/x/crypto/nacl/box" "golang.org/x/crypto/nacl/box"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"tailscale.com/net/tlsdial"
"tailscale.com/tailcfg" "tailscale.com/tailcfg"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/version" "tailscale.com/version"
@ -93,7 +94,6 @@ type Direct struct {
type Options struct { type Options struct {
Persist Persist // initial persistent data Persist Persist // initial persistent data
HTTPC *http.Client // HTTP client used to talk to tailcontrol
ServerURL string // URL of the tailcontrol server ServerURL string // URL of the tailcontrol server
AuthKey string // optional node auth key for auto registration AuthKey string // optional node auth key for auto registration
TimeNow func() time.Time // time.Now implementation used by Client TimeNow func() time.Time // time.Now implementation used by Client
@ -114,9 +114,6 @@ func NewDirect(opts Options) (*Direct, error) {
return nil, errors.New("controlclient.New: no server URL specified") return nil, errors.New("controlclient.New: no server URL specified")
} }
opts.ServerURL = strings.TrimRight(opts.ServerURL, "/") opts.ServerURL = strings.TrimRight(opts.ServerURL, "/")
if opts.HTTPC == nil {
opts.HTTPC = http.DefaultClient
}
if opts.TimeNow == nil { if opts.TimeNow == nil {
opts.TimeNow = time.Now opts.TimeNow = time.Now
} }
@ -125,8 +122,14 @@ func NewDirect(opts Options) (*Direct, error) {
// TODO(bradfitz): ... but then it shouldn't be in Options. // TODO(bradfitz): ... but then it shouldn't be in Options.
opts.Logf = log.Printf opts.Logf = log.Printf
} }
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.ForceAttemptHTTP2 = true
tr.TLSClientConfig = tlsdial.Config("", tr.TLSClientConfig)
httpc := &http.Client{Transport: tr}
c := &Direct{ c := &Direct{
httpc: opts.HTTPC, httpc: httpc,
serverURL: opts.ServerURL, serverURL: opts.ServerURL,
timeNow: opts.TimeNow, timeNow: opts.TimeNow,
logf: opts.Logf, logf: opts.Logf,

View File

@ -26,6 +26,7 @@
"tailscale.com/derp" "tailscale.com/derp"
"tailscale.com/net/dnscache" "tailscale.com/net/dnscache"
"tailscale.com/net/tlsdial"
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/logger" "tailscale.com/types/logger"
) )
@ -37,8 +38,8 @@
// Send/Recv will completely re-establish the connection (unless Close // Send/Recv will completely re-establish the connection (unless Close
// has been called). // has been called).
type Client struct { type Client struct {
TLSConfig *tls.Config // for sever connection, optional, nil means default TLSConfig *tls.Config // optional; nil means default
DNSCache *dnscache.Resolver // optional; if nil, no caching DNSCache *dnscache.Resolver // optional; nil means no caching
privateKey key.Private privateKey key.Private
logf logger.Logf logf logger.Logf
@ -182,12 +183,7 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien
var httpConn net.Conn // a TCP conn or a TLS conn; what we speak HTTP to var httpConn net.Conn // a TCP conn or a TLS conn; what we speak HTTP to
if c.url.Scheme == "https" { if c.url.Scheme == "https" {
tlsConfig := &tls.Config{} httpConn = tls.Client(tcpConn, tlsdial.Config(c.url.Host, c.TLSConfig))
if c.TLSConfig != nil {
tlsConfig = c.TLSConfig.Clone()
}
tlsConfig.ServerName = c.url.Host
httpConn = tls.Client(tcpConn, tlsConfig)
} else { } else {
httpConn = tcpConn httpConn = tcpConn
} }

24
net/tlsdial/tlsdial.go Normal file
View File

@ -0,0 +1,24 @@
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package tlsdial sets up a tls.Config for x509 validation, using
// a memory-optimized path for iOS.
package tlsdial
import "crypto/tls"
// Config returns a tls.Config for dialing the given host.
// If base is non-nil, it's cloned as the base config before
// being configured and returned.
func Config(host string, base *tls.Config) *tls.Config {
var conf *tls.Config
if base == nil {
conf = new(tls.Config)
} else {
conf = base.Clone()
}
conf.ServerName = host
return conf
}