Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package software.amazon.smithy.python.aws.codegen.customizations.apigateway;

import java.util.List;
import software.amazon.smithy.codegen.core.Symbol;
import software.amazon.smithy.codegen.core.SymbolReference;
import software.amazon.smithy.model.shapes.ShapeId;
import software.amazon.smithy.python.codegen.CodegenUtils;
import software.amazon.smithy.python.codegen.GenerationContext;
import software.amazon.smithy.python.codegen.integrations.PythonIntegration;
import software.amazon.smithy.python.codegen.integrations.RuntimeClientPlugin;
import software.amazon.smithy.utils.SmithyInternalApi;

/**
* Adds a runtime plugin that sets the {@code Accept: application/json} header on
* Amazon API Gateway requests.
*/
@SmithyInternalApi
public class ApiGatewayIntegration implements PythonIntegration {

private static final ShapeId API_GATEWAY_SERVICE_ID =
ShapeId.from("com.amazonaws.apigateway#BackplaneControlService");

public static final String ACCEPT_HEADER_PLUGIN = """
def accept_header_plugin(config: $1T):
config.interceptors.append($2T())
""";

@Override
public List<RuntimeClientPlugin> getClientPlugins(GenerationContext context) {
if (!context.applicationProtocol().isHttpProtocol()) {
return List.of();
}

final String pluginFile = "accept_header";
final String moduleName = context.settings().moduleName();

final SymbolReference acceptHeaderPlugin = SymbolReference.builder()
.symbol(Symbol.builder()
.namespace(String.format("%s.%s", moduleName, pluginFile), ".")
.definitionFile(String.format("./src/%s/%s.py", moduleName, pluginFile))
.name("accept_header_plugin")
.build())
.build();
final SymbolReference acceptHeaderInterceptor = SymbolReference.builder()
.symbol(Symbol.builder()
.namespace("smithy_aws_core.interceptors.api_gateway", ".")
.name("ApiGatewayAcceptHeaderInterceptor")
.build())
.build();

return List.of(
RuntimeClientPlugin.builder()
.servicePredicate((model, service) -> service.getId().equals(API_GATEWAY_SERVICE_ID))
.pythonPlugin(acceptHeaderPlugin)
.writeAdditionalFiles((c) -> {
String filename = "src/%s/%s.py".formatted(moduleName, pluginFile);
c.writerDelegator()
.useFileWriter(
filename,
moduleName + ".",
writer -> writer.write(ACCEPT_HEADER_PLUGIN,
CodegenUtils.getConfigSymbol(c.settings()),
acceptHeaderInterceptor));
return List.of(filename);
})
.build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
#

software.amazon.smithy.python.aws.codegen.customizations.apigateway.ApiGatewayIntegration
software.amazon.smithy.python.aws.codegen.AwsAuthIntegration
software.amazon.smithy.python.aws.codegen.AwsProtocolsIntegration
software.amazon.smithy.python.aws.codegen.AwsServiceIdIntegration
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "feature",
"description": "Add an Amazon API Gateway customization that sets the `Accept: application/json` request header."
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Any

from smithy_core.interceptors import Interceptor, RequestContext
from smithy_http import Field
from smithy_http.aio.interfaces import HTTPRequest


class ApiGatewayAcceptHeaderInterceptor(Interceptor[Any, Any, HTTPRequest, None]):
"""Sets the Accept header to application/json on API Gateway requests."""

def modify_before_signing(
self, context: RequestContext[Any, HTTPRequest]
) -> HTTPRequest:
request = context.transport_request
request.fields.set_field(Field(name="Accept", values=["application/json"]))
return request
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

from unittest.mock import Mock

from smithy_aws_core.interceptors.api_gateway import ApiGatewayAcceptHeaderInterceptor
from smithy_core import URI
from smithy_core.interceptors import RequestContext
from smithy_core.types import TypedProperties
from smithy_http import Field, Fields
from smithy_http.aio import HTTPRequest


def _request(fields: Fields) -> HTTPRequest:
destination = URI(host="apigateway.us-east-1.amazonaws.com", path="/restapis")
return HTTPRequest(destination=destination, method="GET", fields=fields)


def test_sets_accept_header() -> None:
interceptor = ApiGatewayAcceptHeaderInterceptor()
request = _request(Fields())
context = RequestContext(
request=Mock(), properties=TypedProperties(), transport_request=request
)

result = interceptor.modify_before_signing(context)

assert result.fields["Accept"].values == ["application/json"]


def test_overwrites_existing_accept_header() -> None:
interceptor = ApiGatewayAcceptHeaderInterceptor()
fields = Fields([Field(name="Accept", values=["application/hal+json"])])
request = _request(fields)
context = RequestContext(
request=Mock(), properties=TypedProperties(), transport_request=request
)

result = interceptor.modify_before_signing(context)

assert result.fields["Accept"].values == ["application/json"]
Loading