mirror of
https://github.com/tailscale/tailscale.git
synced 2025-10-09 08:01:31 +00:00
safesocket: add ConnectContext
This adds a variant for Connect that takes in a context.Context which allows passing through cancellation etc by the caller. Updates tailscale/corp#18266 Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
package safesocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@@ -57,7 +58,7 @@ func TestBasics(t *testing.T) {
|
||||
}()
|
||||
|
||||
go func() {
|
||||
c, err := Connect(sock)
|
||||
c, err := ConnectContext(context.Background(), sock)
|
||||
if err != nil {
|
||||
errs <- err
|
||||
return
|
||||
|
@@ -16,9 +16,8 @@ import (
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func connect(path string) (net.Conn, error) {
|
||||
dl := time.Now().Add(20 * time.Second)
|
||||
ctx, cancel := context.WithDeadline(context.Background(), dl)
|
||||
func connect(ctx context.Context, path string) (net.Conn, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 20*time.Second)
|
||||
defer cancel()
|
||||
// We use the identification impersonation level so that tailscaled may
|
||||
// obtain information about our token for access control purposes.
|
||||
|
@@ -6,6 +6,7 @@
|
||||
package safesocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"runtime"
|
||||
@@ -52,11 +53,14 @@ func tailscaledStillStarting() bool {
|
||||
return tailscaledProcExists()
|
||||
}
|
||||
|
||||
// Connect connects to tailscaled using a unix socket or named pipe.
|
||||
func Connect(path string) (net.Conn, error) {
|
||||
// ConnectContext connects to tailscaled using a unix socket or named pipe.
|
||||
func ConnectContext(ctx context.Context, path string) (net.Conn, error) {
|
||||
for {
|
||||
c, err := connect(path)
|
||||
c, err := connect(ctx, path)
|
||||
if err != nil && tailscaledStillStarting() {
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
@@ -64,6 +68,12 @@ func Connect(path string) (net.Conn, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// Connect connects to tailscaled using a unix socket or named pipe.
|
||||
// Deprecated: use ConnectContext instead.
|
||||
func Connect(path string) (net.Conn, error) {
|
||||
return ConnectContext(context.Background(), path)
|
||||
}
|
||||
|
||||
// Listen returns a listener either on Unix socket path (on Unix), or
|
||||
// the NamedPipe path (on Windows).
|
||||
func Listen(path string) (net.Listener, error) {
|
||||
|
@@ -4,6 +4,7 @@
|
||||
package safesocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/akutz/memconn"
|
||||
@@ -15,6 +16,6 @@ func listen(path string) (net.Listener, error) {
|
||||
return memconn.Listen("memu", memName)
|
||||
}
|
||||
|
||||
func connect(_ string) (net.Conn, error) {
|
||||
return memconn.Dial("memu", memName)
|
||||
func connect(ctx context.Context, _ string) (net.Conn, error) {
|
||||
return memconn.DialContext(ctx, "memu", memName)
|
||||
}
|
||||
|
@@ -6,6 +6,7 @@
|
||||
package safesocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
@@ -85,7 +86,7 @@ func (fc plan9FileConn) SetWriteDeadline(t time.Time) error {
|
||||
return syscall.EPLAN9
|
||||
}
|
||||
|
||||
func connect(path string) (net.Conn, error) {
|
||||
func connect(_ context.Context, path string) (net.Conn, error) {
|
||||
f, err := os.OpenFile(path, os.O_RDWR, 0666)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@@ -6,6 +6,7 @@
|
||||
package safesocket
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
@@ -16,11 +17,12 @@ import (
|
||||
"runtime"
|
||||
)
|
||||
|
||||
func connect(path string) (net.Conn, error) {
|
||||
func connect(ctx context.Context, path string) (net.Conn, error) {
|
||||
if runtime.GOOS == "js" {
|
||||
return nil, errors.New("safesocket.Connect not yet implemented on js/wasm")
|
||||
}
|
||||
return net.Dial("unix", path)
|
||||
var std net.Dialer
|
||||
return std.DialContext(ctx, "unix", path)
|
||||
}
|
||||
|
||||
func listen(path string) (net.Listener, error) {
|
||||
|
Reference in New Issue
Block a user