diff --git a/.autover/changes/c27a62e6-91ca-4a59-9406-394866cdfa62.json b/.autover/changes/c27a62e6-91ca-4a59-9406-394866cdfa62.json new file mode 100644 index 000000000..39be8933f --- /dev/null +++ b/.autover/changes/c27a62e6-91ca-4a59-9406-394866cdfa62.json @@ -0,0 +1,18 @@ +{ + "Projects": [ + { + "Name": "Amazon.Lambda.RuntimeSupport", + "Type": "Minor", + "ChangelogMessages": [ + "(Preview) Add response streaming support" + ] + }, + { + "Name": "Amazon.Lambda.Core", + "Type": "Minor", + "ChangelogMessages": [ + "(Preview) Add response streaming support" + ] + } + ] +} diff --git a/.gitignore b/.gitignore index f91715274..1caae6fe4 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,8 @@ *.suo *.user +**/.kiro/ + #################### # Build/Test folders #################### diff --git a/Libraries/Libraries.sln b/Libraries/Libraries.sln index f3214606a..23840bdfa 100644 --- a/Libraries/Libraries.sln +++ b/Libraries/Libraries.sln @@ -1,7 +1,7 @@  Microsoft Visual Studio Solution File, Format Version 12.00 -# Visual Studio Version 17 -VisualStudioVersion = 17.0.31717.71 +# Visual Studio Version 18 +VisualStudioVersion = 18.3.11512.155 d18.3 MinimumVisualStudioVersion = 10.0.40219.1 Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{AAB54E74-20B1-42ED-BC3D-CE9F7BC7FD12}" EndProject @@ -151,6 +151,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TestCustomAuthorizerApp.Int EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TestCustomAuthorizerApp", "test\TestCustomAuthorizerApp\TestCustomAuthorizerApp.csproj", "{3BFA4B73-BA61-4578-833B-C5B3A16EDA9E}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ResponseStreamingFunctionHandlers", "test\Amazon.Lambda.RuntimeSupport.Tests\ResponseStreamingFunctionHandlers\ResponseStreamingFunctionHandlers.csproj", "{E404A7AC-812B-BC03-CA76-02C0BC2BA7F9}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -941,6 +943,18 @@ Global {3BFA4B73-BA61-4578-833B-C5B3A16EDA9E}.Release|x64.Build.0 = Release|Any CPU {3BFA4B73-BA61-4578-833B-C5B3A16EDA9E}.Release|x86.ActiveCfg = Release|Any CPU {3BFA4B73-BA61-4578-833B-C5B3A16EDA9E}.Release|x86.Build.0 = Release|Any CPU + {E404A7AC-812B-BC03-CA76-02C0BC2BA7F9}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {E404A7AC-812B-BC03-CA76-02C0BC2BA7F9}.Debug|Any CPU.Build.0 = Debug|Any CPU + {E404A7AC-812B-BC03-CA76-02C0BC2BA7F9}.Debug|x64.ActiveCfg = Debug|Any CPU + {E404A7AC-812B-BC03-CA76-02C0BC2BA7F9}.Debug|x64.Build.0 = Debug|Any CPU + {E404A7AC-812B-BC03-CA76-02C0BC2BA7F9}.Debug|x86.ActiveCfg = Debug|Any CPU + {E404A7AC-812B-BC03-CA76-02C0BC2BA7F9}.Debug|x86.Build.0 = Debug|Any CPU + {E404A7AC-812B-BC03-CA76-02C0BC2BA7F9}.Release|Any CPU.ActiveCfg = Release|Any CPU + {E404A7AC-812B-BC03-CA76-02C0BC2BA7F9}.Release|Any CPU.Build.0 = Release|Any CPU + {E404A7AC-812B-BC03-CA76-02C0BC2BA7F9}.Release|x64.ActiveCfg = Release|Any CPU + {E404A7AC-812B-BC03-CA76-02C0BC2BA7F9}.Release|x64.Build.0 = Release|Any CPU + {E404A7AC-812B-BC03-CA76-02C0BC2BA7F9}.Release|x86.ActiveCfg = Release|Any CPU + {E404A7AC-812B-BC03-CA76-02C0BC2BA7F9}.Release|x86.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -1015,6 +1029,7 @@ Global {8D03BDF3-7078-4B46-A3F1-C73BE6D6CE0D} = {1DE4EE60-45BA-4EF7-BE00-B9EB861E4C69} {8EEDD576-7FC4-4FAC-A5A2-F58562753A53} = {1DE4EE60-45BA-4EF7-BE00-B9EB861E4C69} {3BFA4B73-BA61-4578-833B-C5B3A16EDA9E} = {1DE4EE60-45BA-4EF7-BE00-B9EB861E4C69} + {E404A7AC-812B-BC03-CA76-02C0BC2BA7F9} = {B5BD0336-7D08-492C-8489-42C987E29B39} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {503678A4-B8D1-4486-8915-405A3E9CF0EB} diff --git a/Libraries/src/Amazon.Lambda.Core/ResponseStreaming/HttpResponseStreamPrelude.cs b/Libraries/src/Amazon.Lambda.Core/ResponseStreaming/HttpResponseStreamPrelude.cs new file mode 100644 index 000000000..67eb9d3ae --- /dev/null +++ b/Libraries/src/Amazon.Lambda.Core/ResponseStreaming/HttpResponseStreamPrelude.cs @@ -0,0 +1,95 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +#if NET8_0_OR_GREATER +using System; +using System.Collections.Generic; +using System.Net; +using System.Runtime.Versioning; +using System.Text.Json; + +namespace Amazon.Lambda.Core.ResponseStreaming +{ + /// + /// The HTTP response prelude to be sent as the first chunk of a streaming response when using . + /// + [RequiresPreviewFeatures(LambdaResponseStreamFactory.PreviewMessage)] + public class HttpResponseStreamPrelude + { + /// + /// The Http status code to include in the response prelude. + /// + public HttpStatusCode? StatusCode { get; set; } + + /// + /// The response headers to include in the response prelude. This collection supports setting single value for the same headers. + /// + public IDictionary Headers { get; set; } = new Dictionary(); + + /// + /// The response headers to include in the response prelude. This collection supports setting multiple values for the same headers. + /// + public IDictionary> MultiValueHeaders { get; set; } = new Dictionary>(); + + /// + /// The list of cookies to include in the response prelude. This is used for Lambda Function URL responses, which support a separate "cookies" field in the response JSON for setting cookies, rather than requiring cookies to be set via the "Set-Cookie" header. + /// + public IList Cookies { get; set; } = new List(); + + internal byte[] ToByteArray() + { + var bufferWriter = new System.Buffers.ArrayBufferWriter(); + using (var writer = new Utf8JsonWriter(bufferWriter)) + { + writer.WriteStartObject(); + + if (StatusCode.HasValue) + writer.WriteNumber("statusCode", (int)StatusCode); + + if (Headers?.Count > 0) + { + writer.WriteStartObject("headers"); + foreach (var header in Headers) + { + writer.WriteString(header.Key, header.Value); + } + writer.WriteEndObject(); + } + + if (MultiValueHeaders?.Count > 0) + { + writer.WriteStartObject("multiValueHeaders"); + foreach (var header in MultiValueHeaders) + { + writer.WriteStartArray(header.Key); + foreach (var value in header.Value) + { + writer.WriteStringValue(value); + } + writer.WriteEndArray(); + } + writer.WriteEndObject(); + } + + if (Cookies?.Count > 0) + { + writer.WriteStartArray("cookies"); + foreach (var cookie in Cookies) + { + writer.WriteStringValue(cookie); + } + writer.WriteEndArray(); + } + + writer.WriteEndObject(); + } + + if (string.Equals(Environment.GetEnvironmentVariable("LAMBDA_NET_SERIALIZER_DEBUG"), "true", StringComparison.OrdinalIgnoreCase)) + { + LambdaLogger.Log(LogLevel.Information, "HTTP Response Stream Prelude JSON: {Prelude}", System.Text.Encoding.UTF8.GetString(bufferWriter.WrittenSpan)); + } + + return bufferWriter.WrittenSpan.ToArray(); + } + } +} +#endif diff --git a/Libraries/src/Amazon.Lambda.Core/ResponseStreaming/ILambdaResponseStream.cs b/Libraries/src/Amazon.Lambda.Core/ResponseStreaming/ILambdaResponseStream.cs new file mode 100644 index 000000000..1385e551e --- /dev/null +++ b/Libraries/src/Amazon.Lambda.Core/ResponseStreaming/ILambdaResponseStream.cs @@ -0,0 +1,40 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +#if NET8_0_OR_GREATER +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace Amazon.Lambda.Core.ResponseStreaming +{ + /// + /// Interface for writing streaming responses in AWS Lambda functions. + /// Obtained by calling within a handler. + /// + internal interface ILambdaResponseStream : IDisposable + { + /// + /// Asynchronously writes a portion of a byte array to the response stream. + /// + /// The byte array containing data to write. + /// The zero-based byte offset in buffer at which to begin copying bytes. + /// The number of bytes to write. + /// Optional cancellation token. + /// A task representing the asynchronous operation. + /// Thrown if the stream is already completed or an error has been reported. + Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default); + + + /// + /// Gets the total number of bytes written to the stream so far. + /// + long BytesWritten { get; } + + + /// + /// Gets whether an error has been reported. + /// + bool HasError { get; } + } +} +#endif diff --git a/Libraries/src/Amazon.Lambda.Core/ResponseStreaming/LambdaResponseStream.cs b/Libraries/src/Amazon.Lambda.Core/ResponseStreaming/LambdaResponseStream.cs new file mode 100644 index 000000000..83ac446a4 --- /dev/null +++ b/Libraries/src/Amazon.Lambda.Core/ResponseStreaming/LambdaResponseStream.cs @@ -0,0 +1,123 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +#if NET8_0_OR_GREATER + +using System; +using System.IO; +using System.Runtime.Versioning; +using System.Threading; +using System.Threading.Tasks; + +namespace Amazon.Lambda.Core.ResponseStreaming +{ + /// + /// A write-only, non-seekable subclass that streams response data + /// to the Lambda Runtime API. Returned by . + /// Integrates with standard .NET stream consumers such as . + /// + [RequiresPreviewFeatures(LambdaResponseStreamFactory.PreviewMessage)] + public class LambdaResponseStream : Stream + { + private readonly ILambdaResponseStream _responseStream; + + internal LambdaResponseStream(ILambdaResponseStream responseStream) + { + _responseStream = responseStream; + } + + /// + /// The number of bytes written to the Lambda response stream so far. + /// + public long BytesWritten => _responseStream.BytesWritten; + + /// + /// Asynchronously writes a byte array to the response stream. + /// + /// The byte array to write. + /// Optional cancellation token. + /// A task representing the asynchronous operation. + /// Thrown if the stream is already completed or an error has been reported. + public async Task WriteAsync(byte[] buffer, CancellationToken cancellationToken = default) + { + if (buffer == null) + throw new ArgumentNullException(nameof(buffer)); + + await WriteAsync(buffer, 0, buffer.Length, cancellationToken); + } + + /// + /// Asynchronously writes a portion of a byte array to the response stream. + /// + /// The byte array containing data to write. + /// The zero-based byte offset in buffer at which to begin copying bytes. + /// The number of bytes to write. + /// Optional cancellation token. + /// A task representing the asynchronous operation. + /// Thrown if the stream is already completed or an error has been reported. + public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default) + { + await _responseStream.WriteAsync(buffer, offset, count, cancellationToken); + } + + #region Noop Overrides + + /// Gets a value indicating whether the stream supports reading. Always false. + public override bool CanRead => false; + + /// Gets a value indicating whether the stream supports seeking. Always false. + public override bool CanSeek => false; + + /// Gets a value indicating whether the stream supports writing. Always true. + public override bool CanWrite => true; + + /// + /// Gets the total number of bytes written to the stream so far. + /// Equivalent to . + /// + public override long Length => BytesWritten; + + /// + /// Getting or setting the position is not supported. + /// + /// Always thrown. + public override long Position + { + get => throw new NotSupportedException($"{nameof(LambdaResponseStream)} does not support seeking."); + set => throw new NotSupportedException($"{nameof(LambdaResponseStream)} does not support seeking."); + } + + /// Not supported. + /// Always thrown. + public override long Seek(long offset, SeekOrigin origin) + => throw new NotImplementedException($"{nameof(LambdaResponseStream)} does not support seeking."); + + /// Not supported. + /// Always thrown. + public override int Read(byte[] buffer, int offset, int count) + => throw new NotImplementedException($"{nameof(LambdaResponseStream)} does not support reading."); + + /// Not supported. + /// Always thrown. + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + => throw new NotImplementedException($"{nameof(LambdaResponseStream)} does not support reading."); + + /// + /// Writes a sequence of bytes to the stream. Delegates to the async path synchronously. + /// Prefer to avoid blocking. + /// + public override void Write(byte[] buffer, int offset, int count) + => WriteAsync(buffer, offset, count, CancellationToken.None).GetAwaiter().GetResult(); + + /// + /// Flush is a no-op; data is sent to the Runtime API immediately on each write. + /// + public override void Flush() { } + + /// Not supported. + /// Always thrown. + public override void SetLength(long value) + => throw new NotSupportedException($"{nameof(LambdaResponseStream)} does not support SetLength."); + #endregion + } +} +#endif diff --git a/Libraries/src/Amazon.Lambda.Core/ResponseStreaming/LambdaResponseStreamFactory.cs b/Libraries/src/Amazon.Lambda.Core/ResponseStreaming/LambdaResponseStreamFactory.cs new file mode 100644 index 000000000..1b9e6d3b6 --- /dev/null +++ b/Libraries/src/Amazon.Lambda.Core/ResponseStreaming/LambdaResponseStreamFactory.cs @@ -0,0 +1,72 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +#if NET8_0_OR_GREATER +using System; +using System.IO; +using System.Runtime.Versioning; + +namespace Amazon.Lambda.Core.ResponseStreaming +{ + /// + /// Factory to create Lambda response streams for writing streaming responses in AWS Lambda functions. The created streams are write-only and non-seekable. + /// + [RequiresPreviewFeatures(LambdaResponseStreamFactory.PreviewMessage)] + public class LambdaResponseStreamFactory + { + internal const string PreviewMessage = + "Response streaming is in preview till a new version of .NET Lambda runtime client that supports response streaming " + + "has been deployed to the .NET Lambda managed runtime. Till deployment has been made the feature can be used by deploying as an " + + "executable including the latest version of Amazon.Lambda.RuntimeSupport and setting the \"EnablePreviewFeatures\" in the Lambda " + + "project file to \"true\""; + + internal const string UninitializedFactoryMessage = + "LambdaResponseStreamFactory is not initialized. This is caused by mismatch versions of Amazon.Lambda.Core and Amazon.Lambda.RuntimeSupport. " + + "Update both packages to the current version to address the issue."; + + private static Func _streamFactory; + + internal static void SetLambdaResponseStream(Func streamFactory) + { + _streamFactory = streamFactory ?? throw new ArgumentNullException(nameof(streamFactory)); + } + + /// + /// Creates a a subclass of that can be used to write streaming responses back to callers of the Lambda function. Once + /// a Lambda function creates a response stream all output must be returned by writing to the stream; the Lambda function's handler + /// return value will be ignored. The stream is write-only and non-seekable. + /// + /// + public static LambdaResponseStream CreateStream() + { + if (_streamFactory == null) + throw new InvalidOperationException(UninitializedFactoryMessage); + + var runtimeResponseStream = _streamFactory(Array.Empty()); + return new LambdaResponseStream(runtimeResponseStream); + } + + /// + /// Creates a a subclass of for writing streaming responses, with an HTTP response prelude containing status code and headers. This should be used for + /// Lambda functions using response streaming that are invoked via the Lambda Function URLs or API Gateway HTTP APIs, where the response format is expected to be an HTTP response. + /// The prelude will be serialized and sent as the first chunk of the response stream, and should contain any necessary HTTP status code and headers. + /// + /// Once a Lambda function creates a response stream all output must be returned by writing to the stream; the Lambda function's handler + /// return value will be ignored. The stream is write-only and non-seekable. + /// + /// + /// The HTTP response prelude including status code and headers. + /// + public static LambdaResponseStream CreateHttpStream(HttpResponseStreamPrelude prelude) + { + if (_streamFactory == null) + throw new InvalidOperationException(UninitializedFactoryMessage); + + if (prelude is null) + throw new ArgumentNullException(nameof(prelude)); + + var runtimeResponseStream = _streamFactory(prelude.ToByteArray()); + return new LambdaResponseStream(runtimeResponseStream); + } + } +} +#endif diff --git a/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/LambdaBootstrap.cs b/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/LambdaBootstrap.cs index 0e00f3e7f..bb6198d9e 100644 --- a/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/LambdaBootstrap.cs +++ b/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/LambdaBootstrap.cs @@ -20,6 +20,7 @@ using System.Threading; using System.Threading.Tasks; using Amazon.Lambda.RuntimeSupport.Bootstrap; +using Amazon.Lambda.RuntimeSupport.Client.ResponseStreaming; using Amazon.Lambda.RuntimeSupport.Helpers; namespace Amazon.Lambda.RuntimeSupport @@ -225,6 +226,19 @@ internal LambdaBootstrap(HttpClient httpClient, LambdaBootstrapHandler handler, return; } #if NET8_0_OR_GREATER + + try + { + // Initalize in Amazon.Lambda.Core the factory for creating the response stream and related logic for supporting response streaming. + ResponseStreamLambdaCoreInitializerIsolated.InitializeCore(); + } + catch (TypeLoadException) + { + _logger.LogDebug("Failed to configure Amazon.Lambda.Core with factory to create response stream. This happens when the version of Amazon.Lambda.Core referenced by the Lambda function is out of date."); + } + + + // Check if Initialization type is SnapStart, and invoke the snapshot restore logic. if (_configuration.IsInitTypeSnapstart) { @@ -349,6 +363,7 @@ internal async Task InvokeOnceAsync(CancellationToken cancellationToken = defaul _logger.LogInformation("Starting InvokeOnceAsync"); var invocation = await Client.GetNextInvocationAsync(cancellationToken); + var isMultiConcurrency = Utils.IsUsingMultiConcurrency(_environmentVariables); Func processingFunc = async () => { @@ -358,6 +373,17 @@ internal async Task InvokeOnceAsync(CancellationToken cancellationToken = defaul SetInvocationTraceId(impl.RuntimeApiHeaders.TraceId); } + // Initialize ResponseStreamFactory — includes RuntimeApiClient reference + var runtimeApiClient = Client as RuntimeApiClient; + if (runtimeApiClient != null) + { + ResponseStreamFactory.InitializeInvocation( + invocation.LambdaContext.AwsRequestId, + isMultiConcurrency, + runtimeApiClient, + cancellationToken); + } + try { InvocationResponse response = null; @@ -372,15 +398,41 @@ internal async Task InvokeOnceAsync(CancellationToken cancellationToken = defaul catch (Exception exception) { WriteUnhandledExceptionToLog(exception); - await Client.ReportInvocationErrorAsync(invocation.LambdaContext.AwsRequestId, exception, cancellationToken); + + var responseStream = ResponseStreamFactory.GetStreamIfCreated(isMultiConcurrency); + if (responseStream != null) + { + responseStream.ReportError(exception); + } + else + { + await Client.ReportInvocationErrorAsync(invocation.LambdaContext.AwsRequestId, exception, cancellationToken); + } } finally { _logger.LogInformation("Finished invoking handler"); } - if (invokeSucceeded) + var streamIfCreated = ResponseStreamFactory.GetStreamIfCreated(isMultiConcurrency); + if (streamIfCreated != null) + { + streamIfCreated.MarkCompleted(); + + // If streaming was started, await the HTTP send task to ensure it completes + var sendTask = ResponseStreamFactory.GetSendTask(isMultiConcurrency); + if (sendTask != null) + { + // Wait for the streaming response to finish sending before allowing the next invocation to be processed. This ensures that responses are sent in the order the invocations were received. + await sendTask; + sendTask.Result.Dispose(); + } + + streamIfCreated.Dispose(); + } + else if (invokeSucceeded) { + // No streaming — send buffered response _logger.LogInformation("Starting sending response"); try { @@ -415,6 +467,7 @@ internal async Task InvokeOnceAsync(CancellationToken cancellationToken = defaul } finally { + ResponseStreamFactory.CleanupInvocation(isMultiConcurrency); invocation.Dispose(); } }; diff --git a/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/RawStreamingHttpClient.cs b/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/RawStreamingHttpClient.cs new file mode 100644 index 000000000..0226e0660 --- /dev/null +++ b/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/RawStreamingHttpClient.cs @@ -0,0 +1,291 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +#if NET8_0_OR_GREATER +using System; +using System.Collections.Generic; +using System.Globalization; +using System.IO; +using System.Net.Sockets; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Amazon.Lambda.RuntimeSupport.Helpers; + +namespace Amazon.Lambda.RuntimeSupport.Client.ResponseStreaming +{ + /// + /// A raw HTTP/1.1 client for sending streaming responses to the Lambda Runtime API + /// with support for HTTP trailing headers (used for error reporting). + /// + /// .NET's HttpClient/SocketsHttpHandler does not support sending HTTP/1.1 trailing headers. + /// The Lambda Runtime API requires error information to be sent as trailing headers + /// (Lambda-Runtime-Function-Error-Type and Lambda-Runtime-Function-Error-Body) after + /// the chunked transfer encoding body. This class gives us full control over the + /// HTTP wire format to properly send those trailers. + /// + internal class RawStreamingHttpClient : IDisposable + { + private readonly string _host; + private readonly int _port; + private TcpClient _tcpClient; + internal Stream _networkStream; + private bool _disposed; + + private readonly InternalLogger _logger = InternalLogger.GetDefaultLogger(); + + public RawStreamingHttpClient(string hostAndPort) + { + var parts = hostAndPort.Split(':'); + _host = parts[0]; + _port = parts.Length > 1 ? int.Parse(parts[1], CultureInfo.InvariantCulture) : 80; + } + + /// + /// Sends a streaming response to the Lambda Runtime API. + /// Connects via TCP, sends HTTP headers, then streams the response body + /// using chunked transfer encoding. When the response stream completes, + /// writes the chunked encoding terminator with optional trailing headers + /// for error reporting. + /// + /// The Lambda request ID. + /// The response stream that provides data and error state. + /// The User-Agent header value. + /// Cancellation token. + public async Task SendStreamingResponseAsync( + string awsRequestId, + ResponseStream responseStream, + string userAgent, + CancellationToken cancellationToken = default) + { + _tcpClient = new TcpClient(); + _tcpClient.NoDelay = true; + await _tcpClient.ConnectAsync(_host, _port, cancellationToken); + _networkStream = _tcpClient.GetStream(); + + // Send HTTP request line and headers + var path = $"/2018-06-01/runtime/invocation/{awsRequestId}/response"; + var headers = new StringBuilder(); + headers.Append($"POST {path} HTTP/1.1\r\n"); + headers.Append($"Host: {_host}:{_port}\r\n"); + headers.Append($"User-Agent: {userAgent}\r\n"); + headers.Append($"Content-Type: application/vnd.awslambda.http-integration-response\r\n"); + headers.Append($"{StreamingConstants.ResponseModeHeader}: {StreamingConstants.StreamingResponseMode}\r\n"); + headers.Append("Transfer-Encoding: chunked\r\n"); + headers.Append($"Trailer: {StreamingConstants.ErrorTypeTrailer}, {StreamingConstants.ErrorBodyTrailer}\r\n"); + headers.Append("\r\n"); + + var headerBytes = Encoding.ASCII.GetBytes(headers.ToString()); + await _networkStream.WriteAsync(headerBytes, cancellationToken); + await _networkStream.FlushAsync(cancellationToken); + + // Hand the network stream (wrapped in a chunked writer) to the ResponseStream + var chunkedWriter = new ChunkedStreamWriter(_networkStream); + await responseStream.SetHttpOutputStreamAsync(chunkedWriter, cancellationToken); + + _logger.LogInformation("In SendStreamingResponseAsync waiting for the underlying Lambda response stream to indicate it is complete."); + + // Wait for the handler to finish writing + await responseStream.WaitForCompletionAsync(cancellationToken); + + // Write the chunked encoding terminator with optional trailers + if (responseStream.HasError) + { + _logger.LogInformation("Adding response stream trailing error headers"); + await WriteTerminatorWithTrailersAsync(responseStream.ReportedError, cancellationToken); + } + else + { + // No error — write simple terminator: 0\r\n\r\n + var terminator = Encoding.ASCII.GetBytes("0\r\n\r\n"); + await _networkStream.WriteAsync(terminator, cancellationToken); + } + + await _networkStream.FlushAsync(cancellationToken); + + // Read and discard the HTTP response (we don't need it, but must consume it) + await ReadAndDiscardResponseAsync(cancellationToken); + } + + /// + /// Writes the chunked encoding terminator with HTTP trailing headers for error reporting. + /// Format: + /// 0\r\n + /// Lambda-Runtime-Function-Error-Type: errorType\r\n + /// Lambda-Runtime-Function-Error-Body: base64EncodedErrorBodyJson\r\n + /// \r\n + /// + /// The error body JSON is Base64-encoded because LambdaJsonExceptionWriter produces + /// pretty-printed multi-line JSON. HTTP trailer values cannot contain raw CR/LF characters + /// as they would break the HTTP framing — the Runtime API would see the first newline + /// inside the JSON as the end of the trailer and treat the rest as malformed data, + /// resulting in Runtime.TruncatedResponse instead of the actual error. + /// + internal async Task WriteTerminatorWithTrailersAsync(Exception exception, CancellationToken cancellationToken) + { + var exceptionInfo = ExceptionInfo.GetExceptionInfo(exception); + var errorBodyJson = LambdaJsonExceptionWriter.WriteJson(exceptionInfo); + var errorBodyBase64 = Convert.ToBase64String(Encoding.UTF8.GetBytes(errorBodyJson)); + + InternalLogger.GetDefaultLogger().LogInformation($"Writing trailing header {StreamingConstants.ErrorTypeTrailer} with error type {exceptionInfo.ErrorType}."); + var trailers = new StringBuilder(); + trailers.Append("0\r\n"); // zero-length chunk (end of body) + trailers.Append($"{StreamingConstants.ErrorTypeTrailer}: {exceptionInfo.ErrorType}\r\n"); + trailers.Append($"{StreamingConstants.ErrorBodyTrailer}: {errorBodyBase64}\r\n"); + trailers.Append("\r\n"); // end of trailers + + var trailerBytes = Encoding.UTF8.GetBytes(trailers.ToString()); + await _networkStream.WriteAsync(trailerBytes, cancellationToken); + } + + /// + /// Reads and discards the HTTP response from the Runtime API. + /// We need to consume the response to properly close the connection, + /// but we don't need to process it. + /// + internal async Task ReadAndDiscardResponseAsync(CancellationToken cancellationToken) + { + var buffer = new byte[4096]; + try + { + // Read until we get the full response. The Runtime API sends a short response. + var totalRead = 0; + var responseText = new StringBuilder(); + while (true) + { + var bytesRead = await _networkStream.ReadAsync(buffer, 0, buffer.Length, cancellationToken); + if (bytesRead == 0) + break; + + totalRead += bytesRead; + responseText.Append(Encoding.ASCII.GetString(buffer, 0, bytesRead)); + + // Check if we've received the complete response (ends with \r\n\r\n for headers, + // or we've read the content-length worth of body) + var text = responseText.ToString(); + if (text.Contains("\r\n\r\n")) + { + // Find Content-Length to know if there's a body to read + var headerEnd = text.IndexOf("\r\n\r\n", StringComparison.Ordinal); + var headers = text.Substring(0, headerEnd); + + var contentLengthMatch = System.Text.RegularExpressions.Regex.Match( + headers, @"Content-Length:\s*(\d+)", System.Text.RegularExpressions.RegexOptions.IgnoreCase); + + if (contentLengthMatch.Success) + { + var contentLength = int.Parse(contentLengthMatch.Groups[1].Value, CultureInfo.InvariantCulture); + var bodyStart = headerEnd + 4; // skip \r\n\r\n + var bodyRead = text.Length - bodyStart; + if (bodyRead >= contentLength) + break; + } + else + { + // No Content-Length, assume response is complete after headers + break; + } + } + + if (totalRead > 16384) + break; // Safety limit + } + } + catch (Exception ex) + { + // Log but don't throw — the streaming response was already sent + _logger.LogDebug($"Error reading Runtime API response: {ex.Message}"); + } + } + + public void Dispose() + { + if (!_disposed) + { + _networkStream?.Dispose(); + _tcpClient?.Dispose(); + _disposed = true; + } + } + } + + /// + /// A write-only Stream wrapper that writes data in HTTP/1.1 chunked transfer encoding format. + /// Each write produces a chunk: {size in hex}\r\n{data}\r\n + /// FlushAsync flushes the underlying network stream to ensure data is sent immediately. + /// The chunked encoding terminator (0\r\n...\r\n) is NOT written by this class — + /// it is handled by RawStreamingHttpClient to support trailing headers. + /// + internal class ChunkedStreamWriter : Stream + { + private readonly Stream _innerStream; + + public ChunkedStreamWriter(Stream innerStream) + { + _innerStream = innerStream ?? throw new ArgumentNullException(nameof(innerStream)); + } + + public override bool CanRead => false; + public override bool CanSeek => false; + public override bool CanWrite => true; + public override long Length => throw new NotSupportedException(); + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + WriteAsync(buffer, offset, count, CancellationToken.None).GetAwaiter().GetResult(); + } + + public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + if (count == 0) return; + + // Write chunk header: size in hex + \r\n + var chunkHeader = Encoding.ASCII.GetBytes($"{count:X}\r\n"); + await _innerStream.WriteAsync(chunkHeader, 0, chunkHeader.Length, cancellationToken); + + // Write chunk data + await _innerStream.WriteAsync(buffer, offset, count, cancellationToken); + + // Write chunk trailer: \r\n + var crlf = Encoding.ASCII.GetBytes("\r\n"); + await _innerStream.WriteAsync(crlf, 0, crlf.Length, cancellationToken); + } + + public override async ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + if (buffer.Length == 0) return; + + var chunkHeader = Encoding.ASCII.GetBytes($"{buffer.Length:X}\r\n"); + await _innerStream.WriteAsync(chunkHeader, cancellationToken); + await _innerStream.WriteAsync(buffer, cancellationToken); + await _innerStream.WriteAsync(Encoding.ASCII.GetBytes("\r\n"), cancellationToken); + } + + public override void Flush() => _innerStream.Flush(); + + public override Task FlushAsync(CancellationToken cancellationToken) => + _innerStream.FlushAsync(cancellationToken); + + public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException(); + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + public override void SetLength(long value) => throw new NotSupportedException(); + } +} +#endif diff --git a/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/ResponseStream.cs b/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/ResponseStream.cs new file mode 100644 index 000000000..8271bf4f1 --- /dev/null +++ b/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/ResponseStream.cs @@ -0,0 +1,261 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +using System; +using System.Buffers; +using System.IO; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Amazon.Lambda.RuntimeSupport.Helpers; + +namespace Amazon.Lambda.RuntimeSupport.Client.ResponseStreaming +{ + /// + /// Represents the writable stream used by Lambda handlers to write response data for streaming invocations. + /// + internal class ResponseStream + { + private long _bytesWritten; + private bool _isCompleted; + private bool _hasError; + private Exception _reportedError; + private readonly object _lock = new object(); + + // The live HTTP output stream, set by RawStreamingHttpClient when sending the streaming response. + private Stream _httpOutputStream; + private bool _disposedValue; + + // The wait time is a sanity timeout to avoid waiting indefinitely if SetHttpOutputStreamAsync is not called or takes too long to call. + // Reality is that SetHttpOutputStreamAsync should be called very quickly after CreateStream, so this timeout is generous to avoid false positives but still protects against hanging indefinitely. + private readonly static TimeSpan _httpStreamWaitTimeout = TimeSpan.FromSeconds(30); + + private readonly SemaphoreSlim _httpStreamReady = new SemaphoreSlim(0, 1); + private readonly SemaphoreSlim _completionSignal = new SemaphoreSlim(0, 1); + + private static readonly byte[] PreludeDelimiter = new byte[8]; + + /// + /// The number of bytes written to the Lambda response stream so far. + /// + public long BytesWritten => _bytesWritten; + + /// + /// Gets a value indicating whether an error has occurred. + /// + public bool HasError => _hasError; + + private readonly byte[] _prelude; + + + private readonly InternalLogger _logger; + + + internal Exception ReportedError => _reportedError; + + internal ResponseStream(byte[] prelude) + { + _logger = InternalLogger.GetDefaultLogger(); + _prelude = prelude; + } + + /// + /// Called by RawStreamingHttpClient to provide the HTTP output stream (a ChunkedStreamWriter). + /// + internal async Task SetHttpOutputStreamAsync(Stream httpOutputStream, CancellationToken cancellationToken = default) + { + _httpOutputStream = httpOutputStream; + + // Write the prelude BEFORE releasing _httpStreamReady. This prevents a race + // where a handler WriteAsync that is already waiting on the semaphore could + // sneak in and write body data before the prelude, causing intermittent + // "Failed to parse prelude JSON" errors from API Gateway. + // + // Note: we intentionally do NOT check ThrowIfCompletedOrError() here. + // SetHttpOutputStreamAsync is infrastructure setup called by RawStreamingHttpClient, + // not a handler write. For fast-completing responses (e.g. Results.Json), + // LambdaBootstrap may call MarkCompleted() before the TCP connection is established + // and this method is called. The prelude still needs to be written to the wire + // so the response is properly framed. + if (_prelude?.Length > 0) + { + _logger.LogDebug($"Writing prelude of {_prelude.Length} bytes to HTTP stream."); + + var combinedLength = _prelude.Length + PreludeDelimiter.Length; + var combined = ArrayPool.Shared.Rent(combinedLength); + try + { + Buffer.BlockCopy(_prelude, 0, combined, 0, _prelude.Length); + Buffer.BlockCopy(PreludeDelimiter, 0, combined, _prelude.Length, PreludeDelimiter.Length); + + await _httpOutputStream.WriteAsync(combined, 0, combinedLength, cancellationToken); + await _httpOutputStream.FlushAsync(cancellationToken); + } + finally + { + ArrayPool.Shared.Return(combined); + } + } + + _httpStreamReady.Release(); + } + + /// + /// Called by RawStreamingHttpClient to wait until the handler + /// finishes writing (MarkCompleted or ReportError). + /// + internal async Task WaitForCompletionAsync(CancellationToken cancellationToken = default) + { + await _completionSignal.WaitAsync(cancellationToken); + } + + internal async Task WriteAsync(byte[] buffer, CancellationToken cancellationToken = default) + { + if (buffer == null) + throw new ArgumentNullException(nameof(buffer)); + await WriteAsync(buffer, 0, buffer.Length, cancellationToken); + } + + /// + /// Asynchronously writes a portion of a byte array to the response stream. + /// + /// The byte array containing data to write. + /// The zero-based byte offset in buffer at which to begin copying bytes. + /// The number of bytes to write. + /// Optional cancellation token. + /// A task representing the asynchronous operation. + /// Thrown if the stream is already completed or an error has been reported. + public async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default) + { + if (buffer == null) + throw new ArgumentNullException(nameof(buffer)); + if (offset < 0 || offset > buffer.Length) + throw new ArgumentOutOfRangeException(nameof(offset)); + if (count < 0 || offset + count > buffer.Length) + throw new ArgumentOutOfRangeException(nameof(count)); + + // Wait for the HTTP stream to be ready (first write only blocks) + await _httpStreamReady.WaitAsync(_httpStreamWaitTimeout, cancellationToken); + try + { + _logger.LogDebug("Writing chunk to HTTP response stream."); + + lock (_lock) + { + // Only throw on error, not on completed. For buffered ASP.NET Core responses + // (e.g. Results.Json), the pipeline completes and LambdaBootstrap calls + // MarkCompleted() before the pre-start buffer has been flushed to the wire. + // The buffered data still needs to be written even after MarkCompleted. + if (_hasError) + throw new InvalidOperationException("Cannot write to a stream after an error has been reported."); + _bytesWritten += count; + } + + await _httpOutputStream.WriteAsync(buffer, offset, count, cancellationToken); + await _httpOutputStream.FlushAsync(cancellationToken); + } + finally + { + // Re-release so subsequent writes don't block + _httpStreamReady.Release(); + } + } + + /// + /// Reports an error that occurred during streaming. + /// This will send error information via HTTP trailing headers. + /// + /// The exception to report. + /// Thrown if the stream is already completed or an error has already been reported. + internal void ReportError(Exception exception) + { + if (exception == null) + throw new ArgumentNullException(nameof(exception)); + + lock (_lock) + { + if (_isCompleted) + throw new InvalidOperationException("Cannot report an error after the stream has been completed."); + if (_hasError) + throw new InvalidOperationException("An error has already been reported for this stream."); + + _hasError = true; + _reportedError = exception; + _isCompleted = true; + } + // Signal completion so RawStreamingHttpClient can write error trailers and finish + _completionSignal.Release(); + } + + internal void MarkCompleted() + { + bool shouldReleaseLock = false; + lock (_lock) + { + // Release lock if not already completed, otherwise do nothing (idempotent) + if (!_isCompleted) + { + shouldReleaseLock = true; + } + _isCompleted = true; + } + + if (shouldReleaseLock) + { + // Signal completion so RawStreamingHttpClient can write the final chunk and finish + _completionSignal.Release(); + } + } + + private void ThrowIfCompletedOrError() + { + if (_isCompleted) + throw new InvalidOperationException("Cannot write to a completed stream."); + if (_hasError) + throw new InvalidOperationException("Cannot write to a stream after an error has been reported."); + } + + /// + /// Disposes the stream. After calling Dispose, no further writes or error reports should be made. + /// + /// + protected virtual void Dispose(bool disposing) + { + if (!_disposedValue) + { + if (disposing) + { + try { _httpStreamReady.Release(); } catch (SemaphoreFullException) { /* Ignore if already released */ } + _httpStreamReady.Dispose(); + + try { _completionSignal.Release(); } catch (SemaphoreFullException) { /* Ignore if already released */ } + _completionSignal.Dispose(); + } + + _disposedValue = true; + } + } + + /// + /// Dispose of the stream. After calling Dispose, no further writes or error reports should be made. + /// + public void Dispose() + { + // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method + Dispose(disposing: true); + GC.SuppressFinalize(this); + } + } +} diff --git a/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/ResponseStreamContext.cs b/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/ResponseStreamContext.cs new file mode 100644 index 000000000..970c43138 --- /dev/null +++ b/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/ResponseStreamContext.cs @@ -0,0 +1,59 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +using System; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; + +namespace Amazon.Lambda.RuntimeSupport.Client.ResponseStreaming +{ + /// + /// Internal context class used by ResponseStreamFactory to track per-invocation streaming state. + /// + internal class ResponseStreamContext + { + /// + /// The AWS request ID for the current invocation. + /// + public string AwsRequestId { get; set; } + + /// + /// Whether CreateStream() has been called for this invocation. + /// + public bool StreamCreated { get; set; } + + /// + /// The ResponseStream instance if created. + /// + public ResponseStream Stream { get; set; } + + /// + /// The RuntimeApiClient used to start the streaming HTTP POST. + /// + public RuntimeApiClient RuntimeApiClient { get; set; } + + /// + /// Cancellation token for the current invocation. + /// + public CancellationToken CancellationToken { get; set; } + + /// + /// The Task representing the in-flight HTTP POST to the Runtime API. + /// Started when CreateStream() is called, completes when the stream is finalized. + /// + public Task SendTask { get; set; } + } +} diff --git a/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/ResponseStreamFactory.cs b/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/ResponseStreamFactory.cs new file mode 100644 index 000000000..27b34e8db --- /dev/null +++ b/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/ResponseStreamFactory.cs @@ -0,0 +1,133 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace Amazon.Lambda.RuntimeSupport.Client.ResponseStreaming +{ + /// + /// Factory for creating streaming responses in AWS Lambda functions. + /// Call CreateStream() within your handler to opt into response streaming for that invocation. + /// + internal static class ResponseStreamFactory + { + // For on-demand mode (single invocation at a time) + private static ResponseStreamContext _onDemandContext; + + // For multi-concurrency mode (multiple concurrent invocations) + private static readonly AsyncLocal _asyncLocalContext = new AsyncLocal(); + + /// + /// Creates a streaming response for the current invocation. + /// Can only be called once per invocation. + /// + /// + /// + /// Thrown if called outside an invocation context. + /// Thrown if called more than once per invocation. + public static ResponseStream CreateStream(byte[] prelude) + { +#if NET8_0_OR_GREATER + var context = GetCurrentContext(); + + if (context == null) + { + throw new InvalidOperationException( + "ResponseStreamFactory.CreateStream() can only be called within a Lambda handler invocation."); + } + + if (context.StreamCreated) + { + throw new InvalidOperationException( + "ResponseStreamFactory.CreateStream() can only be called once per invocation."); + } + + var lambdaStream = new ResponseStream(prelude); + context.Stream = lambdaStream; + context.StreamCreated = true; + + // Start the HTTP POST to the Runtime API. + // This runs concurrently — SerializeToStreamAsync will block + // until the handler finishes writing or reports an error. + context.SendTask = context.RuntimeApiClient.StartStreamingResponseAsync( + context.AwsRequestId, lambdaStream, context.CancellationToken); + + return lambdaStream; +#else + throw new NotImplementedException(); +#endif + } + + // Internal methods for LambdaBootstrap to manage state + + internal static void InitializeInvocation( + string awsRequestId, bool isMultiConcurrency, + RuntimeApiClient runtimeApiClient, CancellationToken cancellationToken) + { + var context = new ResponseStreamContext + { + AwsRequestId = awsRequestId, + StreamCreated = false, + Stream = null, + RuntimeApiClient = runtimeApiClient, + CancellationToken = cancellationToken + }; + + if (isMultiConcurrency) + { + _asyncLocalContext.Value = context; + } + else + { + _onDemandContext = context; + } + } + + internal static ResponseStream GetStreamIfCreated(bool isMultiConcurrency) + { + var context = isMultiConcurrency ? _asyncLocalContext.Value : _onDemandContext; + return context?.Stream; + } + + /// + /// Returns the Task for the in-flight HTTP send, or null if streaming wasn't started. + /// LambdaBootstrap awaits this after the handler returns to ensure the HTTP request completes. + /// + internal static Task GetSendTask(bool isMultiConcurrency) + { + var context = isMultiConcurrency ? _asyncLocalContext.Value : _onDemandContext; + return context?.SendTask; + } + + internal static void CleanupInvocation(bool isMultiConcurrency) + { + if (isMultiConcurrency) + { + _asyncLocalContext.Value = null; + } + else + { + _onDemandContext = null; + } + } + + private static ResponseStreamContext GetCurrentContext() + { + // Check multi-concurrency first (AsyncLocal), then on-demand + return _asyncLocalContext.Value ?? _onDemandContext; + } + } +} diff --git a/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/ResponseStreamLambdaCoreInitializerIsolated.cs b/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/ResponseStreamLambdaCoreInitializerIsolated.cs new file mode 100644 index 000000000..b86864480 --- /dev/null +++ b/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/ResponseStreamLambdaCoreInitializerIsolated.cs @@ -0,0 +1,61 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +#if NET8_0_OR_GREATER + +using System; +using System.Threading; +using System.Threading.Tasks; +using Amazon.Lambda.Core.ResponseStreaming; +using Amazon.Lambda.RuntimeSupport.Client.ResponseStreaming; +#pragma warning disable CA2252 +namespace Amazon.Lambda.RuntimeSupport +{ + /// + /// This class is used to connect the created by to Amazon.Lambda.Core with it's public interfaces. + /// The deployed Lambda function might be referencing an older version of Amazon.Lambda.Core that does not have the public interfaces for response streaming, + /// so this class is used to avoid a direct dependency on Amazon.Lambda.Core in the rest of the response streaming implementation. + /// + /// Any code referencing this class must wrap the code around a try/catch for to allow for the case where the Lambda function + /// is deployed with an older version of Amazon.Lambda.Core that does not have the response streaming interfaces. + /// + /// + internal class ResponseStreamLambdaCoreInitializerIsolated + { + /// + /// Initalize Amazon.Lambda.Core with a factory method for creating that wraps the internal implementation. + /// + internal static void InitializeCore() + { +#if !ANALYZER_UNIT_TESTS // This precompiler directive is used to avoid the unit tests from needing a dependency on Amazon.Lambda.Core. + Func factory = (byte[] prelude) => new ImplLambdaResponseStream(ResponseStreamFactory.CreateStream(prelude)); + LambdaResponseStreamFactory.SetLambdaResponseStream(factory); +#endif + } + + /// + /// Implements the interface by wrapping a . This is used to connect the internal response streaming implementation to the public interfaces in Amazon.Lambda.Core. + /// + internal class ImplLambdaResponseStream : ILambdaResponseStream + { + private readonly ResponseStream _innerStream; + + internal ImplLambdaResponseStream(ResponseStream innerStream) + { + _innerStream = innerStream; + } + + /// + public long BytesWritten => _innerStream.BytesWritten; + + /// + public bool HasError => _innerStream.HasError; + + /// + public void Dispose() => _innerStream.Dispose(); + + /// + public Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default) => _innerStream.WriteAsync(buffer, offset, count, cancellationToken); + } + } +} +#endif diff --git a/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/StreamingConstants.cs b/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/StreamingConstants.cs new file mode 100644 index 000000000..43ac607b7 --- /dev/null +++ b/Libraries/src/Amazon.Lambda.RuntimeSupport/Bootstrap/ResponseStreaming/StreamingConstants.cs @@ -0,0 +1,43 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +namespace Amazon.Lambda.RuntimeSupport.Client.ResponseStreaming +{ + /// + /// Constants used for Lambda response streaming. + /// + internal static class StreamingConstants + { + /// + /// Header name for Lambda response mode. + /// + public const string ResponseModeHeader = "Lambda-Runtime-Function-Response-Mode"; + + /// + /// Value for streaming response mode. + /// + public const string StreamingResponseMode = "streaming"; + + /// + /// Trailer header name for error type. + /// + public const string ErrorTypeTrailer = "Lambda-Runtime-Function-Error-Type"; + + /// + /// Trailer header name for error body. + /// + public const string ErrorBodyTrailer = "Lambda-Runtime-Function-Error-Body"; + } +} diff --git a/Libraries/src/Amazon.Lambda.RuntimeSupport/Client/RuntimeApiClient.cs b/Libraries/src/Amazon.Lambda.RuntimeSupport/Client/RuntimeApiClient.cs index daa9fff24..0cddfcd2a 100644 --- a/Libraries/src/Amazon.Lambda.RuntimeSupport/Client/RuntimeApiClient.cs +++ b/Libraries/src/Amazon.Lambda.RuntimeSupport/Client/RuntimeApiClient.cs @@ -20,6 +20,7 @@ using System.Threading; using System.Threading.Tasks; using Amazon.Lambda.RuntimeSupport.Bootstrap; +using Amazon.Lambda.RuntimeSupport.Client.ResponseStreaming; namespace Amazon.Lambda.RuntimeSupport { @@ -177,6 +178,34 @@ public Task ReportRestoreErrorAsync(Exception exception, String errorType = null #endif +#if NET8_0_OR_GREATER + /// + /// Start sending a streaming response to the Lambda Runtime API. + /// Uses a raw TCP connection with chunked transfer encoding to support HTTP/1.1 + /// trailing headers for error reporting, which .NET's HttpClient does not support. + /// The actual data is written by the handler via ResponseStream.WriteAsync, which flows + /// through a ChunkedStreamWriter to the TCP connection. + /// This Task completes when the stream is finalized (MarkCompleted or error). + /// + /// The ID of the function request being responded to. + /// The ResponseStream that will provide the streaming data. + /// The optional cancellation token to use. + /// A Task representing the in-flight HTTP POST. The returned IDisposable is the RawStreamingHttpClient that owns the TCP connection. + internal virtual async Task StartStreamingResponseAsync( + string awsRequestId, ResponseStream responseStream, CancellationToken cancellationToken = default) + { + if (awsRequestId == null) throw new ArgumentNullException(nameof(awsRequestId)); + if (responseStream == null) throw new ArgumentNullException(nameof(responseStream)); + + var userAgent = _httpClient.DefaultRequestHeaders.UserAgent.ToString(); + var rawClient = new RawStreamingHttpClient(LambdaEnvironment.RuntimeServerHostAndPort); + + await rawClient.SendStreamingResponseAsync(awsRequestId, responseStream, userAgent, cancellationToken); + + return rawClient; + } +#endif + /// /// Send a response to a function invocation to the Runtime API as an asynchronous operation. /// diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/Amazon.Lambda.RuntimeSupport.IntegrationTests.csproj b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/Amazon.Lambda.RuntimeSupport.IntegrationTests.csproj index 86a3b5c1e..d206a1f1c 100644 --- a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/Amazon.Lambda.RuntimeSupport.IntegrationTests.csproj +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/Amazon.Lambda.RuntimeSupport.IntegrationTests.csproj @@ -1,7 +1,7 @@  - net8.0 + net10.0 @@ -19,19 +19,19 @@ - - - - - - + + + + + + all runtime; build; native; contentfiles; analyzers - + - + diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/BaseCustomRuntimeTest.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/BaseCustomRuntimeTest.cs index c220a671e..314aa45c4 100644 --- a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/BaseCustomRuntimeTest.cs +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/BaseCustomRuntimeTest.cs @@ -17,7 +17,7 @@ public class BaseCustomRuntimeTest { public const int FUNCTION_MEMORY_MB = 512; - protected static readonly RegionEndpoint TestRegion = RegionEndpoint.USWest2; + public static readonly RegionEndpoint TestRegion = RegionEndpoint.USWest2; protected static readonly string LAMBDA_ASSUME_ROLE_POLICY = @" { @@ -63,7 +63,7 @@ protected BaseCustomRuntimeTest(IntegrationTestFixture fixture, string functionN /// /// /// - protected async Task CleanUpTestResources(AmazonS3Client s3Client, AmazonLambdaClient lambdaClient, + public async Task CleanUpTestResources(AmazonS3Client s3Client, AmazonLambdaClient lambdaClient, AmazonIdentityManagementServiceClient iamClient, bool roleAlreadyExisted) { await DeleteFunctionIfExistsAsync(lambdaClient); @@ -109,7 +109,7 @@ await iamClient.DetachRolePolicyAsync(new DetachRolePolicyRequest } } - protected async Task PrepareTestResources(IAmazonS3 s3Client, IAmazonLambda lambdaClient, + public async Task PrepareTestResources(IAmazonS3 s3Client, IAmazonLambda lambdaClient, AmazonIdentityManagementServiceClient iamClient) { var roleAlreadyExisted = await ValidateAndSetIamRoleArn(iamClient); @@ -288,7 +288,7 @@ protected async Task CreateFunctionAsync(IAmazonLambda lambdaClient, string buck Handler = Handler, MemorySize = FUNCTION_MEMORY_MB, Timeout = 30, - Runtime = Runtime.Dotnet6, + Runtime = Runtime.Dotnet10, Role = ExecutionRoleArn }; @@ -351,7 +351,16 @@ private string GetDeploymentZipPath() if (!File.Exists(deploymentZipFile)) { - throw new NoDeploymentPackageFoundException(); + var message = new StringBuilder(); + message.AppendLine($"Deployment package for {DeploymentPackageZipRelativePath} not found at expected path: {deploymentZipFile}"); + message.AppendLine("Available Test Bundles:"); + foreach (var kvp in _fixture.TestAppPaths) + { + message.AppendLine($"{kvp.Key}: {kvp.Value}"); + } + + + throw new NoDeploymentPackageFoundException(message.ToString()); } return deploymentZipFile; @@ -380,7 +389,9 @@ private static string FindUp(string path, string fileOrDirectoryName, bool combi protected class NoDeploymentPackageFoundException : Exception { + public NoDeploymentPackageFoundException() { } + public NoDeploymentPackageFoundException(string message) : base(message) { } } private ApplicationLogLevel ConvertRuntimeLogLevel(RuntimeLogLevel runtimeLogLevel) diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/CustomRuntimeTests.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/CustomRuntimeTests.cs index b548d5ba0..8ab008d66 100644 --- a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/CustomRuntimeTests.cs +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/CustomRuntimeTests.cs @@ -48,7 +48,7 @@ public async Task TestAllNET8HandlersAsync() public class CustomRuntimeTests : BaseCustomRuntimeTest { - public enum TargetFramework { NET6, NET8} + public enum TargetFramework { NET8 } private TargetFramework _targetFramework; diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/Helpers/CommandLineWrapper.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/Helpers/CommandLineWrapper.cs index aa8651eae..ea6fd059e 100644 --- a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/Helpers/CommandLineWrapper.cs +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/Helpers/CommandLineWrapper.cs @@ -1,5 +1,6 @@ using System; using System.Diagnostics; +using System.Text; using System.Threading; using System.Threading.Tasks; using Xunit; @@ -31,6 +32,7 @@ public static async Task Run(string command, string arguments, string workingDir tcs.TrySetResult(true); }; + var output = new StringBuilder(); try { // Attach event handlers @@ -39,6 +41,7 @@ public static async Task Run(string command, string arguments, string workingDir if (!string.IsNullOrEmpty(args.Data)) { Console.WriteLine(args.Data); + output.Append(args.Data); } }; @@ -47,6 +50,7 @@ public static async Task Run(string command, string arguments, string workingDir if (!string.IsNullOrEmpty(args.Data)) { Console.WriteLine(args.Data); + output.Append(args.Data); } }; @@ -78,6 +82,7 @@ public static async Task Run(string command, string arguments, string workingDir catch (Exception ex) { Console.WriteLine("Exception: " + ex); + Console.WriteLine(output.ToString()); if (!process.HasExited) { process.Kill(); @@ -87,4 +92,4 @@ public static async Task Run(string command, string arguments, string workingDir Assert.True(process.ExitCode == 0, $"Command '{command} {arguments}' failed."); } } -} \ No newline at end of file +} diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/Helpers/LambdaToolsHelper.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/Helpers/LambdaToolsHelper.cs index 42a02aac6..154c84f75 100644 --- a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/Helpers/LambdaToolsHelper.cs +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/Helpers/LambdaToolsHelper.cs @@ -10,6 +10,9 @@ public static class LambdaToolsHelper public static string GetTempTestAppDirectory(string workingDirectory, string testAppPath) { +#if DEBUG + return Path.GetFullPath(Path.Combine(workingDirectory, testAppPath)); +#else var customTestAppPath = Path.Combine(Path.GetTempPath(), Path.GetRandomFileName()); Directory.CreateDirectory(customTestAppPath); @@ -17,6 +20,7 @@ public static string GetTempTestAppDirectory(string workingDirectory, string tes CopyDirectory(currentDir, customTestAppPath); return Path.Combine(customTestAppPath, testAppPath); +#endif } public static async Task InstallLambdaTools() @@ -78,4 +82,4 @@ private static void CopyDirectory(DirectoryInfo dir, string destDirName) CopyDirectory(subDir, tempPath); } } -} \ No newline at end of file +} diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/IntegrationTestCollection.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/IntegrationTestCollection.cs index c9ce90e35..6e066eb28 100644 --- a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/IntegrationTestCollection.cs +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/IntegrationTestCollection.cs @@ -2,8 +2,8 @@ namespace Amazon.Lambda.RuntimeSupport.IntegrationTests; -[CollectionDefinition("Integration Tests")] -public class IntegrationTestCollection : ICollectionFixture +[CollectionDefinition("Integration Tests", DisableParallelization = true)] +public class IntegrationTestCollection : ICollectionFixture, ICollectionFixture { -} \ No newline at end of file +} diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/IntegrationTestFixture.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/IntegrationTestFixture.cs index 89d62d61f..b8c71519e 100644 --- a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/IntegrationTestFixture.cs +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/IntegrationTestFixture.cs @@ -14,10 +14,11 @@ public class IntegrationTestFixture : IAsyncLifetime public async Task InitializeAsync() { + var toolPath = await LambdaToolsHelper.InstallLambdaTools(); + var testAppPath = LambdaToolsHelper.GetTempTestAppDirectory( "../../../../../../..", "Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/CustomRuntimeFunctionTest"); - var toolPath = await LambdaToolsHelper.InstallLambdaTools(); _tempPaths.AddRange([testAppPath, toolPath] ); await LambdaToolsHelper.LambdaPackage(toolPath, "net8.0", testAppPath); TestAppPaths[@"CustomRuntimeFunctionTest\bin\Release\net8.0\CustomRuntimeFunctionTest.zip"] = Path.Combine(testAppPath, @"bin\Release\net8.0\CustomRuntimeFunctionTest.zip"); @@ -25,7 +26,6 @@ public async Task InitializeAsync() testAppPath = LambdaToolsHelper.GetTempTestAppDirectory( "../../../../../../..", "Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/CustomRuntimeAspNetCoreMinimalApiTest"); - toolPath = await LambdaToolsHelper.InstallLambdaTools(); _tempPaths.AddRange([testAppPath, toolPath] ); await LambdaToolsHelper.LambdaPackage(toolPath, "net8.0", testAppPath); TestAppPaths[@"CustomRuntimeAspNetCoreMinimalApiTest\bin\Release\net8.0\CustomRuntimeAspNetCoreMinimalApiTest.zip"] = Path.Combine(testAppPath, @"bin\Release\net8.0\CustomRuntimeAspNetCoreMinimalApiTest.zip"); @@ -33,19 +33,27 @@ public async Task InitializeAsync() testAppPath = LambdaToolsHelper.GetTempTestAppDirectory( "../../../../../../..", "Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/CustomRuntimeAspNetCoreMinimalApiCustomSerializerTest"); - toolPath = await LambdaToolsHelper.InstallLambdaTools(); _tempPaths.AddRange([testAppPath, toolPath] ); await LambdaToolsHelper.LambdaPackage(toolPath, "net8.0", testAppPath); TestAppPaths[@"CustomRuntimeAspNetCoreMinimalApiCustomSerializerTest\bin\Release\net8.0\CustomRuntimeAspNetCoreMinimalApiCustomSerializerTest.zip"] = Path.Combine(testAppPath, @"bin\Release\net8.0\CustomRuntimeAspNetCoreMinimalApiCustomSerializerTest.zip"); + + testAppPath = LambdaToolsHelper.GetTempTestAppDirectory( + "../../../../../../..", + "Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/ResponseStreamingFunctionHandlers"); + _tempPaths.AddRange([testAppPath, toolPath]); + await LambdaToolsHelper.LambdaPackage(toolPath, "net10.0", testAppPath); + TestAppPaths[@"ResponseStreamingFunctionHandlers\bin\Release\net10.0\ResponseStreamingFunctionHandlers.zip"] = Path.Combine(testAppPath, "bin", "Release", "net10.0", "ResponseStreamingFunctionHandlers.zip"); } public Task DisposeAsync() { +#if !DEBUG foreach (var tempPath in _tempPaths) { LambdaToolsHelper.CleanUp(tempPath); } +#endif return Task.CompletedTask; } diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/ResponseStreamingTests.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/ResponseStreamingTests.cs new file mode 100644 index 000000000..006df6d15 --- /dev/null +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.IntegrationTests/ResponseStreamingTests.cs @@ -0,0 +1,133 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Amazon.IdentityManagement; +using Amazon.Lambda.Model; +using Amazon.Runtime.EventStreams; +using Amazon.S3; +using Xunit; + +namespace Amazon.Lambda.RuntimeSupport.IntegrationTests +{ + [Collection("Integration Tests")] + public class ResponseStreamingTests : BaseCustomRuntimeTest + { + private readonly static string s_functionName = "IntegTestResponseStreamingFunctionHandlers" + DateTime.Now.Ticks; + + private readonly ResponseStreamingTestsFixture _streamFixture; + + public ResponseStreamingTests(IntegrationTestFixture fixture, ResponseStreamingTestsFixture streamFixture) + : base(fixture, s_functionName, "ResponseStreamingFunctionHandlers.zip", @"ResponseStreamingFunctionHandlers\bin\Release\net10.0\ResponseStreamingFunctionHandlers.zip", "ResponseStreamingFunctionHandlers") + { + _streamFixture = streamFixture; + } + + [Fact] + public async Task SimpleFunctionHandler() + { + await _streamFixture.EnsureResourcesDeployedAsync(this); + + var evnts = await InvokeFunctionAsync(nameof(SimpleFunctionHandler)); + Assert.True(evnts.Any()); + + var content = GetCombinedStreamContent(evnts); + Assert.Equal("Hello, World!", content); + } + + [Fact] + public async Task StreamContentHandler() + { + await _streamFixture.EnsureResourcesDeployedAsync(this); + + var evnts = await InvokeFunctionAsync(nameof(StreamContentHandler)); + Assert.True(evnts.Length > 5); + + var content = GetCombinedStreamContent(evnts); + Assert.Contains("Line 9999", content); + Assert.EndsWith("Finish stream content\n", content); + } + + [Fact] + public async Task UnhandledExceptionHandler() + { + await _streamFixture.EnsureResourcesDeployedAsync(this); + + var evnts = await InvokeFunctionAsync(nameof(UnhandledExceptionHandler)); + Assert.True(evnts.Any()); + + var completeEvent = evnts.Last() as InvokeWithResponseStreamCompleteEvent; + Assert.Equal("InvalidOperationException", completeEvent.ErrorCode); + Assert.Contains("This is an unhandled exception", completeEvent.ErrorDetails); + Assert.Contains("stackTrace", completeEvent.ErrorDetails); + } + + private async Task InvokeFunctionAsync(string handlerScenario) + { + using var client = new AmazonLambdaClient(TestRegion); + + var request = new InvokeWithResponseStreamRequest + { + FunctionName = base.FunctionName, + Payload = new MemoryStream(System.Text.Encoding.UTF8.GetBytes($"\"{handlerScenario}\"")), + InvocationType = ResponseStreamingInvocationType.RequestResponse + }; + + var response = await client.InvokeWithResponseStreamAsync(request); + var evnts = response.EventStream.AsEnumerable().ToArray(); + return evnts; + } + + private string GetCombinedStreamContent(IEventStreamEvent[] events) + { + var sb = new StringBuilder(); + foreach (var evnt in events) + { + if (evnt is InvokeResponseStreamUpdate chunk) + { + var text = System.Text.Encoding.UTF8.GetString(chunk.Payload.ToArray()); + sb.Append(text); + } + } + return sb.ToString(); + } + } + + public class ResponseStreamingTestsFixture : IAsyncLifetime + { + private readonly AmazonLambdaClient _lambdaClient = new AmazonLambdaClient(BaseCustomRuntimeTest.TestRegion); + private readonly AmazonS3Client _s3Client = new AmazonS3Client(BaseCustomRuntimeTest.TestRegion); + private readonly AmazonIdentityManagementServiceClient _iamClient = new AmazonIdentityManagementServiceClient(BaseCustomRuntimeTest.TestRegion); + bool _resourcesCreated; + bool _roleAlreadyExisted; + + ResponseStreamingTests _tests; + + public async Task EnsureResourcesDeployedAsync(ResponseStreamingTests tests) + { + if (_resourcesCreated) + return; + + _tests = tests; + _roleAlreadyExisted = await _tests.PrepareTestResources(_s3Client, _lambdaClient, _iamClient); + + _resourcesCreated = true; + } + + public async Task DisposeAsync() + { + await _tests.CleanUpTestResources(_s3Client, _lambdaClient, _iamClient, _roleAlreadyExisted); + + _lambdaClient.Dispose(); + _s3Client.Dispose(); + _iamClient.Dispose(); + } + + public Task InitializeAsync() => Task.CompletedTask; + } +} diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/HandlerTests.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/HandlerTests.cs index 80f9d13d0..e257b688e 100644 --- a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/HandlerTests.cs +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/HandlerTests.cs @@ -31,7 +31,7 @@ namespace Amazon.Lambda.RuntimeSupport.UnitTests { - [Collection("Bootstrap")] + [Collection("ResponseStreamFactory")] public class HandlerTests { private const string AggregateExceptionTestMarker = "AggregateExceptionTesting"; diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/LambdaBootstrapTests.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/LambdaBootstrapTests.cs index e1636ff16..e7f36a377 100644 --- a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/LambdaBootstrapTests.cs +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/LambdaBootstrapTests.cs @@ -14,12 +14,14 @@ */ using System; using System.Collections.Generic; +using System.IO; using System.Linq; using System.Net.Http; using System.Text; using System.Threading.Tasks; using Xunit; +using Amazon.Lambda.RuntimeSupport.Client.ResponseStreaming; using Amazon.Lambda.RuntimeSupport.Bootstrap; using static Amazon.Lambda.RuntimeSupport.Bootstrap.Constants; @@ -29,6 +31,7 @@ namespace Amazon.Lambda.RuntimeSupport.UnitTests /// Tests to test LambdaBootstrap when it's constructed using its actual constructor. /// Tests of the static GetLambdaBootstrap methods can be found in LambdaBootstrapWrapperTests. /// + [Collection("ResponseStreamFactory")] public class LambdaBootstrapTests { readonly TestHandler _testFunction; @@ -283,5 +286,159 @@ public void IsCallPreJitTest() environmentVariables.SetEnvironmentVariable(ENVIRONMENT_VARIABLE_AWS_LAMBDA_INITIALIZATION_TYPE, AWS_LAMBDA_INITIALIZATION_TYPE_PC); Assert.True(UserCodeInit.IsCallPreJit(environmentVariables)); } + + // --- Streaming Integration Tests --- + + private TestStreamingRuntimeApiClient CreateStreamingClient() + { + var envVars = new TestEnvironmentVariables(); + var headers = new Dictionary> + { + { RuntimeApiHeaders.HeaderAwsRequestId, new List { "streaming-request-id" } }, + { RuntimeApiHeaders.HeaderInvokedFunctionArn, new List { "invoked_function_arn" } }, + { RuntimeApiHeaders.HeaderAwsTenantId, new List { "tenant_id" } } + }; + return new TestStreamingRuntimeApiClient(envVars, headers); + } + + /// + /// Property 2: CreateStream Enables Streaming Mode + /// When a handler calls ResponseStreamFactory.CreateStream(), the response is transmitted + /// using streaming mode. LambdaBootstrap awaits the send task. + /// **Validates: Requirements 1.4, 6.1, 6.2, 6.3, 6.4** + /// + [Fact] + public async Task StreamingMode_HandlerCallsCreateStream_SendTaskAwaited() + { + var streamingClient = CreateStreamingClient(); + + LambdaBootstrapHandler handler = async (invocation) => + { + var stream = ResponseStreamFactory.CreateStream(Array.Empty()); + await stream.WriteAsync(Encoding.UTF8.GetBytes("hello")); + return new InvocationResponse(Stream.Null, false); + }; + + using (var bootstrap = new LambdaBootstrap(handler, null)) + { + bootstrap.Client = streamingClient; + await bootstrap.InvokeOnceAsync(); + } + + Assert.True(streamingClient.StartStreamingResponseAsyncCalled); + Assert.False(streamingClient.SendResponseAsyncCalled); + } + + /// + /// Property 3: Default Mode Is Buffered + /// When a handler does not call ResponseStreamFactory.CreateStream(), the response + /// is transmitted using buffered mode via SendResponseAsync. + /// **Validates: Requirements 1.5, 7.2** + /// + [Fact] + public async Task BufferedMode_HandlerDoesNotCallCreateStream_UsesSendResponse() + { + var streamingClient = CreateStreamingClient(); + + LambdaBootstrapHandler handler = async (invocation) => + { + var outputStream = new MemoryStream(Encoding.UTF8.GetBytes("buffered response")); + return new InvocationResponse(outputStream); + }; + + using (var bootstrap = new LambdaBootstrap(handler, null)) + { + bootstrap.Client = streamingClient; + await bootstrap.InvokeOnceAsync(); + } + + Assert.False(streamingClient.StartStreamingResponseAsyncCalled); + Assert.True(streamingClient.SendResponseAsyncCalled); + } + + /// + /// Property 14: Exception After Writes Uses Trailers + /// When a handler throws an exception after writing data to an IResponseStream, + /// the error is reported via trailers (ReportErrorAsync) rather than standard error reporting. + /// **Validates: Requirements 5.6, 5.7** + /// + [Fact] + public async Task MidstreamError_ExceptionAfterWrites_ReportsViaTrailers() + { + var streamingClient = CreateStreamingClient(); + + LambdaBootstrapHandler handler = async (invocation) => + { + var stream = ResponseStreamFactory.CreateStream(Array.Empty()); + await stream.WriteAsync(Encoding.UTF8.GetBytes("partial data")); + throw new InvalidOperationException("midstream failure"); + }; + + using (var bootstrap = new LambdaBootstrap(handler, null)) + { + bootstrap.Client = streamingClient; + await bootstrap.InvokeOnceAsync(); + } + + // Error should be reported via trailers on the stream, not via standard error reporting + Assert.True(streamingClient.StartStreamingResponseAsyncCalled); + Assert.NotNull(streamingClient.LastStreamingResponseStream); + Assert.True(streamingClient.LastStreamingResponseStream.HasError); + Assert.False(streamingClient.ReportInvocationErrorAsyncExceptionCalled); + } + + /// + /// Property 15: Exception Before CreateStream Uses Standard Error + /// When a handler throws an exception before calling ResponseStreamFactory.CreateStream(), + /// the error is reported using the standard Lambda error reporting mechanism. + /// **Validates: Requirements 5.7, 7.1** + /// + [Fact] + public async Task PreStreamError_ExceptionBeforeCreateStream_UsesStandardErrorReporting() + { + var streamingClient = CreateStreamingClient(); + + LambdaBootstrapHandler handler = async (invocation) => + { + await Task.Yield(); + throw new InvalidOperationException("pre-stream failure"); + }; + + using (var bootstrap = new LambdaBootstrap(handler, null)) + { + bootstrap.Client = streamingClient; + await bootstrap.InvokeOnceAsync(); + } + + Assert.False(streamingClient.StartStreamingResponseAsyncCalled); + Assert.True(streamingClient.ReportInvocationErrorAsyncExceptionCalled); + } + + /// + /// State Isolation: ResponseStreamFactory state is cleared after each invocation. + /// **Validates: Requirements 6.5, 8.9** + /// + [Fact] + public async Task Cleanup_ResponseStreamFactoryStateCleared_AfterInvocation() + { + var streamingClient = CreateStreamingClient(); + + LambdaBootstrapHandler handler = async (invocation) => + { + var stream = ResponseStreamFactory.CreateStream(Array.Empty()); + await stream.WriteAsync(Encoding.UTF8.GetBytes("data")); + return new InvocationResponse(Stream.Null, false); + }; + + using (var bootstrap = new LambdaBootstrap(handler, null)) + { + bootstrap.Client = streamingClient; + await bootstrap.InvokeOnceAsync(); + } + + // After invocation, factory state should be cleaned up + Assert.Null(ResponseStreamFactory.GetStreamIfCreated(false)); + Assert.Null(ResponseStreamFactory.GetSendTask(false)); + } } } diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/LambdaResponseStreamingCoreTests.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/LambdaResponseStreamingCoreTests.cs new file mode 100644 index 000000000..0d5c20c86 --- /dev/null +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/LambdaResponseStreamingCoreTests.cs @@ -0,0 +1,558 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +#if NET8_0_OR_GREATER +#pragma warning disable CA2252 + +using System; +using System.Collections.Generic; +using System.IO; +using System.Net; +using System.Text; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; +using Amazon.Lambda.Core.ResponseStreaming; +using Amazon.Lambda.RuntimeSupport.Client.ResponseStreaming; +using Xunit; + +namespace Amazon.Lambda.RuntimeSupport.UnitTests +{ + // ───────────────────────────────────────────────────────────────────────────── + // HttpResponseStreamPrelude.ToByteArray() tests + // ───────────────────────────────────────────────────────────────────────────── + + public class HttpResponseStreamPreludeTests + { + private static JsonDocument ParsePrelude(HttpResponseStreamPrelude prelude) + => JsonDocument.Parse(prelude.ToByteArray()); + + [Fact] + public void ToByteArray_EmptyPrelude_ProducesEmptyJsonObject() + { + var prelude = new HttpResponseStreamPrelude(); + var doc = ParsePrelude(prelude); + + Assert.Equal(JsonValueKind.Object, doc.RootElement.ValueKind); + // No properties should be present + Assert.False(doc.RootElement.TryGetProperty("statusCode", out _)); + Assert.False(doc.RootElement.TryGetProperty("headers", out _)); + Assert.False(doc.RootElement.TryGetProperty("multiValueHeaders", out _)); + Assert.False(doc.RootElement.TryGetProperty("cookies", out _)); + } + + [Fact] + public void ToByteArray_WithStatusCode_IncludesStatusCode() + { + var prelude = new HttpResponseStreamPrelude { StatusCode = HttpStatusCode.OK }; + var doc = ParsePrelude(prelude); + + Assert.True(doc.RootElement.TryGetProperty("statusCode", out var sc)); + Assert.Equal(200, sc.GetInt32()); + } + + [Fact] + public void ToByteArray_WithHeaders_IncludesHeaders() + { + var prelude = new HttpResponseStreamPrelude + { + Headers = new Dictionary + { + ["Content-Type"] = "application/json", + ["X-Custom"] = "value" + } + }; + var doc = ParsePrelude(prelude); + + Assert.True(doc.RootElement.TryGetProperty("headers", out var headers)); + Assert.Equal("application/json", headers.GetProperty("Content-Type").GetString()); + Assert.Equal("value", headers.GetProperty("X-Custom").GetString()); + } + + [Fact] + public void ToByteArray_WithMultiValueHeaders_IncludesMultiValueHeaders() + { + var prelude = new HttpResponseStreamPrelude + { + MultiValueHeaders = new Dictionary> + { + ["Set-Cookie"] = new List { "a=1", "b=2" } + } + }; + var doc = ParsePrelude(prelude); + + Assert.True(doc.RootElement.TryGetProperty("multiValueHeaders", out var mvh)); + var cookies = mvh.GetProperty("Set-Cookie"); + Assert.Equal(JsonValueKind.Array, cookies.ValueKind); + Assert.Equal(2, cookies.GetArrayLength()); + } + + [Fact] + public void ToByteArray_WithCookies_IncludesCookies() + { + var prelude = new HttpResponseStreamPrelude + { + Cookies = new List { "session=abc", "pref=dark" } + }; + var doc = ParsePrelude(prelude); + + Assert.True(doc.RootElement.TryGetProperty("cookies", out var cookies)); + Assert.Equal(JsonValueKind.Array, cookies.ValueKind); + Assert.Equal(2, cookies.GetArrayLength()); + Assert.Equal("session=abc", cookies[0].GetString()); + } + + [Fact] + public void ToByteArray_AllFieldsPopulated_ProducesCorrectJson() + { + var prelude = new HttpResponseStreamPrelude + { + StatusCode = HttpStatusCode.Created, + Headers = new Dictionary { ["X-Req"] = "1" }, + MultiValueHeaders = new Dictionary> { ["X-Multi"] = new List { "a", "b" } }, + Cookies = new List { "c=1" } + }; + var doc = ParsePrelude(prelude); + + Assert.Equal(201, doc.RootElement.GetProperty("statusCode").GetInt32()); + Assert.Equal("1", doc.RootElement.GetProperty("headers").GetProperty("X-Req").GetString()); + Assert.Equal(2, doc.RootElement.GetProperty("multiValueHeaders").GetProperty("X-Multi").GetArrayLength()); + Assert.Equal("c=1", doc.RootElement.GetProperty("cookies")[0].GetString()); + } + + [Fact] + public void ToByteArray_EmptyCollections_OmitsThoseFields() + { + var prelude = new HttpResponseStreamPrelude + { + StatusCode = HttpStatusCode.OK, + Headers = new Dictionary(), // empty — should be omitted + MultiValueHeaders = new Dictionary>(), // empty + Cookies = new List() // empty + }; + var doc = ParsePrelude(prelude); + + Assert.True(doc.RootElement.TryGetProperty("statusCode", out _)); + Assert.False(doc.RootElement.TryGetProperty("headers", out _)); + Assert.False(doc.RootElement.TryGetProperty("multiValueHeaders", out _)); + Assert.False(doc.RootElement.TryGetProperty("cookies", out _)); + } + + [Fact] + public void ToByteArray_ProducesValidUtf8() + { + var prelude = new HttpResponseStreamPrelude + { + StatusCode = HttpStatusCode.OK, + Headers = new Dictionary { ["Content-Type"] = "text/plain; charset=utf-8" } + }; + var bytes = prelude.ToByteArray(); + + // Should not throw + var text = Encoding.UTF8.GetString(bytes); + Assert.NotEmpty(text); + } + } + + // ───────────────────────────────────────────────────────────────────────────── + // LambdaResponseStream (Stream subclass) tests + // ───────────────────────────────────────────────────────────────────────────── + + public class LambdaResponseStreamTests + { + /// + /// Creates a LambdaResponseStream backed by a real ResponseStream wired to a MemoryStream. + /// + private static async Task<(LambdaResponseStream lambdaStream, MemoryStream httpOutput)> CreateWiredLambdaStream() + { + var inner = new ResponseStream(Array.Empty()); + var output = new MemoryStream(); + await inner.SetHttpOutputStreamAsync(output); + + var implStream = new ResponseStreamLambdaCoreInitializerIsolated.ImplLambdaResponseStream(inner); + var lambdaStream = new LambdaResponseStream(implStream); + return (lambdaStream, output); + } + + [Fact] + public void LambdaResponseStream_IsStreamSubclass() + { + var inner = new ResponseStream(Array.Empty()); + var impl = new ResponseStreamLambdaCoreInitializerIsolated.ImplLambdaResponseStream(inner); + var stream = new LambdaResponseStream(impl); + + Assert.IsAssignableFrom(stream); + } + + [Fact] + public void CanWrite_IsTrue() + { + var inner = new ResponseStream(Array.Empty()); + var impl = new ResponseStreamLambdaCoreInitializerIsolated.ImplLambdaResponseStream(inner); + var stream = new LambdaResponseStream(impl); + + Assert.True(stream.CanWrite); + } + + [Fact] + public void CanRead_IsFalse() + { + var inner = new ResponseStream(Array.Empty()); + var impl = new ResponseStreamLambdaCoreInitializerIsolated.ImplLambdaResponseStream(inner); + var stream = new LambdaResponseStream(impl); + + Assert.False(stream.CanRead); + } + + [Fact] + public void CanSeek_IsFalse() + { + var inner = new ResponseStream(Array.Empty()); + var impl = new ResponseStreamLambdaCoreInitializerIsolated.ImplLambdaResponseStream(inner); + var stream = new LambdaResponseStream(impl); + + Assert.False(stream.CanSeek); + } + + [Fact] + public void Read_ThrowsNotImplementedException() + { + var inner = new ResponseStream(Array.Empty()); + var impl = new ResponseStreamLambdaCoreInitializerIsolated.ImplLambdaResponseStream(inner); + var stream = new LambdaResponseStream(impl); + + Assert.Throws(() => stream.Read(new byte[1], 0, 1)); + } + + [Fact] + public void ReadAsync_ThrowsNotImplementedException() + { + var inner = new ResponseStream(Array.Empty()); + var impl = new ResponseStreamLambdaCoreInitializerIsolated.ImplLambdaResponseStream(inner); + var stream = new LambdaResponseStream(impl); + + // ReadAsync throws synchronously (not async) — capture the thrown task + var ex = Assert.Throws( + () => { var _ = stream.ReadAsync(new byte[1], 0, 1, CancellationToken.None); }); + Assert.NotNull(ex); + } + + [Fact] + public void Seek_ThrowsNotImplementedException() + { + var inner = new ResponseStream(Array.Empty()); + var impl = new ResponseStreamLambdaCoreInitializerIsolated.ImplLambdaResponseStream(inner); + var stream = new LambdaResponseStream(impl); + + Assert.Throws(() => stream.Seek(0, SeekOrigin.Begin)); + } + + [Fact] + public void Position_Get_ThrowsNotSupportedException() + { + var inner = new ResponseStream(Array.Empty()); + var impl = new ResponseStreamLambdaCoreInitializerIsolated.ImplLambdaResponseStream(inner); + var stream = new LambdaResponseStream(impl); + + Assert.Throws(() => _ = stream.Position); + } + + [Fact] + public void Position_Set_ThrowsNotSupportedException() + { + var inner = new ResponseStream(Array.Empty()); + var impl = new ResponseStreamLambdaCoreInitializerIsolated.ImplLambdaResponseStream(inner); + var stream = new LambdaResponseStream(impl); + + Assert.Throws(() => stream.Position = 0); + } + + [Fact] + public void SetLength_ThrowsNotSupportedException() + { + var inner = new ResponseStream(Array.Empty()); + var impl = new ResponseStreamLambdaCoreInitializerIsolated.ImplLambdaResponseStream(inner); + var stream = new LambdaResponseStream(impl); + + Assert.Throws(() => stream.SetLength(100)); + } + + [Fact] + public async Task WriteAsync_WritesRawBytesToHttpStream() + { + var (stream, output) = await CreateWiredLambdaStream(); + var data = Encoding.UTF8.GetBytes("hello streaming"); + + await stream.WriteAsync(data, 0, data.Length); + + Assert.Equal(data, output.ToArray()); + } + + [Fact] + public async Task Write_SyncOverload_WritesRawBytes() + { + var (stream, output) = await CreateWiredLambdaStream(); + var data = new byte[] { 1, 2, 3 }; + + stream.Write(data, 0, data.Length); + + Assert.Equal(data, output.ToArray()); + } + + [Fact] + public async Task Length_ReflectsBytesWritten() + { + var (stream, _) = await CreateWiredLambdaStream(); + var data = new byte[42]; + + await stream.WriteAsync(data, 0, data.Length); + + Assert.Equal(42, stream.Length); + Assert.Equal(42, stream.BytesWritten); + } + + [Fact] + public async Task Flush_IsNoOp() + { + var (stream, _) = await CreateWiredLambdaStream(); + // Should not throw + stream.Flush(); + } + + [Fact] + public async Task WriteAsync_ByteArrayOverload_WritesFullArray() + { + var (stream, output) = await CreateWiredLambdaStream(); + var data = new byte[] { 0xDE, 0xAD, 0xBE, 0xEF }; + + await stream.WriteAsync(data); + + Assert.Equal(data, output.ToArray()); + } + } + + // ───────────────────────────────────────────────────────────────────────────── + // ImplLambdaResponseStream (bridge class) tests + // ───────────────────────────────────────────────────────────────────────────── + + public class ImplLambdaResponseStreamTests + { + [Fact] + public async Task WriteAsync_DelegatesToInnerResponseStream() + { + var inner = new ResponseStream(Array.Empty()); + var output = new MemoryStream(); + await inner.SetHttpOutputStreamAsync(output); + + var impl = new ResponseStreamLambdaCoreInitializerIsolated.ImplLambdaResponseStream(inner); + var data = new byte[] { 1, 2, 3 }; + + await impl.WriteAsync(data, 0, data.Length); + + Assert.Equal(data, output.ToArray()); + } + + [Fact] + public async Task BytesWritten_ReflectsInnerStreamBytesWritten() + { + var inner = new ResponseStream(Array.Empty()); + var output = new MemoryStream(); + await inner.SetHttpOutputStreamAsync(output); + + var impl = new ResponseStreamLambdaCoreInitializerIsolated.ImplLambdaResponseStream(inner); + await impl.WriteAsync(new byte[7], 0, 7); + + Assert.Equal(7, impl.BytesWritten); + } + + [Fact] + public void HasError_InitiallyFalse() + { + var inner = new ResponseStream(Array.Empty()); + var impl = new ResponseStreamLambdaCoreInitializerIsolated.ImplLambdaResponseStream(inner); + + Assert.False(impl.HasError); + } + + [Fact] + public void HasError_TrueAfterReportError() + { + var inner = new ResponseStream(Array.Empty()); + inner.ReportError(new Exception("test")); + + var impl = new ResponseStreamLambdaCoreInitializerIsolated.ImplLambdaResponseStream(inner); + + Assert.True(impl.HasError); + } + + [Fact] + public void Dispose_DisposesInnerStream() + { + var inner = new ResponseStream(Array.Empty()); + var impl = new ResponseStreamLambdaCoreInitializerIsolated.ImplLambdaResponseStream(inner); + + // Should not throw + impl.Dispose(); + } + } + + // ───────────────────────────────────────────────────────────────────────────── + // LambdaResponseStreamFactory tests + // ───────────────────────────────────────────────────────────────────────────── + + [Collection("ResponseStreamFactory")] + public class LambdaResponseStreamFactoryTests : IDisposable + { + + public LambdaResponseStreamFactoryTests() + { + // Wire up the factory via the initializer (same as production bootstrap does) + ResponseStreamLambdaCoreInitializerIsolated.InitializeCore(); + } + + public void Dispose() + { + ResponseStreamFactory.CleanupInvocation(isMultiConcurrency: false); + } + + private void InitializeInvocation(string requestId = "test-req") + { + var envVars = new TestEnvironmentVariables(); + var client = new NoOpStreamingRuntimeApiClient(envVars); + ResponseStreamFactory.InitializeInvocation(requestId, false, client, CancellationToken.None); + } + + /// + /// Minimal RuntimeApiClient that accepts StartStreamingResponseAsync without real HTTP. + /// + private class NoOpStreamingRuntimeApiClient : RuntimeApiClient + { + public NoOpStreamingRuntimeApiClient(IEnvironmentVariables envVars) + : base(envVars, new TestHelpers.NoOpInternalRuntimeApiClient()) { } + + internal override async Task StartStreamingResponseAsync( + string awsRequestId, ResponseStream responseStream, CancellationToken cancellationToken = default) + { + // Provide the HTTP output stream so writes don't block + await responseStream.SetHttpOutputStreamAsync(new MemoryStream(), cancellationToken); + await responseStream.WaitForCompletionAsync(cancellationToken); + return new NoOpDisposable(); + } + } + + [Fact] + public void CreateStream_ReturnsLambdaResponseStream() + { + InitializeInvocation(); + + var stream = LambdaResponseStreamFactory.CreateStream(); + + Assert.NotNull(stream); + Assert.IsType(stream); + } + + [Fact] + public void CreateStream_ReturnsStreamSubclass() + { + InitializeInvocation(); + + var stream = LambdaResponseStreamFactory.CreateStream(); + + Assert.IsAssignableFrom(stream); + } + + [Fact] + public void CreateStream_ReturnedStream_IsWritable() + { + InitializeInvocation(); + + var stream = LambdaResponseStreamFactory.CreateStream(); + + Assert.True(stream.CanWrite); + } + + [Fact] + public void CreateStream_ReturnedStream_IsNotSeekable() + { + InitializeInvocation(); + + var stream = LambdaResponseStreamFactory.CreateStream(); + + Assert.False(stream.CanSeek); + } + + [Fact] + public void CreateStream_ReturnedStream_IsNotReadable() + { + InitializeInvocation(); + + var stream = LambdaResponseStreamFactory.CreateStream(); + + Assert.False(stream.CanRead); + } + + [Fact] + public void CreateHttpStream_WithPrelude_ReturnsLambdaResponseStream() + { + InitializeInvocation(); + + var prelude = new HttpResponseStreamPrelude { StatusCode = HttpStatusCode.OK }; + var stream = LambdaResponseStreamFactory.CreateHttpStream(prelude); + + Assert.NotNull(stream); + Assert.IsType(stream); + } + + [Fact] + public void CreateHttpStream_PassesSerializedPreludeToFactory() + { + // Capture the prelude bytes passed to the inner factory + byte[] capturedPrelude = null; + LambdaResponseStreamFactory.SetLambdaResponseStream(prelude => + { + capturedPrelude = prelude; + // Return a minimal stub that satisfies the interface + return new StubLambdaResponseStream(); + }); + + var httpPrelude = new HttpResponseStreamPrelude + { + StatusCode = HttpStatusCode.Created, + Headers = new Dictionary { ["X-Test"] = "1" } + }; + LambdaResponseStreamFactory.CreateHttpStream(httpPrelude); + + Assert.NotNull(capturedPrelude); + Assert.True(capturedPrelude.Length > 0); + + // Verify the bytes are valid JSON containing the status code + var doc = JsonDocument.Parse(capturedPrelude); + Assert.Equal(201, doc.RootElement.GetProperty("statusCode").GetInt32()); + } + + [Fact] + public void CreateStream_PassesEmptyPreludeToFactory() + { + byte[] capturedPrelude = null; + LambdaResponseStreamFactory.SetLambdaResponseStream(prelude => + { + capturedPrelude = prelude; + return new StubLambdaResponseStream(); + }); + + LambdaResponseStreamFactory.CreateStream(); + + Assert.NotNull(capturedPrelude); + Assert.Empty(capturedPrelude); + } + + private class StubLambdaResponseStream : ILambdaResponseStream + { + public long BytesWritten => 0; + public bool HasError => false; + public void Dispose() { } + public Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default) + => Task.CompletedTask; + } + } +} +#endif diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/RawStreamingHttpClientTests.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/RawStreamingHttpClientTests.cs new file mode 100644 index 000000000..e203d6968 --- /dev/null +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/RawStreamingHttpClientTests.cs @@ -0,0 +1,502 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +#if NET8_0_OR_GREATER + +using System; +using System.IO; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Amazon.Lambda.RuntimeSupport.Client.ResponseStreaming; +using Xunit; + +namespace Amazon.Lambda.RuntimeSupport.UnitTests +{ + // ───────────────────────────────────────────────────────────────────────────── + // RawStreamingHttpClient tests + // ───────────────────────────────────────────────────────────────────────────── + + public class RawStreamingHttpClientTests + { + // --- Constructor / host parsing --- + + [Fact] + public void Constructor_HostAndPort_ParsedCorrectly() + { + using var client = new RawStreamingHttpClient("localhost:9001"); + // No exception means parsing succeeded. Fields are private but + // we verify indirectly via Dispose not throwing. + } + + [Fact] + public void Constructor_HostOnly_DefaultsToPort80() + { + using var client = new RawStreamingHttpClient("localhost"); + // Should not throw — defaults port to 80 + } + + [Fact] + public void Constructor_HighPort_ParsedCorrectly() + { + using var client = new RawStreamingHttpClient("127.0.0.1:65535"); + } + + // --- Dispose --- + + [Fact] + public void Dispose_CalledTwice_DoesNotThrow() + { + var client = new RawStreamingHttpClient("localhost:9001"); + client.Dispose(); + client.Dispose(); + } + + [Fact] + public void Dispose_WithoutConnect_DoesNotThrow() + { + var client = new RawStreamingHttpClient("localhost:9001"); + client.Dispose(); + } + } + + // ───────────────────────────────────────────────────────────────────────────── + // WriteTerminatorWithTrailersAsync tests + // ───────────────────────────────────────────────────────────────────────────── + + public class WriteTerminatorWithTrailersAsyncTests + { + private static (RawStreamingHttpClient client, MemoryStream output) CreateClientWithMemoryStream() + { + var client = new RawStreamingHttpClient("localhost:9001"); + var output = new MemoryStream(); + client._networkStream = output; + return (client, output); + } + + [Fact] + public async Task WriteTerminator_StartsWithZeroChunk() + { + var (client, output) = CreateClientWithMemoryStream(); + + await client.WriteTerminatorWithTrailersAsync( + new Exception("test"), CancellationToken.None); + + var written = Encoding.UTF8.GetString(output.ToArray()); + Assert.StartsWith("0\r\n", written); + } + + [Fact] + public async Task WriteTerminator_ContainsErrorTypeTrailer() + { + var (client, output) = CreateClientWithMemoryStream(); + + await client.WriteTerminatorWithTrailersAsync( + new InvalidOperationException("bad op"), CancellationToken.None); + + var written = Encoding.UTF8.GetString(output.ToArray()); + Assert.Contains($"{StreamingConstants.ErrorTypeTrailer}: InvalidOperationException\r\n", written); + } + + [Fact] + public async Task WriteTerminator_ContainsErrorBodyTrailerHeader() + { + var (client, output) = CreateClientWithMemoryStream(); + + await client.WriteTerminatorWithTrailersAsync( + new Exception("some error"), CancellationToken.None); + + var written = Encoding.UTF8.GetString(output.ToArray()); + Assert.Contains($"{StreamingConstants.ErrorBodyTrailer}: ", written); + } + + [Fact] + public async Task WriteTerminator_ErrorBodyIsBase64Encoded() + { + var (client, output) = CreateClientWithMemoryStream(); + const string errorMessage = "something broke"; + + await client.WriteTerminatorWithTrailersAsync( + new Exception(errorMessage), CancellationToken.None); + + var written = Encoding.UTF8.GetString(output.ToArray()); + + // Extract the Base64 value from the error body trailer + var prefix = $"{StreamingConstants.ErrorBodyTrailer}: "; + var start = written.IndexOf(prefix, StringComparison.Ordinal) + prefix.Length; + var end = written.IndexOf("\r\n", start, StringComparison.Ordinal); + var base64Value = written.Substring(start, end - start); + + // Should be valid Base64 + var decoded = Encoding.UTF8.GetString(Convert.FromBase64String(base64Value)); + Assert.Contains(errorMessage, decoded); + } + + [Fact] + public async Task WriteTerminator_ErrorBodyBase64ContainsNoNewlines() + { + var (client, output) = CreateClientWithMemoryStream(); + + // Use an exception with a stack trace that would produce multi-line JSON + Exception caughtException; + try { throw new InvalidOperationException("multi\nline\nerror"); } + catch (Exception ex) { caughtException = ex; } + + await client.WriteTerminatorWithTrailersAsync( + caughtException, CancellationToken.None); + + var written = Encoding.UTF8.GetString(output.ToArray()); + + // Extract just the error body trailer line + var prefix = $"{StreamingConstants.ErrorBodyTrailer}: "; + var start = written.IndexOf(prefix, StringComparison.Ordinal) + prefix.Length; + var end = written.IndexOf("\r\n", start, StringComparison.Ordinal); + var base64Value = written.Substring(start, end - start); + + // The Base64 value itself must not contain any newlines + Assert.DoesNotContain("\n", base64Value); + Assert.DoesNotContain("\r", base64Value); + } + + [Fact] + public async Task WriteTerminator_EndsWithEmptyLine() + { + var (client, output) = CreateClientWithMemoryStream(); + + await client.WriteTerminatorWithTrailersAsync( + new Exception("test"), CancellationToken.None); + + var written = Encoding.UTF8.GetString(output.ToArray()); + // Must end with \r\n\r\n — the last trailer line's \r\n plus the empty terminator line + Assert.EndsWith("\r\n\r\n", written); + } + + [Fact] + public async Task WriteTerminator_CorrectWireFormat() + { + var (client, output) = CreateClientWithMemoryStream(); + + await client.WriteTerminatorWithTrailersAsync( + new ArgumentException("bad arg"), CancellationToken.None); + + var written = Encoding.UTF8.GetString(output.ToArray()); + var lines = written.Split("\r\n"); + + // Line 0: "0" (zero-length chunk) + Assert.Equal("0", lines[0]); + // Line 1: error type trailer + Assert.StartsWith($"{StreamingConstants.ErrorTypeTrailer}: ", lines[1]); + // Line 2: error body trailer (Base64) + Assert.StartsWith($"{StreamingConstants.ErrorBodyTrailer}: ", lines[2]); + // Line 3: empty (end of trailers) + Assert.Equal("", lines[3]); + } + } + + // ───────────────────────────────────────────────────────────────────────────── + // ReadAndDiscardResponseAsync tests + // ───────────────────────────────────────────────────────────────────────────── + + public class ReadAndDiscardResponseAsyncTests + { + private static (RawStreamingHttpClient client, MemoryStream input) CreateClientWithResponse(string httpResponse) + { + var client = new RawStreamingHttpClient("localhost:9001"); + var input = new MemoryStream(Encoding.ASCII.GetBytes(httpResponse)); + client._networkStream = input; + return (client, input); + } + + [Fact] + public async Task ReadAndDiscard_HeadersOnly_CompletesSuccessfully() + { + var (client, _) = CreateClientWithResponse( + "HTTP/1.1 202 Accepted\r\nContent-Length: 0\r\n\r\n"); + + await client.ReadAndDiscardResponseAsync(CancellationToken.None); + // Should complete without error + } + + [Fact] + public async Task ReadAndDiscard_WithBody_ReadsFullBody() + { + var body = "OK"; + var (client, _) = CreateClientWithResponse( + $"HTTP/1.1 200 OK\r\nContent-Length: {body.Length}\r\n\r\n{body}"); + + await client.ReadAndDiscardResponseAsync(CancellationToken.None); + } + + [Fact] + public async Task ReadAndDiscard_NoContentLength_CompletesAfterHeaders() + { + var (client, _) = CreateClientWithResponse( + "HTTP/1.1 202 Accepted\r\n\r\n"); + + await client.ReadAndDiscardResponseAsync(CancellationToken.None); + } + + [Fact] + public async Task ReadAndDiscard_EmptyStream_CompletesSuccessfully() + { + var client = new RawStreamingHttpClient("localhost:9001"); + client._networkStream = new MemoryStream(Array.Empty()); + + await client.ReadAndDiscardResponseAsync(CancellationToken.None); + } + + [Fact] + public async Task ReadAndDiscard_PartialBody_WaitsForFullBody() + { + // Content-Length says 10 but we provide all 10 bytes + var body = "0123456789"; + var (client, _) = CreateClientWithResponse( + $"HTTP/1.1 200 OK\r\nContent-Length: 10\r\n\r\n{body}"); + + await client.ReadAndDiscardResponseAsync(CancellationToken.None); + } + + [Fact] + public async Task ReadAndDiscard_CancellationToken_Respected() + { + // Use a stream that blocks on read to test cancellation + var cts = new CancellationTokenSource(); + cts.Cancel(); + + var client = new RawStreamingHttpClient("localhost:9001"); + client._networkStream = new MemoryStream(Encoding.ASCII.GetBytes( + "HTTP/1.1 200 OK\r\nContent-Length: 100\r\n\r\n")); + + // Should not throw — ReadAndDiscardResponseAsync catches exceptions + await client.ReadAndDiscardResponseAsync(cts.Token); + } + } + + // ───────────────────────────────────────────────────────────────────────────── + // ChunkedStreamWriter tests + // ───────────────────────────────────────────────────────────────────────────── + + public class ChunkedStreamWriterTests + { + [Fact] + public void CanWrite_IsTrue() + { + using var inner = new MemoryStream(); + using var writer = new ChunkedStreamWriter(inner); + Assert.True(writer.CanWrite); + } + + [Fact] + public void CanRead_IsFalse() + { + using var inner = new MemoryStream(); + using var writer = new ChunkedStreamWriter(inner); + Assert.False(writer.CanRead); + } + + [Fact] + public void CanSeek_IsFalse() + { + using var inner = new MemoryStream(); + using var writer = new ChunkedStreamWriter(inner); + Assert.False(writer.CanSeek); + } + + [Fact] + public void Constructor_NullStream_ThrowsArgumentNullException() + { + Assert.Throws(() => new ChunkedStreamWriter(null)); + } + + [Fact] + public void Length_ThrowsNotSupportedException() + { + using var inner = new MemoryStream(); + using var writer = new ChunkedStreamWriter(inner); + Assert.Throws(() => writer.Length); + } + + [Fact] + public void Position_Get_ThrowsNotSupportedException() + { + using var inner = new MemoryStream(); + using var writer = new ChunkedStreamWriter(inner); + Assert.Throws(() => writer.Position); + } + + [Fact] + public void Position_Set_ThrowsNotSupportedException() + { + using var inner = new MemoryStream(); + using var writer = new ChunkedStreamWriter(inner); + Assert.Throws(() => writer.Position = 0); + } + + [Fact] + public void Read_ThrowsNotSupportedException() + { + using var inner = new MemoryStream(); + using var writer = new ChunkedStreamWriter(inner); + Assert.Throws(() => writer.Read(new byte[1], 0, 1)); + } + + [Fact] + public void Seek_ThrowsNotSupportedException() + { + using var inner = new MemoryStream(); + using var writer = new ChunkedStreamWriter(inner); + Assert.Throws(() => writer.Seek(0, SeekOrigin.Begin)); + } + + [Fact] + public void SetLength_ThrowsNotSupportedException() + { + using var inner = new MemoryStream(); + using var writer = new ChunkedStreamWriter(inner); + Assert.Throws(() => writer.SetLength(0)); + } + + [Fact] + public async Task WriteAsync_ByteArray_ProducesCorrectChunkFormat() + { + using var inner = new MemoryStream(); + using var writer = new ChunkedStreamWriter(inner); + + var data = Encoding.UTF8.GetBytes("Hello"); + await writer.WriteAsync(data, 0, data.Length); + + var output = Encoding.ASCII.GetString(inner.ToArray()); + // "Hello" is 5 bytes = 0x5 + Assert.Equal("5\r\nHello\r\n", output); + } + + [Fact] + public async Task WriteAsync_ReadOnlyMemory_ProducesCorrectChunkFormat() + { + using var inner = new MemoryStream(); + using var writer = new ChunkedStreamWriter(inner); + + var data = Encoding.UTF8.GetBytes("Hi"); + await writer.WriteAsync(new ReadOnlyMemory(data)); + + var output = Encoding.ASCII.GetString(inner.ToArray()); + Assert.Equal("2\r\nHi\r\n", output); + } + + [Fact] + public async Task WriteAsync_ZeroBytes_WritesNothing() + { + using var inner = new MemoryStream(); + using var writer = new ChunkedStreamWriter(inner); + + await writer.WriteAsync(Array.Empty(), 0, 0); + + Assert.Equal(0, inner.Length); + } + + [Fact] + public async Task WriteAsync_ReadOnlyMemory_ZeroBytes_WritesNothing() + { + using var inner = new MemoryStream(); + using var writer = new ChunkedStreamWriter(inner); + + await writer.WriteAsync(ReadOnlyMemory.Empty); + + Assert.Equal(0, inner.Length); + } + + [Fact] + public async Task WriteAsync_MultipleChunks_EachCorrectlyFormatted() + { + using var inner = new MemoryStream(); + using var writer = new ChunkedStreamWriter(inner); + + await writer.WriteAsync(Encoding.UTF8.GetBytes("AB"), 0, 2); + await writer.WriteAsync(Encoding.UTF8.GetBytes("CDE"), 0, 3); + + var output = Encoding.ASCII.GetString(inner.ToArray()); + Assert.Equal("2\r\nAB\r\n3\r\nCDE\r\n", output); + } + + [Fact] + public async Task WriteAsync_LargeChunk_HexSizeCorrect() + { + using var inner = new MemoryStream(); + using var writer = new ChunkedStreamWriter(inner); + + var data = new byte[256]; + Array.Fill(data, (byte)'X'); + await writer.WriteAsync(data, 0, data.Length); + + var output = Encoding.ASCII.GetString(inner.ToArray()); + // 256 = 0x100 + Assert.StartsWith("100\r\n", output); + Assert.EndsWith("\r\n", output); + } + + [Fact] + public async Task WriteAsync_WithOffset_WritesCorrectSlice() + { + using var inner = new MemoryStream(); + using var writer = new ChunkedStreamWriter(inner); + + var data = Encoding.UTF8.GetBytes("ABCDE"); + await writer.WriteAsync(data, 1, 3); // "BCD" + + var output = Encoding.ASCII.GetString(inner.ToArray()); + Assert.Equal("3\r\nBCD\r\n", output); + } + + [Fact] + public void Write_Sync_ProducesCorrectChunkFormat() + { + using var inner = new MemoryStream(); + using var writer = new ChunkedStreamWriter(inner); + + var data = Encoding.UTF8.GetBytes("OK"); + writer.Write(data, 0, data.Length); + + var output = Encoding.ASCII.GetString(inner.ToArray()); + Assert.Equal("2\r\nOK\r\n", output); + } + + [Fact] + public async Task FlushAsync_DelegatesToInnerStream() + { + var flushCalled = false; + var inner = new FlushTrackingStream(() => flushCalled = true); + using var writer = new ChunkedStreamWriter(inner); + + await writer.FlushAsync(CancellationToken.None); + + Assert.True(flushCalled); + } + + [Fact] + public void Flush_DelegatesToInnerStream() + { + var flushCalled = false; + var inner = new FlushTrackingStream(() => flushCalled = true); + using var writer = new ChunkedStreamWriter(inner); + + writer.Flush(); + + Assert.True(flushCalled); + } + + /// + /// A minimal writable stream that tracks Flush calls. + /// + private class FlushTrackingStream : MemoryStream + { + private readonly Action _onFlush; + public FlushTrackingStream(Action onFlush) => _onFlush = onFlush; + public override void Flush() { _onFlush(); base.Flush(); } + public override Task FlushAsync(CancellationToken cancellationToken) + { + _onFlush(); + return base.FlushAsync(cancellationToken); + } + } + } +} +#endif diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/ResponseStreamFactoryTests.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/ResponseStreamFactoryTests.cs new file mode 100644 index 000000000..cc9a19af2 --- /dev/null +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/ResponseStreamFactoryTests.cs @@ -0,0 +1,284 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +using System; +using System.Threading; +using System.Threading.Tasks; +using Amazon.Lambda.RuntimeSupport.Client.ResponseStreaming; +using Xunit; + +namespace Amazon.Lambda.RuntimeSupport.UnitTests +{ + [Collection("ResponseStreamFactory")] + public class ResponseStreamFactoryTests : IDisposable + { + private const long MaxResponseSize = 20 * 1024 * 1024; + + public void Dispose() + { + // Clean up both modes to avoid test pollution + ResponseStreamFactory.CleanupInvocation(isMultiConcurrency: false); + ResponseStreamFactory.CleanupInvocation(isMultiConcurrency: true); + } + + /// + /// A minimal RuntimeApiClient subclass for testing that overrides StartStreamingResponseAsync + /// to avoid real HTTP calls while tracking invocations. + /// + private class MockStreamingRuntimeApiClient : RuntimeApiClient + { + public bool StartStreamingCalled { get; private set; } + public string LastAwsRequestId { get; private set; } + public ResponseStream LastResponseStream { get; private set; } + public TaskCompletionSource SendTaskCompletion { get; } = new TaskCompletionSource(); + + public MockStreamingRuntimeApiClient() + : base(new TestEnvironmentVariables(), new TestHelpers.NoOpInternalRuntimeApiClient()) + { + } + + internal override async Task StartStreamingResponseAsync( + string awsRequestId, ResponseStream responseStream, CancellationToken cancellationToken = default) + { + StartStreamingCalled = true; + LastAwsRequestId = awsRequestId; + LastResponseStream = responseStream; + await SendTaskCompletion.Task; + return new NoOpDisposable(); + } + } + + private void InitializeWithMock(string requestId, bool isMultiConcurrency, MockStreamingRuntimeApiClient mockClient) + { + ResponseStreamFactory.InitializeInvocation( + requestId, isMultiConcurrency, + mockClient, CancellationToken.None); + } + + // --- Property 1: CreateStream Returns Valid Stream --- + + /// + /// Property 1: CreateStream Returns Valid Stream - on-demand mode. + /// Validates: Requirements 1.3, 2.2, 2.3 + /// + [Fact] + public void CreateStream_OnDemandMode_ReturnsValidStream() + { + var mock = new MockStreamingRuntimeApiClient(); + InitializeWithMock("req-1", isMultiConcurrency: false, mock); + + var stream = ResponseStreamFactory.CreateStream(Array.Empty()); + + Assert.NotNull(stream); + Assert.IsAssignableFrom(stream); + } + + /// + /// Property 1: CreateStream Returns Valid Stream - multi-concurrency mode. + /// Validates: Requirements 1.3, 2.2, 2.3 + /// + [Fact] + public void CreateStream_MultiConcurrencyMode_ReturnsValidStream() + { + var mock = new MockStreamingRuntimeApiClient(); + InitializeWithMock("req-2", isMultiConcurrency: true, mock); + + var stream = ResponseStreamFactory.CreateStream(Array.Empty()); + + Assert.NotNull(stream); + Assert.IsAssignableFrom(stream); + } + + // --- Property 4: Single Stream Per Invocation --- + + /// + /// Property 4: Single Stream Per Invocation - calling CreateStream twice throws. + /// Validates: Requirements 2.5, 2.6 + /// + [Fact] + public void CreateStream_CalledTwice_ThrowsInvalidOperationException() + { + var mock = new MockStreamingRuntimeApiClient(); + InitializeWithMock("req-3", isMultiConcurrency: false, mock); + ResponseStreamFactory.CreateStream(Array.Empty()); + + Assert.Throws(() => ResponseStreamFactory.CreateStream(Array.Empty())); + } + + [Fact] + public void CreateStream_OutsideInvocationContext_ThrowsInvalidOperationException() + { + // No InitializeInvocation called + Assert.Throws(() => ResponseStreamFactory.CreateStream(Array.Empty())); + } + + // --- CreateStream starts HTTP POST --- + + /// + /// Validates that CreateStream calls StartStreamingResponseAsync on the RuntimeApiClient. + /// Validates: Requirements 1.3, 1.4, 2.2, 2.3, 2.4 + /// + [Fact] + public void CreateStream_CallsStartStreamingResponseAsync() + { + var mock = new MockStreamingRuntimeApiClient(); + InitializeWithMock("req-start", isMultiConcurrency: false, mock); + + ResponseStreamFactory.CreateStream(Array.Empty()); + + Assert.True(mock.StartStreamingCalled); + Assert.Equal("req-start", mock.LastAwsRequestId); + Assert.NotNull(mock.LastResponseStream); + } + + // --- GetSendTask --- + + /// + /// Validates that GetSendTask returns the task from the HTTP POST. + /// Validates: Requirements 5.1, 7.3 + /// + [Fact] + public void GetSendTask_AfterCreateStream_ReturnsNonNullTask() + { + var mock = new MockStreamingRuntimeApiClient(); + InitializeWithMock("req-send", isMultiConcurrency: false, mock); + + ResponseStreamFactory.CreateStream(Array.Empty()); + + var sendTask = ResponseStreamFactory.GetSendTask(isMultiConcurrency: false); + Assert.NotNull(sendTask); + } + + [Fact] + public void GetSendTask_BeforeCreateStream_ReturnsNull() + { + var mock = new MockStreamingRuntimeApiClient(); + InitializeWithMock("req-nosend", isMultiConcurrency: false, mock); + + var sendTask = ResponseStreamFactory.GetSendTask(isMultiConcurrency: false); + Assert.Null(sendTask); + } + + [Fact] + public void GetSendTask_NoContext_ReturnsNull() + { + Assert.Null(ResponseStreamFactory.GetSendTask(isMultiConcurrency: false)); + } + + // --- Internal methods --- + + [Fact] + public void InitializeInvocation_OnDemand_SetsUpContext() + { + var mock = new MockStreamingRuntimeApiClient(); + InitializeWithMock("req-4", isMultiConcurrency: false, mock); + + Assert.Null(ResponseStreamFactory.GetStreamIfCreated(isMultiConcurrency: false)); + + var stream = ResponseStreamFactory.CreateStream(Array.Empty()); + Assert.NotNull(stream); + } + + [Fact] + public void InitializeInvocation_MultiConcurrency_SetsUpContext() + { + var mock = new MockStreamingRuntimeApiClient(); + InitializeWithMock("req-5", isMultiConcurrency: true, mock); + + Assert.Null(ResponseStreamFactory.GetStreamIfCreated(isMultiConcurrency: true)); + + var stream = ResponseStreamFactory.CreateStream(Array.Empty()); + Assert.NotNull(stream); + } + + [Fact] + public void GetStreamIfCreated_AfterCreateStream_ReturnsStream() + { + var mock = new MockStreamingRuntimeApiClient(); + InitializeWithMock("req-6", isMultiConcurrency: false, mock); + ResponseStreamFactory.CreateStream(Array.Empty()); + + var retrieved = ResponseStreamFactory.GetStreamIfCreated(isMultiConcurrency: false); + Assert.NotNull(retrieved); + } + + [Fact] + public void GetStreamIfCreated_NoContext_ReturnsNull() + { + Assert.Null(ResponseStreamFactory.GetStreamIfCreated(isMultiConcurrency: false)); + } + + [Fact] + public void CleanupInvocation_ClearsState() + { + var mock = new MockStreamingRuntimeApiClient(); + InitializeWithMock("req-7", isMultiConcurrency: false, mock); + ResponseStreamFactory.CreateStream(Array.Empty()); + + ResponseStreamFactory.CleanupInvocation(isMultiConcurrency: false); + + Assert.Null(ResponseStreamFactory.GetStreamIfCreated(isMultiConcurrency: false)); + Assert.Throws(() => ResponseStreamFactory.CreateStream(Array.Empty())); + } + + // --- Property 16: State Isolation Between Invocations --- + + /// + /// Property 16: State Isolation Between Invocations - state from one invocation doesn't leak to the next. + /// Validates: Requirements 6.5, 8.9 + /// + [Fact] + public void StateIsolation_SequentialInvocations_NoLeakage() + { + var mock = new MockStreamingRuntimeApiClient(); + + // First invocation - streaming + InitializeWithMock("req-8a", isMultiConcurrency: false, mock); + var stream1 = ResponseStreamFactory.CreateStream(Array.Empty()); + Assert.NotNull(stream1); + ResponseStreamFactory.CleanupInvocation(isMultiConcurrency: false); + + // Second invocation - should start fresh + InitializeWithMock("req-8b", isMultiConcurrency: false, mock); + Assert.Null(ResponseStreamFactory.GetStreamIfCreated(isMultiConcurrency: false)); + + var stream2 = ResponseStreamFactory.CreateStream(Array.Empty()); + Assert.NotNull(stream2); + ResponseStreamFactory.CleanupInvocation(isMultiConcurrency: false); + } + + /// + /// Property 16: State Isolation - multi-concurrency mode uses AsyncLocal. + /// Validates: Requirements 2.9, 2.10 + /// + [Fact] + public async Task StateIsolation_MultiConcurrency_UsesAsyncLocal() + { + var mock = new MockStreamingRuntimeApiClient(); + InitializeWithMock("req-9", isMultiConcurrency: true, mock); + var stream = ResponseStreamFactory.CreateStream(Array.Empty()); + Assert.NotNull(stream); + + bool childSawNull = false; + await Task.Run(() => + { + ResponseStreamFactory.CleanupInvocation(isMultiConcurrency: true); + childSawNull = ResponseStreamFactory.GetStreamIfCreated(isMultiConcurrency: true) == null; + }); + + Assert.True(childSawNull); + } + } +} diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/ResponseStreamTests.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/ResponseStreamTests.cs new file mode 100644 index 000000000..cd2c00fd2 --- /dev/null +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/ResponseStreamTests.cs @@ -0,0 +1,447 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +using System; +using System.IO; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Amazon.Lambda.RuntimeSupport.Client.ResponseStreaming; +using Xunit; + +namespace Amazon.Lambda.RuntimeSupport.UnitTests +{ + public class ResponseStreamTests + { + /// + /// Helper: creates a ResponseStream and wires up a MemoryStream as the HTTP output stream. + /// Returns both so tests can inspect what was written. + /// + private static async Task<(ResponseStream stream, MemoryStream httpOutput)> CreateWiredStream() + { + var rs = new ResponseStream(Array.Empty()); + var output = new MemoryStream(); + await rs.SetHttpOutputStreamAsync(output); + return (rs, output); + } + + // ---- Basic state tests ---- + + [Fact] + public void Constructor_InitializesStateCorrectly() + { + var stream = new ResponseStream(Array.Empty()); + + Assert.Equal(0, stream.BytesWritten); + Assert.False(stream.HasError); + Assert.Null(stream.ReportedError); + } + + [Fact] + public async Task WriteAsync_WithOffset_WritesCorrectSlice() + { + var (stream, httpOutput) = await CreateWiredStream(); + var data = new byte[] { 0, 1, 2, 3, 0 }; + + await stream.WriteAsync(data, 1, 3); + + // Raw bytes {1,2,3} written directly — no chunked encoding + var expected = new byte[] { 1, 2, 3 }; + Assert.Equal(expected, httpOutput.ToArray()); + } + + [Fact] + public async Task WriteAsync_MultipleWrites_EachAppearsImmediately() + { + var (stream, httpOutput) = await CreateWiredStream(); + + var data = new byte[] { 0xAA }; + await stream.WriteAsync(data, 0, data.Length); + var afterFirst = httpOutput.ToArray().Length; + Assert.True(afterFirst > 0, "First chunk should be on the HTTP stream immediately after WriteAsync returns"); + + await stream.WriteAsync(new byte[] { 0xBB, 0xCC }, 0, 2); + var afterSecond = httpOutput.ToArray().Length; + Assert.True(afterSecond > afterFirst, "Second chunk should appear on the HTTP stream immediately"); + + Assert.Equal(3, stream.BytesWritten); + } + + [Fact] + public async Task WriteAsync_BlocksUntilSetHttpOutputStream() + { + var rs = new ResponseStream(Array.Empty()); + var httpOutput = new MemoryStream(); + var writeStarted = new ManualResetEventSlim(false); + var writeCompleted = new ManualResetEventSlim(false); + + // Start a write on a background thread — it should block + var writeTask = Task.Run(async () => + { + writeStarted.Set(); + await rs.WriteAsync(new byte[] { 1, 2, 3 }, 0, 3); + writeCompleted.Set(); + }); + + // Wait for the write to start, then verify it hasn't completed + writeStarted.Wait(TimeSpan.FromSeconds(2)); + await Task.Delay(100); // give it a moment + Assert.False(writeCompleted.IsSet, "WriteAsync should block until SetHttpOutputStream is called"); + + // Now provide the HTTP stream — the write should complete + await rs.SetHttpOutputStreamAsync(httpOutput); + await writeTask; + + Assert.True(writeCompleted.IsSet); + Assert.True(httpOutput.ToArray().Length > 0); + } + + [Fact] + public async Task MarkCompleted_ReleasesCompletionSignal() + { + var (stream, _) = await CreateWiredStream(); + + var waitTask = stream.WaitForCompletionAsync(); + Assert.False(waitTask.IsCompleted, "WaitForCompletionAsync should block before MarkCompleted"); + + stream.MarkCompleted(); + + // Should complete within a reasonable time + var completed = await Task.WhenAny(waitTask, Task.Delay(TimeSpan.FromSeconds(2))); + Assert.Same(waitTask, completed); + } + + [Fact] + public async Task ReportErrorAsync_ReleasesCompletionSignal() + { + var (stream, _) = await CreateWiredStream(); + + var waitTask = stream.WaitForCompletionAsync(); + Assert.False(waitTask.IsCompleted, "WaitForCompletionAsync should block before ReportErrorAsync"); + + stream.ReportError(new Exception("test error")); + + var completed = await Task.WhenAny(waitTask, Task.Delay(TimeSpan.FromSeconds(2))); + Assert.Same(waitTask, completed); + Assert.True(stream.HasError); + } + + [Fact] + public async Task WriteAsync_AfterMarkCompleted_StillSucceeds() + { + var (stream, output) = await CreateWiredStream(); + await stream.WriteAsync(new byte[] { 1 }, 0, 1); + stream.MarkCompleted(); + + // Writes after MarkCompleted are allowed — buffered ASP.NET Core responses + // (e.g. Results.Json) may flush pre-start buffer data after the pipeline + // completes and LambdaBootstrap calls MarkCompleted. + await stream.WriteAsync(new byte[] { 2 }, 0, 1); + + Assert.Equal(new byte[] { 1, 2 }, output.ToArray()); + } + + [Fact] + public async Task WriteAsync_AfterReportError_Throws() + { + var (stream, _) = await CreateWiredStream(); + await stream.WriteAsync(new byte[] { 1 }, 0, 1); + stream.ReportError(new Exception("test")); + + await Assert.ThrowsAsync( + () => stream.WriteAsync(new byte[] { 2 }, 0, 1)); + } + + [Fact] + public async Task ReportErrorAsync_SetsErrorState() + { + var stream = new ResponseStream(Array.Empty()); + var exception = new InvalidOperationException("something broke"); + + stream.ReportError(exception); + + Assert.True(stream.HasError); + Assert.Same(exception, stream.ReportedError); + } + + [Fact] + public async Task ReportErrorAsync_AfterCompleted_Throws() + { + var stream = new ResponseStream(Array.Empty()); + stream.MarkCompleted(); + + Assert.Throws( + () => stream.ReportError(new Exception("test"))); + } + + [Fact] + public async Task ReportErrorAsync_CalledTwice_Throws() + { + var stream = new ResponseStream(Array.Empty()); + stream.ReportError(new Exception("first")); + + Assert.Throws( + () => stream.ReportError(new Exception("second"))); + } + + [Fact] + public async Task WriteAsync_NullBuffer_ThrowsArgumentNull() + { + var (stream, _) = await CreateWiredStream(); + + await Assert.ThrowsAsync(() => stream.WriteAsync((byte[])null, 0, 0)); + } + + [Fact] + public async Task WriteAsync_NullBufferWithOffset_ThrowsArgumentNull() + { + var (stream, _) = await CreateWiredStream(); + + await Assert.ThrowsAsync(() => stream.WriteAsync(null, 0, 0)); + } + + [Fact] + public async Task ReportErrorAsync_NullException_ThrowsArgumentNull() + { + var stream = new ResponseStream(Array.Empty()); + + Assert.Throws(() => stream.ReportError(null)); + } + + [Fact] + public async Task Dispose_CalledTwice_DoesNotThrow() + { + var stream = new ResponseStream(Array.Empty()); + stream.Dispose(); + // Second dispose should be a no-op + stream.Dispose(); + } + + // ---- Prelude tests ---- + + [Fact] + public async Task SetHttpOutputStreamAsync_WithPrelude_WritesPreludeBeforeHandlerData() + { + var prelude = new byte[] { 0x01, 0x02, 0x03 }; + var rs = new ResponseStream(prelude); + var output = new MemoryStream(); + + await rs.SetHttpOutputStreamAsync(output); + + // Prelude bytes + 8-byte null delimiter should be written before any handler data + var written = output.ToArray(); + Assert.True(written.Length >= prelude.Length + 8, "Prelude + delimiter should be written"); + Assert.Equal(prelude, written[..prelude.Length]); + Assert.Equal(new byte[8], written[prelude.Length..(prelude.Length + 8)]); + } + + [Fact] + public async Task SetHttpOutputStreamAsync_WithEmptyPrelude_WritesNoPreludeBytes() + { + var rs = new ResponseStream(Array.Empty()); + var output = new MemoryStream(); + + await rs.SetHttpOutputStreamAsync(output); + + // Empty prelude — nothing written yet (handler hasn't written anything) + Assert.Empty(output.ToArray()); + } + + [Fact] + public async Task SetHttpOutputStreamAsync_WithPrelude_HandlerDataAppendsAfterDelimiter() + { + var prelude = new byte[] { 0xAA, 0xBB }; + var rs = new ResponseStream(prelude); + var output = new MemoryStream(); + + await rs.SetHttpOutputStreamAsync(output); + await rs.WriteAsync(new byte[] { 0xFF }, 0, 1); + + var written = output.ToArray(); + // Layout: [prelude][8 null bytes][handler data] + int expectedMinLength = prelude.Length + 8 + 1; + Assert.Equal(expectedMinLength, written.Length); + Assert.Equal(new byte[] { 0xFF }, written[^1..]); + } + + [Fact] + public async Task SetHttpOutputStreamAsync_NullPrelude_WritesNoPreludeBytes() + { + var rs = new ResponseStream(null); + var output = new MemoryStream(); + + await rs.SetHttpOutputStreamAsync(output); + + Assert.Empty(output.ToArray()); + } + + // ---- Prelude + delimiter single-chunk tests (via ChunkedStreamWriter) ---- + + [Fact] + public async Task SetHttpOutputStreamAsync_WithPrelude_ViaChunkedWriter_ProducesSingleChunk() + { + var preludeJson = Encoding.UTF8.GetBytes("{\"statusCode\":200}"); + var rs = new ResponseStream(preludeJson); + var rawOutput = new MemoryStream(); + var chunkedWriter = new ChunkedStreamWriter(rawOutput); + + await rs.SetHttpOutputStreamAsync(chunkedWriter); + + var wireBytes = Encoding.ASCII.GetString(rawOutput.ToArray()); + + // The prelude (18 bytes) + delimiter (8 bytes) = 26 bytes = 0x1A + // Should be exactly one chunk: "1A\r\n{prelude}{8 null bytes}\r\n" + var expectedDataLength = preludeJson.Length + 8; // 26 + var expectedHex = expectedDataLength.ToString("X"); + Assert.StartsWith($"{expectedHex}\r\n", wireBytes); + + // Verify there is only one chunk header (only one hex size prefix) + var chunkCount = 0; + var remaining = wireBytes; + while (remaining.Length > 0) + { + var crlfIndex = remaining.IndexOf("\r\n", StringComparison.Ordinal); + if (crlfIndex < 0) break; + var sizeStr = remaining.Substring(0, crlfIndex); + if (int.TryParse(sizeStr, System.Globalization.NumberStyles.HexNumber, null, out var chunkSize) && chunkSize >= 0) + { + chunkCount++; + // Skip past: hex\r\n{data}\r\n + remaining = remaining.Substring(crlfIndex + 2 + chunkSize + 2); + } + else + { + break; + } + } + Assert.Equal(1, chunkCount); + } + + [Fact] + public async Task SetHttpOutputStreamAsync_WithPrelude_ViaChunkedWriter_DelimiterImmediatelyFollowsPrelude() + { + var preludeJson = Encoding.UTF8.GetBytes("{\"statusCode\":201}"); + var rs = new ResponseStream(preludeJson); + var rawOutput = new MemoryStream(); + var chunkedWriter = new ChunkedStreamWriter(rawOutput); + + await rs.SetHttpOutputStreamAsync(chunkedWriter); + + // Parse the chunk to get the raw data payload + var wireBytes = rawOutput.ToArray(); + var wireStr = Encoding.ASCII.GetString(wireBytes); + var firstCrlf = wireStr.IndexOf("\r\n", StringComparison.Ordinal); + var dataStart = firstCrlf + 2; + var dataLength = preludeJson.Length + 8; + var chunkData = new byte[dataLength]; + Array.Copy(wireBytes, dataStart, chunkData, 0, dataLength); + + // First part should be the prelude JSON + Assert.Equal(preludeJson, chunkData[..preludeJson.Length]); + // Immediately followed by 8 null bytes (delimiter) + Assert.Equal(new byte[8], chunkData[preludeJson.Length..]); + } + + [Fact] + public async Task SetHttpOutputStreamAsync_WithPrelude_ViaChunkedWriter_HandlerDataInSeparateChunk() + { + var preludeJson = Encoding.UTF8.GetBytes("{\"statusCode\":200}"); + var rs = new ResponseStream(preludeJson); + var rawOutput = new MemoryStream(); + var chunkedWriter = new ChunkedStreamWriter(rawOutput); + + await rs.SetHttpOutputStreamAsync(chunkedWriter); + await rs.WriteAsync(Encoding.UTF8.GetBytes("body data"), 0, 9); + + var wireStr = Encoding.ASCII.GetString(rawOutput.ToArray()); + + // Should have exactly 2 chunks: one for prelude+delimiter, one for body + var chunkCount = 0; + var remaining = wireStr; + while (remaining.Length > 0) + { + var crlfIndex = remaining.IndexOf("\r\n", StringComparison.Ordinal); + if (crlfIndex < 0) break; + var sizeStr = remaining.Substring(0, crlfIndex); + if (int.TryParse(sizeStr, System.Globalization.NumberStyles.HexNumber, null, out var chunkSize) && chunkSize >= 0) + { + chunkCount++; + remaining = remaining.Substring(crlfIndex + 2 + chunkSize + 2); + } + else + { + break; + } + } + Assert.Equal(2, chunkCount); + } + + // ---- MarkCompleted idempotency ---- + + [Fact] + public async Task MarkCompleted_CalledTwice_DoesNotThrowOrDoubleRelease() + { + var (stream, _) = await CreateWiredStream(); + + stream.MarkCompleted(); + // Second call should be a no-op — semaphore should not be double-released + stream.MarkCompleted(); + + // WaitForCompletionAsync should complete exactly once without hanging + var waitTask = stream.WaitForCompletionAsync(); + var completed = await Task.WhenAny(waitTask, Task.Delay(TimeSpan.FromSeconds(2))); + Assert.Same(waitTask, completed); + } + + [Fact] + public async Task ReportError_ThenMarkCompleted_MarkCompletedIsNoOp() + { + var stream = new ResponseStream(Array.Empty()); + stream.ReportError(new Exception("error")); + + // MarkCompleted after ReportError should not throw and not double-release + stream.MarkCompleted(); + + // WaitForCompletionAsync should complete (released by ReportError) + var waitTask = stream.WaitForCompletionAsync(); + var completed = await Task.WhenAny(waitTask, Task.Delay(TimeSpan.FromSeconds(2))); + Assert.Same(waitTask, completed); + } + + // ---- BytesWritten tracking ---- + + [Fact] + public async Task BytesWritten_TracksAcrossMultipleWrites() + { + var (stream, _) = await CreateWiredStream(); + + await stream.WriteAsync(new byte[10], 0, 10); + await stream.WriteAsync(new byte[5], 0, 5); + + Assert.Equal(15, stream.BytesWritten); + } + + [Fact] + public async Task BytesWritten_ReflectsOffsetAndCount() + { + var (stream, _) = await CreateWiredStream(); + + await stream.WriteAsync(new byte[10], 2, 6); // only 6 bytes + + Assert.Equal(6, stream.BytesWritten); + } + } +} diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/RuntimeApiClientTests.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/RuntimeApiClientTests.cs new file mode 100644 index 000000000..71102ddf1 --- /dev/null +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/RuntimeApiClientTests.cs @@ -0,0 +1,211 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +using System; +using System.IO; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using Amazon.Lambda.RuntimeSupport.Client.ResponseStreaming; +using Xunit; + +namespace Amazon.Lambda.RuntimeSupport.UnitTests +{ + /// + /// Tests for RuntimeApiClient streaming and buffered behavior. + /// Validates Properties 7, 8, 10, 13, 18. + /// + public class RuntimeApiClientTests + { + private const long MaxResponseSize = 20 * 1024 * 1024; + + /// + /// Mock HttpMessageHandler that captures the request for header inspection. + /// It completes the ResponseStream and returns immediately without reading + /// the content body, avoiding the SerializeToStreamAsync blocking issue. + /// + private class MockHttpMessageHandler : HttpMessageHandler + { + public HttpRequestMessage CapturedRequest { get; private set; } + private readonly ResponseStream _responseStream; + + public MockHttpMessageHandler(ResponseStream responseStream) + { + _responseStream = responseStream; + } + + protected override Task SendAsync( + HttpRequestMessage request, CancellationToken cancellationToken) + { + CapturedRequest = request; + + return Task.FromResult(new HttpResponseMessage(HttpStatusCode.OK)); + } + } + + private static RuntimeApiClient CreateClientWithMockHandler( + ResponseStream stream, out MockHttpMessageHandler handler) + { + handler = new MockHttpMessageHandler(stream); + var httpClient = new HttpClient(handler); + var envVars = new TestEnvironmentVariables(); + envVars.SetEnvironmentVariable("AWS_LAMBDA_RUNTIME_API", "localhost:9001"); + return new RuntimeApiClient(envVars, httpClient); + } + + // --- Property 7: Streaming Response Mode Header --- + // Note: Properties 7, 8, 13 test the HttpClient-based streaming path which is only used on pre-NET8 targets. + // On NET8+, StartStreamingResponseAsync uses RawStreamingHttpClient (raw TCP) which doesn't go through HttpClient. + +#if !NET8_0_OR_GREATER + /// + /// Property 7: Streaming Response Mode Header + /// For any streaming response, the HTTP request should include + /// "Lambda-Runtime-Function-Response-Mode: streaming". + /// **Validates: Requirements 4.1** + /// + [Fact] + public async Task StartStreamingResponseAsync_IncludesStreamingResponseModeHeader() + { + var stream = new ResponseStream(Array.Empty()); + var client = CreateClientWithMockHandler(stream, out var handler); + + await client.StartStreamingResponseAsync("req-1", stream, CancellationToken.None); + + Assert.NotNull(handler.CapturedRequest); + Assert.True(handler.CapturedRequest.Headers.Contains(StreamingConstants.ResponseModeHeader)); + var values = handler.CapturedRequest.Headers.GetValues(StreamingConstants.ResponseModeHeader).ToList(); + Assert.Single(values); + Assert.Equal(StreamingConstants.StreamingResponseMode, values[0]); + } + + // --- Property 8: Chunked Transfer Encoding Header --- + + /// + /// Property 8: Chunked Transfer Encoding Header + /// For any streaming response, the HTTP request should include + /// "Transfer-Encoding: chunked". + /// **Validates: Requirements 4.2** + /// + [Fact] + public async Task StartStreamingResponseAsync_IncludesChunkedTransferEncodingHeader() + { + var stream = new ResponseStream(Array.Empty()); + var client = CreateClientWithMockHandler(stream, out var handler); + + await client.StartStreamingResponseAsync("req-2", stream, CancellationToken.None); + + Assert.NotNull(handler.CapturedRequest); + Assert.True(handler.CapturedRequest.Headers.TransferEncodingChunked); + } + + // --- Property 13: Trailer Declaration Header --- + + /// + /// Property 13: Trailer Declaration Header + /// For any streaming response, the HTTP request should include a "Trailer" header + /// declaring the error trailer headers upfront (since we cannot know at request + /// start whether an error will occur). + /// **Validates: Requirements 5.4** + /// + [Fact] + public async Task StartStreamingResponseAsync_DeclaresTrailerHeaderUpfront() + { + var stream = new ResponseStream(Array.Empty()); + var client = CreateClientWithMockHandler(stream, out var handler); + + await client.StartStreamingResponseAsync("req-3", stream, CancellationToken.None); + + Assert.NotNull(handler.CapturedRequest); + Assert.True(handler.CapturedRequest.Headers.Contains("Trailer")); + var trailerValue = string.Join(", ", handler.CapturedRequest.Headers.GetValues("Trailer")); + Assert.Contains(StreamingConstants.ErrorTypeTrailer, trailerValue); + Assert.Contains(StreamingConstants.ErrorBodyTrailer, trailerValue); + } +#endif + + // --- Property 10: Buffered Responses Exclude Streaming Headers --- + + /// + /// Mock HttpMessageHandler that captures the request for buffered response header inspection. + /// Returns an Accepted (202) response since that's what the InternalRuntimeApiClient expects. + /// + private class BufferedMockHttpMessageHandler : HttpMessageHandler + { + public HttpRequestMessage CapturedRequest { get; private set; } + + protected override Task SendAsync( + HttpRequestMessage request, CancellationToken cancellationToken) + { + CapturedRequest = request; + return Task.FromResult(new HttpResponseMessage(HttpStatusCode.Accepted)); + } + } + + /// + /// Property 10: Buffered Responses Exclude Streaming Headers + /// For any buffered response (where CreateStream was not called), the HTTP request + /// should not include "Lambda-Runtime-Function-Response-Mode" or + /// "Transfer-Encoding: chunked" or "Trailer" headers. + /// **Validates: Requirements 4.6** + /// + [Fact] + public async Task SendResponseAsync_BufferedResponse_ExcludesStreamingHeaders() + { + var bufferedHandler = new BufferedMockHttpMessageHandler(); + var httpClient = new HttpClient(bufferedHandler); + var envVars = new TestEnvironmentVariables(); + envVars.SetEnvironmentVariable("AWS_LAMBDA_RUNTIME_API", "localhost:9001"); + var client = new RuntimeApiClient(envVars, httpClient); + + var outputStream = new MemoryStream(new byte[] { 1, 2, 3 }); + await client.SendResponseAsync("req-buffered", outputStream, CancellationToken.None); + + Assert.NotNull(bufferedHandler.CapturedRequest); + // Buffered responses must not include streaming-specific headers + Assert.False(bufferedHandler.CapturedRequest.Headers.Contains(StreamingConstants.ResponseModeHeader), + "Buffered response should not include Lambda-Runtime-Function-Response-Mode header"); + Assert.NotEqual(true, bufferedHandler.CapturedRequest.Headers.TransferEncodingChunked); + Assert.False(bufferedHandler.CapturedRequest.Headers.Contains("Trailer"), + "Buffered response should not include Trailer header"); + } + + // --- Argument validation --- + +#if NET8_0_OR_GREATER + [Fact] + public async Task StartStreamingResponseAsync_NullRequestId_ThrowsArgumentNullException() + { + var stream = new ResponseStream(Array.Empty()); + var client = CreateClientWithMockHandler(stream, out _); + + await Assert.ThrowsAsync( + () => client.StartStreamingResponseAsync(null, stream, CancellationToken.None)); + } + + [Fact] + public async Task StartStreamingResponseAsync_NullResponseStream_ThrowsArgumentNullException() + { + var stream = new ResponseStream(Array.Empty()); + var client = CreateClientWithMockHandler(stream, out _); + + await Assert.ThrowsAsync( + () => client.StartStreamingResponseAsync("req-5", null, CancellationToken.None)); + } +#endif + } +} diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/StreamingE2EWithMoq.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/StreamingE2EWithMoq.cs new file mode 100644 index 000000000..f46c76f13 --- /dev/null +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/StreamingE2EWithMoq.cs @@ -0,0 +1,545 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.IO; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Amazon.Lambda.RuntimeSupport.Client.ResponseStreaming; +using Amazon.Lambda.RuntimeSupport.UnitTests.TestHelpers; +using Xunit; + +namespace Amazon.Lambda.RuntimeSupport.UnitTests +{ + [CollectionDefinition("ResponseStreamFactory")] + public class ResponseStreamFactoryCollection { } + + /// + /// End-to-end integration tests for the true-streaming architecture. + /// These tests exercise the full pipeline: LambdaBootstrap → ResponseStreamFactory → + /// ResponseStream → captured HTTP output stream. + /// + [Collection("ResponseStreamFactory")] + public class StreamingE2EWithMoq : IDisposable + { + public void Dispose() + { + ResponseStreamFactory.CleanupInvocation(isMultiConcurrency: false); + ResponseStreamFactory.CleanupInvocation(isMultiConcurrency: true); + } + + // ─── Helpers ──────────────────────────────────────────────────────────────── + + private static Dictionary> MakeHeaders(string requestId = "test-request-id") + => new Dictionary> + { + { RuntimeApiHeaders.HeaderAwsRequestId, new List { requestId } }, + { RuntimeApiHeaders.HeaderInvokedFunctionArn, new List { "arn:aws:lambda:us-east-1:123456789012:function:test" } }, + { RuntimeApiHeaders.HeaderAwsTenantId, new List { "tenant-id" } }, + { RuntimeApiHeaders.HeaderTraceId, new List { "trace-id" } }, + { RuntimeApiHeaders.HeaderDeadlineMs, new List { "9999999999999" } }, + }; + + /// + /// A capturing RuntimeApiClient that records the raw bytes written to the HTTP output stream + /// by SerializeToStreamAsync. + /// + private class CapturingStreamingRuntimeApiClient : RuntimeApiClient, IRuntimeApiClient + { + private readonly IEnvironmentVariables _envVars; + private readonly Dictionary> _headers; + + public bool StartStreamingCalled { get; private set; } + public bool SendResponseCalled { get; private set; } + public bool ReportInvocationErrorCalled { get; private set; } + public byte[] CapturedHttpBytes { get; private set; } + public ResponseStream LastResponseStream { get; private set; } + public Stream LastBufferedOutputStream { get; private set; } + + public new Amazon.Lambda.RuntimeSupport.Helpers.IConsoleLoggerWriter ConsoleLogger { get; } = new Helpers.LogLevelLoggerWriter(new SystemEnvironmentVariables()); + + public CapturingStreamingRuntimeApiClient( + IEnvironmentVariables envVars, + Dictionary> headers) + : base(envVars, new NoOpInternalRuntimeApiClient()) + { + _envVars = envVars; + _headers = headers; + } + + public new async Task GetNextInvocationAsync(CancellationToken cancellationToken = default) + { + _headers[RuntimeApiHeaders.HeaderTraceId] = new List { Guid.NewGuid().ToString() }; + var inputStream = new MemoryStream(new byte[0]); + return new InvocationRequest + { + InputStream = inputStream, + LambdaContext = new LambdaContext( + new RuntimeApiHeaders(_headers), + new LambdaEnvironment(_envVars), + new TestDateTimeHelper(), + new Helpers.SimpleLoggerWriter(_envVars)) + }; + } + + internal override async Task StartStreamingResponseAsync( + string awsRequestId, ResponseStream responseStream, CancellationToken cancellationToken = default) + { + StartStreamingCalled = true; + LastResponseStream = responseStream; + + // Use a real MemoryStream as the HTTP output stream so we capture actual bytes + var captureStream = new MemoryStream(); + await responseStream.SetHttpOutputStreamAsync(captureStream, cancellationToken); + + // Wait for the handler to finish writing (mirrors real RawStreamingHttpClient behavior) + await responseStream.WaitForCompletionAsync(cancellationToken); + CapturedHttpBytes = captureStream.ToArray(); + return new NoOpDisposable(); + } + + public new async Task SendResponseAsync(string awsRequestId, Stream outputStream, CancellationToken cancellationToken = default) + { + SendResponseCalled = true; + if (outputStream != null) + { + var ms = new MemoryStream(); + await outputStream.CopyToAsync(ms); + ms.Position = 0; + LastBufferedOutputStream = ms; + } + } + + public new Task ReportInvocationErrorAsync(string awsRequestId, Exception exception, CancellationToken cancellationToken = default) + { + ReportInvocationErrorCalled = true; + return Task.CompletedTask; + } + + public new Task ReportInitializationErrorAsync(Exception exception, string errorType = null, CancellationToken cancellationToken = default) + => Task.CompletedTask; + + public new Task ReportInitializationErrorAsync(string errorType, CancellationToken cancellationToken = default) + => Task.CompletedTask; + +#if NET8_0_OR_GREATER + public new Task RestoreNextInvocationAsync(CancellationToken cancellationToken = default) => Task.CompletedTask; + public new Task ReportRestoreErrorAsync(Exception exception, string errorType = null, CancellationToken cancellationToken = default) => Task.CompletedTask; +#endif + } + + private static CapturingStreamingRuntimeApiClient CreateClient(string requestId = "test-request-id") + => new CapturingStreamingRuntimeApiClient(new TestEnvironmentVariables(), MakeHeaders(requestId)); + + /// + /// End-to-end: all data is transmitted correctly (content round-trip). + /// Requirements: 3.2, 4.3, 10.1 + /// + [Fact] + public async Task Streaming_AllDataTransmitted_ContentRoundTrip() + { + var client = CreateClient(); + var payload = Encoding.UTF8.GetBytes("integration test payload"); + + LambdaBootstrapHandler handler = async (invocation) => + { + var stream = ResponseStreamFactory.CreateStream(Array.Empty()); + await stream.WriteAsync(payload); + return new InvocationResponse(Stream.Null, false); + }; + + using var bootstrap = new LambdaBootstrap(handler, null); + bootstrap.Client = client; + await bootstrap.InvokeOnceAsync(); + + var output = client.CapturedHttpBytes; + Assert.NotNull(output); + + var outputStr = Encoding.UTF8.GetString(output); + Assert.Contains("integration test payload", outputStr); + } + + /// + /// End-to-end: stream is finalized (final chunk written, BytesWritten matches). + /// Requirements: 3.2, 4.3, 10.1 + /// + [Fact] + public async Task Streaming_StreamFinalized_BytesWrittenMatchesPayload() + { + var client = CreateClient(); + var data = Encoding.UTF8.GetBytes("finalization check"); + + LambdaBootstrapHandler handler = async (invocation) => + { + var stream = ResponseStreamFactory.CreateStream(Array.Empty()); + await stream.WriteAsync(data); + return new InvocationResponse(Stream.Null, false); + }; + + using var bootstrap = new LambdaBootstrap(handler, null); + bootstrap.Client = client; + await bootstrap.InvokeOnceAsync(); + + Assert.NotNull(client.LastResponseStream); + Assert.Equal(data.Length, client.LastResponseStream.BytesWritten); + } + + // ─── 10.2 End-to-end buffered response ────────────────────────────────────── + + /// + /// End-to-end: handler does NOT call CreateStream — response goes via buffered path. + /// Verifies SendResponseAsync is called and streaming headers are absent. + /// Requirements: 1.5, 4.6, 9.4 + /// + [Fact] + public async Task Buffered_HandlerDoesNotCallCreateStream_UsesSendResponsePath() + { + var client = CreateClient(); + var responseBody = Encoding.UTF8.GetBytes("buffered response body"); + + LambdaBootstrapHandler handler = async (invocation) => + { + await Task.Yield(); + return new InvocationResponse(new MemoryStream(responseBody)); + }; + + using var bootstrap = new LambdaBootstrap(handler, null); + bootstrap.Client = client; + await bootstrap.InvokeOnceAsync(); + + Assert.False(client.StartStreamingCalled, "StartStreamingResponseAsync should NOT be called for buffered mode"); + Assert.True(client.SendResponseCalled, "SendResponseAsync should be called for buffered mode"); + Assert.Null(client.CapturedHttpBytes); + } + + /// + /// End-to-end: buffered response body is transmitted correctly. + /// Requirements: 1.5, 4.6, 9.4 + /// + [Fact] + public async Task Buffered_ResponseBodyTransmittedCorrectly() + { + var client = CreateClient(); + var responseBody = Encoding.UTF8.GetBytes("hello buffered world"); + + LambdaBootstrapHandler handler = async (invocation) => + { + await Task.Yield(); + return new InvocationResponse(new MemoryStream(responseBody)); + }; + + using var bootstrap = new LambdaBootstrap(handler, null); + bootstrap.Client = client; + await bootstrap.InvokeOnceAsync(); + + Assert.True(client.SendResponseCalled); + Assert.NotNull(client.LastBufferedOutputStream); + var received = new MemoryStream(); + await client.LastBufferedOutputStream.CopyToAsync(received); + Assert.Equal(responseBody, received.ToArray()); + } + + /// + /// End-to-end: midstream error sets error state on ResponseStream with exception details. + /// In production, RawStreamingHttpClient reads this state and writes trailing headers. + /// Requirements: 5.2, 5.3 + /// + [Fact] + public async Task MidstreamError_SetsErrorStateWithExceptionDetails() + { + var client = CreateClient(); + const string errorMessage = "something went wrong mid-stream"; + + LambdaBootstrapHandler handler = async (invocation) => + { + var stream = ResponseStreamFactory.CreateStream(Array.Empty()); + await stream.WriteAsync(Encoding.UTF8.GetBytes("some data")); + throw new InvalidOperationException(errorMessage); + }; + + using var bootstrap = new LambdaBootstrap(handler, null); + bootstrap.Client = client; + await bootstrap.InvokeOnceAsync(); + + Assert.True(client.StartStreamingCalled); + Assert.NotNull(client.LastResponseStream); + Assert.True(client.LastResponseStream.HasError); + Assert.NotNull(client.LastResponseStream.ReportedError); + Assert.IsType(client.LastResponseStream.ReportedError); + Assert.Equal(errorMessage, client.LastResponseStream.ReportedError.Message); + + // Verify the handler's data was still captured before the error + var output = Encoding.UTF8.GetString(client.CapturedHttpBytes); + Assert.Contains("some data", output); + } + + // ─── 10.4 Multi-concurrency ────────────────────────────────────────────────── + + /// + /// Multi-concurrency: concurrent invocations use AsyncLocal for state isolation. + /// Each invocation independently uses streaming or buffered mode without interference. + /// Requirements: 2.9, 6.5, 8.9 + /// + [Fact] + public async Task MultiConcurrency_ConcurrentInvocations_StateIsolated() + { + const int concurrency = 3; + var results = new ConcurrentDictionary(); + var barrier = new SemaphoreSlim(0, concurrency); + var allStarted = new SemaphoreSlim(0, concurrency); + + // Simulate concurrent invocations using AsyncLocal directly + var tasks = new List(); + for (int i = 0; i < concurrency; i++) + { + var requestId = $"req-{i}"; + var payload = $"payload-{i}"; + tasks.Add(Task.Run(async () => + { + var mockClient = new MockMultiConcurrencyStreamingClient(); + ResponseStreamFactory.InitializeInvocation( + requestId, + isMultiConcurrency: true, + mockClient, + CancellationToken.None); + + var stream = ResponseStreamFactory.CreateStream(Array.Empty()); + allStarted.Release(); + + // Wait until all tasks have started (to ensure true concurrency) + await barrier.WaitAsync(); + + await stream.WriteAsync(Encoding.UTF8.GetBytes(payload)); + stream.MarkCompleted(); + + // Verify this invocation's stream is still accessible + var retrieved = ResponseStreamFactory.GetStreamIfCreated(isMultiConcurrency: true); + results[requestId] = retrieved != null ? payload : "MISSING"; + + ResponseStreamFactory.CleanupInvocation(isMultiConcurrency: true); + })); + } + + // Wait for all tasks to start, then release the barrier + for (int i = 0; i < concurrency; i++) + await allStarted.WaitAsync(); + barrier.Release(concurrency); + + await Task.WhenAll(tasks); + + // Each invocation should have seen its own stream + Assert.Equal(concurrency, results.Count); + for (int i = 0; i < concurrency; i++) + Assert.Equal($"payload-{i}", results[$"req-{i}"]); + } + + /// + /// Multi-concurrency: streaming and buffered invocations can run concurrently without interference. + /// Requirements: 2.9, 6.5, 8.9 + /// + [Fact] + public async Task MultiConcurrency_StreamingAndBufferedMixedConcurrently_NoInterference() + { + var streamingResults = new ConcurrentBag(); + var bufferedResults = new ConcurrentBag(); + var barrier = new SemaphoreSlim(0, 4); + var allStarted = new SemaphoreSlim(0, 4); + + var tasks = new List(); + + // 2 streaming invocations + for (int i = 0; i < 2; i++) + { + var requestId = $"stream-{i}"; + tasks.Add(Task.Run(async () => + { + var mockClient = new MockMultiConcurrencyStreamingClient(); + ResponseStreamFactory.InitializeInvocation( + requestId, + isMultiConcurrency: true, mockClient, CancellationToken.None); + + var stream = ResponseStreamFactory.CreateStream(Array.Empty()); + allStarted.Release(); + await barrier.WaitAsync(); + + await stream.WriteAsync(Encoding.UTF8.GetBytes("streaming data")); + stream.MarkCompleted(); + + var retrieved = ResponseStreamFactory.GetStreamIfCreated(isMultiConcurrency: true); + streamingResults.Add(retrieved != null); + ResponseStreamFactory.CleanupInvocation(isMultiConcurrency: true); + })); + } + + // 2 buffered invocations (no CreateStream) + for (int i = 0; i < 2; i++) + { + var requestId = $"buffered-{i}"; + tasks.Add(Task.Run(async () => + { + var mockClient = new MockMultiConcurrencyStreamingClient(); + ResponseStreamFactory.InitializeInvocation( + requestId, + isMultiConcurrency: true, mockClient, CancellationToken.None); + + allStarted.Release(); + await barrier.WaitAsync(); + + // No CreateStream — buffered mode + var retrieved = ResponseStreamFactory.GetStreamIfCreated(isMultiConcurrency: true); + bufferedResults.Add(retrieved == null); // should be null (no stream created) + ResponseStreamFactory.CleanupInvocation(isMultiConcurrency: true); + })); + } + + for (int i = 0; i < 4; i++) + await allStarted.WaitAsync(); + barrier.Release(4); + + await Task.WhenAll(tasks); + + Assert.Equal(2, streamingResults.Count); + Assert.All(streamingResults, r => Assert.True(r, "Streaming invocation should have a stream")); + + Assert.Equal(2, bufferedResults.Count); + Assert.All(bufferedResults, r => Assert.True(r, "Buffered invocation should have no stream")); + } + + /// + /// Minimal mock RuntimeApiClient for multi-concurrency tests. + /// Accepts StartStreamingResponseAsync calls without real HTTP. + /// + private class MockMultiConcurrencyStreamingClient : RuntimeApiClient + { + public MockMultiConcurrencyStreamingClient() + : base(new TestEnvironmentVariables(), new NoOpInternalRuntimeApiClient()) { } + + internal override async Task StartStreamingResponseAsync( + string awsRequestId, ResponseStream responseStream, CancellationToken cancellationToken = default) + { + // Provide the HTTP output stream so writes don't block + await responseStream.SetHttpOutputStreamAsync(new MemoryStream()); + await responseStream.WaitForCompletionAsync(); + return new NoOpDisposable(); + } + } + + // ─── 10.5 Backward compatibility ──────────────────────────────────────────── + + /// + /// Backward compatibility: existing handler signatures (event + ILambdaContext) work without modification. + /// Requirements: 9.1, 9.2, 9.3 + /// + [Fact] + public async Task BackwardCompat_ExistingHandlerSignature_WorksUnchanged() + { + var client = CreateClient(); + bool handlerCalled = false; + + // Simulate a classic handler that returns a buffered response + LambdaBootstrapHandler handler = async (invocation) => + { + handlerCalled = true; + await Task.Yield(); + return new InvocationResponse(new MemoryStream(Encoding.UTF8.GetBytes("classic response"))); + }; + + using var bootstrap = new LambdaBootstrap(handler, null); + bootstrap.Client = client; + await bootstrap.InvokeOnceAsync(); + + Assert.True(handlerCalled); + Assert.True(client.SendResponseCalled); + Assert.False(client.StartStreamingCalled); + } + + /// + /// Backward compatibility: no regression in buffered response behavior — response body is correct. + /// Requirements: 9.4, 9.5 + /// + [Fact] + public async Task BackwardCompat_BufferedResponse_NoRegression() + { + var client = CreateClient(); + var expected = Encoding.UTF8.GetBytes("no regression here"); + + LambdaBootstrapHandler handler = async (invocation) => + { + await Task.Yield(); + return new InvocationResponse(new MemoryStream(expected)); + }; + + using var bootstrap = new LambdaBootstrap(handler, null); + bootstrap.Client = client; + await bootstrap.InvokeOnceAsync(); + + Assert.True(client.SendResponseCalled); + Assert.NotNull(client.LastBufferedOutputStream); + var received = new MemoryStream(); + await client.LastBufferedOutputStream.CopyToAsync(received); + Assert.Equal(expected, received.ToArray()); + } + + /// + /// Backward compatibility: handler that returns null OutputStream still works. + /// Requirements: 9.4 + /// + [Fact] + public async Task BackwardCompat_NullOutputStream_HandledGracefully() + { + var client = CreateClient(); + + LambdaBootstrapHandler handler = async (invocation) => + { + await Task.Yield(); + return new InvocationResponse(Stream.Null, false); + }; + + using var bootstrap = new LambdaBootstrap(handler, null); + bootstrap.Client = client; + + // Should not throw + await bootstrap.InvokeOnceAsync(); + + Assert.True(client.SendResponseCalled); + } + + /// + /// Backward compatibility: handler that throws before CreateStream uses standard error path. + /// Requirements: 9.5 + /// + [Fact] + public async Task BackwardCompat_HandlerThrows_StandardErrorReportingUsed() + { + var client = CreateClient(); + + LambdaBootstrapHandler handler = async (invocation) => + { + await Task.Yield(); + throw new Exception("classic handler error"); + }; + + using var bootstrap = new LambdaBootstrap(handler, null); + bootstrap.Client = client; + await bootstrap.InvokeOnceAsync(); + + Assert.True(client.ReportInvocationErrorCalled); + Assert.False(client.StartStreamingCalled); + } + } +} diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/TestHelpers/NoOpInternalRuntimeApiClient.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/TestHelpers/NoOpInternalRuntimeApiClient.cs new file mode 100644 index 000000000..9fa0434cd --- /dev/null +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/TestHelpers/NoOpInternalRuntimeApiClient.cs @@ -0,0 +1,60 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Amazon.Lambda.RuntimeSupport.UnitTests.TestHelpers +{ + /// + /// A no-op implementation of IInternalRuntimeApiClient for unit tests + /// that need to construct a RuntimeApiClient without real HTTP calls. + /// + internal class NoOpInternalRuntimeApiClient : IInternalRuntimeApiClient + { + private static readonly SwaggerResponse EmptyStatusResponse = + new SwaggerResponse(200, new Dictionary>(), new StatusResponse()); + + public Task> ErrorAsync( + string lambda_Runtime_Function_Error_Type, string errorJson, CancellationToken cancellationToken) + => Task.FromResult(EmptyStatusResponse); + + public Task> NextAsync(CancellationToken cancellationToken) + => Task.FromResult(new SwaggerResponse(200, new Dictionary>(), Stream.Null)); + + public Task> ResponseAsync(string awsRequestId, Stream outputStream) + => Task.FromResult(EmptyStatusResponse); + + public Task> ResponseAsync( + string awsRequestId, Stream outputStream, CancellationToken cancellationToken) + => Task.FromResult(EmptyStatusResponse); + + public Task> ErrorWithXRayCauseAsync( + string awsRequestId, string lambda_Runtime_Function_Error_Type, + string errorJson, string xrayCause, CancellationToken cancellationToken) + => Task.FromResult(EmptyStatusResponse); + +#if NET8_0_OR_GREATER + public Task> RestoreNextAsync(CancellationToken cancellationToken) + => Task.FromResult(new SwaggerResponse(200, new Dictionary>(), Stream.Null)); + + public Task> RestoreErrorAsync( + string lambda_Runtime_Function_Error_Type, string errorJson, CancellationToken cancellationToken) + => Task.FromResult(EmptyStatusResponse); +#endif + } +} diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/TestHelpers/TestStreamingRuntimeApiClient.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/TestHelpers/TestStreamingRuntimeApiClient.cs new file mode 100644 index 000000000..1cd6fa09e --- /dev/null +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/Amazon.Lambda.RuntimeSupport.UnitTests/TestHelpers/TestStreamingRuntimeApiClient.cs @@ -0,0 +1,142 @@ +/* + * Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +using Amazon.Lambda.RuntimeSupport.Client.ResponseStreaming; +using Amazon.Lambda.RuntimeSupport.Helpers; +using Amazon.Lambda.RuntimeSupport.UnitTests.TestHelpers; +using System; +using System.Collections.Generic; +using System.IO; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Amazon.Lambda.RuntimeSupport.UnitTests +{ + /// + /// A RuntimeApiClient subclass for testing LambdaBootstrap streaming integration. + /// Extends RuntimeApiClient so the (RuntimeApiClient)Client cast in LambdaBootstrap works. + /// Overrides StartStreamingResponseAsync to avoid real HTTP calls. + /// + internal class TestStreamingRuntimeApiClient : RuntimeApiClient, IRuntimeApiClient + { + private readonly IEnvironmentVariables _environmentVariables; + private readonly Dictionary> _headers; + + public new IConsoleLoggerWriter ConsoleLogger { get; } = new LogLevelLoggerWriter(new SystemEnvironmentVariables()); + + public TestStreamingRuntimeApiClient(IEnvironmentVariables environmentVariables, Dictionary> headers) + : base(environmentVariables, new NoOpInternalRuntimeApiClient()) + { + _environmentVariables = environmentVariables; + _headers = headers; + } + + // Tracking flags + public bool GetNextInvocationAsyncCalled { get; private set; } + public bool ReportInitializationErrorAsyncExceptionCalled { get; private set; } + public bool ReportInvocationErrorAsyncExceptionCalled { get; private set; } + public bool SendResponseAsyncCalled { get; private set; } + public bool StartStreamingResponseAsyncCalled { get; private set; } + + public string LastTraceId { get; private set; } + public byte[] FunctionInput { get; set; } + public Stream LastOutputStream { get; private set; } + public Exception LastRecordedException { get; private set; } + public ResponseStream LastStreamingResponseStream { get; private set; } + + public new async Task GetNextInvocationAsync(CancellationToken cancellationToken = default) + { + GetNextInvocationAsyncCalled = true; + + LastTraceId = Guid.NewGuid().ToString(); + _headers[RuntimeApiHeaders.HeaderTraceId] = new List() { LastTraceId }; + + var inputStream = new MemoryStream(FunctionInput == null ? new byte[0] : FunctionInput); + inputStream.Position = 0; + + return new InvocationRequest() + { + InputStream = inputStream, + LambdaContext = new LambdaContext( + new RuntimeApiHeaders(_headers), + new LambdaEnvironment(_environmentVariables), + new TestDateTimeHelper(), new SimpleLoggerWriter(_environmentVariables)) + }; + } + + public new Task ReportInitializationErrorAsync(Exception exception, String errorType = null, CancellationToken cancellationToken = default) + { + LastRecordedException = exception; + ReportInitializationErrorAsyncExceptionCalled = true; + return Task.CompletedTask; + } + + public new Task ReportInitializationErrorAsync(string errorType, CancellationToken cancellationToken = default) + { + return Task.CompletedTask; + } + + public new Task ReportInvocationErrorAsync(string awsRequestId, Exception exception, CancellationToken cancellationToken = default) + { + LastRecordedException = exception; + ReportInvocationErrorAsyncExceptionCalled = true; + return Task.CompletedTask; + } + + public new async Task SendResponseAsync(string awsRequestId, Stream outputStream, CancellationToken cancellationToken = default) + { + if (outputStream != null) + { + LastOutputStream = new MemoryStream((int)outputStream.Length); + outputStream.CopyTo(LastOutputStream); + LastOutputStream.Position = 0; + } + + SendResponseAsyncCalled = true; + } + + internal override async Task StartStreamingResponseAsync( + string awsRequestId, ResponseStream responseStream, CancellationToken cancellationToken = default) + { + StartStreamingResponseAsyncCalled = true; + LastStreamingResponseStream = responseStream; + + // Simulate the HTTP stream being available + await responseStream.SetHttpOutputStreamAsync(new MemoryStream(), cancellationToken); + + // Wait for the handler to finish writing (mirrors real SerializeToStreamAsync behavior) + await responseStream.WaitForCompletionAsync(); + + return new NoOpDisposable(); + } + +#if NET8_0_OR_GREATER + public new Task RestoreNextInvocationAsync(CancellationToken cancellationToken = default) + => Task.CompletedTask; + + public new Task ReportRestoreErrorAsync(Exception exception, String errorType = null, CancellationToken cancellationToken = default) + => Task.CompletedTask; +#endif + } + + /// + /// A no-op IDisposable for test overrides of StartStreamingResponseAsync. + /// + internal class NoOpDisposable : IDisposable + { + public void Dispose() { } + } +} diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/ResponseStreamingFunctionHandlers/Function.cs b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/ResponseStreamingFunctionHandlers/Function.cs new file mode 100644 index 000000000..8c645ff5b --- /dev/null +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/ResponseStreamingFunctionHandlers/Function.cs @@ -0,0 +1,56 @@ +#pragma warning disable CA2252 + +using Amazon.Lambda.Core; +using Amazon.Lambda.Core.ResponseStreaming; +using Amazon.Lambda.RuntimeSupport; +using Amazon.Lambda.Serialization.SystemTextJson; + +// The function handler that will be called for each Lambda event +var handler = async (string input, ILambdaContext context) => +{ + using var stream = LambdaResponseStreamFactory.CreateStream(); + + switch(input) + { + case $"{nameof(SimpleFunctionHandler)}": + await SimpleFunctionHandler(stream, context); + break; + case $"{nameof(StreamContentHandler)}": + await StreamContentHandler(stream, context); + break; + case $"{nameof(UnhandledExceptionHandler)}": + await UnhandledExceptionHandler(stream, context); + break; + default: + throw new ArgumentException($"Unknown handler scenario {input}"); + } +}; + +async Task SimpleFunctionHandler(Stream stream, ILambdaContext context) +{ + using var writer = new StreamWriter(stream); + await writer.WriteAsync("Hello, World!"); +} + +async Task StreamContentHandler(Stream stream, ILambdaContext context) +{ + using var writer = new StreamWriter(stream); + + await writer.WriteLineAsync("Starting stream content..."); + for(var i = 0; i < 10000; i++) + { + await writer.WriteLineAsync($"Line {i}"); + } + await writer.WriteLineAsync("Finish stream content"); +} + +async Task UnhandledExceptionHandler(Stream stream, ILambdaContext context) +{ + using var writer = new StreamWriter(stream); + await writer.WriteAsync("This method will fail"); + throw new InvalidOperationException("This is an unhandled exception"); +} + +await LambdaBootstrapBuilder.Create(handler, new DefaultLambdaJsonSerializer()) + .Build() + .RunAsync(); diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/ResponseStreamingFunctionHandlers/ResponseStreamingFunctionHandlers.csproj b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/ResponseStreamingFunctionHandlers/ResponseStreamingFunctionHandlers.csproj new file mode 100644 index 000000000..fa81eaa17 --- /dev/null +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/ResponseStreamingFunctionHandlers/ResponseStreamingFunctionHandlers.csproj @@ -0,0 +1,19 @@ + + + Exe + net10.0 + enable + enable + true + Lambda + + true + + true + + + + + + + \ No newline at end of file diff --git a/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/ResponseStreamingFunctionHandlers/aws-lambda-tools-defaults.json b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/ResponseStreamingFunctionHandlers/aws-lambda-tools-defaults.json new file mode 100644 index 000000000..3042c3978 --- /dev/null +++ b/Libraries/test/Amazon.Lambda.RuntimeSupport.Tests/ResponseStreamingFunctionHandlers/aws-lambda-tools-defaults.json @@ -0,0 +1,15 @@ +{ + "Information": [ + "This file provides default values for the deployment wizard inside Visual Studio and the AWS Lambda commands added to the .NET Core CLI.", + "To learn more about the Lambda commands with the .NET Core CLI execute the following command at the command line in the project root directory.", + "dotnet lambda help", + "All the command line options for the Lambda command can be specified in this file." + ], + "profile": "default", + "region": "us-west-2", + "configuration": "Release", + "function-runtime": "dotnet10", + "function-memory-size": 512, + "function-timeout": 30, + "function-handler": "ResponseStreamingFunctionHandlers" +} \ No newline at end of file