diff --git a/src/address/address.go b/src/address/address.go index 3960b783..eba61708 100644 --- a/src/address/address.go +++ b/src/address/address.go @@ -2,7 +2,11 @@ // Of particular importance are the functions used to derive addresses or subnets from a NodeID, or to get the NodeID and bitmask of the bits visible from an address, which is needed for DHT searches. package address -import "github.com/yggdrasil-network/yggdrasil-go/src/crypto" +import ( + "fmt" + + "github.com/yggdrasil-network/yggdrasil-go/src/crypto" +) // Address represents an IPv6 address in the yggdrasil address range. type Address [16]byte @@ -128,6 +132,13 @@ func (a *Address) GetNodeIDandMask() (*crypto.NodeID, *crypto.NodeID) { return &nid, &mask } +// GetNodeIDLengthString returns a string representation of the known bits of the NodeID, along with the number of known bits, for use with yggdrasil.Dialer's Dial and DialContext functions. +func (a *Address) GetNodeIDLengthString() string { + nid, mask := a.GetNodeIDandMask() + l := mask.PrefixLength() + return fmt.Sprintf("%s/%d", nid.String(), l) +} + // GetNodeIDandMask returns two *NodeID. // The first is a NodeID with all the bits known from the Subnet set to their correct values. // The second is a bitmask with 1 bit set for each bit that was known from the Subnet. @@ -156,3 +167,10 @@ func (s *Subnet) GetNodeIDandMask() (*crypto.NodeID, *crypto.NodeID) { } return &nid, &mask } + +// GetNodeIDLengthString returns a string representation of the known bits of the NodeID, along with the number of known bits, for use with yggdrasil.Dialer's Dial and DialContext functions. +func (s *Subnet) GetNodeIDLengthString() string { + nid, mask := s.GetNodeIDandMask() + l := mask.PrefixLength() + return fmt.Sprintf("%s/%d", nid.String(), l) +} diff --git a/src/tuntap/conn.go b/src/tuntap/conn.go index 6db46b23..207cd14f 100644 --- a/src/tuntap/conn.go +++ b/src/tuntap/conn.go @@ -93,7 +93,7 @@ func (s *tunConn) _read(bs []byte) (err error) { skip = true } else if key, err := s.tun.ckr.getPublicKeyForAddress(srcAddr, addrlen); err == nil { srcNodeID := crypto.GetNodeID(&key) - if s.conn.RemoteAddr() == *srcNodeID { + if *s.conn.RemoteAddr().(*crypto.NodeID) == *srcNodeID { // This is the one allowed CKR case, where source and destination addresses are both good } else { // The CKR key associated with this address doesn't match the sender's NodeID @@ -170,7 +170,7 @@ func (s *tunConn) _write(bs []byte) (err error) { skip = true } else if key, err := s.tun.ckr.getPublicKeyForAddress(dstAddr, addrlen); err == nil { dstNodeID := crypto.GetNodeID(&key) - if s.conn.RemoteAddr() == *dstNodeID { + if *s.conn.RemoteAddr().(*crypto.NodeID) == *dstNodeID { // This is the one allowed CKR case, where source and destination addresses are both good } else { // The CKR key associated with this address doesn't match the sender's NodeID diff --git a/src/tuntap/iface.go b/src/tuntap/iface.go index 0da99631..3d788b1a 100644 --- a/src/tuntap/iface.go +++ b/src/tuntap/iface.go @@ -9,6 +9,7 @@ import ( "github.com/yggdrasil-network/yggdrasil-go/src/address" "github.com/yggdrasil-network/yggdrasil-go/src/crypto" "github.com/yggdrasil-network/yggdrasil-go/src/util" + "github.com/yggdrasil-network/yggdrasil-go/src/yggdrasil" "github.com/Arceliar/phony" ) @@ -225,7 +226,7 @@ func (tun *TunAdapter) _handlePacket(recvd []byte, err error) { return } // Do we have an active connection for this node address? - var dstNodeID, dstNodeIDMask *crypto.NodeID + var dstString string session, isIn := tun.addrToConn[dstAddr] if !isIn || session == nil { session, isIn = tun.subnetToConn[dstSnet] @@ -233,9 +234,9 @@ func (tun *TunAdapter) _handlePacket(recvd []byte, err error) { // Neither an address nor a subnet mapping matched, therefore populate // the node ID and mask to commence a search if dstAddr.IsValid() { - dstNodeID, dstNodeIDMask = dstAddr.GetNodeIDandMask() + dstString = dstAddr.GetNodeIDLengthString() } else { - dstNodeID, dstNodeIDMask = dstSnet.GetNodeIDandMask() + dstString = dstSnet.GetNodeIDLengthString() } } } @@ -243,27 +244,27 @@ func (tun *TunAdapter) _handlePacket(recvd []byte, err error) { if !isIn || session == nil { // Check we haven't been given empty node ID, really this shouldn't ever // happen but just to be sure... - if dstNodeID == nil || dstNodeIDMask == nil { - panic("Given empty dstNodeID and dstNodeIDMask - this shouldn't happen") + if dstString == "" { + panic("Given empty dstString - this shouldn't happen") } - _, known := tun.dials[*dstNodeID] - tun.dials[*dstNodeID] = append(tun.dials[*dstNodeID], bs) - for len(tun.dials[*dstNodeID]) > 32 { - util.PutBytes(tun.dials[*dstNodeID][0]) - tun.dials[*dstNodeID] = tun.dials[*dstNodeID][1:] + _, known := tun.dials[dstString] + tun.dials[dstString] = append(tun.dials[dstString], bs) + for len(tun.dials[dstString]) > 32 { + util.PutBytes(tun.dials[dstString][0]) + tun.dials[dstString] = tun.dials[dstString][1:] } if !known { go func() { - conn, err := tun.dialer.DialByNodeIDandMask(dstNodeID, dstNodeIDMask) + conn, err := tun.dialer.Dial("nodeid", dstString) tun.Act(nil, func() { - packets := tun.dials[*dstNodeID] - delete(tun.dials, *dstNodeID) + packets := tun.dials[dstString] + delete(tun.dials, dstString) if err != nil { return } // We've been given a connection so prepare the session wrapper var tc *tunConn - if tc, err = tun._wrap(conn); err != nil { + if tc, err = tun._wrap(conn.(*yggdrasil.Conn)); err != nil { // Something went wrong when storing the connection, typically that // something already exists for this address or subnet tun.log.Debugln("TUN/TAP iface wrap:", err) diff --git a/src/tuntap/tun.go b/src/tuntap/tun.go index 74d055ee..5d77ecab 100644 --- a/src/tuntap/tun.go +++ b/src/tuntap/tun.go @@ -52,7 +52,7 @@ type TunAdapter struct { //mutex sync.RWMutex // Protects the below addrToConn map[address.Address]*tunConn subnetToConn map[address.Subnet]*tunConn - dials map[crypto.NodeID][][]byte // Buffer of packets to send after dialing finishes + dials map[string][][]byte // Buffer of packets to send after dialing finishes isOpen bool } @@ -117,7 +117,7 @@ func (tun *TunAdapter) Init(config *config.NodeState, log *log.Logger, listener tun.dialer = dialer tun.addrToConn = make(map[address.Address]*tunConn) tun.subnetToConn = make(map[address.Subnet]*tunConn) - tun.dials = make(map[crypto.NodeID][][]byte) + tun.dials = make(map[string][][]byte) tun.writer.tun = tun tun.reader.tun = tun } @@ -219,7 +219,7 @@ func (tun *TunAdapter) handler() error { return err } phony.Block(tun, func() { - if _, err := tun._wrap(conn); err != nil { + if _, err := tun._wrap(conn.(*yggdrasil.Conn)); err != nil { // Something went wrong when storing the connection, typically that // something already exists for this address or subnet tun.log.Debugln("TUN/TAP handler wrap:", err) @@ -237,9 +237,9 @@ func (tun *TunAdapter) _wrap(conn *yggdrasil.Conn) (c *tunConn, err error) { } c = &s // Get the remote address and subnet of the other side - remoteNodeID := conn.RemoteAddr() - s.addr = *address.AddrForNodeID(&remoteNodeID) - s.snet = *address.SubnetForNodeID(&remoteNodeID) + remoteNodeID := conn.RemoteAddr().(*crypto.NodeID) + s.addr = *address.AddrForNodeID(remoteNodeID) + s.snet = *address.SubnetForNodeID(remoteNodeID) // Work out if this is already a destination we already know about atc, aok := tun.addrToConn[s.addr] stc, sok := tun.subnetToConn[s.snet] diff --git a/src/yggdrasil/conn.go b/src/yggdrasil/conn.go index bb5964b6..67426f4f 100644 --- a/src/yggdrasil/conn.go +++ b/src/yggdrasil/conn.go @@ -3,6 +3,7 @@ package yggdrasil import ( "errors" "fmt" + "net" "time" "github.com/yggdrasil-network/yggdrasil-go/src/crypto" @@ -348,14 +349,14 @@ func (c *Conn) Close() (err error) { // LocalAddr returns the complete node ID of the local side of the connection. // This is always going to return your own node's node ID. -func (c *Conn) LocalAddr() crypto.NodeID { - return *crypto.GetNodeID(&c.core.boxPub) +func (c *Conn) LocalAddr() net.Addr { + return crypto.GetNodeID(&c.core.boxPub) } // RemoteAddr returns the complete node ID of the remote side of the connection. -func (c *Conn) RemoteAddr() crypto.NodeID { +func (c *Conn) RemoteAddr() net.Addr { // RemoteAddr is set during the dial or accept, and isn't changed, so it's safe to access directly - return *c.nodeID + return c.nodeID } // SetDeadline is equivalent to calling both SetReadDeadline and diff --git a/src/yggdrasil/dialer.go b/src/yggdrasil/dialer.go index 04410855..47a68813 100644 --- a/src/yggdrasil/dialer.go +++ b/src/yggdrasil/dialer.go @@ -1,8 +1,10 @@ package yggdrasil import ( + "context" "encoding/hex" "errors" + "net" "strconv" "strings" "time" @@ -15,12 +17,15 @@ type Dialer struct { core *Core } -// TODO DialContext that allows timeouts/cancellation, Dial should just call this with no timeout set in the context - // Dial opens a session to the given node. The first paramter should be "nodeid" // and the second parameter should contain a hexadecimal representation of the -// target node ID. -func (d *Dialer) Dial(network, address string) (*Conn, error) { +// target node ID. It uses DialContext internally. +func (d *Dialer) Dial(network, address string) (net.Conn, error) { + return d.DialContext(nil, network, address) +} + +// DialContext is used internally by Dial, and should only be used with a context that includes a timeout. It uses DialByNodeIDandMask internally. +func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { var nodeID crypto.NodeID var nodeMask crypto.NodeID // Process @@ -28,7 +33,7 @@ func (d *Dialer) Dial(network, address string) (*Conn, error) { case "nodeid": // A node ID was provided - we don't need to do anything special with it if tokens := strings.Split(address, "/"); len(tokens) == 2 { - len, err := strconv.Atoi(tokens[1]) + l, err := strconv.Atoi(tokens[1]) if err != nil { return nil, err } @@ -37,7 +42,7 @@ func (d *Dialer) Dial(network, address string) (*Conn, error) { return nil, err } copy(nodeID[:], dest) - for idx := 0; idx < len; idx++ { + for idx := 0; idx < l; idx++ { nodeMask[idx/8] |= 0x80 >> byte(idx%8) } } else { @@ -50,7 +55,7 @@ func (d *Dialer) Dial(network, address string) (*Conn, error) { nodeMask[i] = 0xFF } } - return d.DialByNodeIDandMask(&nodeID, &nodeMask) + return d.DialByNodeIDandMask(ctx, &nodeID, &nodeMask) default: // An unexpected address type was given, so give up return nil, errors.New("unexpected address type") @@ -58,20 +63,25 @@ func (d *Dialer) Dial(network, address string) (*Conn, error) { } // DialByNodeIDandMask opens a session to the given node based on raw -// NodeID parameters. -func (d *Dialer) DialByNodeIDandMask(nodeID, nodeMask *crypto.NodeID) (*Conn, error) { +// NodeID parameters. If ctx is nil or has no timeout, then a default timeout of 6 seconds will apply, beginning *after* the search finishes. +func (d *Dialer) DialByNodeIDandMask(ctx context.Context, nodeID, nodeMask *crypto.NodeID) (net.Conn, error) { conn := newConn(d.core, nodeID, nodeMask, nil) if err := conn.search(); err != nil { + // TODO: make searches take a context, so they can be cancelled early conn.Close() return nil, err } conn.session.setConn(nil, conn) - t := time.NewTimer(6 * time.Second) // TODO use a context instead - defer t.Stop() + var cancel context.CancelFunc + if ctx == nil { + ctx = context.Background() + } + ctx, cancel = context.WithTimeout(ctx, 6*time.Second) + defer cancel() select { case <-conn.session.init: return conn, nil - case <-t.C: + case <-ctx.Done(): conn.Close() return nil, errors.New("session handshake timeout") } diff --git a/src/yggdrasil/listener.go b/src/yggdrasil/listener.go index fec543f4..63830970 100644 --- a/src/yggdrasil/listener.go +++ b/src/yggdrasil/listener.go @@ -13,7 +13,7 @@ type Listener struct { } // Accept blocks until a new incoming session is received -func (l *Listener) Accept() (*Conn, error) { +func (l *Listener) Accept() (net.Conn, error) { select { case c, ok := <-l.conn: if !ok {