Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions src/Castle.Core.AsyncInterceptor/AsyncDeterminationInterceptor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ public class AsyncDeterminationInterceptor : IInterceptor
typeof(AsyncDeterminationInterceptor)
.GetMethod(nameof(HandleAsyncWithResult), BindingFlags.Static | BindingFlags.NonPublic)!;

private static readonly MethodInfo HandleAsyncEnumerableInfo =
typeof(AsyncDeterminationInterceptor).GetMethod(nameof(HandleAsyncEnumerable), BindingFlags.Static | BindingFlags.NonPublic)!;

private static readonly ConcurrentDictionary<Type, GenericAsyncHandler> GenericAsyncHandlers = new();

/// <summary>
Expand All @@ -34,6 +37,7 @@ private enum MethodType
Synchronous,
AsyncAction,
AsyncFunction,
AsyncEnumerableFunction,
}

/// <summary>
Expand All @@ -56,6 +60,7 @@ public virtual void Intercept(IInvocation invocation)
AsyncInterceptor.InterceptAsynchronous(invocation);
return;
case MethodType.AsyncFunction:
case MethodType.AsyncEnumerableFunction:
GetHandler(invocation.Method.ReturnType).Invoke(invocation, AsyncInterceptor);
return;
case MethodType.Synchronous:
Expand All @@ -70,6 +75,11 @@ public virtual void Intercept(IInvocation invocation)
/// </summary>
private static MethodType GetMethodType(Type returnType)
{
TypeInfo typeInfo = returnType.GetTypeInfo();

if (typeInfo.IsGenericType && typeInfo.GetGenericTypeDefinition() == typeof(IAsyncEnumerable<>))
return MethodType.AsyncEnumerableFunction;

// If there's no return type, or it's not a task, then assume it's a synchronous method.
if (returnType == typeof(void) || !typeof(Task).IsAssignableFrom(returnType))
return MethodType.Synchronous;
Expand All @@ -92,9 +102,21 @@ private static GenericAsyncHandler GetHandler(Type returnType)
/// </summary>
private static GenericAsyncHandler CreateHandler(Type returnType)
{
Type taskReturnType = returnType.GetGenericArguments()[0];
MethodInfo method = HandleAsyncMethodInfo.MakeGenericMethod(taskReturnType);
return (GenericAsyncHandler)method.CreateDelegate(typeof(GenericAsyncHandler));
Type genericType = returnType.GetGenericTypeDefinition();
if (typeof(Task).IsAssignableFrom(genericType))
{
Type? taskReturnType = returnType.GetGenericArguments()[0];
MethodInfo method = HandleAsyncMethodInfo.MakeGenericMethod(taskReturnType);
return (GenericAsyncHandler)method.CreateDelegate(typeof(GenericAsyncHandler));
}
else if (genericType == typeof(IAsyncEnumerable<>))
{
Type enumerableType = returnType.GetGenericArguments()[0];
MethodInfo method = HandleAsyncEnumerableInfo.MakeGenericMethod(enumerableType);
return (GenericAsyncHandler)method.CreateDelegate(typeof(GenericAsyncHandler));
}

throw new ArgumentException("Only Task, Task<> or IAsyncEnumerable<> return types are supported", nameof(returnType));
}

/// <summary>
Expand All @@ -107,4 +129,13 @@ private static void HandleAsyncWithResult<TResult>(IInvocation invocation, IAsyn
{
asyncInterceptor.InterceptAsynchronous<TResult>(invocation);
}

/// <summary>
/// This method is created as a delegate and used to make the call to the generic
/// <see cref="IAsyncInterceptor.InterceptAsyncEnumerable{T}"/> method.
/// </summary>
/// <typeparam name="TResult">The type of the <see cref="IAsyncEnumerable{T}"/> of the method
/// <paramref name="invocation"/>.</typeparam>
private static void HandleAsyncEnumerable<TResult>(IInvocation invocation, IAsyncInterceptor asyncInterceptor)
=> asyncInterceptor.InterceptAsyncEnumerable<TResult>(invocation);
}
6 changes: 6 additions & 0 deletions src/Castle.Core.AsyncInterceptor/AsyncInterceptorBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ public void InterceptAsynchronous<TResult>(IInvocation invocation)
InterceptAsync(invocation, invocation.CaptureProceedInfo(), ProceedAsynchronous<TResult>);
}

/// <inheritdoc/>
public void InterceptAsyncEnumerable<TResult>(IInvocation invocation)
{
throw new NotImplementedException(); // TODO : discuss the correct model here
}

/// <summary>
/// Override in derived classes to intercept method invocations.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFrameworks>net45;netstandard2.0;net5.0;net6.0;net7.0</TargetFrameworks>
<TargetFrameworks>netstandard2.0;net5.0;net6.0;net7.0</TargetFrameworks>
<RootNamespace>Castle.DynamicProxy</RootNamespace>
<GenerateDocumentationFile>true</GenerateDocumentationFile>
<GenerateAssemblyInfo>false</GenerateAssemblyInfo>
Expand Down Expand Up @@ -49,4 +49,28 @@
<None Include="..\..\docs\images\castle-logo.png" Pack="true" PackagePath="" />
</ItemGroup>

