From a413fa4f85564a62a28ed140b530d57cf8a63c1e Mon Sep 17 00:00:00 2001 From: Maisem Ali Date: Thu, 3 Nov 2022 12:17:16 +0500 Subject: [PATCH] control/controlclient: export NoiseClient This allows reusing the NoiseClient in other repos without having to reimplement the earlyPayload logic. Signed-off-by: Maisem Ali --- control/controlclient/direct.go | 27 +++++------------ control/controlclient/noise.go | 47 +++++++++++++++++++++-------- control/controlclient/noise_test.go | 2 +- 3 files changed, 43 insertions(+), 33 deletions(-) diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index db3fe9d63..020397142 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -84,8 +84,8 @@ type Direct struct { serverKey key.MachinePublic // original ("legacy") nacl crypto_box-based public key serverNoiseKey key.MachinePublic - sfGroup singleflight.Group[struct{}, *noiseClient] // protects noiseClient creation. - noiseClient *noiseClient + sfGroup singleflight.Group[struct{}, *NoiseClient] // protects noiseClient creation. + noiseClient *NoiseClient persist persist.Persist authKey string @@ -262,7 +262,7 @@ func NewDirect(opts Options) (*Direct, error) { } } if opts.NoiseTestClient != nil { - c.noiseClient = &noiseClient{ + c.noiseClient = &NoiseClient{ Client: opts.NoiseTestClient, } c.serverNoiseKey = key.NewMachine().Public() // prevent early error before hitting test client @@ -1470,7 +1470,7 @@ func sleepAsRequested(ctx context.Context, logf logger.Logf, timeoutReset chan<- } // getNoiseClient returns the noise client, creating one if one doesn't exist. -func (c *Direct) getNoiseClient() (*noiseClient, error) { +func (c *Direct) getNoiseClient() (*NoiseClient, error) { c.mu.Lock() serverNoiseKey := c.serverNoiseKey nc := c.noiseClient @@ -1485,13 +1485,13 @@ func (c *Direct) getNoiseClient() (*noiseClient, error) { if c.dialPlan != nil { dp = c.dialPlan.Load } - nc, err, _ := c.sfGroup.Do(struct{}{}, func() (*noiseClient, error) { + nc, err, _ := c.sfGroup.Do(struct{}{}, func() (*NoiseClient, error) { k, err := c.getMachinePrivKey() if err != nil { return nil, err } c.logf("creating new noise client") - nc, err := newNoiseClient(k, serverNoiseKey, c.serverURL, c.dialer, dp) + nc, err := NewNoiseClient(k, serverNoiseKey, c.serverURL, c.dialer, dp) if err != nil { return nil, err } @@ -1618,20 +1618,7 @@ func (c *Direct) GetSingleUseNoiseRoundTripper(ctx context.Context) (http.RoundT if err != nil { return nil, nil, err } - for tries := 0; tries < 3; tries++ { - conn, err := nc.getConn(ctx) - if err != nil { - return nil, nil, err - } - earlyPayloadMaybeNil, err := conn.getEarlyPayload(ctx) - if err != nil { - return nil, nil, err - } - if conn.h2cc.ReserveNewRequest() { - return conn, earlyPayloadMaybeNil, nil - } - } - return nil, nil, errors.New("[unexpected] failed to reserve a request on a connection") + return nc.GetSingleUseRoundTripper(ctx) } // doPingerPing sends a Ping to pr.IP using pinger, and sends an http request back to diff --git a/control/controlclient/noise.go b/control/controlclient/noise.go index d7ecfcafa..884ba0375 100644 --- a/control/controlclient/noise.go +++ b/control/controlclient/noise.go @@ -35,7 +35,7 @@ type noiseConn struct { *controlbase.Conn id int - pool *noiseClient + pool *NoiseClient h2cc *http2.ClientConn readHeaderOnce sync.Once // guards init of reader field @@ -135,9 +135,9 @@ func (c *noiseConn) Close() error { return nil } -// noiseClient provides a http.Client to connect to tailcontrol over +// NoiseClient provides a http.Client to connect to tailcontrol over // the ts2021 protocol. -type noiseClient struct { +type NoiseClient struct { // Client is an HTTP client to talk to the coordination server. // It automatically makes a new Noise connection as needed. // It does not support node key proofs. To do that, call @@ -175,11 +175,11 @@ type noiseClient struct { connPool map[int]*noiseConn // active connections not yet closed; see noiseConn.Close } -// newNoiseClient returns a new noiseClient for the provided server and machine key. +// NewNoiseClient returns a new noiseClient for the provided server and machine key. // serverURL is of the form https://: (no trailing slash). // // dialPlan may be nil -func newNoiseClient(privKey key.MachinePrivate, serverPubKey key.MachinePublic, serverURL string, dialer *tsdial.Dialer, dialPlan func() *tailcfg.ControlDialPlan) (*noiseClient, error) { +func NewNoiseClient(privKey key.MachinePrivate, serverPubKey key.MachinePublic, serverURL string, dialer *tsdial.Dialer, dialPlan func() *tailcfg.ControlDialPlan) (*NoiseClient, error) { u, err := url.Parse(serverURL) if err != nil { return nil, err @@ -200,7 +200,7 @@ func newNoiseClient(privKey key.MachinePrivate, serverPubKey key.MachinePublic, httpPort = "80" httpsPort = "443" } - np := &noiseClient{ + np := &NoiseClient{ serverPubKey: serverPubKey, privKey: privKey, host: u.Hostname(), @@ -227,7 +227,30 @@ func newNoiseClient(privKey key.MachinePrivate, serverPubKey key.MachinePublic, return np, nil } -func (nc *noiseClient) getConn(ctx context.Context) (*noiseConn, error) { +// GetSingleUseRoundTripper returns a RoundTripper that can be only be used once +// (and must be used once) to make a single HTTP request over the noise channel +// to the coordination server. +// +// In addition to the RoundTripper, it returns the HTTP/2 channel's early noise +// payload, if any. +func (nc *NoiseClient) GetSingleUseRoundTripper(ctx context.Context) (http.RoundTripper, *tailcfg.EarlyNoise, error) { + for tries := 0; tries < 3; tries++ { + conn, err := nc.getConn(ctx) + if err != nil { + return nil, nil, err + } + earlyPayloadMaybeNil, err := conn.getEarlyPayload(ctx) + if err != nil { + return nil, nil, err + } + if conn.h2cc.ReserveNewRequest() { + return conn, earlyPayloadMaybeNil, nil + } + } + return nil, nil, errors.New("[unexpected] failed to reserve a request on a connection") +} + +func (nc *NoiseClient) getConn(ctx context.Context) (*noiseConn, error) { nc.mu.Lock() if last := nc.last; last != nil && last.canTakeNewRequest() { nc.mu.Unlock() @@ -242,7 +265,7 @@ func (nc *noiseClient) getConn(ctx context.Context) (*noiseConn, error) { return conn, nil } -func (nc *noiseClient) RoundTrip(req *http.Request) (*http.Response, error) { +func (nc *NoiseClient) RoundTrip(req *http.Request) (*http.Response, error) { ctx := req.Context() conn, err := nc.getConn(ctx) if err != nil { @@ -253,7 +276,7 @@ func (nc *noiseClient) RoundTrip(req *http.Request) (*http.Response, error) { // connClosed removes the connection with the provided ID from the pool // of active connections. -func (nc *noiseClient) connClosed(id int) { +func (nc *NoiseClient) connClosed(id int) { nc.mu.Lock() defer nc.mu.Unlock() conn := nc.connPool[id] @@ -267,7 +290,7 @@ func (nc *noiseClient) connClosed(id int) { // Close closes all the underlying noise connections. // It is a no-op and returns nil if the connection is already closed. -func (nc *noiseClient) Close() error { +func (nc *NoiseClient) Close() error { nc.mu.Lock() conns := nc.connPool nc.connPool = nil @@ -284,7 +307,7 @@ func (nc *noiseClient) Close() error { // dial opens a new connection to tailcontrol, fetching the server noise key // if not cached. -func (nc *noiseClient) dial() (*noiseConn, error) { +func (nc *NoiseClient) dial() (*noiseConn, error) { nc.mu.Lock() connID := nc.nextID nc.nextID++ @@ -369,7 +392,7 @@ func (nc *noiseClient) dial() (*noiseConn, error) { return ncc, nil } -func (nc *noiseClient) post(ctx context.Context, path string, body any) (*http.Response, error) { +func (nc *NoiseClient) post(ctx context.Context, path string, body any) (*http.Response, error) { jbody, err := json.Marshal(body) if err != nil { return nil, err diff --git a/control/controlclient/noise_test.go b/control/controlclient/noise_test.go index 469bc281d..9485177d6 100644 --- a/control/controlclient/noise_test.go +++ b/control/controlclient/noise_test.go @@ -75,7 +75,7 @@ func (tt noiseClientTest) run(t *testing.T) { defer hs.Close() dialer := new(tsdial.Dialer) - nc, err := newNoiseClient(clientPrivate, serverPrivate.Public(), hs.URL, dialer, nil) + nc, err := NewNoiseClient(clientPrivate, serverPrivate.Public(), hs.URL, dialer, nil) if err != nil { t.Fatal(err) }