Reusable peer lookup/dial logic

This commit is contained in:
Neil Alexander 2024-11-16 22:59:03 +00:00
parent 75d2080e53
commit 42873be09b
No known key found for this signature in database
GPG Key ID: A02A2019A2BB0944
12 changed files with 193 additions and 124 deletions

View File

@ -191,9 +191,16 @@ func main() {
// Set up the Yggdrasil node itself. // Set up the Yggdrasil node itself.
{ {
iprange := net.IPNet{
IP: net.ParseIP("200::"),
Mask: net.CIDRMask(7, 128),
}
options := []core.SetupOption{ options := []core.SetupOption{
core.NodeInfo(cfg.NodeInfo), core.NodeInfo(cfg.NodeInfo),
core.NodeInfoPrivacy(cfg.NodeInfoPrivacy), core.NodeInfoPrivacy(cfg.NodeInfoPrivacy),
core.PeerFilter(func(ip net.IP) bool {
return !iprange.Contains(ip)
}),
} }
for _, addr := range cfg.Listen { for _, addr := range cfg.Listen {
options = append(options, core.ListenAddress(addr)) options = append(options, core.ListenAddress(addr))

View File

@ -53,7 +53,15 @@ func (m *Yggdrasil) StartJSON(configjson []byte) error {
} }
// Set up the Yggdrasil node itself. // Set up the Yggdrasil node itself.
{ {
options := []core.SetupOption{} iprange := net.IPNet{
IP: net.ParseIP("200::"),
Mask: net.CIDRMask(7, 128),
}
options := []core.SetupOption{
core.PeerFilter(func(ip net.IP) bool {
return !iprange.Contains(ip)
}),
}
for _, peer := range m.config.Peers { for _, peer := range m.config.Peers {
options = append(options, core.Peer{URI: peer}) options = append(options, core.Peer{URI: peer})
} }

View File

@ -40,6 +40,7 @@ type Core struct {
tls *tls.Config // immutable after startup tls *tls.Config // immutable after startup
//_peers map[Peer]*linkInfo // configurable after startup //_peers map[Peer]*linkInfo // configurable after startup
_listeners map[ListenAddress]struct{} // configurable after startup _listeners map[ListenAddress]struct{} // configurable after startup
peerFilter func(ip net.IP) bool // immutable after startup
nodeinfo NodeInfo // immutable after startup nodeinfo NodeInfo // immutable after startup
nodeinfoPrivacy NodeInfoPrivacy // immutable after startup nodeinfoPrivacy NodeInfoPrivacy // immutable after startup
_allowedPublicKeys map[[32]byte]struct{} // configurable after startup _allowedPublicKeys map[[32]byte]struct{} // configurable after startup

View File

@ -127,6 +127,7 @@ const ErrLinkPasswordInvalid = linkError("invalid password supplied")
const ErrLinkUnrecognisedSchema = linkError("link schema unknown") const ErrLinkUnrecognisedSchema = linkError("link schema unknown")
const ErrLinkMaxBackoffInvalid = linkError("max backoff duration invalid") const ErrLinkMaxBackoffInvalid = linkError("max backoff duration invalid")
const ErrLinkSNINotSupported = linkError("SNI not supported on this link type") const ErrLinkSNINotSupported = linkError("SNI not supported on this link type")
const ErrLinkNoSuitableIPs = linkError("no suitable remote IPs")
func (l *links) add(u *url.URL, sintf string, linkType linkType) error { func (l *links) add(u *url.URL, sintf string, linkType linkType) error {
var retErr error var retErr error
@ -653,6 +654,43 @@ func (l *links) handler(linkType linkType, options linkOptions, conn net.Conn, s
return err return err
} }
func (l *links) findSuitableIP(url *url.URL, fn func(hostname string, ip net.IP, port int) (net.Conn, error)) (net.Conn, error) {
host, p, err := net.SplitHostPort(url.Host)
if err != nil {
return nil, err
}
port, err := strconv.Atoi(p)
if err != nil {
return nil, err
}
resp, err := net.LookupIP(host)
if err != nil {
return nil, err
}
var _ips [64]net.IP
ips := _ips[:0]
for _, ip := range resp {
if l.core.config.peerFilter != nil && !l.core.config.peerFilter(ip) {
continue
}
ips = append(ips, ip)
}
if len(ips) == 0 {
return nil, ErrLinkNoSuitableIPs
}
for _, ip := range ips {
var conn net.Conn
if conn, err = fn(host, ip, port); err != nil {
url := *url
url.RawQuery = ""
l.core.log.Debugln("Dialling", url.Redacted(), "reported error:", err)
continue
}
return conn, nil
}
return nil, err
}
func urlForLinkInfo(u url.URL) url.URL { func urlForLinkInfo(u url.URL) url.URL {
u.RawQuery = "" u.RawQuery = ""
return u return u

View File

@ -51,18 +51,23 @@ func (l *links) newLinkQUIC() *linkQUIC {
} }
func (l *linkQUIC) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) { func (l *linkQUIC) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) {
qc, err := quic.DialAddr(ctx, url.Host, l.tlsconfig, l.quicconfig) tlsconfig := l.tlsconfig.Clone()
if err != nil { return l.links.findSuitableIP(url, func(hostname string, ip net.IP, port int) (net.Conn, error) {
return nil, err tlsconfig.ServerName = hostname
} hostport := net.JoinHostPort(ip.String(), fmt.Sprintf("%d", port))
qs, err := qc.OpenStreamSync(ctx) qc, err := quic.DialAddr(ctx, hostport, l.tlsconfig, l.quicconfig)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &linkQUICStream{ qs, err := qc.OpenStreamSync(ctx)
Connection: qc, if err != nil {
Stream: qs, return nil, err
}, nil }
return &linkQUICStream{
Connection: qc,
Stream: qs,
}, nil
})
} }
func (l *linkQUIC) listen(ctx context.Context, url *url.URL, _ string) (net.Listener, error) { func (l *linkQUIC) listen(ctx context.Context, url *url.URL, _ string) (net.Listener, error) {

View File

@ -23,9 +23,6 @@ func (l *links) newLinkSOCKS() *linkSOCKS {
} }
func (l *linkSOCKS) dial(_ context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) { func (l *linkSOCKS) dial(_ context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) {
if url.Scheme != "sockstls" && options.tlsSNI != "" {
return nil, ErrLinkSNINotSupported
}
var proxyAuth *proxy.Auth var proxyAuth *proxy.Auth
if url.User != nil && url.User.Username() != "" { if url.User != nil && url.User.Username() != "" {
proxyAuth = &proxy.Auth{ proxyAuth = &proxy.Auth{
@ -33,21 +30,34 @@ func (l *linkSOCKS) dial(_ context.Context, url *url.URL, info linkInfo, options
} }
proxyAuth.Password, _ = url.User.Password() proxyAuth.Password, _ = url.User.Password()
} }
dialer, err := proxy.SOCKS5("tcp", url.Host, proxyAuth, proxy.Direct) tlsconfig := l.tls.config.Clone()
if err != nil { return l.links.findSuitableIP(url, func(hostname string, ip net.IP, port int) (net.Conn, error) {
return nil, fmt.Errorf("failed to configure proxy") hostport := net.JoinHostPort(ip.String(), fmt.Sprintf("%d", port))
} dialer, err := l.tcp.dialerFor(&net.TCPAddr{
pathtokens := strings.Split(strings.Trim(url.Path, "/"), "/") IP: ip,
conn, err := dialer.Dial("tcp", pathtokens[0]) Port: port,
if err != nil { }, info.sintf)
return nil, fmt.Errorf("failed to dial: %w", err) if err != nil {
} return nil, err
if url.Scheme == "sockstls" { }
tlsconfig := l.tls.config.Clone() proxy, err := proxy.SOCKS5("tcp", hostport, proxyAuth, dialer)
tlsconfig.ServerName = options.tlsSNI if err != nil {
conn = tls.Client(conn, tlsconfig) return nil, err
} }
return conn, nil pathtokens := strings.Split(strings.Trim(url.Path, "/"), "/")
conn, err := proxy.Dial("tcp", pathtokens[0])
if err != nil {
return nil, err
}
if url.Scheme == "sockstls" {
tlsconfig.ServerName = hostname
if sni := options.tlsSNI; sni != "" {
tlsconfig.ServerName = sni
}
conn = tls.Client(conn, tlsconfig)
}
return conn, nil
})
} }
func (l *linkSOCKS) listen(ctx context.Context, url *url.URL, _ string) (net.Listener, error) { func (l *linkSOCKS) listen(ctx context.Context, url *url.URL, _ string) (net.Listener, error) {

View File

@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"net" "net"
"net/url" "net/url"
"strconv"
"time" "time"
"github.com/Arceliar/phony" "github.com/Arceliar/phony"
@ -34,59 +33,18 @@ type tcpDialer struct {
addr *net.TCPAddr addr *net.TCPAddr
} }
func (l *linkTCP) dialersFor(url *url.URL, info linkInfo) ([]*tcpDialer, error) { func (l *linkTCP) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) {
host, p, err := net.SplitHostPort(url.Host) return l.links.findSuitableIP(url, func(hostname string, ip net.IP, port int) (net.Conn, error) {
if err != nil {
return nil, err
}
port, err := strconv.Atoi(p)
if err != nil {
return nil, err
}
ips, err := net.LookupIP(host)
if err != nil {
return nil, err
}
dialers := make([]*tcpDialer, 0, len(ips))
for _, ip := range ips {
addr := &net.TCPAddr{ addr := &net.TCPAddr{
IP: ip, IP: ip,
Port: port, Port: port,
} }
dialer, err := l.dialerFor(addr, info.sintf) dialer, err := l.tcp.dialerFor(addr, info.sintf)
if err != nil { if err != nil {
continue return nil, err
} }
dialers = append(dialers, &tcpDialer{ return dialer.DialContext(ctx, "tcp", addr.String())
info: info, })
dialer: dialer,
addr: addr,
})
}
return dialers, nil
}
func (l *linkTCP) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) {
if options.tlsSNI != "" {
return nil, ErrLinkSNINotSupported
}
dialers, err := l.dialersFor(url, info)
if err != nil {
return nil, err
}
if len(dialers) == 0 {
return nil, nil
}
for _, d := range dialers {
var conn net.Conn
conn, err = d.dialer.DialContext(ctx, "tcp", d.addr.String())
if err != nil {
l.core.log.Warnf("Failed to connect to %s: %s", d.addr, err)
continue
}
return conn, nil
}
return nil, err
} }
func (l *linkTCP) listen(ctx context.Context, url *url.URL, sintf string) (net.Listener, error) { func (l *linkTCP) listen(ctx context.Context, url *url.URL, sintf string) (net.Listener, error) {

View File

@ -32,28 +32,26 @@ func (l *links) newLinkTLS(tcp *linkTCP) *linkTLS {
} }
func (l *linkTLS) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) { func (l *linkTLS) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) {
dialers, err := l.tcp.dialersFor(url, info) tlsconfig := l.config.Clone()
if err != nil { return l.links.findSuitableIP(url, func(hostname string, ip net.IP, port int) (net.Conn, error) {
return nil, err tlsconfig.ServerName = hostname
} if sni := options.tlsSNI; sni != "" {
if len(dialers) == 0 { tlsconfig.ServerName = sni
return nil, nil }
} addr := &net.TCPAddr{
for _, d := range dialers { IP: ip,
tlsconfig := l.config.Clone() Port: port,
tlsconfig.ServerName = options.tlsSNI }
dialer, err := l.tcp.dialerFor(addr, info.sintf)
if err != nil {
return nil, err
}
tlsdialer := &tls.Dialer{ tlsdialer := &tls.Dialer{
NetDialer: d.dialer, NetDialer: dialer,
Config: tlsconfig, Config: tlsconfig,
} }
var conn net.Conn return tlsdialer.DialContext(ctx, "tcp", addr.String())
conn, err = tlsdialer.DialContext(ctx, "tcp", d.addr.String()) })
if err != nil {
continue
}
return conn, nil
}
return nil, err
} }
func (l *linkTLS) listen(ctx context.Context, url *url.URL, sintf string) (net.Listener, error) { func (l *linkTLS) listen(ctx context.Context, url *url.URL, sintf string) (net.Listener, error) {

View File

@ -31,9 +31,6 @@ func (l *links) newLinkUNIX() *linkUNIX {
} }
func (l *linkUNIX) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) { func (l *linkUNIX) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) {
if options.tlsSNI != "" {
return nil, ErrLinkSNINotSupported
}
addr, err := net.ResolveUnixAddr("unix", url.Path) addr, err := net.ResolveUnixAddr("unix", url.Path)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -2,6 +2,7 @@ package core
import ( import (
"context" "context"
"fmt"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
@ -87,18 +88,35 @@ func (l *links) newLinkWS() *linkWS {
} }
func (l *linkWS) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) { func (l *linkWS) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) {
if options.tlsSNI != "" { return l.links.findSuitableIP(url, func(hostname string, ip net.IP, port int) (net.Conn, error) {
return nil, ErrLinkSNINotSupported u := *url
} u.Host = net.JoinHostPort(ip.String(), fmt.Sprintf("%d", port))
wsconn, _, err := websocket.Dial(ctx, url.String(), &websocket.DialOptions{ addr := &net.TCPAddr{
Subprotocols: []string{"ygg-ws"}, IP: ip,
Port: port,
}
dialer, err := l.tcp.dialerFor(addr, info.sintf)
if err != nil {
return nil, err
}
wsconn, _, err := websocket.Dial(ctx, u.String(), &websocket.DialOptions{
HTTPClient: &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: dialer.Dial,
DialContext: dialer.DialContext,
},
},
Subprotocols: []string{"ygg-ws"},
Host: hostname,
})
if err != nil {
return nil, err
}
return &linkWSConn{
Conn: websocket.NetConn(ctx, wsconn, websocket.MessageBinary),
}, nil
}) })
if err != nil {
return nil, err
}
return &linkWSConn{
Conn: websocket.NetConn(ctx, wsconn, websocket.MessageBinary),
}, nil
} }
func (l *linkWS) listen(ctx context.Context, url *url.URL, _ string) (net.Listener, error) { func (l *linkWS) listen(ctx context.Context, url *url.URL, _ string) (net.Listener, error) {

View File

@ -2,8 +2,10 @@ package core
import ( import (
"context" "context"
"crypto/tls"
"fmt" "fmt"
"net" "net"
"net/http"
"net/url" "net/url"
"github.com/Arceliar/phony" "github.com/Arceliar/phony"
@ -13,6 +15,7 @@ import (
type linkWSS struct { type linkWSS struct {
phony.Inbox phony.Inbox
*links *links
tlsconfig *tls.Config
} }
type linkWSSConn struct { type linkWSSConn struct {
@ -21,24 +24,45 @@ type linkWSSConn struct {
func (l *links) newLinkWSS() *linkWSS { func (l *links) newLinkWSS() *linkWSS {
lwss := &linkWSS{ lwss := &linkWSS{
links: l, links: l,
tlsconfig: l.core.config.tls.Clone(),
} }
return lwss return lwss
} }
func (l *linkWSS) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) { func (l *linkWSS) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) {
if options.tlsSNI != "" { tlsconfig := l.tlsconfig.Clone()
return nil, ErrLinkSNINotSupported return l.links.findSuitableIP(url, func(hostname string, ip net.IP, port int) (net.Conn, error) {
} tlsconfig.ServerName = hostname
wsconn, _, err := websocket.Dial(ctx, url.String(), &websocket.DialOptions{ u := *url
Subprotocols: []string{"ygg-ws"}, u.Host = net.JoinHostPort(ip.String(), fmt.Sprintf("%d", port))
addr := &net.TCPAddr{
IP: ip,
Port: port,
}
dialer, err := l.tcp.dialerFor(addr, info.sintf)
if err != nil {
return nil, err
}
wsconn, _, err := websocket.Dial(ctx, u.String(), &websocket.DialOptions{
HTTPClient: &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: dialer.Dial,
DialContext: dialer.DialContext,
TLSClientConfig: tlsconfig,
},
},
Subprotocols: []string{"ygg-ws"},
Host: hostname,
})
if err != nil {
return nil, err
}
return &linkWSConn{
Conn: websocket.NetConn(ctx, wsconn, websocket.MessageBinary),
}, nil
}) })
if err != nil {
return nil, err
}
return &linkWSSConn{
Conn: websocket.NetConn(ctx, wsconn, websocket.MessageBinary),
}, nil
} }
func (l *linkWSS) listen(ctx context.Context, url *url.URL, _ string) (net.Listener, error) { func (l *linkWSS) listen(ctx context.Context, url *url.URL, _ string) (net.Listener, error) {

View File

@ -3,6 +3,7 @@ package core
import ( import (
"crypto/ed25519" "crypto/ed25519"
"fmt" "fmt"
"net"
"net/url" "net/url"
) )
@ -24,6 +25,8 @@ func (c *Core) _applyOption(opt SetupOption) (err error) {
} }
case ListenAddress: case ListenAddress:
c.config._listeners[v] = struct{}{} c.config._listeners[v] = struct{}{}
case PeerFilter:
c.config.peerFilter = v
case NodeInfo: case NodeInfo:
c.config.nodeinfo = v c.config.nodeinfo = v
case NodeInfoPrivacy: case NodeInfoPrivacy:
@ -48,9 +51,11 @@ type Peer struct {
type NodeInfo map[string]interface{} type NodeInfo map[string]interface{}
type NodeInfoPrivacy bool type NodeInfoPrivacy bool
type AllowedPublicKey ed25519.PublicKey type AllowedPublicKey ed25519.PublicKey
type PeerFilter func(net.IP) bool
func (a ListenAddress) isSetupOption() {} func (a ListenAddress) isSetupOption() {}
func (a Peer) isSetupOption() {} func (a Peer) isSetupOption() {}
func (a NodeInfo) isSetupOption() {} func (a NodeInfo) isSetupOption() {}
func (a NodeInfoPrivacy) isSetupOption() {} func (a NodeInfoPrivacy) isSetupOption() {}
func (a AllowedPublicKey) isSetupOption() {} func (a AllowedPublicKey) isSetupOption() {}
func (a PeerFilter) isSetupOption() {}