Skip to content

Commit 0788f5d

Browse files
authored
Merge pull request #32 from bakdata/feature/api-gateway-support
Feature/api gateway support
2 parents a285fed + 891bfb9 commit 0788f5d

10 files changed

Lines changed: 124 additions & 30 deletions

File tree

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ __pycache__/
44
.idea/
55
*.iml
66
build/
7-
venv/
7+
venv/
8+
packaged.yaml

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ cat response.txt
4040

4141
The expected result should look similar to this:
4242
```json
43-
{"result":2}
43+
2
4444
```
4545

4646
### Using packages
@@ -77,7 +77,7 @@ cat response.txt
7777

7878
The expected result should look similar to this:
7979
```json
80-
{"result":[4,5,6]}
80+
[4,5,6]
8181
```
8282

8383
## Provided layers
@@ -213,7 +213,7 @@ The handler must be separated by `.`, e.g., `script.handler`.
213213

214214
The lambda payload is unwrapped as named arguments to the R function to call, e.g., `{"x":1}` is unwrapped to `handler(x=1)`.
215215

216-
The lambda function returns whatever is returned by the R function as a JSON object with `result` as a root element.
216+
The lambda function returns whatever is returned by the R function as a JSON object.
217217

218218
### Building custom layers
219219

runtime/src/runtime.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ throwRuntimeError <- function(error, REQUEST_ID) {
6464

6565
postResult <- function(result, REQUEST_ID) {
6666
url <- paste0(API_ENDPOINT, "invocation/", REQUEST_ID, "/response")
67-
res <- POST(url, body = list(result = result), encode = "json")
67+
res <- POST(url, body = toJSON(result, auto_unbox = TRUE), encode = "raw", content_type_json())
6868
loginfo("Posted result:\n%s", to_str(res))
6969
}
7070

template.yaml

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ Transform: 'AWS::Serverless-2016-10-31'
33
Globals:
44
Function:
55
Runtime: provided
6-
CodeUri: tests/R/
76
Timeout: 300
87
MemorySize: 3008
98
Layers:
@@ -13,42 +12,72 @@ Resources:
1312
Type: 'AWS::Serverless::Function'
1413
Properties:
1514
Handler: script.handler
15+
CodeUri: tests/R/
16+
FunctionName: ExampleFunction
1617
LowerCaseExtensionFunction:
1718
Type: 'AWS::Serverless::Function'
1819
Properties:
1920
Handler: lowercase.handler
21+
CodeUri: tests/R/
22+
FunctionName: LowerCaseExtensionFunction
2023
MissingFunctionFunction:
2124
Type: 'AWS::Serverless::Function'
2225
Properties:
2326
Handler: script.handler_missing
27+
CodeUri: tests/R/
28+
FunctionName: MissingFunctionFunction
2429
MissingSourceFileFunction:
2530
Type: 'AWS::Serverless::Function'
2631
Properties:
2732
Handler: missing.handler
33+
CodeUri: tests/R/
34+
FunctionName: MissingSourceFileFunction
2835
MultipleArgumentsFunction:
2936
Type: 'AWS::Serverless::Function'
3037
Properties:
3138
Handler: script.handler_with_multiple_arguments
39+
CodeUri: tests/R/
40+
FunctionName: MultipleArgumentsFunction
3241
VariableArgumentsFunction:
3342
Type: 'AWS::Serverless::Function'
3443
Properties:
3544
Handler: script.handler_with_variable_arguments
45+
CodeUri: tests/R/
46+
FunctionName: VariableArgumentsFunction
3647
MatrixFunction:
3748
Type: 'AWS::Serverless::Function'
3849
Properties:
3950
Handler: matrix.handler
51+
CodeUri: tests/R/
4052
Layers:
4153
- !Ref RecommendedLayer
54+
FunctionName: MatrixFunction
4255
MissingLibraryFunction:
4356
Type: 'AWS::Serverless::Function'
4457
Properties:
4558
Handler: matrix.handler
59+
CodeUri: tests/R/
60+
FunctionName: MissingLibraryFunction
4661
AWSFunction:
4762
Type: 'AWS::Serverless::Function'
4863
Properties:
4964
Handler: aws.handler
65+
CodeUri: tests/R/
5066
Layers:
5167
- !Ref AWSLayer
68+
FunctionName: AWSFunction
69+
ApiFunction:
70+
Type: 'AWS::Serverless::Function'
71+
Properties:
72+
Handler: api.handler
73+
CodeUri: tests/R/
74+
FunctionName: ApiFunction
75+
Events:
76+
Api:
77+
Type: Api
78+
Properties:
79+
Path: '/hello'
80+
Method: GET
5281
RuntimeLayer:
5382
Type: AWS::Serverless::LayerVersion
5483
Properties:

tests/R/api.R

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
library(jsonlite)
2+
3+
handler <- function(headers, multiValueHeaders, queryStringParameters, multiValueQueryStringParameters, pathParameters, body, ...) {
4+
return(
5+
list(
6+
statusCode = 200,
7+
headers = list("Content-Type" = "application/json"),
8+
body = toJSON(list(hello = queryStringParameters$who), auto_unbox = TRUE)
9+
)
10+
)
11+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,31 @@
77
from tests import wait_for_port
88

99

10+
class LocalApi:
11+
12+
def __init__(self, host: str = '127.0.0.1', port: int = 3000):
13+
self.host = host
14+
self.port = port
15+
self.process = Popen(['sam', 'local', 'start-api', '--host', self.host, '--port', str(self.port)])
16+
17+
def kill(self):
18+
self.process.kill()
19+
return_code = self.process.wait()
20+
logging.info('Killed server with code %s', return_code)
21+
22+
def wait(self, interval: int = 10, retries: int = 6):
23+
wait_for_port(self.port, self.host, interval=interval, retries=retries)
24+
25+
def get_uri(self) -> str:
26+
return 'http://{}:{}'.format(self.host, self.port)
27+
28+
29+
def start_local_api() -> LocalApi:
30+
server = LocalApi()
31+
server.wait()
32+
return server
33+
34+
1035
class LocalLambdaServer:
1136

1237
def __init__(self, host: str = '127.0.0.1', port: int = 3001):

tests/test_api.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import unittest
2+
3+
import requests
4+
5+
from tests.sam import LocalApi, start_local_api
6+
7+
8+
class TestApi(unittest.TestCase):
9+
api: LocalApi = None
10+
11+
@classmethod
12+
def setUpClass(cls):
13+
cls.api = start_local_api()
14+
15+
def test_matrix(self):
16+
response = requests.get('%s/hello' % self.api.get_uri(), params={'who': 'World'})
17+
result = response.json()
18+
self.assertDictEqual({'hello': 'World'}, result)
19+
20+
@classmethod
21+
def tearDownClass(cls):
22+
cls.api.kill()

tests/test_aws.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import unittest
33

4-
from tests.aws_lambda import LocalLambdaServer, start_local_lambda
4+
from tests.sam import LocalLambdaServer, start_local_lambda
55

66

77
class TestAWSLayer(unittest.TestCase):
@@ -15,8 +15,7 @@ def test_s3_get_object(self):
1515
lambda_client = self.lambda_server.get_client()
1616
response = lambda_client.invoke(FunctionName="AWSFunction")
1717
raw_payload = response['Payload'].read().decode('utf-8')
18-
json_payload = json.loads(raw_payload)
19-
result = json_payload['result']
18+
result = json.loads(raw_payload)
2019
self.assertEqual(len(result), 1)
2120
self.assertDictEqual(result[0], {
2221
"DRG.Definition": "039 - EXTRACRANIAL PROCEDURES W/O CC/MCC",

tests/test_matrix.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import unittest
33

4-
from tests.aws_lambda import LocalLambdaServer, start_local_lambda
4+
from tests.sam import LocalLambdaServer, start_local_lambda
55

66

77
class TestRecommendedLayer(unittest.TestCase):
@@ -15,8 +15,7 @@ def test_matrix(self):
1515
lambda_client = self.lambda_server.get_client()
1616
response = lambda_client.invoke(FunctionName="MatrixFunction")
1717
raw_payload = response['Payload'].read().decode('utf-8')
18-
json_payload = json.loads(raw_payload)
19-
result = json_payload['result']
18+
result = json.loads(raw_payload)
2019
self.assertEqual(len(result), 3)
2120
self.assertIn(4, result)
2221
self.assertIn(5, result)

tests/test_runtime.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,51 @@
11
import json
22
import unittest
33

4-
from tests.aws_lambda import LocalLambdaServer, start_local_lambda
4+
import boto3
5+
6+
from tests.sam import LocalLambdaServer, start_local_lambda
57

68

79
class TestRuntimeLayer(unittest.TestCase):
810
lambda_server: LocalLambdaServer = None
911

12+
@classmethod
13+
def isLocal(cls) -> bool:
14+
return True
15+
1016
@classmethod
1117
def setUpClass(cls):
12-
cls.lambda_server = start_local_lambda()
18+
if cls.isLocal():
19+
cls.lambda_server = start_local_lambda()
1320

1421
def test_script(self):
15-
lambda_client = self.lambda_server.get_client()
22+
lambda_client = self.get_client()
1623
response = lambda_client.invoke(FunctionName="ExampleFunction", Payload=json.dumps({'x': 1}))
1724
raw_payload = response['Payload'].read().decode('utf-8')
18-
json_payload = json.loads(raw_payload)
19-
result = json_payload['result']
25+
result = json.loads(raw_payload)
2026
self.assertEqual(result, 2)
2127

28+
def get_client(self):
29+
return self.lambda_server.get_client() if self.isLocal() else boto3.client('lambda')
30+
2231
def test_lowercase_extension(self):
23-
lambda_client = self.lambda_server.get_client()
32+
lambda_client = self.get_client()
2433
response = lambda_client.invoke(FunctionName="LowerCaseExtensionFunction", Payload=json.dumps({'x': 1}))
2534
raw_payload = response['Payload'].read().decode('utf-8')
26-
json_payload = json.loads(raw_payload)
27-
result = json_payload['result']
35+
result = json.loads(raw_payload)
2836
self.assertEqual(result, 2)
2937

3038
def test_multiple_arguments(self):
31-
lambda_client = self.lambda_server.get_client()
39+
lambda_client = self.get_client()
3240
payload = {'x': 'bar', 'y': 1}
3341
response = lambda_client.invoke(FunctionName="MultipleArgumentsFunction", Payload=json.dumps(payload))
3442
raw_payload = response['Payload'].read().decode('utf-8')
35-
json_payload = json.loads(raw_payload)
36-
result = json_payload['result']
43+
result = json.loads(raw_payload)
3744
self.assertDictEqual(result, payload)
3845

3946
@unittest.skip('Lambda local does not pass errors properly')
4047
def test_missing_source_file(self):
41-
lambda_client = self.lambda_server.get_client()
48+
lambda_client = self.get_client()
4249
response = lambda_client.invoke(FunctionName="MissingSourceFileFunction", Payload=json.dumps({'y': 1}))
4350
raw_payload = response['Payload'].read().decode('utf-8')
4451
json_payload = json.loads(raw_payload)
@@ -47,7 +54,7 @@ def test_missing_source_file(self):
4754

4855
@unittest.skip('Lambda local does not pass errors properly')
4956
def test_missing_function(self):
50-
lambda_client = self.lambda_server.get_client()
57+
lambda_client = self.get_client()
5158
response = lambda_client.invoke(FunctionName="MissingFunctionFunction", Payload=json.dumps({'y': 1}))
5259
raw_payload = response['Payload'].read().decode('utf-8')
5360
json_payload = json.loads(raw_payload)
@@ -56,7 +63,7 @@ def test_missing_function(self):
5663

5764
@unittest.skip('Lambda local does not pass errors properly')
5865
def test_missing_argument(self):
59-
lambda_client = self.lambda_server.get_client()
66+
lambda_client = self.get_client()
6067
response = lambda_client.invoke(FunctionName="ExampleFunction")
6168
raw_payload = response['Payload'].read().decode('utf-8')
6269
json_payload = json.loads(raw_payload)
@@ -65,7 +72,7 @@ def test_missing_argument(self):
6572

6673
@unittest.skip('Lambda local does not pass errors properly')
6774
def test_unused_argument(self):
68-
lambda_client = self.lambda_server.get_client()
75+
lambda_client = self.get_client()
6976
response = lambda_client.invoke(FunctionName="ExampleFunction", Payload=json.dumps({'x': 1, 'y': 1}))
7077
raw_payload = response['Payload'].read().decode('utf-8')
7178
json_payload = json.loads(raw_payload)
@@ -74,7 +81,7 @@ def test_unused_argument(self):
7481

7582
@unittest.skip('Lambda local does not pass errors properly')
7683
def test_long_argument(self):
77-
lambda_client = self.lambda_server.get_client()
84+
lambda_client = self.get_client()
7885
payload = {x: x for x in range(0, 100000)}
7986
response = lambda_client.invoke(FunctionName="VariableArgumentsFunction", Payload=json.dumps(payload))
8087
raw_payload = response['Payload'].read().decode('utf-8')
@@ -83,7 +90,7 @@ def test_long_argument(self):
8390

8491
@unittest.skip('Lambda local does not pass errors properly')
8592
def test_missing_library(self):
86-
lambda_client = self.lambda_server.get_client()
93+
lambda_client = self.get_client()
8794
response = lambda_client.invoke(FunctionName="MissingLibraryFunction", Payload=json.dumps({'y': 1}))
8895
raw_payload = response['Payload'].read().decode('utf-8')
8996
json_payload = json.loads(raw_payload)
@@ -92,4 +99,5 @@ def test_missing_library(self):
9299

93100
@classmethod
94101
def tearDownClass(cls):
95-
cls.lambda_server.kill()
102+
if cls.isLocal():
103+
cls.lambda_server.kill()

0 commit comments

Comments
 (0)