safer dial timeout handling, in case it was used with a nil context or a context that had no timeout set

This commit is contained in:
Arceliar 2019-10-21 20:47:50 -05:00
parent eccd9a348f
commit 681c8ca6f9

View File

@ -19,15 +19,12 @@ type Dialer struct {
// Dial opens a session to the given node. The first paramter should be "nodeid" // Dial opens a session to the given node. The first paramter should be "nodeid"
// and the second parameter should contain a hexadecimal representation of the // and the second parameter should contain a hexadecimal representation of the
// target node ID. Internally, it uses DialContext with a 6-second timeout. // target node ID. It uses DialContext internally.
func (d *Dialer) Dial(network, address string) (net.Conn, error) { func (d *Dialer) Dial(network, address string) (net.Conn, error) {
const timeout = 6 * time.Second return d.DialContext(nil, network, address)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return d.DialContext(ctx, network, address)
} }
// DialContext is used internally by Dial, and should only be used with a context that includes a timeout. // DialContext is used internally by Dial, and should only be used with a context that includes a timeout. It uses DialByNodeIDandMask internally.
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
var nodeID crypto.NodeID var nodeID crypto.NodeID
var nodeMask crypto.NodeID var nodeMask crypto.NodeID
@ -66,7 +63,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
} }
// DialByNodeIDandMask opens a session to the given node based on raw // DialByNodeIDandMask opens a session to the given node based on raw
// NodeID parameters. // NodeID parameters. If ctx is nil or has no timeout, then a default timeout of 6 seconds will apply, beginning *after* the search finishes.
func (d *Dialer) DialByNodeIDandMask(ctx context.Context, nodeID, nodeMask *crypto.NodeID) (net.Conn, error) { func (d *Dialer) DialByNodeIDandMask(ctx context.Context, nodeID, nodeMask *crypto.NodeID) (net.Conn, error) {
conn := newConn(d.core, nodeID, nodeMask, nil) conn := newConn(d.core, nodeID, nodeMask, nil)
if err := conn.search(); err != nil { if err := conn.search(); err != nil {
@ -75,10 +72,19 @@ func (d *Dialer) DialByNodeIDandMask(ctx context.Context, nodeID, nodeMask *cryp
return nil, err return nil, err
} }
conn.session.setConn(nil, conn) conn.session.setConn(nil, conn)
var c context.Context
var cancel context.CancelFunc
const timeout = 6 * time.Second
if ctx != nil {
c, cancel = context.WithTimeout(ctx, timeout)
} else {
c, cancel = context.WithTimeout(context.Background(), timeout)
}
defer cancel()
select { select {
case <-conn.session.init: case <-conn.session.init:
return conn, nil return conn, nil
case <-ctx.Done(): case <-c.Done():
conn.Close() conn.Close()
return nil, errors.New("session handshake timeout") return nil, errors.New("session handshake timeout")
} }