diff --git a/cmd/yggdrasil/main.go b/cmd/yggdrasil/main.go index 0ae8ab42..3ec6414c 100644 --- a/cmd/yggdrasil/main.go +++ b/cmd/yggdrasil/main.go @@ -191,9 +191,16 @@ func main() { // Set up the Yggdrasil node itself. { + iprange := net.IPNet{ + IP: net.ParseIP("200::"), + Mask: net.CIDRMask(7, 128), + } options := []core.SetupOption{ core.NodeInfo(cfg.NodeInfo), core.NodeInfoPrivacy(cfg.NodeInfoPrivacy), + core.PeerFilter(func(ip net.IP) bool { + return !iprange.Contains(ip) + }), } for _, addr := range cfg.Listen { options = append(options, core.ListenAddress(addr)) diff --git a/contrib/mobile/mobile.go b/contrib/mobile/mobile.go index 06f48027..abc89f1c 100644 --- a/contrib/mobile/mobile.go +++ b/contrib/mobile/mobile.go @@ -53,7 +53,15 @@ func (m *Yggdrasil) StartJSON(configjson []byte) error { } // Set up the Yggdrasil node itself. { - options := []core.SetupOption{} + iprange := net.IPNet{ + IP: net.ParseIP("200::"), + Mask: net.CIDRMask(7, 128), + } + options := []core.SetupOption{ + core.PeerFilter(func(ip net.IP) bool { + return !iprange.Contains(ip) + }), + } for _, peer := range m.config.Peers { options = append(options, core.Peer{URI: peer}) } diff --git a/src/core/core.go b/src/core/core.go index 2b206ee1..a7f9fe96 100644 --- a/src/core/core.go +++ b/src/core/core.go @@ -40,6 +40,7 @@ type Core struct { tls *tls.Config // immutable after startup //_peers map[Peer]*linkInfo // configurable after startup _listeners map[ListenAddress]struct{} // configurable after startup + peerFilter func(ip net.IP) bool // immutable after startup nodeinfo NodeInfo // immutable after startup nodeinfoPrivacy NodeInfoPrivacy // immutable after startup _allowedPublicKeys map[[32]byte]struct{} // configurable after startup diff --git a/src/core/link.go b/src/core/link.go index c2267f24..d7e5b110 100644 --- a/src/core/link.go +++ b/src/core/link.go @@ -127,6 +127,7 @@ const ErrLinkPasswordInvalid = linkError("invalid password supplied") const ErrLinkUnrecognisedSchema = linkError("link schema unknown") const ErrLinkMaxBackoffInvalid = linkError("max backoff duration invalid") const ErrLinkSNINotSupported = linkError("SNI not supported on this link type") +const ErrLinkNoSuitableIPs = linkError("no suitable remote IPs") func (l *links) add(u *url.URL, sintf string, linkType linkType) error { var retErr error @@ -653,6 +654,43 @@ func (l *links) handler(linkType linkType, options linkOptions, conn net.Conn, s return err } +func (l *links) findSuitableIP(url *url.URL, fn func(hostname string, ip net.IP, port int) (net.Conn, error)) (net.Conn, error) { + host, p, err := net.SplitHostPort(url.Host) + if err != nil { + return nil, err + } + port, err := strconv.Atoi(p) + if err != nil { + return nil, err + } + resp, err := net.LookupIP(host) + if err != nil { + return nil, err + } + var _ips [64]net.IP + ips := _ips[:0] + for _, ip := range resp { + if l.core.config.peerFilter != nil && !l.core.config.peerFilter(ip) { + continue + } + ips = append(ips, ip) + } + if len(ips) == 0 { + return nil, ErrLinkNoSuitableIPs + } + for _, ip := range ips { + var conn net.Conn + if conn, err = fn(host, ip, port); err != nil { + url := *url + url.RawQuery = "" + l.core.log.Debugln("Dialling", url.Redacted(), "reported error:", err) + continue + } + return conn, nil + } + return nil, err +} + func urlForLinkInfo(u url.URL) url.URL { u.RawQuery = "" return u diff --git a/src/core/link_quic.go b/src/core/link_quic.go index 9ad5456d..ffb69a6d 100644 --- a/src/core/link_quic.go +++ b/src/core/link_quic.go @@ -51,18 +51,23 @@ func (l *links) newLinkQUIC() *linkQUIC { } func (l *linkQUIC) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) { - qc, err := quic.DialAddr(ctx, url.Host, l.tlsconfig, l.quicconfig) - if err != nil { - return nil, err - } - qs, err := qc.OpenStreamSync(ctx) - if err != nil { - return nil, err - } - return &linkQUICStream{ - Connection: qc, - Stream: qs, - }, nil + tlsconfig := l.tlsconfig.Clone() + return l.links.findSuitableIP(url, func(hostname string, ip net.IP, port int) (net.Conn, error) { + tlsconfig.ServerName = hostname + hostport := net.JoinHostPort(ip.String(), fmt.Sprintf("%d", port)) + qc, err := quic.DialAddr(ctx, hostport, l.tlsconfig, l.quicconfig) + if err != nil { + return nil, err + } + qs, err := qc.OpenStreamSync(ctx) + if err != nil { + return nil, err + } + return &linkQUICStream{ + Connection: qc, + Stream: qs, + }, nil + }) } func (l *linkQUIC) listen(ctx context.Context, url *url.URL, _ string) (net.Listener, error) { diff --git a/src/core/link_socks.go b/src/core/link_socks.go index 0f66661b..f33cd190 100644 --- a/src/core/link_socks.go +++ b/src/core/link_socks.go @@ -23,9 +23,6 @@ func (l *links) newLinkSOCKS() *linkSOCKS { } func (l *linkSOCKS) dial(_ context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) { - if url.Scheme != "sockstls" && options.tlsSNI != "" { - return nil, ErrLinkSNINotSupported - } var proxyAuth *proxy.Auth if url.User != nil && url.User.Username() != "" { proxyAuth = &proxy.Auth{ @@ -33,21 +30,34 @@ func (l *linkSOCKS) dial(_ context.Context, url *url.URL, info linkInfo, options } proxyAuth.Password, _ = url.User.Password() } - dialer, err := proxy.SOCKS5("tcp", url.Host, proxyAuth, proxy.Direct) - if err != nil { - return nil, fmt.Errorf("failed to configure proxy") - } - pathtokens := strings.Split(strings.Trim(url.Path, "/"), "/") - conn, err := dialer.Dial("tcp", pathtokens[0]) - if err != nil { - return nil, fmt.Errorf("failed to dial: %w", err) - } - if url.Scheme == "sockstls" { - tlsconfig := l.tls.config.Clone() - tlsconfig.ServerName = options.tlsSNI - conn = tls.Client(conn, tlsconfig) - } - return conn, nil + tlsconfig := l.tls.config.Clone() + return l.links.findSuitableIP(url, func(hostname string, ip net.IP, port int) (net.Conn, error) { + hostport := net.JoinHostPort(ip.String(), fmt.Sprintf("%d", port)) + dialer, err := l.tcp.dialerFor(&net.TCPAddr{ + IP: ip, + Port: port, + }, info.sintf) + if err != nil { + return nil, err + } + proxy, err := proxy.SOCKS5("tcp", hostport, proxyAuth, dialer) + if err != nil { + return nil, err + } + pathtokens := strings.Split(strings.Trim(url.Path, "/"), "/") + conn, err := proxy.Dial("tcp", pathtokens[0]) + if err != nil { + return nil, err + } + if url.Scheme == "sockstls" { + tlsconfig.ServerName = hostname + if sni := options.tlsSNI; sni != "" { + tlsconfig.ServerName = sni + } + conn = tls.Client(conn, tlsconfig) + } + return conn, nil + }) } func (l *linkSOCKS) listen(ctx context.Context, url *url.URL, _ string) (net.Listener, error) { diff --git a/src/core/link_tcp.go b/src/core/link_tcp.go index 38c42def..3315f5c7 100644 --- a/src/core/link_tcp.go +++ b/src/core/link_tcp.go @@ -5,7 +5,6 @@ import ( "fmt" "net" "net/url" - "strconv" "time" "github.com/Arceliar/phony" @@ -34,59 +33,18 @@ type tcpDialer struct { addr *net.TCPAddr } -func (l *linkTCP) dialersFor(url *url.URL, info linkInfo) ([]*tcpDialer, error) { - host, p, err := net.SplitHostPort(url.Host) - if err != nil { - return nil, err - } - port, err := strconv.Atoi(p) - if err != nil { - return nil, err - } - ips, err := net.LookupIP(host) - if err != nil { - return nil, err - } - dialers := make([]*tcpDialer, 0, len(ips)) - for _, ip := range ips { +func (l *linkTCP) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) { + return l.links.findSuitableIP(url, func(hostname string, ip net.IP, port int) (net.Conn, error) { addr := &net.TCPAddr{ IP: ip, Port: port, } - dialer, err := l.dialerFor(addr, info.sintf) + dialer, err := l.tcp.dialerFor(addr, info.sintf) if err != nil { - continue + return nil, err } - dialers = append(dialers, &tcpDialer{ - info: info, - dialer: dialer, - addr: addr, - }) - } - return dialers, nil -} - -func (l *linkTCP) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) { - if options.tlsSNI != "" { - return nil, ErrLinkSNINotSupported - } - dialers, err := l.dialersFor(url, info) - if err != nil { - return nil, err - } - if len(dialers) == 0 { - return nil, nil - } - for _, d := range dialers { - var conn net.Conn - conn, err = d.dialer.DialContext(ctx, "tcp", d.addr.String()) - if err != nil { - l.core.log.Warnf("Failed to connect to %s: %s", d.addr, err) - continue - } - return conn, nil - } - return nil, err + return dialer.DialContext(ctx, "tcp", addr.String()) + }) } func (l *linkTCP) listen(ctx context.Context, url *url.URL, sintf string) (net.Listener, error) { diff --git a/src/core/link_tls.go b/src/core/link_tls.go index d51d0ce5..da3c7791 100644 --- a/src/core/link_tls.go +++ b/src/core/link_tls.go @@ -32,28 +32,26 @@ func (l *links) newLinkTLS(tcp *linkTCP) *linkTLS { } func (l *linkTLS) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) { - dialers, err := l.tcp.dialersFor(url, info) - if err != nil { - return nil, err - } - if len(dialers) == 0 { - return nil, nil - } - for _, d := range dialers { - tlsconfig := l.config.Clone() - tlsconfig.ServerName = options.tlsSNI + tlsconfig := l.config.Clone() + return l.links.findSuitableIP(url, func(hostname string, ip net.IP, port int) (net.Conn, error) { + tlsconfig.ServerName = hostname + if sni := options.tlsSNI; sni != "" { + tlsconfig.ServerName = sni + } + addr := &net.TCPAddr{ + IP: ip, + Port: port, + } + dialer, err := l.tcp.dialerFor(addr, info.sintf) + if err != nil { + return nil, err + } tlsdialer := &tls.Dialer{ - NetDialer: d.dialer, + NetDialer: dialer, Config: tlsconfig, } - var conn net.Conn - conn, err = tlsdialer.DialContext(ctx, "tcp", d.addr.String()) - if err != nil { - continue - } - return conn, nil - } - return nil, err + return tlsdialer.DialContext(ctx, "tcp", addr.String()) + }) } func (l *linkTLS) listen(ctx context.Context, url *url.URL, sintf string) (net.Listener, error) { diff --git a/src/core/link_unix.go b/src/core/link_unix.go index 1da6931d..ddbfa0ad 100644 --- a/src/core/link_unix.go +++ b/src/core/link_unix.go @@ -31,9 +31,6 @@ func (l *links) newLinkUNIX() *linkUNIX { } func (l *linkUNIX) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) { - if options.tlsSNI != "" { - return nil, ErrLinkSNINotSupported - } addr, err := net.ResolveUnixAddr("unix", url.Path) if err != nil { return nil, err diff --git a/src/core/link_ws.go b/src/core/link_ws.go index 59816098..86f065a6 100644 --- a/src/core/link_ws.go +++ b/src/core/link_ws.go @@ -2,6 +2,7 @@ package core import ( "context" + "fmt" "net" "net/http" "net/url" @@ -87,18 +88,35 @@ func (l *links) newLinkWS() *linkWS { } func (l *linkWS) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) { - if options.tlsSNI != "" { - return nil, ErrLinkSNINotSupported - } - wsconn, _, err := websocket.Dial(ctx, url.String(), &websocket.DialOptions{ - Subprotocols: []string{"ygg-ws"}, + return l.links.findSuitableIP(url, func(hostname string, ip net.IP, port int) (net.Conn, error) { + u := *url + u.Host = net.JoinHostPort(ip.String(), fmt.Sprintf("%d", port)) + addr := &net.TCPAddr{ + IP: ip, + Port: port, + } + dialer, err := l.tcp.dialerFor(addr, info.sintf) + if err != nil { + return nil, err + } + wsconn, _, err := websocket.Dial(ctx, u.String(), &websocket.DialOptions{ + HTTPClient: &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + Dial: dialer.Dial, + DialContext: dialer.DialContext, + }, + }, + Subprotocols: []string{"ygg-ws"}, + Host: hostname, + }) + if err != nil { + return nil, err + } + return &linkWSConn{ + Conn: websocket.NetConn(ctx, wsconn, websocket.MessageBinary), + }, nil }) - if err != nil { - return nil, err - } - return &linkWSConn{ - Conn: websocket.NetConn(ctx, wsconn, websocket.MessageBinary), - }, nil } func (l *linkWS) listen(ctx context.Context, url *url.URL, _ string) (net.Listener, error) { diff --git a/src/core/link_wss.go b/src/core/link_wss.go index f09d7955..4d968c2f 100644 --- a/src/core/link_wss.go +++ b/src/core/link_wss.go @@ -2,8 +2,10 @@ package core import ( "context" + "crypto/tls" "fmt" "net" + "net/http" "net/url" "github.com/Arceliar/phony" @@ -13,6 +15,7 @@ import ( type linkWSS struct { phony.Inbox *links + tlsconfig *tls.Config } type linkWSSConn struct { @@ -21,24 +24,45 @@ type linkWSSConn struct { func (l *links) newLinkWSS() *linkWSS { lwss := &linkWSS{ - links: l, + links: l, + tlsconfig: l.core.config.tls.Clone(), } return lwss } func (l *linkWSS) dial(ctx context.Context, url *url.URL, info linkInfo, options linkOptions) (net.Conn, error) { - if options.tlsSNI != "" { - return nil, ErrLinkSNINotSupported - } - wsconn, _, err := websocket.Dial(ctx, url.String(), &websocket.DialOptions{ - Subprotocols: []string{"ygg-ws"}, + tlsconfig := l.tlsconfig.Clone() + return l.links.findSuitableIP(url, func(hostname string, ip net.IP, port int) (net.Conn, error) { + tlsconfig.ServerName = hostname + u := *url + u.Host = net.JoinHostPort(ip.String(), fmt.Sprintf("%d", port)) + addr := &net.TCPAddr{ + IP: ip, + Port: port, + } + dialer, err := l.tcp.dialerFor(addr, info.sintf) + if err != nil { + return nil, err + } + wsconn, _, err := websocket.Dial(ctx, u.String(), &websocket.DialOptions{ + HTTPClient: &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + Dial: dialer.Dial, + DialContext: dialer.DialContext, + TLSClientConfig: tlsconfig, + }, + }, + Subprotocols: []string{"ygg-ws"}, + Host: hostname, + }) + if err != nil { + return nil, err + } + return &linkWSConn{ + Conn: websocket.NetConn(ctx, wsconn, websocket.MessageBinary), + }, nil }) - if err != nil { - return nil, err - } - return &linkWSSConn{ - Conn: websocket.NetConn(ctx, wsconn, websocket.MessageBinary), - }, nil } func (l *linkWSS) listen(ctx context.Context, url *url.URL, _ string) (net.Listener, error) { diff --git a/src/core/options.go b/src/core/options.go index 581c033b..7e67bfb4 100644 --- a/src/core/options.go +++ b/src/core/options.go @@ -3,6 +3,7 @@ package core import ( "crypto/ed25519" "fmt" + "net" "net/url" ) @@ -24,6 +25,8 @@ func (c *Core) _applyOption(opt SetupOption) (err error) { } case ListenAddress: c.config._listeners[v] = struct{}{} + case PeerFilter: + c.config.peerFilter = v case NodeInfo: c.config.nodeinfo = v case NodeInfoPrivacy: @@ -48,9 +51,11 @@ type Peer struct { type NodeInfo map[string]interface{} type NodeInfoPrivacy bool type AllowedPublicKey ed25519.PublicKey +type PeerFilter func(net.IP) bool func (a ListenAddress) isSetupOption() {} func (a Peer) isSetupOption() {} func (a NodeInfo) isSetupOption() {} func (a NodeInfoPrivacy) isSetupOption() {} func (a AllowedPublicKey) isSetupOption() {} +func (a PeerFilter) isSetupOption() {}