diff --git a/pkg/beholder/auth_test.go b/pkg/beholder/auth_test.go index e30ceb405..96aae1af8 100644 --- a/pkg/beholder/auth_test.go +++ b/pkg/beholder/auth_test.go @@ -193,4 +193,139 @@ func TestRotatingAuth(t *testing.T) { mockSigner.AssertExpectations(t) }) + + t.Run("concurrent access during header rotation detects race condition", func(t *testing.T) { + // This test is designed to catch the race condition where r.headers + // is read without holding the lock while another goroutine is writing to it. + // + // Race condition scenario: + // 1. Goroutine A: Acquires lock, updates r.headers[key], releases lock + // 2. Goroutine B: Reads r.headers without lock (line 132 in auth.go) + // 3. Result: Concurrent map read/write = DATA RACE + // + // Run with: go test -race -run "TestRotatingAuth/concurrent" ./pkg/beholder + + mockSigner := &MockSigner{} + dummySignature := ed25519.Sign(privKey, []byte("test data")) + + // Make signing slow to increase chance of catching the race + mockSigner. + On("Sign", mock.Anything, mock.Anything, mock.Anything). + Run(func(args mock.Arguments) { + time.Sleep(10 * time.Millisecond) + }). + Return(dummySignature, nil) + + // Use a very short TTL to force frequent rotations + ttl := 1 * time.Millisecond + auth := beholder.NewRotatingAuth(pubKey, mockSigner, ttl, false) + + // Force initial header creation + _, err := auth.Headers(t.Context()) + require.NoError(t, err) + + // Wait for TTL to expire + time.Sleep(5 * time.Millisecond) + + // Launch multiple goroutines that will all try to access headers + // when TTL has expired, causing concurrent rotation attempts + const numGoroutines = 50 + errChan := make(chan error, numGoroutines) + doneChan := make(chan struct{}) + + for i := 0; i < numGoroutines; i++ { + go func() { + for { + select { + case <-doneChan: + return + default: + // Continuously read headers to maximize chance of race + _, err := auth.Headers(context.Background()) + if err != nil { + errChan <- err + return + } + } + } + }() + } + + // Let goroutines run for a bit to trigger multiple rotations + time.Sleep(100 * time.Millisecond) + close(doneChan) + + // Check for errors + select { + case err := <-errChan: + t.Fatalf("Unexpected error during concurrent access: %v", err) + case <-time.After(100 * time.Millisecond): + // No errors, test passed + } + + // If run with -race flag, the race detector will catch the issue + // even if the test doesn't fail functionally + }) + + t.Run("concurrent header reads during rotation", func(t *testing.T) { + // Another variant focusing on the specific race between + // writing to r.headers[key] and returning r.headers + + mockSigner := &MockSigner{} + dummySignature := ed25519.Sign(privKey, []byte("test data")) + + callCount := 0 + mockSigner. + On("Sign", mock.Anything, mock.Anything, mock.Anything). + Run(func(args mock.Arguments) { + callCount++ + // Slow down signing to create a window for the race + time.Sleep(20 * time.Millisecond) + }). + Return(dummySignature, nil) + + ttl := 10 * time.Millisecond + auth := beholder.NewRotatingAuth(pubKey, mockSigner, ttl, false) + + // Create initial headers + headers1, err := auth.Headers(t.Context()) + require.NoError(t, err) + require.NotEmpty(t, headers1) + + // Wait for TTL to expire + time.Sleep(15 * time.Millisecond) + + // Now launch concurrent readers + // One will trigger rotation (acquire lock, start signing) + // Others should either wait or read the map concurrently (race!) + const numReaders = 20 + results := make(chan map[string]string, numReaders) + errors := make(chan error, numReaders) + + for i := 0; i < numReaders; i++ { + go func() { + headers, err := auth.Headers(context.Background()) + if err != nil { + errors <- err + return + } + results <- headers + }() + } + + // Collect results + for i := 0; i < numReaders; i++ { + select { + case err := <-errors: + t.Fatalf("Unexpected error: %v", err) + case headers := <-results: + assert.NotEmpty(t, headers) + case <-time.After(2 * time.Second): + t.Fatal("Timeout waiting for results") + } + } + + // Verify signing was called (rotation happened) + assert.Greater(t, callCount, 1, "Expected at least one rotation to occur") + }) }