diff --git a/src/Castle.Core.AsyncInterceptor/AsyncInterceptorBase.cs b/src/Castle.Core.AsyncInterceptor/AsyncInterceptorBase.cs index 7690d45..44b2c3a 100644 --- a/src/Castle.Core.AsyncInterceptor/AsyncInterceptorBase.cs +++ b/src/Castle.Core.AsyncInterceptor/AsyncInterceptorBase.cs @@ -127,6 +127,7 @@ private static void InterceptSynchronousResult(AsyncInterceptorBase me, } task.RethrowIfFaulted(); + invocation.ReturnValue = task.Result; } private static Task ProceedSynchronous(IInvocation invocation, IInvocationProceedInfo proceedInfo) diff --git a/test/Castle.Core.AsyncInterceptor.Tests/InterfaceProxies/TestProcessingReturnValueWithoutInvokingAsyncInterceptor.cs b/test/Castle.Core.AsyncInterceptor.Tests/InterfaceProxies/TestProcessingReturnValueWithoutInvokingAsyncInterceptor.cs new file mode 100644 index 0000000..85876c9 --- /dev/null +++ b/test/Castle.Core.AsyncInterceptor.Tests/InterfaceProxies/TestProcessingReturnValueWithoutInvokingAsyncInterceptor.cs @@ -0,0 +1,98 @@ +// Copyright (c) 2016-2023 James Skimming. All rights reserved. +// Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information. + +namespace Castle.DynamicProxy.InterfaceProxies; + +using System.Reflection; + +public class TestProcessingReturnValueWithoutInvokingAsyncInterceptor : AsyncInterceptorBase +{ + private static readonly MethodInfo TaskFromResultMethod = typeof(Task) + .GetMethod(nameof(Task.FromResult), BindingFlags.Static | BindingFlags.Public)!; + + private readonly ListLogger _log; + + public TestProcessingReturnValueWithoutInvokingAsyncInterceptor(ListLogger log) + { + _log = log ?? throw new ArgumentNullException(nameof(log)); + } + + protected override Task InterceptAsync( + IInvocation invocation, + IInvocationProceedInfo proceedInfo, + Func proceed) + { + try + { + _log.Add($"{invocation.Method.Name}:StartingVoidInvocation"); + + /* Without invoking original method + await proceed(invocation, proceedInfo).ConfigureAwait(false); + */ + + _log.Add($"{invocation.Method.Name}:CompletedVoidInvocation"); + + // There is no async call in this example so we just return a completed task without async. + return Task.CompletedTask; + } + catch (Exception e) + { + _log.Add($"{invocation.Method.Name}:VoidExceptionThrown:{e.Message}"); + throw; + } + } + + protected override Task InterceptAsync( + IInvocation invocation, + IInvocationProceedInfo proceedInfo, + Func> proceed) + { + try + { + _log.Add($"{invocation.Method.Name}:StartingResultInvocation"); + + /* Without invoking original method + TResult result = await proceed(invocation, proceedInfo).ConfigureAwait(false); + */ + + // But we need a default result + TResult? result = GetDefaultValue(); + + _log.Add($"{invocation.Method.Name}:CompletedResultInvocation:{result}"); + + // Add ! here because the return value type in the base definition is supposed to be a nullable reference + // but it isn't + return Task.FromResult(result!); + } + catch (Exception e) + { + _log.Add($"{invocation.Method.Name}:VoidExceptionThrown:{e.Message}"); + throw; + } + } + + private TResult? GetDefaultValue() + { + return (TResult?)GetDefaultValue(typeof(TResult)); + } + + private object? GetDefaultValue(Type type) + { + if (type.IsAssignableFrom(typeof(Task))) + { + if (type.IsGenericType) + { + Type innerType = type.GetGenericArguments().Single(); + object? innerResult = GetDefaultValue(innerType); + MethodInfo fromResult = TaskFromResultMethod.MakeGenericMethod(innerType); + return fromResult.Invoke(null, new[] { innerResult }); + } + + return Task.CompletedTask; + } + + if (type.IsValueType) return Activator.CreateInstance(type); + + return null; + } +} diff --git a/test/Castle.Core.AsyncInterceptor.Tests/ProcessingAsyncInterceptorWithoutInvokingShould.cs b/test/Castle.Core.AsyncInterceptor.Tests/ProcessingAsyncInterceptorWithoutInvokingShould.cs new file mode 100644 index 0000000..9a68e5d --- /dev/null +++ b/test/Castle.Core.AsyncInterceptor.Tests/ProcessingAsyncInterceptorWithoutInvokingShould.cs @@ -0,0 +1,102 @@ +// Copyright (c) 2016-2023 James Skimming. All rights reserved. +// Licensed under the Apache License, Version 2.0. See LICENSE in the project root for license information. + +namespace Castle.DynamicProxy; + +using Castle.DynamicProxy.InterfaceProxies; +using Xunit; +using Xunit.Abstractions; + +public class WhenProcessingSynchronousVoidMethodsWithoutInvoking +{ + private readonly ListLogger _log; + private readonly IInterfaceToProxy _proxy; + + public WhenProcessingSynchronousVoidMethodsWithoutInvoking(ITestOutputHelper output) + { + _log = new ListLogger(output); + var interceptor = new TestProcessingReturnValueWithoutInvokingAsyncInterceptor(_log); + _proxy = ProxyGen.CreateProxy(_log, interceptor); + } + + [Fact] + public void ShouldLog4Entries() + { + // Act + _proxy.SynchronousVoidMethod(); + + // Assert + Assert.Equal(2, _log.Count); + } +} + +public class WhenProcessingSynchronousResultMethodsWithoutInvoking +{ + private readonly ListLogger _log; + private readonly IInterfaceToProxy _proxy; + + public WhenProcessingSynchronousResultMethodsWithoutInvoking(ITestOutputHelper output) + { + _log = new ListLogger(output); + var interceptor = new TestProcessingReturnValueWithoutInvokingAsyncInterceptor(_log); + _proxy = ProxyGen.CreateProxy(_log, interceptor); + } + + [Fact] + public void ShouldLog4Entries() + { + // Act + Guid result = _proxy.SynchronousResultMethod(); + + // Assert + Assert.Equal(Guid.Empty, result); + Assert.Equal(2, _log.Count); + } +} + +public class WhenProcessingAsynchronousVoidMethodsWithoutInvoking +{ + private readonly ListLogger _log; + private readonly IInterfaceToProxy _proxy; + + public WhenProcessingAsynchronousVoidMethodsWithoutInvoking(ITestOutputHelper output) + { + _log = new ListLogger(output); + var interceptor = new TestProcessingReturnValueWithoutInvokingAsyncInterceptor(_log); + _proxy = ProxyGen.CreateProxy(_log, interceptor); + } + + [Fact] + public async Task ShouldLog4Entries() + { + // Act + await _proxy.AsynchronousVoidMethod().ConfigureAwait(false); + + // Assert + Assert.Equal(2, _log.Count); + } +} + +public class WhenProcessingAsynchronousResultMethodsWithoutInvoking +{ + private readonly ListLogger _log; + private readonly IInterfaceToProxy _proxy; + + public WhenProcessingAsynchronousResultMethodsWithoutInvoking(ITestOutputHelper output) + { + _log = new ListLogger(output); + var interceptor = new TestProcessingReturnValueWithoutInvokingAsyncInterceptor(_log); + _proxy = ProxyGen.CreateProxy(_log, interceptor); + } + + [Fact] + public async Task ShouldLog4Entries() + { + // Act + Guid result = await _proxy.AsynchronousResultMethod().ConfigureAwait(false); + + // Assert + Assert.Equal(Guid.Empty, result); + Assert.Equal(2, _log.Count); + } +}