mirror of
https://github.com/tailscale/tailscale.git
synced 2025-07-31 00:03:47 +00:00
close the conn, don't leave it open
This commit is contained in:
parent
ad7d1ee07a
commit
24ce3279f4
@ -112,29 +112,37 @@ func (sl StreamLayer) Dial(address raft.ServerAddress, timeout time.Duration) (n
|
||||
return sl.s.Dial(ctx, "tcp", string(address))
|
||||
}
|
||||
|
||||
func (sl StreamLayer) connAuthorized(conn net.Conn) (bool, error) {
|
||||
if conn.RemoteAddr() == nil {
|
||||
return false, nil
|
||||
}
|
||||
addr, err := addrFromServerAddress(conn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
// bad RemoteAddr is not authorized
|
||||
return false, nil
|
||||
}
|
||||
ctx := context.Background() // TODO
|
||||
err = sl.auth.refresh(ctx)
|
||||
if err != nil {
|
||||
// might be authorized, we couldn't tell
|
||||
return false, err
|
||||
}
|
||||
return sl.auth.allowsHost(addr), nil
|
||||
}
|
||||
|
||||
func (sl StreamLayer) Accept() (net.Conn, error) {
|
||||
for {
|
||||
conn, err := sl.Listener.Accept()
|
||||
if err != nil || conn == nil {
|
||||
return conn, err
|
||||
}
|
||||
ctx := context.Background() // TODO
|
||||
err = sl.auth.refresh(ctx)
|
||||
authorized, err := sl.connAuthorized(conn)
|
||||
if err != nil {
|
||||
// TODO should we stay alive here?
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if conn.RemoteAddr() == nil {
|
||||
continue
|
||||
}
|
||||
addr, err := addrFromServerAddress(conn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
// TODO should we stay alive here?
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !sl.auth.allowsHost(addr) {
|
||||
if !authorized {
|
||||
conn.Close()
|
||||
continue
|
||||
}
|
||||
return conn, err
|
||||
|
@ -546,14 +546,6 @@ func TestOnlyTaggedPeersCanDialRaftPort(t *testing.T) {
|
||||
ipv4, _ := ps[0].ts.TailscaleIPs()
|
||||
sAddr := fmt.Sprintf("%s:%d", ipv4, cfg.RaftPort)
|
||||
|
||||
isNetTimeoutErr := func(err error) bool {
|
||||
var netErr net.Error
|
||||
if !errors.As(err, &netErr) {
|
||||
return false
|
||||
}
|
||||
return netErr.Timeout()
|
||||
}
|
||||
|
||||
getErrorFromTryingToSend := func(s *tsnet.Server) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
@ -561,7 +553,6 @@ func TestOnlyTaggedPeersCanDialRaftPort(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected Dial err: %v", err)
|
||||
}
|
||||
conn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
fmt.Fprintf(conn, "hellllllloooooo")
|
||||
status, err := bufio.NewReader(conn).ReadString('\n')
|
||||
if status != "" {
|
||||
@ -573,15 +564,20 @@ func TestOnlyTaggedPeersCanDialRaftPort(t *testing.T) {
|
||||
return err
|
||||
}
|
||||
|
||||
isNetErr := func(err error) bool {
|
||||
var netErr net.Error
|
||||
return errors.As(err, &netErr)
|
||||
}
|
||||
|
||||
err := getErrorFromTryingToSend(untaggedNode)
|
||||
if !isNetTimeoutErr(err) {
|
||||
t.Fatalf("untagged node trying to send should time out, got: %v", err)
|
||||
if !isNetErr(err) {
|
||||
t.Fatalf("untagged node trying to send should get a net.Error, got: %v", err)
|
||||
}
|
||||
// we still get an error trying to send but it's EOF the target node was happy to talk
|
||||
// to us but couldn't understand what we said.
|
||||
err = getErrorFromTryingToSend(taggedNode)
|
||||
if isNetTimeoutErr(err) {
|
||||
t.Fatalf("tagged node trying to send should not time out, got: %v", err)
|
||||
if isNetErr(err) {
|
||||
t.Fatalf("tagged node trying to send should not get a net.Error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user