Skip to content

Commit ec24dc6

Browse files
authored
Merge pull request #133 from aws-beam/fix-aws_cloudfront_keyvaluestore
Fix aws cloudfront keyvaluestore
2 parents 1dd433f + da61565 commit ec24dc6

3 files changed

Lines changed: 49 additions & 16 deletions

File tree

lib/aws_codegen/rest_service.ex

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ defmodule AWS.CodeGen.RestService do
3333
"#{if action.language == :elixir, do: ":", else: ""}#{result}"
3434
end
3535

36-
def url_path(action) do
36+
def url_path(context, action) do
3737
Enum.reduce(action.url_parameters, action.request_uri, fn parameter, acc ->
3838
multi_segment = Parameter.multi_segment?(parameter, acc)
3939

@@ -54,7 +54,11 @@ defmodule AWS.CodeGen.RestService do
5454
if multi_segment do
5555
Enum.join(["\", aws_util:encode_multi_segment_uri(", parameter.code_name, "), \""])
5656
else
57-
Enum.join(["\", aws_util:encode_uri(", parameter.code_name, "), \""])
57+
if context.module_name == "aws_cloudfront_keyvaluestore" do
58+
Enum.join(["\", aws_util:encode_uri(", parameter.code_name, ", full), \""])
59+
else
60+
Enum.join(["\", aws_util:encode_uri(", parameter.code_name, "), \""])
61+
end
5862
end
5963
end
6064

priv/rest.erl.eex

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,15 @@ end) %>
6666
<%= AWS.CodeGen.Types.return_type(context.language, action)%>.
6767
<%= action.function_name %>(Client<%= if context.module_name == "aws_apigatewaymanagementapi" do %>, ApiId, Stage<% end %><%= AWS.CodeGen.RestService.required_function_parameters(action) %>, QueryMap, HeadersMap, Options0)
6868
when is_map(Client), is_map(QueryMap), is_map(HeadersMap), is_list(Options0) ->
69-
Path = ["<%= if context.module_name == "aws_apigatewaymanagementapi" do %>/", Stage, "<% end %><%= AWS.CodeGen.RestService.Action.url_path(action) %>"],<%= if AWS.CodeGen.RestService.Context.s3_context?(context) do %>
69+
Path = ["<%= if context.module_name == "aws_apigatewaymanagementapi" do %>/", Stage, "<% end %><%= AWS.CodeGen.RestService.Action.url_path(context, action) %>"],<%= if AWS.CodeGen.RestService.Context.s3_context?(context) do %>
7070
<%= if !String.contains?("Bucket", AWS.CodeGen.RestService.required_function_parameters(action)) do %><% else %> Bucket = undefined,<% end %><% end %>
7171
SuccessStatusCode = <%= inspect(action.success_status_code) %>,
72-
{SendBodyAsBinary, Options1} = proplists_take(send_body_as_binary, Options0, <%= action.send_body_as_binary? %>),
72+
{SendBodyAsBinary, Options1} = proplists_take(send_body_as_binary, Options0, <%= if context.module_name == "aws_cloudfront_keyvaluestore" do %>true<% else %><%= action.send_body_as_binary? %><% end %>),
7373
{ReceiveBodyAsBinary, Options2} = proplists_take(receive_body_as_binary, Options1, <%= action.receive_body_as_binary? %>),
7474
Options = [{send_body_as_binary, SendBodyAsBinary},
75-
{receive_body_as_binary, ReceiveBodyAsBinary}
75+
{receive_body_as_binary, ReceiveBodyAsBinary}<%= if context.module_name == "aws_cloudfront_keyvaluestore" do %>,
76+
{account, get_account_id(KvsARN)},
77+
{sign_with_v4a, true}<% end %>
7678
| Options2],
7779
<%= if length(action.request_header_parameters) > 0 do %>
7880
Headers0 =
@@ -122,14 +124,16 @@ end) %>
122124
<%= AWS.CodeGen.Types.return_type(context.language, action)%>.
123125
<%= action.function_name %>(Client<%= if context.module_name == "aws_apigatewaymanagementapi" do %>, ApiId, Stage<% end %><%= AWS.CodeGen.RestService.function_parameters(action) %>, Input0, Options0) ->
124126
Method = <%= AWS.CodeGen.RestService.Action.method(action) %>,
125-
Path = ["<%= if context.module_name == "aws_apigatewaymanagementapi" do %>/", Stage, "<% end %><%= AWS.CodeGen.RestService.Action.url_path(action) %>"],<%= if AWS.CodeGen.RestService.Context.s3_context?(context) do %>
127+
Path = ["<%= if context.module_name == "aws_apigatewaymanagementapi" do %>/", Stage, "<% end %><%= AWS.CodeGen.RestService.Action.url_path(context, action) %>"],<%= if AWS.CodeGen.RestService.Context.s3_context?(context) do %>
126128
<%= if !String.contains?("Bucket", AWS.CodeGen.RestService.required_function_parameters(action)) do %><% else %> Bucket = undefined,<% end %><% end %>
127129
SuccessStatusCode = <%= inspect(action.success_status_code) %>,
128-
{SendBodyAsBinary, Options1} = proplists_take(send_body_as_binary, Options0, <%= action.send_body_as_binary? %>),
130+
{SendBodyAsBinary, Options1} = proplists_take(send_body_as_binary, Options0, <%= if context.module_name == "aws_cloudfront_keyvaluestore" do %>true<% else %><%= action.send_body_as_binary? %><% end %>),
129131
{ReceiveBodyAsBinary, Options2} = proplists_take(receive_body_as_binary, Options1, <%= action.receive_body_as_binary? %>),
130132
Options = [{send_body_as_binary, SendBodyAsBinary},
131133
{receive_body_as_binary, ReceiveBodyAsBinary},
132-
{append_sha256_content_hash, <%= Enum.member?(["put_bucket_cors", "put_bucket_lifecycle", "put_bucket_tagging", "delete_objects"], action.function_name) %>}
134+
{append_sha256_content_hash, <%= Enum.member?(["put_bucket_cors", "put_bucket_lifecycle", "put_bucket_tagging", "delete_objects"], action.function_name) %>}<%= if context.module_name == "aws_cloudfront_keyvaluestore" do %>,
135+
{account, get_account_id(KvsARN)},
136+
{sign_with_v4a, true}<% end %>
133137
| Options2],
134138
<%= if length(action.request_header_parameters) > 0 do %>
135139
HeadersMapping = [<%= for parameter <- Enum.drop(action.request_header_parameters, -1) do %>
@@ -209,7 +213,7 @@ do_request(Client, Method, Path, Query, Headers0, Input, Options, SuccessStatusC
209213
Client1 = Client#{service => <<"<%= context.signing_name %>">><%= if context.is_global do %>,
210214
region => <<"<%= context.credential_scope %>">><% end %>},
211215
<%= if context.endpoint_prefix == "s3-control" do %>AccountId = proplists:get_value(<<"x-amz-account-id">>, Headers0),
212-
DefaultHost = build_host(AccountId, <<"<%= context.endpoint_prefix %>">>, Client1),<% else %><%= if AWS.CodeGen.RestService.Context.s3_context?(context) do %>DefaultHost = build_host(<<"<%= context.endpoint_prefix %>">>, Client1, Bucket),<%else %>DefaultHost = build_host(<<"<%= context.endpoint_prefix %>">>, Client1),<% end %><% end %>
216+
DefaultHost = build_host(AccountId, <<"<%= context.endpoint_prefix %>">>, Client1),<% else %><%= if AWS.CodeGen.RestService.Context.s3_context?(context) do %>DefaultHost = build_host(<<"<%= context.endpoint_prefix %>">>, Client1, Bucket),<% else %><%= if context.module_name == "aws_cloudfront_keyvaluestore" do %>DefaultHost = build_host(proplists:get_value(account, Options), <<"cloudfront-kvs">>, Client1),<% else %>DefaultHost = build_host(<<"<%= context.endpoint_prefix %>">>, Client1),<% end %><% end %><% end %>
213217
URL0 = build_url(DefaultHost, Path, Client1<%= if AWS.CodeGen.RestService.Context.s3_context?(context) do %>, Bucket<% end %>),
214218
PathBin = erlang:iolist_to_binary(Path),
215219
{URL1, Host} = aws_util:apply_endpoint_url_override(URL0, DefaultHost, PathBin, <<"<%= context.endpoint_url_env_var %>">>),
@@ -219,8 +223,12 @@ do_request(Client, Method, Path, Query, Headers0, Input, Options, SuccessStatusC
219223
],
220224
Payload =
221225
case proplists:get_value(send_body_as_binary, Options) of
222-
true ->
223-
maps:get(<<"Body">>, Input, <<"">>);
226+
true when is_list(Input) ->
227+
proplists:get_value(<<"Body">>, Input, <<"">>);
228+
true when Input =:= undefined ->
229+
<<"">>;
230+
true ->
231+
maps:get(<<"Body">>, Input, <<"">>);
224232
false ->
225233
encode_payload(Input)
226234
end,
@@ -233,7 +241,7 @@ do_request(Client, Method, Path, Query, Headers0, Input, Options, SuccessStatusC
233241
Headers1 = aws_request:add_headers(AdditionalHeaders, Headers0),
234242

235243
MethodBin = aws_request:method_to_binary(Method),
236-
SignedHeaders = aws_request:sign_request(Client1, MethodBin, URL, Headers1, Payload<%= if context.module_name == "aws_apigatewaymanagementapi" or String.contains?(context.module_name, "aws_bedrock") do %>, [{uri_encode_path, true}]<% else %><% end %>),
244+
SignedHeaders = aws_request:sign_request(Client1, MethodBin, URL, Headers1, Payload<%= if context.module_name == "aws_apigatewaymanagementapi" or String.contains?(context.module_name, "aws_bedrock") do %>, [{uri_encode_path, true}]<% else %><%= if context.module_name == "aws_cloudfront_keyvaluestore" do %>, [{sign_with_v4a, true}, {uri_encode_path, false}]<% end %><% end %>),
237245
Response = hackney:request(Method, URL, SignedHeaders, Payload, Options),
238246
DecodeBody = not proplists:get_value(receive_body_as_binary, Options),
239247
handle_response(Response, SuccessStatusCode, DecodeBody).
@@ -305,6 +313,22 @@ build_host(undefined, _EndpointPrefix, _Client) ->
305313
error(missing_account_id);
306314
build_host(AccountId, EndpointPrefix, #{region := Region, endpoint := Endpoint}) ->
307315
aws_util:binary_join([AccountId, EndpointPrefix, Region, Endpoint],
316+
<<".">>).<% else %><%= if context.module_name == "aws_cloudfront_keyvaluestore" do %>
317+
build_host(_AccountPrefix, _EndpointPrefix, #{region := <<"local">>, endpoint := Endpoint}) ->
318+
Endpoint;
319+
build_host(_AccountPrefix, _EndpointPrefix, #{region := <<"local">>}) ->
320+
<<"localhost">>;
321+
build_host(AccountPrefix, EndpointPrefix, #{region := <<"global">>, endpoint := Endpoint}) ->
322+
aws_util:binary_join([AccountPrefix, EndpointPrefix, <<"global">>, Endpoint], <<".">>).
323+
<% else %><%= if context.endpoint_prefix == "s3-control" do %>
324+
build_host(_AccountPrefix, _EndpointPrefix, #{region := <<"local">>, endpoint := Endpoint}) ->
325+
Endpoint;
326+
build_host(_AccountPrefix, _EndpointPrefix, #{region := <<"local">>}) ->
327+
<<"localhost">>;
328+
build_host(undefined, _EndpointPrefix, _Client) ->
329+
error(missing_account_id);
330+
build_host(AccountPrefix, EndpointPrefix, #{region := Region, endpoint := Endpoint}) ->
331+
aws_util:binary_join([_AccountPrefix, EndpointPrefix, Region, Endpoint],
308332
<<".">>).<% else %>
309333
<%= if context.endpoint_prefix == "s3" do %><%= if context.is_global do %>
310334
build_host(EndpointPrefix, #{endpoint := Endpoint}, undefined) ->
@@ -333,7 +357,7 @@ build_host(_EndpointPrefix, #{region := <<"local">>}) ->
333357
build_host(EndpointPrefix, #{endpoint := Endpoint}) ->
334358
aws_util:binary_join([EndpointPrefix, Endpoint], <<".">>).<% else %>
335359
build_host(EndpointPrefix, #{region := Region, endpoint := Endpoint}) ->
336-
aws_util:binary_join([EndpointPrefix, Region, Endpoint], <<".">>).<% end %><% end %><% end %><% end %>
360+
aws_util:binary_join([EndpointPrefix, Region, Endpoint], <<".">>).<% end %><% end %><% end %><% end %><% end %><% end %>
337361
<%= if AWS.CodeGen.RestService.Context.s3_context?(context) do %>build_url(Host0, Path0, Client, Bucket) ->
338362
Proto = aws_client:proto(Client),
339363
%% Mocks are notoriously bad with host-style requests, just skip it and use path-style for anything local
@@ -353,7 +377,8 @@ build_host(EndpointPrefix, #{region := Region, endpoint := Endpoint}) ->
353377
Host1
354378
end,
355379
Port = aws_client:port(Client),
356-
aws_util:binary_join([Proto, <<"://">>, Host, <<":">>, Port, Path], <<"">>).<% else %>build_url(Host, Path0, Client) ->
380+
aws_util:binary_join([Proto, <<"://">>, Host, <<":">>, Port, Path], <<"">>).<% else %>
381+
build_url(Host, Path0, Client) ->
357382
Proto = aws_client:proto(Client),
358383
Path = erlang:iolist_to_binary(Path0),
359384
Port = aws_client:port(Client),
@@ -364,3 +389,7 @@ encode_payload(undefined) ->
364389
<<>>;
365390
encode_payload(Input) ->
366391
<%= context.encode %>.
392+
<%= if context.module_name == "aws_cloudfront_keyvaluestore" do %>
393+
get_account_id(Arn) ->
394+
[<<"arn">>, <<"aws">>, <<"cloudfront">>, <<>>, AccountId, _Rest] = binary:split(Arn, <<":">>, [global]),
395+
AccountId.<% end %>

priv/rest.ex.eex

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ defmodule <%= context.module_name %> do
7373
"""<% end %><%= if action.method == "GET" do %>
7474
@spec <%= action.function_name %>(map()<%= if context.module_name == "AWS.ApiGatewayManagementApi" do %>, String.t() | atom()<% end %><%= AWS.CodeGen.Types.function_parameter_types(action.method, action, false)%>, list()) :: <%= AWS.CodeGen.Types.return_type(context.language, action)%>
7575
def <%= action.function_name %>(%Client{} = client<%= if context.module_name == "AWS.ApiGatewayManagementApi" do %>, stage<% end %><%= AWS.CodeGen.RestService.function_parameters(action) %>, options \\ []) do
76-
url_path = "<%= if context.module_name == "AWS.ApiGatewayManagementApi" do %>/#{stage}<% end %><%= AWS.CodeGen.RestService.Action.url_path(action) %>"
76+
url_path = "<%= if context.module_name == "AWS.ApiGatewayManagementApi" do %>/#{stage}<% end %><%= AWS.CodeGen.RestService.Action.url_path(context, action) %>"
7777
headers = []<%= for parameter <- action.request_header_parameters do %>
7878
headers = if !is_nil(<%= parameter.code_name %>) do
7979
[{"<%= parameter.location_name %>", <%= parameter.code_name %>} | headers]
@@ -117,7 +117,7 @@ defmodule <%= context.module_name %> do
117117
Request.request_rest(client, meta, :get, url_path, query_params, headers, nil, options, <%= inspect(action.success_status_code) %>)<% else %>
118118
@spec <%= action.function_name %>(map()<%= AWS.CodeGen.Types.function_parameter_types(action.method, action, false)%>, <%= if context.module_name == "AWS.ApiGatewayManagementApi" do %> String.t() | atom(), <% end %><%= AWS.CodeGen.Types.function_argument_type(context.language, action)%>, list()) :: <%= AWS.CodeGen.Types.return_type(context.language, action)%>
119119
def <%= action.function_name %>(%Client{} = client<%= AWS.CodeGen.RestService.function_parameters(action) %>, <%= if context.module_name == "AWS.ApiGatewayManagementApi" do %> stage, <% end %>input, options \\ []) do
120-
url_path = "<%= if context.module_name == "AWS.ApiGatewayManagementApi" do %>/#{stage}<% end %><%= AWS.CodeGen.RestService.Action.url_path(action) %>"<%= if length(action.request_header_parameters) > 0 do %>
120+
url_path = "<%= if context.module_name == "AWS.ApiGatewayManagementApi" do %>/#{stage}<% end %><%= AWS.CodeGen.RestService.Action.url_path(context, action) %>"<%= if length(action.request_header_parameters) > 0 do %>
121121
{headers, input} =
122122
[<%= for parameter <- action.request_header_parameters do %>
123123
{"<%= parameter.name %>", "<%= parameter.location_name %>"},<% end %>

0 commit comments

Comments
 (0)