close the conn, don't leave it open

This commit is contained in:
Fran Bull 2025-02-27 13:01:05 -08:00
parent ad7d1ee07a
commit 24ce3279f4
2 changed files with 31 additions and 27 deletions

View File

@ -112,29 +112,37 @@ func (sl StreamLayer) Dial(address raft.ServerAddress, timeout time.Duration) (n
return sl.s.Dial(ctx, "tcp", string(address)) return sl.s.Dial(ctx, "tcp", string(address))
} }
func (sl StreamLayer) connAuthorized(conn net.Conn) (bool, error) {
if conn.RemoteAddr() == nil {
return false, nil
}
addr, err := addrFromServerAddress(conn.RemoteAddr().String())
if err != nil {
// bad RemoteAddr is not authorized
return false, nil
}
ctx := context.Background() // TODO
err = sl.auth.refresh(ctx)
if err != nil {
// might be authorized, we couldn't tell
return false, err
}
return sl.auth.allowsHost(addr), nil
}
func (sl StreamLayer) Accept() (net.Conn, error) { func (sl StreamLayer) Accept() (net.Conn, error) {
for { for {
conn, err := sl.Listener.Accept() conn, err := sl.Listener.Accept()
if err != nil || conn == nil { if err != nil || conn == nil {
return conn, err return conn, err
} }
ctx := context.Background() // TODO authorized, err := sl.connAuthorized(conn)
err = sl.auth.refresh(ctx)
if err != nil { if err != nil {
// TODO should we stay alive here? conn.Close()
return nil, err return nil, err
} }
if !authorized {
if conn.RemoteAddr() == nil { conn.Close()
continue
}
addr, err := addrFromServerAddress(conn.RemoteAddr().String())
if err != nil {
// TODO should we stay alive here?
return nil, err
}
if !sl.auth.allowsHost(addr) {
continue continue
} }
return conn, err return conn, err

View File

@ -546,14 +546,6 @@ func TestOnlyTaggedPeersCanDialRaftPort(t *testing.T) {
ipv4, _ := ps[0].ts.TailscaleIPs() ipv4, _ := ps[0].ts.TailscaleIPs()
sAddr := fmt.Sprintf("%s:%d", ipv4, cfg.RaftPort) sAddr := fmt.Sprintf("%s:%d", ipv4, cfg.RaftPort)
isNetTimeoutErr := func(err error) bool {
var netErr net.Error
if !errors.As(err, &netErr) {
return false
}
return netErr.Timeout()
}
getErrorFromTryingToSend := func(s *tsnet.Server) error { getErrorFromTryingToSend := func(s *tsnet.Server) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
@ -561,7 +553,6 @@ func TestOnlyTaggedPeersCanDialRaftPort(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("unexpected Dial err: %v", err) t.Fatalf("unexpected Dial err: %v", err)
} }
conn.SetDeadline(time.Now().Add(5 * time.Second))
fmt.Fprintf(conn, "hellllllloooooo") fmt.Fprintf(conn, "hellllllloooooo")
status, err := bufio.NewReader(conn).ReadString('\n') status, err := bufio.NewReader(conn).ReadString('\n')
if status != "" { if status != "" {
@ -573,15 +564,20 @@ func TestOnlyTaggedPeersCanDialRaftPort(t *testing.T) {
return err return err
} }
isNetErr := func(err error) bool {
var netErr net.Error
return errors.As(err, &netErr)
}
err := getErrorFromTryingToSend(untaggedNode) err := getErrorFromTryingToSend(untaggedNode)
if !isNetTimeoutErr(err) { if !isNetErr(err) {
t.Fatalf("untagged node trying to send should time out, got: %v", err) t.Fatalf("untagged node trying to send should get a net.Error, got: %v", err)
} }
// we still get an error trying to send but it's EOF the target node was happy to talk // we still get an error trying to send but it's EOF the target node was happy to talk
// to us but couldn't understand what we said. // to us but couldn't understand what we said.
err = getErrorFromTryingToSend(taggedNode) err = getErrorFromTryingToSend(taggedNode)
if isNetTimeoutErr(err) { if isNetErr(err) {
t.Fatalf("tagged node trying to send should not time out, got: %v", err) t.Fatalf("tagged node trying to send should not get a net.Error, got: %v", err)
} }
} }