mirror of
https://github.com/tailscale/tailscale.git
synced 2025-01-07 08:07:42 +00:00
derp: clean up derphttp client code, use contexts
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
parent
cdc10b74f1
commit
752146a70f
@ -97,10 +97,13 @@ func (s *Server) isClosed() bool {
|
|||||||
return s.closed
|
return s.closed
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accept adds a new connection to the server.
|
// Accept adds a new connection to the server and serves it.
|
||||||
|
//
|
||||||
// The provided bufio ReadWriter must be already connected to nc.
|
// The provided bufio ReadWriter must be already connected to nc.
|
||||||
// Accept blocks until the Server is closed or the connection closes
|
// Accept blocks until the Server is closed or the connection closes
|
||||||
// on its own.
|
// on its own.
|
||||||
|
//
|
||||||
|
// Accept closes nc.
|
||||||
func (s *Server) Accept(nc net.Conn, brw *bufio.ReadWriter) {
|
func (s *Server) Accept(nc net.Conn, brw *bufio.ReadWriter) {
|
||||||
closed := make(chan struct{})
|
closed := make(chan struct{})
|
||||||
|
|
||||||
|
@ -12,16 +12,18 @@
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"tailscale.com/derp"
|
"tailscale.com/derp"
|
||||||
"tailscale.com/types/key"
|
"tailscale.com/types/key"
|
||||||
@ -37,15 +39,15 @@
|
|||||||
type Client struct {
|
type Client struct {
|
||||||
privateKey key.Private
|
privateKey key.Private
|
||||||
logf logger.Logf
|
logf logger.Logf
|
||||||
closed chan struct{}
|
|
||||||
url *url.URL
|
url *url.URL
|
||||||
resp *http.Response
|
|
||||||
|
|
||||||
netConnMu sync.Mutex
|
ctx context.Context // closed via cancelCtx in Client.Close
|
||||||
netConn net.Conn
|
cancelCtx context.CancelFunc
|
||||||
|
|
||||||
clientMu sync.Mutex
|
mu sync.Mutex
|
||||||
client *derp.Client
|
closed bool
|
||||||
|
netConn io.Closer
|
||||||
|
client *derp.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient returns a new DERP-over-HTTP client. It connects lazily.
|
// NewClient returns a new DERP-over-HTTP client. It connects lazily.
|
||||||
@ -55,12 +57,16 @@ func NewClient(privateKey key.Private, serverURL string, logf logger.Logf) (*Cli
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("derphttp.NewClient: %v", err)
|
return nil, fmt.Errorf("derphttp.NewClient: %v", err)
|
||||||
}
|
}
|
||||||
|
if urlPort(u) == "" {
|
||||||
|
return nil, fmt.Errorf("derphttp.NewClient: invalid URL scheme %q", u.Scheme)
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
c := &Client{
|
c := &Client{
|
||||||
privateKey: privateKey,
|
privateKey: privateKey,
|
||||||
logf: logf,
|
logf: logf,
|
||||||
url: u,
|
url: u,
|
||||||
closed: make(chan struct{}),
|
ctx: ctx,
|
||||||
|
cancelCtx: cancel,
|
||||||
}
|
}
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
@ -72,71 +78,119 @@ func (c *Client) Connect(ctx context.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) connect(ctx context.Context, caller string) (client *derp.Client, err error) {
|
func urlPort(u *url.URL) string {
|
||||||
// TODO: use ctx for TCP+TLS+HTTP below
|
if p := u.Port(); p != "" {
|
||||||
select {
|
return p
|
||||||
case <-c.closed:
|
|
||||||
return nil, ErrClientClosed
|
|
||||||
default:
|
|
||||||
}
|
}
|
||||||
|
switch u.Scheme {
|
||||||
|
case "https":
|
||||||
|
return "443"
|
||||||
|
case "http":
|
||||||
|
return "80"
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
c.clientMu.Lock()
|
func (c *Client) connect(ctx context.Context, caller string) (client *derp.Client, err error) {
|
||||||
defer c.clientMu.Unlock()
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
if c.closed {
|
||||||
|
return nil, ErrClientClosed
|
||||||
|
}
|
||||||
if c.client != nil {
|
if c.client != nil {
|
||||||
return c.client, nil
|
return c.client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
c.logf("%s: connecting", caller)
|
c.logf("%s: connecting to %v", caller, c.url)
|
||||||
|
|
||||||
var netConn net.Conn
|
// timeout is the fallback maximum time (if ctx doesn't limit
|
||||||
|
// it further) to do all of: DNS + TCP + TLS + HTTP Upgrade +
|
||||||
|
// DERP upgrade.
|
||||||
|
const timeout = 10 * time.Second
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Printf("XXXX normal")
|
||||||
|
// Either timeout fired (handled below), or
|
||||||
|
// we're returning via the defer cancel()
|
||||||
|
// below.
|
||||||
|
case <-c.ctx.Done():
|
||||||
|
log.Printf("XXXX dead2")
|
||||||
|
// Propagate a Client.Close call into
|
||||||
|
// cancelling this context.
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
var tcpConn net.Conn
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("%s connect: %v", caller, err)
|
if ctx.Err() != nil {
|
||||||
if netConn != nil {
|
err = fmt.Errorf("%v: %v", ctx.Err(), err)
|
||||||
netConn.Close()
|
}
|
||||||
|
err = fmt.Errorf("%s connect to %v: %v", caller, c.url, err)
|
||||||
|
if tcpConn != nil {
|
||||||
|
go tcpConn.Close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if c.url.Scheme == "https" {
|
var d net.Dialer
|
||||||
port := c.url.Port()
|
log.Printf("Dialing: %q", net.JoinHostPort(c.url.Hostname(), urlPort(c.url)))
|
||||||
if port == "" {
|
tcpConn, err = d.DialContext(ctx, "tcp", net.JoinHostPort(c.url.Hostname(), urlPort(c.url)))
|
||||||
port = "443"
|
|
||||||
}
|
|
||||||
config := &tls.Config{}
|
|
||||||
var tlsConn *tls.Conn
|
|
||||||
tlsConn, err = tls.Dial("tcp", net.JoinHostPort(c.url.Host, port), config)
|
|
||||||
if tlsConn != nil {
|
|
||||||
netConn = tlsConn
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
netConn, err = net.Dial("tcp", c.url.Host)
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.netConnMu.Lock()
|
// Now that we have a TCP connection, force close it.
|
||||||
c.netConn = netConn
|
done := make(chan struct{})
|
||||||
c.netConnMu.Unlock()
|
defer close(done)
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
// Normal path. Upgrade occurred in time.
|
||||||
|
case <-ctx.Done():
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
// Normal path. Upgrade occurred in time.
|
||||||
|
// But the ctx.Done() is also done because
|
||||||
|
// the "defer cancel()" above scheduled
|
||||||
|
// before this goroutine.
|
||||||
|
default:
|
||||||
|
// The TLS or HTTP or DERP exchanges didn't complete
|
||||||
|
// in time. Force close the TCP connection to force
|
||||||
|
// them to fail quickly.
|
||||||
|
tcpConn.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
conn := bufio.NewReadWriter(bufio.NewReader(netConn), bufio.NewWriter(netConn))
|
var httpConn net.Conn // a TCP conn or a TLS conn; what we speak HTTP to
|
||||||
|
if c.url.Scheme == "https" {
|
||||||
|
httpConn = tls.Client(tcpConn, &tls.Config{ServerName: c.url.Host})
|
||||||
|
} else {
|
||||||
|
httpConn = tcpConn
|
||||||
|
}
|
||||||
|
|
||||||
|
brw := bufio.NewReadWriter(bufio.NewReader(httpConn), bufio.NewWriter(httpConn))
|
||||||
|
|
||||||
req, err := http.NewRequest("GET", c.url.String(), nil)
|
req, err := http.NewRequest("GET", c.url.String(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
req.Header.Set("Upgrade", "WebSocket")
|
req.Header.Set("Upgrade", "DERP")
|
||||||
req.Header.Set("Connection", "Upgrade")
|
req.Header.Set("Connection", "Upgrade")
|
||||||
if err := req.Write(conn); err != nil {
|
|
||||||
|
if err := req.Write(brw); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := conn.Flush(); err != nil {
|
if err := brw.Flush(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := http.ReadResponse(conn.Reader, req)
|
resp, err := http.ReadResponse(brw.Reader, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -145,14 +199,14 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien
|
|||||||
resp.Body.Close()
|
resp.Body.Close()
|
||||||
return nil, fmt.Errorf("GET failed: %v: %s", err, b)
|
return nil, fmt.Errorf("GET failed: %v: %s", err, b)
|
||||||
}
|
}
|
||||||
resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
|
|
||||||
|
|
||||||
derpClient, err := derp.NewClient(c.privateKey, netConn, conn, c.logf)
|
derpClient, err := derp.NewClient(c.privateKey, httpConn, brw, c.logf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
c.resp = resp
|
|
||||||
c.client = derpClient
|
c.client = derpClient
|
||||||
|
c.netConn = tcpConn
|
||||||
return c.client, nil
|
return c.client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -162,7 +216,7 @@ func (c *Client) Send(dstKey key.Public, b []byte) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := client.Send(dstKey, b); err != nil {
|
if err := client.Send(dstKey, b); err != nil {
|
||||||
c.close()
|
c.Close()
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -174,7 +228,7 @@ func (c *Client) Recv(b []byte) (derp.ReceivedMessage, error) {
|
|||||||
}
|
}
|
||||||
m, err := client.Recv(b)
|
m, err := client.Recv(b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.close()
|
c.Close()
|
||||||
}
|
}
|
||||||
return m, err
|
return m, err
|
||||||
}
|
}
|
||||||
@ -182,35 +236,20 @@ func (c *Client) Recv(b []byte) (derp.ReceivedMessage, error) {
|
|||||||
// Close closes the client. It will not automatically reconnect after
|
// Close closes the client. It will not automatically reconnect after
|
||||||
// being closed.
|
// being closed.
|
||||||
func (c *Client) Close() error {
|
func (c *Client) Close() error {
|
||||||
select {
|
c.cancelCtx() // not in lock, so it can cancel Connect, which holds mu
|
||||||
case <-c.closed:
|
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
if c.closed {
|
||||||
return ErrClientClosed
|
return ErrClientClosed
|
||||||
default:
|
|
||||||
}
|
}
|
||||||
close(c.closed)
|
c.closed = true
|
||||||
c.close()
|
if c.netConn != nil {
|
||||||
|
c.netConn.Close()
|
||||||
|
c.netConn = nil
|
||||||
|
}
|
||||||
|
c.client = nil
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) close() {
|
|
||||||
c.netConnMu.Lock()
|
|
||||||
netConn := c.netConn
|
|
||||||
c.netConnMu.Unlock()
|
|
||||||
|
|
||||||
if netConn != nil {
|
|
||||||
netConn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
c.clientMu.Lock()
|
|
||||||
defer c.clientMu.Unlock()
|
|
||||||
if c.client == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.resp = nil
|
|
||||||
c.client = nil
|
|
||||||
c.netConnMu.Lock()
|
|
||||||
c.netConn = nil
|
|
||||||
c.netConnMu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
var ErrClientClosed = errors.New("derphttp.Client closed")
|
var ErrClientClosed = errors.New("derphttp.Client closed")
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
package derphttp
|
package derphttp
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"tailscale.com/derp"
|
"tailscale.com/derp"
|
||||||
@ -12,11 +13,11 @@
|
|||||||
|
|
||||||
func Handler(s *derp.Server) http.Handler {
|
func Handler(s *derp.Server) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Header.Get("Upgrade") != "WebSocket" {
|
if p := r.Header.Get("Upgrade"); p != "WebSocket" && p != "DERP" {
|
||||||
http.Error(w, "DERP requires connection upgrade", http.StatusUpgradeRequired)
|
http.Error(w, "DERP requires connection upgrade", http.StatusUpgradeRequired)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
w.Header().Set("Upgrade", "WebSocket")
|
w.Header().Set("Upgrade", "DERP")
|
||||||
w.Header().Set("Connection", "Upgrade")
|
w.Header().Set("Connection", "Upgrade")
|
||||||
w.WriteHeader(http.StatusSwitchingProtocols)
|
w.WriteHeader(http.StatusSwitchingProtocols)
|
||||||
|
|
||||||
@ -27,6 +28,7 @@ func Handler(s *derp.Server) http.Handler {
|
|||||||
}
|
}
|
||||||
netConn, conn, err := h.Hijack()
|
netConn, conn, err := h.Hijack()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Printf("Hijack failed: %v", err)
|
||||||
http.Error(w, "HTTP does not support general TCP support", 500)
|
http.Error(w, "HTTP does not support general TCP support", 500)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user