From 24ce3279f4e9eb6baa4732ade64b833366a45a2d Mon Sep 17 00:00:00 2001 From: Fran Bull Date: Thu, 27 Feb 2025 13:01:05 -0800 Subject: [PATCH] close the conn, don't leave it open --- tsconsensus/tsconsensus.go | 36 ++++++++++++++++++++------------- tsconsensus/tsconsensus_test.go | 22 +++++++++----------- 2 files changed, 31 insertions(+), 27 deletions(-) diff --git a/tsconsensus/tsconsensus.go b/tsconsensus/tsconsensus.go index 03be108de..38e9755d9 100644 --- a/tsconsensus/tsconsensus.go +++ b/tsconsensus/tsconsensus.go @@ -112,29 +112,37 @@ func (sl StreamLayer) Dial(address raft.ServerAddress, timeout time.Duration) (n return sl.s.Dial(ctx, "tcp", string(address)) } +func (sl StreamLayer) connAuthorized(conn net.Conn) (bool, error) { + if conn.RemoteAddr() == nil { + return false, nil + } + addr, err := addrFromServerAddress(conn.RemoteAddr().String()) + if err != nil { + // bad RemoteAddr is not authorized + return false, nil + } + ctx := context.Background() // TODO + err = sl.auth.refresh(ctx) + if err != nil { + // might be authorized, we couldn't tell + return false, err + } + return sl.auth.allowsHost(addr), nil +} + func (sl StreamLayer) Accept() (net.Conn, error) { for { conn, err := sl.Listener.Accept() if err != nil || conn == nil { return conn, err } - ctx := context.Background() // TODO - err = sl.auth.refresh(ctx) + authorized, err := sl.connAuthorized(conn) if err != nil { - // TODO should we stay alive here? + conn.Close() return nil, err } - - if conn.RemoteAddr() == nil { - continue - } - addr, err := addrFromServerAddress(conn.RemoteAddr().String()) - if err != nil { - // TODO should we stay alive here? - return nil, err - } - - if !sl.auth.allowsHost(addr) { + if !authorized { + conn.Close() continue } return conn, err diff --git a/tsconsensus/tsconsensus_test.go b/tsconsensus/tsconsensus_test.go index e5746b5ff..0f8b7cd45 100644 --- a/tsconsensus/tsconsensus_test.go +++ b/tsconsensus/tsconsensus_test.go @@ -546,14 +546,6 @@ func TestOnlyTaggedPeersCanDialRaftPort(t *testing.T) { ipv4, _ := ps[0].ts.TailscaleIPs() sAddr := fmt.Sprintf("%s:%d", ipv4, cfg.RaftPort) - isNetTimeoutErr := func(err error) bool { - var netErr net.Error - if !errors.As(err, &netErr) { - return false - } - return netErr.Timeout() - } - getErrorFromTryingToSend := func(s *tsnet.Server) error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -561,7 +553,6 @@ func TestOnlyTaggedPeersCanDialRaftPort(t *testing.T) { if err != nil { t.Fatalf("unexpected Dial err: %v", err) } - conn.SetDeadline(time.Now().Add(5 * time.Second)) fmt.Fprintf(conn, "hellllllloooooo") status, err := bufio.NewReader(conn).ReadString('\n') if status != "" { @@ -573,15 +564,20 @@ func TestOnlyTaggedPeersCanDialRaftPort(t *testing.T) { return err } + isNetErr := func(err error) bool { + var netErr net.Error + return errors.As(err, &netErr) + } + err := getErrorFromTryingToSend(untaggedNode) - if !isNetTimeoutErr(err) { - t.Fatalf("untagged node trying to send should time out, got: %v", err) + if !isNetErr(err) { + t.Fatalf("untagged node trying to send should get a net.Error, got: %v", err) } // we still get an error trying to send but it's EOF the target node was happy to talk // to us but couldn't understand what we said. err = getErrorFromTryingToSend(taggedNode) - if isNetTimeoutErr(err) { - t.Fatalf("tagged node trying to send should not time out, got: %v", err) + if isNetErr(err) { + t.Fatalf("tagged node trying to send should not get a net.Error, got: %v", err) } }