diff --git a/hscontrol/db/ephemeral_garbage_collector_test.go b/hscontrol/db/ephemeral_garbage_collector_test.go new file mode 100644 index 00000000..ae75c6d7 --- /dev/null +++ b/hscontrol/db/ephemeral_garbage_collector_test.go @@ -0,0 +1,389 @@ +package db + +import ( + "math/rand" + "runtime" + "sync" + "testing" + "time" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" +) + +const fiveHundredMillis = 500 * time.Millisecond +const oneHundredMillis = 100 * time.Millisecond +const fiftyMillis = 50 * time.Millisecond + +// TestEphemeralGarbageCollectorGoRoutineLeak is a test for a goroutine leak in EphemeralGarbageCollector(). +// It creates a new EphemeralGarbageCollector, schedules several nodes for deletion with a short expiry, +// and verifies that the nodes are deleted when the expiry time passes, and then +// for any leaked goroutines after the garbage collector is closed. +func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) { + // Count goroutines at the start + initialGoroutines := runtime.NumGoroutine() + t.Logf("Initial number of goroutines: %d", initialGoroutines) + + // Basic deletion tracking mechanism + var deletedIDs []types.NodeID + var deleteMutex sync.Mutex + var deletionWg sync.WaitGroup + + deleteFunc := func(nodeID types.NodeID) { + deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() + deletionWg.Done() + } + + // Start the GC + gc := NewEphemeralGarbageCollector(deleteFunc) + go gc.Start() + + // Schedule several nodes for deletion with short expiry + const expiry = fiftyMillis + const numNodes = 100 + + // Set up wait group for expected deletions + deletionWg.Add(numNodes) + + for i := 1; i <= numNodes; i++ { + gc.Schedule(types.NodeID(i), expiry) + } + + // Wait for all scheduled deletions to complete + deletionWg.Wait() + + // Check nodes are deleted + deleteMutex.Lock() + assert.Equal(t, numNodes, len(deletedIDs), "Not all nodes were deleted") + deleteMutex.Unlock() + + // Schedule and immediately cancel to test that part of the code + for i := numNodes + 1; i <= numNodes*2; i++ { + nodeID := types.NodeID(i) + gc.Schedule(nodeID, time.Hour) + gc.Cancel(nodeID) + } + + // Create a channel to signal when we're done with cleanup checks + cleanupDone := make(chan struct{}) + + // Close GC and check for leaks in a separate goroutine + go func() { + // Close GC + gc.Close() + + // Give any potential leaked goroutines a chance to exit + // Still need a small sleep here as we're checking for absence of goroutines + time.Sleep(oneHundredMillis) + + // Check for leaked goroutines + finalGoroutines := runtime.NumGoroutine() + t.Logf("Final number of goroutines: %d", finalGoroutines) + + // NB: We have to allow for a small number of extra goroutines because of test itself + assert.LessOrEqual(t, finalGoroutines, initialGoroutines+5, + "There are significantly more goroutines after GC usage, which suggests a leak") + + close(cleanupDone) + }() + + // Wait for cleanup to complete + <-cleanupDone +} + +// TestEphemeralGarbageCollectorReschedule is a test for the rescheduling of nodes in EphemeralGarbageCollector(). +// It creates a new EphemeralGarbageCollector, schedules a node for deletion with a longer expiry, +// and then reschedules it with a shorter expiry, and verifies that the node is deleted only once. +func TestEphemeralGarbageCollectorReschedule(t *testing.T) { + // Deletion tracking mechanism + var deletedIDs []types.NodeID + var deleteMutex sync.Mutex + + deleteFunc := func(nodeID types.NodeID) { + deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() + } + + // Start GC + gc := NewEphemeralGarbageCollector(deleteFunc) + go gc.Start() + defer gc.Close() + + const shortExpiry = fiftyMillis + const longExpiry = 1 * time.Hour + + nodeID := types.NodeID(1) + + // Schedule node for deletion with long expiry + gc.Schedule(nodeID, longExpiry) + + // Reschedule the same node with a shorter expiry + gc.Schedule(nodeID, shortExpiry) + + // Wait for deletion + time.Sleep(shortExpiry * 2) + + // Verify that the node was deleted once + deleteMutex.Lock() + assert.Equal(t, 1, len(deletedIDs), "Node should be deleted exactly once") + assert.Equal(t, nodeID, deletedIDs[0], "The correct node should be deleted") + deleteMutex.Unlock() +} + +// TestEphemeralGarbageCollectorCancelAndReschedule is a test for the cancellation and rescheduling of nodes in EphemeralGarbageCollector(). +// It creates a new EphemeralGarbageCollector, schedules a node for deletion, cancels it, and then reschedules it, +// and verifies that the node is deleted only once. +func TestEphemeralGarbageCollectorCancelAndReschedule(t *testing.T) { + // Deletion tracking mechanism + var deletedIDs []types.NodeID + var deleteMutex sync.Mutex + deletionNotifier := make(chan types.NodeID, 1) + + deleteFunc := func(nodeID types.NodeID) { + deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() + deletionNotifier <- nodeID + } + + // Start the GC + gc := NewEphemeralGarbageCollector(deleteFunc) + go gc.Start() + defer gc.Close() + + nodeID := types.NodeID(1) + const expiry = fiftyMillis + + // Schedule node for deletion + gc.Schedule(nodeID, expiry) + + // Cancel the scheduled deletion + gc.Cancel(nodeID) + + // Use a timeout to verify no deletion occurred + select { + case <-deletionNotifier: + t.Fatal("Node was deleted after cancellation") + case <-time.After(expiry * 2): // Still need a timeout for negative test + // This is expected - no deletion should occur + } + + deleteMutex.Lock() + assert.Equal(t, 0, len(deletedIDs), "Node should not be deleted after cancellation") + deleteMutex.Unlock() + + // Reschedule the node + gc.Schedule(nodeID, expiry) + + // Wait for deletion with timeout + select { + case deletedNodeID := <-deletionNotifier: + // Verify the correct node was deleted + assert.Equal(t, nodeID, deletedNodeID, "The correct node should be deleted") + case <-time.After(time.Second): // Longer timeout as a safety net + t.Fatal("Timed out waiting for node deletion") + } + + // Verify final state + deleteMutex.Lock() + assert.Equal(t, 1, len(deletedIDs), "Node should be deleted after rescheduling") + assert.Equal(t, nodeID, deletedIDs[0], "The correct node should be deleted") + deleteMutex.Unlock() +} + +// TestEphemeralGarbageCollectorCloseBeforeTimerFires is a test for the closing of the EphemeralGarbageCollector before the timer fires. +// It creates a new EphemeralGarbageCollector, schedules a node for deletion, closes the GC, and verifies that the node is not deleted. +func TestEphemeralGarbageCollectorCloseBeforeTimerFires(t *testing.T) { + // Deletion tracking + var deletedIDs []types.NodeID + var deleteMutex sync.Mutex + + deleteFunc := func(nodeID types.NodeID) { + deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() + } + + // Start the GC + gc := NewEphemeralGarbageCollector(deleteFunc) + go gc.Start() + + const longExpiry = 1 * time.Hour + const shortExpiry = fiftyMillis + + // Schedule node deletion with a long expiry + gc.Schedule(types.NodeID(1), longExpiry) + + // Close the GC before the timer + gc.Close() + + // Wait a short time + time.Sleep(shortExpiry * 2) + + // Verify that no deletion occurred + deleteMutex.Lock() + assert.Equal(t, 0, len(deletedIDs), "No node should be deleted when GC is closed before timer fires") + deleteMutex.Unlock() +} + +// TestEphemeralGarbageCollectorScheduleAfterClose verifies that calling Schedule after Close +// is a no-op and doesn't cause any panics, goroutine leaks, or other issues. +func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) { + // Count initial goroutines to check for leaks + initialGoroutines := runtime.NumGoroutine() + t.Logf("Initial number of goroutines: %d", initialGoroutines) + + // Deletion tracking + var deletedIDs []types.NodeID + var deleteMutex sync.Mutex + nodeDeleted := make(chan struct{}) + + deleteFunc := func(nodeID types.NodeID) { + deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() + close(nodeDeleted) // Signal that deletion happened + } + + // Start new GC + gc := NewEphemeralGarbageCollector(deleteFunc) + + // Use a WaitGroup to ensure the GC has started + var startWg sync.WaitGroup + startWg.Add(1) + go func() { + startWg.Done() // Signal that the goroutine has started + gc.Start() + }() + startWg.Wait() // Wait for the GC to start + + // Close GC right away + gc.Close() + + // Use a channel to signal when we should check for goroutine count + gcClosedCheck := make(chan struct{}) + go func() { + // Give the GC time to fully close and clean up resources + // This is still time-based but only affects when we check the goroutine count, + // not the actual test logic + time.Sleep(oneHundredMillis) + close(gcClosedCheck) + }() + + // Now try to schedule node for deletion with a very short expiry + // If the Schedule operation incorrectly creates a timer, it would fire quickly + nodeID := types.NodeID(1) + gc.Schedule(nodeID, 1*time.Millisecond) + + // Set up a timeout channel for our test + timeout := time.After(fiveHundredMillis) + + // Check if any node was deleted (which shouldn't happen) + select { + case <-nodeDeleted: + t.Fatal("Node was deleted after GC was closed, which should not happen") + case <-timeout: + // This is the expected path - no deletion should occur + } + + // Check no node was deleted + deleteMutex.Lock() + nodesDeleted := len(deletedIDs) + deleteMutex.Unlock() + assert.Equal(t, 0, nodesDeleted, "No nodes should be deleted when Schedule is called after Close") + + // Check for goroutine leaks after GC is fully closed + <-gcClosedCheck + finalGoroutines := runtime.NumGoroutine() + t.Logf("Final number of goroutines: %d", finalGoroutines) + + // Allow for small fluctuations in goroutine count for testing routines etc + assert.LessOrEqual(t, finalGoroutines, initialGoroutines+2, + "There should be no significant goroutine leaks when Schedule is called after Close") +} + +// TestEphemeralGarbageCollectorConcurrentScheduleAndClose tests the behavior of the garbage collector +// when Schedule and Close are called concurrently from multiple goroutines. +func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) { + // Count initial goroutines + initialGoroutines := runtime.NumGoroutine() + t.Logf("Initial number of goroutines: %d", initialGoroutines) + + // Deletion tracking mechanism + var deletedIDs []types.NodeID + var deleteMutex sync.Mutex + + deleteFunc := func(nodeID types.NodeID) { + deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() + } + + // Start the GC + gc := NewEphemeralGarbageCollector(deleteFunc) + go gc.Start() + + // Number of concurrent scheduling goroutines + const numSchedulers = 10 + const nodesPerScheduler = 50 + const schedulingDuration = fiveHundredMillis + + // Use WaitGroup to wait for all scheduling goroutines to finish + var wg sync.WaitGroup + wg.Add(numSchedulers + 1) // +1 for the closer goroutine + + // Create a stopper channel to signal scheduling goroutines to stop + stopScheduling := make(chan struct{}) + + // Launch goroutines that continuously schedule nodes + for i := 0; i < numSchedulers; i++ { + go func(schedulerID int) { + defer wg.Done() + + baseNodeID := schedulerID * nodesPerScheduler + + // Keep scheduling nodes until signaled to stop + for j := 0; j < nodesPerScheduler; j++ { + select { + case <-stopScheduling: + return + default: + nodeID := types.NodeID(baseNodeID + j + 1) + gc.Schedule(nodeID, 1*time.Hour) // Long expiry to ensure it doesn't trigger during test + + // Random (short) sleep to introduce randomness/variability + time.Sleep(time.Duration(rand.Intn(5)) * time.Millisecond) + } + } + }(i) + } + + // After a short delay, close the garbage collector while schedulers are still running + go func() { + defer wg.Done() + time.Sleep(schedulingDuration / 2) + + // Close GC + gc.Close() + + // Signal schedulers to stop + close(stopScheduling) + }() + + // Wait for all goroutines to complete + wg.Wait() + + // Wait a bit longer to allow any leaked goroutines to do their work + time.Sleep(oneHundredMillis) + + // Check for leaks + finalGoroutines := runtime.NumGoroutine() + t.Logf("Final number of goroutines: %d", finalGoroutines) + + // Allow for a reasonable small variable routine count due to testing + assert.LessOrEqual(t, finalGoroutines, initialGoroutines+5, + "There should be no significant goroutine leaks during concurrent Schedule and Close operations") +} diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 09bc795d..ed9e1f73 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -630,22 +630,59 @@ func NewEphemeralGarbageCollector(deleteFunc func(types.NodeID)) *EphemeralGarba // Close stops the garbage collector. func (e *EphemeralGarbageCollector) Close() { - e.cancelCh <- struct{}{} + e.mu.Lock() + defer e.mu.Unlock() + + // Stop all timers + for _, timer := range e.toBeDeleted { + timer.Stop() + } + + // Close the cancel channel to signal all goroutines to exit + close(e.cancelCh) } // Schedule schedules a node for deletion after the expiry duration. +// If the garbage collector is already closed, this is a no-op. func (e *EphemeralGarbageCollector) Schedule(nodeID types.NodeID, expiry time.Duration) { e.mu.Lock() + defer e.mu.Unlock() + + // Don't schedule new timers if the garbage collector is already closed + select { + case <-e.cancelCh: + // The cancel channel is closed, meaning the GC is shutting down + // or already shut down, so we shouldn't schedule anything new + return + default: + // Continue with scheduling + } + + // If a timer already exists for this node, stop it first + if oldTimer, exists := e.toBeDeleted[nodeID]; exists { + oldTimer.Stop() + } + timer := time.NewTimer(expiry) e.toBeDeleted[nodeID] = timer - e.mu.Unlock() - + // Start a goroutine to handle the timer completion go func() { select { - case _, ok := <-timer.C: - if ok { - e.deleteCh <- nodeID + case <-timer.C: + // This is to handle the situation where the GC is shutting down and + // we are trying to schedule a new node for deletion at the same time + // i.e. We don't want to send to deleteCh if the GC is shutting down + // So, we try to send to deleteCh, but also watch for cancelCh + select { + case e.deleteCh <- nodeID: + // Successfully sent to deleteCh + case <-e.cancelCh: + // GC is shutting down, don't send to deleteCh + return } + case <-e.cancelCh: + // If the GC is closed, exit the goroutine + return } }() }