Skip to content
78 changes: 27 additions & 51 deletions src/ModularPipelines/Engine/IModuleResultRegistry.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System.Collections.Concurrent;
using ModularPipelines.Models;

namespace ModularPipelines.Engine;
Expand Down Expand Up @@ -61,92 +62,67 @@ internal interface IModuleResultRegistry

/// <summary>
/// Default implementation of the module result registry.
/// Uses ConcurrentDictionary for lock-free reads and thread-safe writes.
/// </summary>
internal class ModuleResultRegistry : IModuleResultRegistry
{
private readonly Dictionary<Type, object> _results = new();
private readonly Dictionary<Type, TaskCompletionSource<object?>> _completionSources = new();
private readonly object _lock = new();
private readonly ConcurrentDictionary<Type, object> _results = new();
private readonly ConcurrentDictionary<Type, TaskCompletionSource<object?>> _completionSources = new();

public void RegisterModule(Type moduleType)
{
lock (_lock)
{
if (!_completionSources.ContainsKey(moduleType))
{
_completionSources[moduleType] = new TaskCompletionSource<object?>();
}
}
_completionSources.GetOrAdd(moduleType, _ => new TaskCompletionSource<object?>());
}

public void RegisterResult<T>(Type moduleType, ModuleResult<T> result)
{
lock (_lock)
{
_results[moduleType] = result;

if (_completionSources.TryGetValue(moduleType, out var tcs))
{
tcs.TrySetResult(result);
}
}
// Store result first, then signal completion
// TrySetResult provides release semantics, ensuring _results write is visible to awaiters
_results[moduleType] = result;
var tcs = _completionSources.GetOrAdd(moduleType, _ => new TaskCompletionSource<object?>());
tcs.TrySetResult(result);
}

public ModuleResult<T>? GetResult<T>(Type moduleType)
{
lock (_lock)
if (_results.TryGetValue(moduleType, out var result) && result is ModuleResult<T> typedResult)
{
if (_results.TryGetValue(moduleType, out var result) && result is ModuleResult<T> typedResult)
{
return typedResult;
}

return null;
return typedResult;
}

return null;
}

public IModuleResult? GetResult(Type moduleType)
{
lock (_lock)
if (_results.TryGetValue(moduleType, out var result))
{
if (_results.TryGetValue(moduleType, out var result))
{
return result as IModuleResult;
}

return null;
return result as IModuleResult;
}

return null;
}

public Task? GetCompletionTask(Type moduleType)
{
lock (_lock)
{
return _completionSources.TryGetValue(moduleType, out var tcs) ? tcs.Task : null;
}
return _completionSources.TryGetValue(moduleType, out var tcs) ? tcs.Task : null;
}

public void SetException(Type moduleType, Exception exception)
{
lock (_lock)
// Only set exception if module was previously registered (preserves original behavior)
if (_completionSources.TryGetValue(moduleType, out var tcs))
{
if (_completionSources.TryGetValue(moduleType, out var tcs))
{
tcs.TrySetException(exception);
}
tcs.TrySetException(exception);
}
}
Comment on lines 111 to 118

Copilot AI Jan 1, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SetException method has a similar race condition issue. If RegisterModule is called after TryGetValue but before or during TrySetException, a new TaskCompletionSource could be created that never gets the exception set. This would cause any code waiting on GetCompletionTask to hang.

Consider using a similar fix as for RegisterResult to ensure atomicity of operations on the TaskCompletionSource.

Copilot uses AI. Check for mistakes.

public void RegisterResult(Type moduleType, IModuleResult result)
{
lock (_lock)
{
_results[moduleType] = result;

if (_completionSources.TryGetValue(moduleType, out var tcs))
{
tcs.TrySetResult(result);
}
}
// Store result first, then signal completion
// TrySetResult provides release semantics, ensuring _results write is visible to awaiters
_results[moduleType] = result;
var tcs = _completionSources.GetOrAdd(moduleType, _ => new TaskCompletionSource<object?>());
tcs.TrySetResult(result);
}
}
Loading