<ItemGroup Condition="'$(TargetFramework)' == 'netstandard2.0'">
<PackageReference Include="Microsoft.Bcl.AsyncInterfaces">
<Version>7.0.0</Version>
</PackageReference>
</ItemGroup>

<ItemGroup Condition="'$(TargetFramework)' == 'net5.0'">
<PackageReference Include="Microsoft.Bcl.AsyncInterfaces">
<Version>7.0.0</Version>
</PackageReference>
</ItemGroup>

<ItemGroup Condition="'$(TargetFramework)' == 'net6.0'">
<PackageReference Include="Microsoft.Bcl.AsyncInterfaces">
<Version>7.0.0</Version>
</PackageReference>
</ItemGroup>

<ItemGroup Condition="'$(TargetFramework)' == 'net7.0'">
<PackageReference Include="Microsoft.Bcl.AsyncInterfaces">
<Version>7.0.0</Version>
</PackageReference>
</ItemGroup>

</Project>
7 changes: 7 additions & 0 deletions src/Castle.Core.AsyncInterceptor/IAsyncInterceptor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,11 @@ public interface IAsyncInterceptor
/// <typeparam name="TResult">The type of the <see cref="Task{T}"/> <see cref="Task{T}.Result"/>.</typeparam>
/// <param name="invocation">The method invocation.</param>
void InterceptAsynchronous<TResult>(IInvocation invocation);

/// <summary>
/// Intercepts a method <paramref name="invocation"/> that returns an <see cref="IAsyncEnumerable{T}"/>.
/// </summary>
/// <typeparam name="TResult">The type of the returned enumerable.</typeparam>
/// <param name="invocation">The method invocation.</param>
void InterceptAsyncEnumerable<TResult>(IInvocation invocation);
}
25 changes: 25 additions & 0 deletions src/Castle.Core.AsyncInterceptor/ProcessingAsyncInterceptor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information.

namespace Castle.DynamicProxy;
using System.Collections.Generic;

/// <summary>
/// A base type for an <see cref="IAsyncInterceptor"/> which executes only minimal processing when intercepting a
Expand Down Expand Up @@ -47,6 +48,18 @@ public void InterceptAsynchronous<TResult>(IInvocation invocation)
invocation.ReturnValue = SignalWhenCompleteAsync<TResult>(invocation, state);
}

/// <summary>
/// Intercepts a method <paramref name="invocation"/> with return type of <see cref="IAsyncEnumerable{T}"/>.
/// </summary>
/// <typeparam name="TResult">The type of ther returned enumerable.</typeparam>
/// <param name="invocation">The method invocation.</param>
public void InterceptAsyncEnumerable<TResult>(IInvocation invocation)
{
TState state = Proceed(invocation);
var innerAsync = (IAsyncEnumerable<TResult>)invocation.ReturnValue;
invocation.ReturnValue = SignalWhenEnumerationCompleteAsync<TResult>(invocation, innerAsync, state);
}

/// <summary>
/// Override in derived classes to receive signals prior method <paramref name="invocation"/>.
/// </summary>
Expand Down Expand Up @@ -140,4 +153,16 @@ private async Task<TResult> SignalWhenCompleteAsync<TResult>(IInvocation invocat

return result;
}

private async IAsyncEnumerable<TResult> SignalWhenEnumerationCompleteAsync<TResult>(IInvocation invocation, IAsyncEnumerable<TResult> innerAsync, TState state)
{
// loop / yield the proxied method
await foreach (TResult item in innerAsync)
{
yield return item;
}

// Signal that the invocation has been completed.
CompletedInvocation(invocation, state);
}
}
56 changes: 56 additions & 0 deletions test/Castle.Core.AsyncInterceptor.Tests/AsyncInterceptorShould.cs
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,59 @@ public async Task ShouldAllowInterceptionAfterInvocation()
Assert.Equal($"{MethodName}:InterceptEnd", _log[3]);
}
}

public class WhenInterceptingAsyncEnumerableMethods
{
private const string MethodName = nameof(IInterfaceToProxy.AsyncEnumerableMethod);
private readonly ListLogger _log;
private readonly IInterfaceToProxy _proxy;

public WhenInterceptingAsyncEnumerableMethods(ITestOutputHelper output)
{
_log = new ListLogger(output);
_proxy = ProxyGen.CreateProxy(_log, new TestAsyncInterceptor(_log));
}

[Fact]
public async Task ShouldLog4Entries()
{
// Act
List<Guid> results = new();
await foreach (Guid result in _proxy.AsyncEnumerableMethod())
{
results.Add(result);
}

// Assert
Assert.Equal(10, results.Count);
Assert.Equal(4, _log.Count);
}

[Fact]
public async Task ShouldAllowInterceptionPriorToInvocation()
{
// Act
List<Guid> results = new();
await foreach (Guid result in _proxy.AsyncEnumerableMethod())
{
results.Add(result);
}

// Assert
Assert.Equal($"{MethodName}:InterceptStart", _log[0]);
}

[Fact]
public async Task ShouldAllowInterceptionAfterInvocation()
{
// Act
List<Guid> results = new();
await foreach (Guid result in _proxy.AsyncEnumerableMethod())
{
results.Add(result);
}

// Assert
Assert.Equal($"{MethodName}:InterceptEnd", _log[3]);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,61 @@ public async Task ShouldAllowTimingAfterInvocation()
Assert.Equal($"{MethodName}:CompletedTiming:{_interceptor.Stopwatch.Elapsed:g}", _log[3]);
}
}

