mirror of
https://github.com/tailscale/tailscale.git
synced 2025-02-18 02:48:40 +00:00
safesocket: add ConnectContext
This adds a variant for Connect that takes in a context.Context which allows passing through cancellation etc by the caller. Updates tailscale/corp#18266 Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
parent
3672f66c74
commit
4b6a0c42c8
@ -167,7 +167,7 @@ func (s *FileSystemForRemote) buildChild(share *drive.Share) *compositedav.Child
|
|||||||
return fmt.Sprintf("http://%s/%s/%s", hex.EncodeToString([]byte(share.Name)), secretToken, url.PathEscape(share.Name)), nil
|
return fmt.Sprintf("http://%s/%s/%s", hex.EncodeToString([]byte(share.Name)), secretToken, url.PathEscape(share.Name)), nil
|
||||||
},
|
},
|
||||||
Transport: &http.Transport{
|
Transport: &http.Transport{
|
||||||
Dial: func(_, shareAddr string) (net.Conn, error) {
|
DialContext: func(ctx context.Context, _, shareAddr string) (net.Conn, error) {
|
||||||
shareNameHex, _, err := net.SplitHostPort(shareAddr)
|
shareNameHex, _, err := net.SplitHostPort(shareAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to parse share address %v: %w", shareAddr, err)
|
return nil, fmt.Errorf("unable to parse share address %v: %w", shareAddr, err)
|
||||||
@ -188,10 +188,11 @@ func (s *FileSystemForRemote) buildChild(share *drive.Share) *compositedav.Child
|
|||||||
_, err = netip.ParseAddrPort(addr)
|
_, err = netip.ParseAddrPort(addr)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
// this is a regular network address, dial normally
|
// this is a regular network address, dial normally
|
||||||
return net.Dial("tcp", addr)
|
var std net.Dialer
|
||||||
|
return std.DialContext(ctx, "tcp", addr)
|
||||||
}
|
}
|
||||||
// assume this is a safesocket address
|
// assume this is a safesocket address
|
||||||
return safesocket.Connect(addr)
|
return safesocket.ConnectContext(ctx, addr)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -709,7 +709,7 @@ func dialContext(ctx context.Context, netw, addr string, netMon *netmon.Monitor,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if version.IsWindowsGUI() && strings.HasPrefix(netw, "tcp") {
|
if version.IsWindowsGUI() && strings.HasPrefix(netw, "tcp") {
|
||||||
if c, err := safesocket.Connect(""); err == nil {
|
if c, err := safesocket.ConnectContext(ctx, ""); err == nil {
|
||||||
fmt.Fprintf(c, "CONNECT %s HTTP/1.0\r\n\r\n", addr)
|
fmt.Fprintf(c, "CONNECT %s HTTP/1.0\r\n\r\n", addr)
|
||||||
br := bufio.NewReader(c)
|
br := bufio.NewReader(c)
|
||||||
res, err := http.ReadResponse(br, nil)
|
res, err := http.ReadResponse(br, nil)
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
package safesocket
|
package safesocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
@ -57,7 +58,7 @@ func TestBasics(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
c, err := Connect(sock)
|
c, err := ConnectContext(context.Background(), sock)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errs <- err
|
errs <- err
|
||||||
return
|
return
|
||||||
|
@ -16,9 +16,8 @@ import (
|
|||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
)
|
)
|
||||||
|
|
||||||
func connect(path string) (net.Conn, error) {
|
func connect(ctx context.Context, path string) (net.Conn, error) {
|
||||||
dl := time.Now().Add(20 * time.Second)
|
ctx, cancel := context.WithTimeout(ctx, 20*time.Second)
|
||||||
ctx, cancel := context.WithDeadline(context.Background(), dl)
|
|
||||||
defer cancel()
|
defer cancel()
|
||||||
// We use the identification impersonation level so that tailscaled may
|
// We use the identification impersonation level so that tailscaled may
|
||||||
// obtain information about our token for access control purposes.
|
// obtain information about our token for access control purposes.
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
package safesocket
|
package safesocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"runtime"
|
"runtime"
|
||||||
@ -52,11 +53,14 @@ func tailscaledStillStarting() bool {
|
|||||||
return tailscaledProcExists()
|
return tailscaledProcExists()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connect connects to tailscaled using a unix socket or named pipe.
|
// ConnectContext connects to tailscaled using a unix socket or named pipe.
|
||||||
func Connect(path string) (net.Conn, error) {
|
func ConnectContext(ctx context.Context, path string) (net.Conn, error) {
|
||||||
for {
|
for {
|
||||||
c, err := connect(path)
|
c, err := connect(ctx, path)
|
||||||
if err != nil && tailscaledStillStarting() {
|
if err != nil && tailscaledStillStarting() {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
time.Sleep(250 * time.Millisecond)
|
time.Sleep(250 * time.Millisecond)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -64,6 +68,12 @@ func Connect(path string) (net.Conn, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Connect connects to tailscaled using a unix socket or named pipe.
|
||||||
|
// Deprecated: use ConnectContext instead.
|
||||||
|
func Connect(path string) (net.Conn, error) {
|
||||||
|
return ConnectContext(context.Background(), path)
|
||||||
|
}
|
||||||
|
|
||||||
// Listen returns a listener either on Unix socket path (on Unix), or
|
// Listen returns a listener either on Unix socket path (on Unix), or
|
||||||
// the NamedPipe path (on Windows).
|
// the NamedPipe path (on Windows).
|
||||||
func Listen(path string) (net.Listener, error) {
|
func Listen(path string) (net.Listener, error) {
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
package safesocket
|
package safesocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"github.com/akutz/memconn"
|
"github.com/akutz/memconn"
|
||||||
@ -15,6 +16,6 @@ func listen(path string) (net.Listener, error) {
|
|||||||
return memconn.Listen("memu", memName)
|
return memconn.Listen("memu", memName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func connect(_ string) (net.Conn, error) {
|
func connect(ctx context.Context, _ string) (net.Conn, error) {
|
||||||
return memconn.Dial("memu", memName)
|
return memconn.DialContext(ctx, "memu", memName)
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
package safesocket
|
package safesocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
@ -85,7 +86,7 @@ func (fc plan9FileConn) SetWriteDeadline(t time.Time) error {
|
|||||||
return syscall.EPLAN9
|
return syscall.EPLAN9
|
||||||
}
|
}
|
||||||
|
|
||||||
func connect(path string) (net.Conn, error) {
|
func connect(_ context.Context, path string) (net.Conn, error) {
|
||||||
f, err := os.OpenFile(path, os.O_RDWR, 0666)
|
f, err := os.OpenFile(path, os.O_RDWR, 0666)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
package safesocket
|
package safesocket
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
@ -16,11 +17,12 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
)
|
)
|
||||||
|
|
||||||
func connect(path string) (net.Conn, error) {
|
func connect(ctx context.Context, path string) (net.Conn, error) {
|
||||||
if runtime.GOOS == "js" {
|
if runtime.GOOS == "js" {
|
||||||
return nil, errors.New("safesocket.Connect not yet implemented on js/wasm")
|
return nil, errors.New("safesocket.Connect not yet implemented on js/wasm")
|
||||||
}
|
}
|
||||||
return net.Dial("unix", path)
|
var std net.Dialer
|
||||||
|
return std.DialContext(ctx, "unix", path)
|
||||||
}
|
}
|
||||||
|
|
||||||
func listen(path string) (net.Listener, error) {
|
func listen(path string) (net.Listener, error) {
|
||||||
|
@ -1497,7 +1497,7 @@ func (n *testNode) Ping(otherNode *testNode) error {
|
|||||||
func (n *testNode) AwaitListening() {
|
func (n *testNode) AwaitListening() {
|
||||||
t := n.env.t
|
t := n.env.t
|
||||||
if err := tstest.WaitFor(20*time.Second, func() (err error) {
|
if err := tstest.WaitFor(20*time.Second, func() (err error) {
|
||||||
c, err := safesocket.Connect(n.sockFile)
|
c, err := safesocket.ConnectContext(context.Background(), n.sockFile)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
c.Close()
|
c.Close()
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user