11import json
22import unittest
33
4- from tests .aws_lambda import LocalLambdaServer , start_local_lambda
4+ import boto3
5+
6+ from tests .sam import LocalLambdaServer , start_local_lambda
57
68
79class TestRuntimeLayer (unittest .TestCase ):
810 lambda_server : LocalLambdaServer = None
911
12+ @classmethod
13+ def isLocal (cls ) -> bool :
14+ return True
15+
1016 @classmethod
1117 def setUpClass (cls ):
12- cls .lambda_server = start_local_lambda ()
18+ if cls .isLocal ():
19+ cls .lambda_server = start_local_lambda ()
1320
1421 def test_script (self ):
15- lambda_client = self .lambda_server . get_client ()
22+ lambda_client = self .get_client ()
1623 response = lambda_client .invoke (FunctionName = "ExampleFunction" , Payload = json .dumps ({'x' : 1 }))
1724 raw_payload = response ['Payload' ].read ().decode ('utf-8' )
18- json_payload = json .loads (raw_payload )
19- result = json_payload ['result' ]
25+ result = json .loads (raw_payload )
2026 self .assertEqual (result , 2 )
2127
28+ def get_client (self ):
29+ return self .lambda_server .get_client () if self .isLocal () else boto3 .client ('lambda' )
30+
2231 def test_lowercase_extension (self ):
23- lambda_client = self .lambda_server . get_client ()
32+ lambda_client = self .get_client ()
2433 response = lambda_client .invoke (FunctionName = "LowerCaseExtensionFunction" , Payload = json .dumps ({'x' : 1 }))
2534 raw_payload = response ['Payload' ].read ().decode ('utf-8' )
26- json_payload = json .loads (raw_payload )
27- result = json_payload ['result' ]
35+ result = json .loads (raw_payload )
2836 self .assertEqual (result , 2 )
2937
3038 def test_multiple_arguments (self ):
31- lambda_client = self .lambda_server . get_client ()
39+ lambda_client = self .get_client ()
3240 payload = {'x' : 'bar' , 'y' : 1 }
3341 response = lambda_client .invoke (FunctionName = "MultipleArgumentsFunction" , Payload = json .dumps (payload ))
3442 raw_payload = response ['Payload' ].read ().decode ('utf-8' )
35- json_payload = json .loads (raw_payload )
36- result = json_payload ['result' ]
43+ result = json .loads (raw_payload )
3744 self .assertDictEqual (result , payload )
3845
3946 @unittest .skip ('Lambda local does not pass errors properly' )
4047 def test_missing_source_file (self ):
41- lambda_client = self .lambda_server . get_client ()
48+ lambda_client = self .get_client ()
4249 response = lambda_client .invoke (FunctionName = "MissingSourceFileFunction" , Payload = json .dumps ({'y' : 1 }))
4350 raw_payload = response ['Payload' ].read ().decode ('utf-8' )
4451 json_payload = json .loads (raw_payload )
@@ -47,7 +54,7 @@ def test_missing_source_file(self):
4754
4855 @unittest .skip ('Lambda local does not pass errors properly' )
4956 def test_missing_function (self ):
50- lambda_client = self .lambda_server . get_client ()
57+ lambda_client = self .get_client ()
5158 response = lambda_client .invoke (FunctionName = "MissingFunctionFunction" , Payload = json .dumps ({'y' : 1 }))
5259 raw_payload = response ['Payload' ].read ().decode ('utf-8' )
5360 json_payload = json .loads (raw_payload )
@@ -56,7 +63,7 @@ def test_missing_function(self):
5663
5764 @unittest .skip ('Lambda local does not pass errors properly' )
5865 def test_missing_argument (self ):
59- lambda_client = self .lambda_server . get_client ()
66+ lambda_client = self .get_client ()
6067 response = lambda_client .invoke (FunctionName = "ExampleFunction" )
6168 raw_payload = response ['Payload' ].read ().decode ('utf-8' )
6269 json_payload = json .loads (raw_payload )
@@ -65,7 +72,7 @@ def test_missing_argument(self):
6572
6673 @unittest .skip ('Lambda local does not pass errors properly' )
6774 def test_unused_argument (self ):
68- lambda_client = self .lambda_server . get_client ()
75+ lambda_client = self .get_client ()
6976 response = lambda_client .invoke (FunctionName = "ExampleFunction" , Payload = json .dumps ({'x' : 1 , 'y' : 1 }))
7077 raw_payload = response ['Payload' ].read ().decode ('utf-8' )
7178 json_payload = json .loads (raw_payload )
@@ -74,7 +81,7 @@ def test_unused_argument(self):
7481
7582 @unittest .skip ('Lambda local does not pass errors properly' )
7683 def test_long_argument (self ):
77- lambda_client = self .lambda_server . get_client ()
84+ lambda_client = self .get_client ()
7885 payload = {x : x for x in range (0 , 100000 )}
7986 response = lambda_client .invoke (FunctionName = "VariableArgumentsFunction" , Payload = json .dumps (payload ))
8087 raw_payload = response ['Payload' ].read ().decode ('utf-8' )
@@ -83,7 +90,7 @@ def test_long_argument(self):
8390
8491 @unittest .skip ('Lambda local does not pass errors properly' )
8592 def test_missing_library (self ):
86- lambda_client = self .lambda_server . get_client ()
93+ lambda_client = self .get_client ()
8794 response = lambda_client .invoke (FunctionName = "MissingLibraryFunction" , Payload = json .dumps ({'y' : 1 }))
8895 raw_payload = response ['Payload' ].read ().decode ('utf-8' )
8996 json_payload = json .loads (raw_payload )
@@ -92,4 +99,5 @@ def test_missing_library(self):
9299
93100 @classmethod
94101 def tearDownClass (cls ):
95- cls .lambda_server .kill ()
102+ if cls .isLocal ():
103+ cls .lambda_server .kill ()
0 commit comments