@@ -45,24 +45,10 @@ public void RegisterInitializerRequests(HandlerWrapper handlerWrapper)
4545 {
4646 Amazon . Lambda . Core . SnapshotRestore . RegisterBeforeSnapshot ( async ( ) =>
4747 {
48- // Construct specialized HttpClient that will intercept requests and saved them inside
49- // LambdaSnapstartInitializerHttpMessageHandler.CapturedHttpRequests.
50- //
51- // They will be processed later by SnapstartHelperLambdaRequests.ExecuteSnapstartInitRequests which will
52- // route them correctly through a simulated lambda pipeline.
53- var messageHandlerThatCollectsRequests = new LambdaSnapstartInitializerHttpMessageHandler ( _lambdaEventSource ) ;
54-
55- var httpClientThatCollectsRequests = new HttpClient ( messageHandlerThatCollectsRequests ) ;
56- httpClientThatCollectsRequests . BaseAddress = LambdaSnapstartInitializerHttpMessageHandler . BaseUri ;
57-
58- // "Invoke" each registered request function. Requests will be captured inside.
59- // LambdaSnapstartInitializerHttpMessageHandler.CapturedHttpRequests.
60- await Registrar . Execute ( httpClientThatCollectsRequests ) ;
61-
62- // Request are now in CapturedHttpRequests. Serialize each one into a json object
63- // and execute the request through the lambda pipeline (ie handlerWrapper).
64- foreach ( var json in LambdaSnapstartInitializerHttpMessageHandler . CapturedHttpRequestsJson )
48+ foreach ( var req in Registrar . GetAllRequests ( ) )
6549 {
50+ var json = await SnapstartHelperLambdaRequests . SerializeToJson ( req , _lambdaEventSource ) ;
51+
6652 await SnapstartHelperLambdaRequests . ExecuteSnapstartInitRequests ( json , times : 5 , handlerWrapper ) ;
6753 }
6854 } ) ;
@@ -74,17 +60,19 @@ public void RegisterInitializerRequests(HandlerWrapper handlerWrapper)
7460
7561 internal class BeforeSnapstartRequestRegistrar
7662 {
77- private readonly List < Func < HttpClient , Task > > _beforeSnapstartFuncs = new ( ) ;
63+ private readonly List < Func < IEnumerable < HttpRequestMessage > > > _beforeSnapstartRequests = new ( ) ;
7864
79- public void Register ( Func < HttpClient , Task > beforeSnapstartRequest )
65+
66+ public void Register ( Func < IEnumerable < HttpRequestMessage > > beforeSnapstartRequests )
8067 {
81- _beforeSnapstartFuncs . Add ( beforeSnapstartRequest ) ;
68+ _beforeSnapstartRequests . Add ( beforeSnapstartRequests ) ;
8269 }
8370
84- internal async Task Execute ( HttpClient client )
71+ internal IEnumerable < HttpRequestMessage > GetAllRequests ( )
8572 {
86- foreach ( var f in _beforeSnapstartFuncs )
87- await f ( client ) ;
73+ foreach ( var batch in _beforeSnapstartRequests )
74+ foreach ( var r in batch ( ) )
75+ yield return r ;
8876 }
8977 }
9078
@@ -113,6 +101,8 @@ private class HelperLambdaContext : ILambdaContext, ICognitoIdentity, IClientCon
113101
114102 private static class SnapstartHelperLambdaRequests
115103 {
104+ private static readonly Uri _baseUri = new Uri ( "http://localhost" ) ;
105+
116106 public static async Task ExecuteSnapstartInitRequests ( string jsonRequest , int times , HandlerWrapper handlerWrapper )
117107 {
118108 var dummyRequest = new InvocationRequest (
@@ -131,25 +121,21 @@ public static async Task ExecuteSnapstartInitRequests(string jsonRequest, int ti
131121 }
132122 }
133123 }
134- }
135-
136- private class LambdaSnapstartInitializerHttpMessageHandler : HttpMessageHandler
137- {
138- private readonly LambdaEventSource _lambdaEventSource ;
139-
140- public static Uri BaseUri { get ; } = new Uri ( "http://localhost" ) ;
141-
142- public static List < string > CapturedHttpRequestsJson { get ; } = new ( ) ;
143-
144- public LambdaSnapstartInitializerHttpMessageHandler ( LambdaEventSource lambdaEventSource )
145- {
146- _lambdaEventSource = lambdaEventSource ;
147- }
148124
149- protected override async Task < HttpResponseMessage > SendAsync ( HttpRequestMessage request , CancellationToken cancellationToken )
125+ public static async Task < string > SerializeToJson ( HttpRequestMessage request , LambdaEventSource lambdaType )
150126 {
151127 if ( null == request . RequestUri )
152- return new HttpResponseMessage ( HttpStatusCode . OK ) ;
128+ {
129+ throw new ArgumentException ( $ "{ nameof ( HttpRequestMessage . RequestUri ) } must be set.", nameof ( request ) ) ;
130+ }
131+
132+ if ( request . RequestUri . IsAbsoluteUri )
133+ {
134+ throw new ArgumentException ( $ "{ nameof ( HttpRequestMessage . RequestUri ) } must be relative.", nameof ( request ) ) ;
135+ }
136+
137+ // make request absolut (relative to localhost) otherwise parsing the query will not work
138+ request . RequestUri = new Uri ( _baseUri , request . RequestUri ) ;
153139
154140 var duckRequest = new
155141 {
@@ -160,12 +146,12 @@ protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage
160146 kvp => kvp . Value . FirstOrDefault ( ) ,
161147 StringComparer . OrdinalIgnoreCase ) ,
162148 HttpMethod = request . Method . ToString ( ) ,
163- Path = "/" + BaseUri . MakeRelativeUri ( request . RequestUri ) ,
149+ Path = "/" + _baseUri . MakeRelativeUri ( request . RequestUri ) ,
164150 RawQuery = request . RequestUri ? . Query ,
165151 Query = QueryHelpers . ParseNullableQuery ( request . RequestUri ? . Query )
166152 } ;
167-
168- string translatedRequestJson = _lambdaEventSource switch
153+
154+ string translatedRequestJson = lambdaType switch
169155 {
170156 LambdaEventSource . ApplicationLoadBalancer =>
171157 JsonSerializer . Serialize (
@@ -213,17 +199,13 @@ protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage
213199 } ,
214200 LambdaRequestTypeClasses . Default . APIGatewayProxyRequest ) ,
215201 _ => throw new NotImplementedException (
216- $ "Unknown { nameof ( LambdaEventSource ) } : { Enum . GetName ( _lambdaEventSource ) } ")
202+ $ "Unknown { nameof ( LambdaEventSource ) } : { Enum . GetName ( lambdaType ) } ")
217203 } ;
218204
219- // NOTE: Any object added to CapturedHttpRequests must have it's type added
220- // to the
221- CapturedHttpRequestsJson . Add ( translatedRequestJson ) ;
222-
223- return new HttpResponseMessage ( HttpStatusCode . OK ) ;
205+ return translatedRequestJson ;
224206 }
225207
226- private async Task < string > ReadContent ( HttpRequestMessage r )
208+ private static async Task < string > ReadContent ( HttpRequestMessage r )
227209 {
228210 if ( r . Content == null )
229211 return string . Empty ;
0 commit comments