public class WhenTimingAsyncEnumerableMethods
{
private const string MethodName = nameof(IInterfaceToProxy.AsyncEnumerableMethod);
private readonly ListLogger _log;
private readonly TestAsyncTimingInterceptor _interceptor;
private readonly IInterfaceToProxy _proxy;

public WhenTimingAsyncEnumerableMethods(ITestOutputHelper output)
{
_log = new ListLogger(output);
_interceptor = new TestAsyncTimingInterceptor(_log);
_proxy = ProxyGen.CreateProxy(_log, _interceptor);
}

[Fact]
public async Task ShouldLog4Entries()
{
// Act
List<Guid> results = new();
await foreach (Guid result in _proxy.AsyncEnumerableMethod())
{
results.Add(result);
}

// Assert
Assert.Equal(10, results.Count);
Assert.Equal(4, _log.Count);
}

[Fact]
public async Task ShouldAllowTimingPriorToInvocation()
{
// Act
List<Guid> results = new();
await foreach (Guid result in _proxy.AsyncEnumerableMethod())
{
results.Add(result);
}

// Assert
Assert.Equal($"{MethodName}:StartingTiming", _log[0]);
}

[Fact]
public async Task ShouldAllowTimingAfterInvocation()
{
// Act
List<Guid> results = new();
await foreach (Guid result in _proxy.AsyncEnumerableMethod())
{
results.Add(result);
}

// Assert
Assert.Equal($"{MethodName}:CompletedTiming:{_interceptor.Stopwatch.Elapsed:g}", _log[3]);
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFrameworks>net472;netcoreapp3.1;net5.0;net6.0;net7.0</TargetFrameworks>
<TargetFrameworks>netcoreapp3.1;net5.0;net6.0;net7.0</TargetFrameworks>
<IsPackable>false</IsPackable>
<NoWarn>$(NoWarn);SA0001</NoWarn>
</PropertyGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,19 @@ public Task<Guid> AsynchronousResultMethod()
_log.Add(nameof(AsynchronousResultMethod) + ":End");
return Task.FromResult(Guid.NewGuid());
}

public IAsyncEnumerable<Guid> AsyncEnumerableMethod()
{
throw new NotImplementedException();
}

public IAsyncEnumerator<Guid> AsyncEnumerableExceptionMethodNoReturnValues()
{
throw new NotImplementedException();
}

public IAsyncEnumerator<Guid> AsyncEnumerableExceptionMethodReturnSomeValues()
{
throw new NotImplementedException();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,19 @@ public async Task<Guid> AsynchronousResultMethod()
_log.Add(nameof(AsynchronousResultMethod) + ":End");
return Guid.NewGuid();
}

public IAsyncEnumerable<Guid> AsyncEnumerableMethod()
{
throw new NotImplementedException();
}

public IAsyncEnumerator<Guid> AsyncEnumerableExceptionMethodNoReturnValues()
{
throw new NotImplementedException();
}

public IAsyncEnumerator<Guid> AsyncEnumerableExceptionMethodReturnSomeValues()
{
throw new NotImplementedException();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,34 @@ public async Task<Guid> AsynchronousResultExceptionMethod()
await Task.Delay(10).ConfigureAwait(false);
throw new InvalidOperationException(nameof(AsynchronousResultExceptionMethod) + ":Exception");
}

public async IAsyncEnumerable<Guid> AsyncEnumerableMethod()
{
_log.Add(nameof(AsyncEnumerableMethod) + ":Start");
for (int i = 0; i < 10; i++)
{
await Task.Delay(10).ConfigureAwait(false);
yield return Guid.NewGuid();
}

_log.Add(nameof(AsyncEnumerableMethod) + ":End");
}

public IAsyncEnumerator<Guid> AsyncEnumerableExceptionMethodNoReturnValues()
{
_log.Add(nameof(AsyncEnumerableExceptionMethodNoReturnValues) + ":Start");
throw new InvalidOperationException(nameof(AsyncEnumerableExceptionMethodNoReturnValues) + ":Exception");
}

public async IAsyncEnumerator<Guid> AsyncEnumerableExceptionMethodReturnSomeValues()
{
_log.Add(nameof(AsyncEnumerableExceptionMethodReturnSomeValues) + ":Start");
for (int i = 0; i < 2; i++)
{
await Task.Delay(10).ConfigureAwait(false);
yield return Guid.NewGuid();
}

throw new InvalidOperationException(nameof(AsyncEnumerableExceptionMethodReturnSomeValues) + ":Exception");
}
}
Loading