@@ -9,10 +9,17 @@ import (
99 "bytes"
1010 "context"
1111 "fmt"
12+ "io"
1213 "log"
14+ "math"
1315 "math/rand"
16+ "net"
17+ "net/http"
1418 "os"
19+ "os/exec"
20+ "path/filepath"
1521 "strings"
22+ "sync"
1623 "sync/atomic"
1724 "testing"
1825 "time"
@@ -130,3 +137,105 @@ func TestRuntimeAPILoopWithConcurrencyPanic(t *testing.T) {
130137 assert .Greater (t , idx2 , idx1 )
131138 assert .Greater (t , idx3 , idx2 )
132139}
140+
141+ func TestConcurrencyWithRIE (t * testing.T ) {
142+ containerCmd := ""
143+ if _ , err := exec .LookPath ("finch" ); err == nil {
144+ containerCmd = "finch"
145+ } else if _ , err := exec .LookPath ("docker" ); err == nil {
146+ containerCmd = "docker"
147+ } else {
148+ t .Skip ("finch or docker required" )
149+ }
150+
151+ testDir := t .TempDir ()
152+ handlerBuild := exec .Command ("go" , "build" , "-o" , filepath .Join (testDir , "bootstrap" ), "./testdata/sleep.go" )
153+ handlerBuild .Env = append (os .Environ (), "GOOS=linux" )
154+ require .NoError (t , handlerBuild .Run ())
155+
156+ nInvokes := 10
157+ concurrency := 3
158+ sleepMs := 1000
159+ batches := int (math .Ceil (float64 (nInvokes ) / float64 (concurrency )))
160+ expectedMaxDuration := time .Duration (float64 (batches * sleepMs )* 1.1 ) * time .Millisecond // 10% margin for retries, network overhead, scheduling
161+
162+ // Find an available port
163+ listener , err := net .Listen ("tcp" , "127.0.0.1:0" )
164+ require .NoError (t , err )
165+ port := listener .Addr ().(* net.TCPAddr ).Port
166+ listener .Close ()
167+
168+ cmd := exec .Command (containerCmd , "run" , "--rm" ,
169+ "-v" , testDir + ":/var/runtime:ro,delegated" ,
170+ "-p" , fmt .Sprintf ("%d:8080" , port ),
171+ "-e" , fmt .Sprintf ("AWS_LAMBDA_MAX_CONCURRENCY=%d" , concurrency ),
172+ "public.ecr.aws/lambda/provided:al2023" ,
173+ "bootstrap" )
174+ stdout , err := cmd .StdoutPipe ()
175+ require .NoError (t , err )
176+ stderr , err := cmd .StderrPipe ()
177+ require .NoError (t , err )
178+
179+ var logBuf strings.Builder
180+ logDone := make (chan struct {})
181+ go func () {
182+ _ , _ = io .Copy (io .MultiWriter (os .Stderr , & logBuf ), io .MultiReader (stdout , stderr ))
183+ close (logDone )
184+
185+ }()
186+
187+ require .NoError (t , cmd .Start ())
188+ t .Cleanup (func () { _ = cmd .Process .Kill () })
189+
190+ time .Sleep (5 * time .Second ) // Wait for container to start and pull image if needed
191+
192+ client := & http.Client {Timeout : 15 * time .Second }
193+ invokeURL := fmt .Sprintf ("http://127.0.0.1:%d/2015-03-31/functions/function/invocations" , port )
194+
195+ start := time .Now ()
196+ var wg sync.WaitGroup
197+ ctx , cancel := context .WithTimeout (context .Background (), 20 * time .Second )
198+ defer cancel ()
199+ for range nInvokes {
200+ wg .Add (1 )
201+ go func () {
202+ defer wg .Done ()
203+ for {
204+ select {
205+ case <- ctx .Done ():
206+ return
207+ default :
208+ }
209+ time .Sleep (50 * time .Millisecond )
210+ body := strings .NewReader (fmt .Sprintf (`{"sleep_ms":%d}` , sleepMs ))
211+ resp , err := client .Post (invokeURL , "application/json" , body )
212+ if err != nil {
213+ continue
214+ }
215+ _ , _ = io .Copy (io .Discard , resp .Body )
216+ _ = resp .Body .Close ()
217+ if resp .StatusCode == 400 {
218+ continue
219+ }
220+ return
221+ }
222+ }()
223+ }
224+ wg .Wait ()
225+ duration := time .Since (start )
226+
227+ t .Logf ("Completed %d invocations in %v" , nInvokes , duration )
228+
229+ _ = cmd .Process .Kill ()
230+ _ = cmd .Wait ()
231+ <- logDone
232+
233+ logs := logBuf .String ()
234+ processingCount := strings .Count (logs , "processing" )
235+ completedCount := strings .Count (logs , "completed" )
236+
237+ assert .Equal (t , nInvokes , processingCount , "expected %d processing logs" , nInvokes )
238+ assert .Equal (t , nInvokes , completedCount , "expected %d completed logs" , nInvokes )
239+ assert .Less (t , duration , expectedMaxDuration , "concurrent execution should complete faster than sequential" )
240+
241+ }
0 commit comments