diff --git a/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs b/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs index a3a40f95..03107238 100644 --- a/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs +++ b/test/StreamJsonRpc.Tests/AsyncEnumerableTests.cs @@ -58,6 +58,8 @@ public partial interface IServer IAsyncEnumerable GetNumbersAsync(CancellationToken cancellationToken); + IAsyncEnumerable GetNumbersThatWerePassedInAsync(IAsyncEnumerable numbers, CancellationToken cancellationToken); + IAsyncEnumerable GetNumbersNoCancellationAsync(); IAsyncEnumerable WaitTillCanceledBeforeFirstItemAsync(CancellationToken cancellationToken); @@ -144,6 +146,25 @@ public async Task GetIAsyncEnumerableAsReturnType(bool useProxy) Assert.Equal(Server.ValuesReturnedByEnumerables, realizedValuesCount); } + [Theory] + [PairwiseData] + public async Task GetIAsyncEnumerableAsReturnTypeAndParameter(bool useProxy) + { + IAsyncEnumerable? numbers = Enumerable.Range(1, Server.ValuesReturnedByEnumerables).AsAsyncEnumerable(); + + int realizedValuesCount = 0; + IAsyncEnumerable enumerable = useProxy + ? this.clientProxy.Value.GetNumbersThatWerePassedInAsync(numbers, this.TimeoutToken) + : await this.clientRpc.InvokeWithCancellationAsync>(nameof(Server.GetNumbersThatWerePassedInAsync), new object[] { numbers }, this.TimeoutToken); + await foreach (int number in enumerable) + { + realizedValuesCount++; + this.Logger.WriteLine(number.ToString(CultureInfo.InvariantCulture)); + } + + Assert.Equal(Server.ValuesReturnedByEnumerables, realizedValuesCount); + } + [Fact] public async Task GetIAsyncEnumerableAsReturnType_WithProxy_NoCancellation() { @@ -698,6 +719,14 @@ public async IAsyncEnumerable GetNumbersAsync([EnumeratorCancellation] Canc } } + public async IAsyncEnumerable GetNumbersThatWerePassedInAsync(IAsyncEnumerable numbers, [EnumeratorCancellation] CancellationToken cancellationToken) + { + await foreach (int number in numbers.WithCancellation(cancellationToken)) + { + yield return number; + } + } + public IAsyncEnumerable GetNumbersParameterizedAsync(int batchSize, int readAhead, int prefetch, int totalCount, bool endWithException, CancellationToken cancellationToken) { return this.GetNumbersAsync(totalCount, endWithException, cancellationToken)