Skip to content

Commit 2453fb7

Browse files
committed
Implementing ProtocolRpcV2Cbor FeatureID
1 parent 5988db3 commit 2453fb7

3 files changed

Lines changed: 168 additions & 3 deletions

File tree

codegen/src/main/java/software/amazon/awssdk/codegen/poet/client/specs/JsonProtocolSpec.java

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,14 @@
2525
import com.squareup.javapoet.TypeName;
2626
import com.squareup.javapoet.TypeVariableName;
2727
import com.squareup.javapoet.WildcardTypeName;
28+
import java.util.ArrayList;
29+
import java.util.List;
2830
import java.util.Optional;
2931
import java.util.concurrent.CompletableFuture;
32+
import java.util.function.Consumer;
3033
import java.util.function.Function;
3134
import javax.lang.model.element.Modifier;
35+
import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration;
3236
import software.amazon.awssdk.awscore.eventstream.EventStreamAsyncResponseTransformer;
3337
import software.amazon.awssdk.awscore.eventstream.EventStreamTaggedUnionPojoSupplier;
3438
import software.amazon.awssdk.awscore.eventstream.RestEventStreamAsyncResponseTransformer;
@@ -48,6 +52,7 @@
4852
import software.amazon.awssdk.codegen.poet.client.traits.RequestCompressionTrait;
4953
import software.amazon.awssdk.codegen.poet.eventstream.EventStreamUtils;
5054
import software.amazon.awssdk.codegen.poet.model.EventStreamSpecHelper;
55+
import software.amazon.awssdk.core.ApiName;
5156
import software.amazon.awssdk.core.SdkPojoBuilder;
5257
import software.amazon.awssdk.core.SdkResponse;
5358
import software.amazon.awssdk.core.async.AsyncRequestBody;
@@ -56,6 +61,7 @@
5661
import software.amazon.awssdk.core.client.handler.ClientExecutionParams;
5762
import software.amazon.awssdk.core.http.HttpResponseHandler;
5863
import software.amazon.awssdk.core.protocol.VoidSdkResponse;
64+
import software.amazon.awssdk.core.useragent.BusinessMetricFeatureId;
5965
import software.amazon.awssdk.protocols.cbor.AwsCborProtocolFactory;
6066
import software.amazon.awssdk.protocols.core.ExceptionMetadata;
6167
import software.amazon.awssdk.protocols.json.AwsJsonProtocol;
@@ -224,7 +230,10 @@ public CodeBlock executionHandler(OperationModel opModel) {
224230
.add(discoveredEndpoint(opModel))
225231
.add(credentialType(opModel, model))
226232
.add(".withRequestConfiguration(clientConfiguration)")
227-
.add(".withInput($L)\n", opModel.getInput().getVariableName())
233+
.add(".withInput($L)\n",
234+
model.getMetadata().isRpcV2CborProtocol() ?
235+
"applyRpcV2CborUserAgent(" + opModel.getInput().getVariableName() + ")" :
236+
opModel.getInput().getVariableName())
228237
.add(".withMetricCollector(apiCallMetricCollector)")
229238
.add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel))
230239
.add(HttpChecksumTrait.create(opModel));
@@ -320,7 +329,10 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper
320329

321330
builder.add(RequestCompressionTrait.create(opModel, model))
322331
.add(".withInput($L)$L)",
323-
opModel.getInput().getVariableName(), asyncResponseTransformerVariable(isStreaming, isRestJson, opModel))
332+
intermediateModel.getMetadata().isRpcV2CborProtocol() ?
333+
"applyRpcV2CborUserAgent(" + opModel.getInput().getVariableName() + ")" :
334+
opModel.getInput().getVariableName(),
335+
asyncResponseTransformerVariable(isStreaming, isRestJson, opModel))
324336
.add(opModel.getEndpointDiscovery() != null ? ");" : ";");
325337

326338
if (opModel.hasStreamingOutput()) {
@@ -568,4 +580,49 @@ private String protocolFactoryLiteral(IntermediateModel model, OperationModel op
568580
private boolean isRestJson(IntermediateModel model) {
569581
return model.getMetadata().getProtocol() == Protocol.REST_JSON;
570582
}
583+
584+
@Override
585+
public List<MethodSpec> additionalMethods() {
586+
List<MethodSpec> methods = new ArrayList<>();
587+
588+
applyRpcV2CborUserAgentMethod().ifPresent(methods::add);
589+
590+
return methods;
591+
}
592+
593+
private Optional<MethodSpec> applyRpcV2CborUserAgentMethod() {
594+
if (!model.getMetadata().isRpcV2CborProtocol()) {
595+
return Optional.empty();
596+
}
597+
598+
TypeVariableName typeVariableName =
599+
TypeVariableName.get("T", poetExtensions.getModelClass(model.getSdkRequestBaseClassName()));
600+
601+
ParameterizedTypeName parameterizedTypeName = ParameterizedTypeName
602+
.get(ClassName.get(Consumer.class), ClassName.get(AwsRequestOverrideConfiguration.Builder.class));
603+
604+
CodeBlock codeBlock = CodeBlock.builder()
605+
.addStatement("$T userAgentApplier = b -> "
606+
+ "b.addApiName($T.builder().name($S).version($S).build())",
607+
parameterizedTypeName, ApiName.class,
608+
"sdk-metrics",
609+
BusinessMetricFeatureId.PROTOCOL_RPC_V2_CBOR.value())
610+
.addStatement("$T overrideConfiguration =\n"
611+
+ " request.overrideConfiguration().map(c -> c.toBuilder()"
612+
+ ".applyMutation(userAgentApplier).build())\n"
613+
+ " .orElse((AwsRequestOverrideConfiguration.builder()"
614+
+ ".applyMutation(userAgentApplier).build()))",
615+
AwsRequestOverrideConfiguration.class)
616+
.addStatement("return (T) request.toBuilder().overrideConfiguration(overrideConfiguration)"
617+
+ ".build()")
618+
.build();
619+
620+
return Optional.of(MethodSpec.methodBuilder("applyRpcV2CborUserAgent")
621+
.addModifiers(Modifier.PRIVATE, Modifier.STATIC)
622+
.addParameter(typeVariableName, "request")
623+
.addTypeVariable(typeVariableName)
624+
.addCode(codeBlock)
625+
.returns(typeVariableName)
626+
.build());
627+
}
571628
}

core/sdk-core/src/main/java/software/amazon/awssdk/core/useragent/BusinessMetricFeatureId.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
/**
2323
* An enum class representing a short form of identity providers to record in the UA string.
2424
*
25-
* Unimplemented metrics: I,J,K,M,O,S,U-c
25+
* Unimplemented metrics: I,J,K,O,S,U-c
2626
* Unsupported metrics (these will never be added): A,H
2727
*/
2828
@SdkProtectedApi
@@ -35,6 +35,7 @@ public enum BusinessMetricFeatureId {
3535
RETRY_MODE_ADAPTIVE("F"),
3636
S3_TRANSFER("G"),
3737
GZIP_REQUEST_COMPRESSION("L"), //TODO(metrics): Not working, compression happens after header
38+
PROTOCOL_RPC_V2_CBOR("M"),
3839
ENDPOINT_OVERRIDE("N"),
3940
ACCOUNT_ID_MODE_PREFERRED("P"),
4041
ACCOUNT_ID_MODE_DISABLED("Q"),
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* A copy of the License is located at
7+
*
8+
* http://aws.amazon.com/apache2.0
9+
*
10+
* or in the "license" file accompanying this file. This file is distributed
11+
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12+
* express or implied. See the License for the specific language governing
13+
* permissions and limitations under the License.
14+
*/
15+
16+
package software.amazon.awssdk.protocol.tests;
17+
18+
import static org.assertj.core.api.Assertions.assertThat;
19+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
20+
import static software.amazon.awssdk.core.useragent.BusinessMetricCollection.METRIC_SEARCH_PATTERN;
21+
22+
import java.util.List;
23+
import java.util.Map;
24+
import org.junit.jupiter.api.BeforeEach;
25+
import org.junit.jupiter.api.Test;
26+
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
27+
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
28+
import software.amazon.awssdk.core.interceptor.Context;
29+
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
30+
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
31+
import software.amazon.awssdk.core.useragent.BusinessMetricFeatureId;
32+
import software.amazon.awssdk.regions.Region;
33+
import software.amazon.awssdk.services.protocolsmithyrpcv2.ProtocolSmithyrpcv2AsyncClient;
34+
import software.amazon.awssdk.services.protocolsmithyrpcv2.ProtocolSmithyrpcv2AsyncClientBuilder;
35+
import software.amazon.awssdk.services.protocolrestjson.ProtocolRestJsonAsyncClient;
36+
import software.amazon.awssdk.services.protocolrestjson.ProtocolRestJsonAsyncClientBuilder;
37+
38+
class ProtocolRpcV2CborUserAgentTest {
39+
private CapturingInterceptor interceptor;
40+
41+
private static final String USER_AGENT_HEADER_NAME = "User-Agent";
42+
private static final StaticCredentialsProvider CREDENTIALS_PROVIDER =
43+
StaticCredentialsProvider.create(AwsBasicCredentials.create("akid", "skid"));
44+
45+
@BeforeEach
46+
public void setup() {
47+
this.interceptor = new CapturingInterceptor();
48+
}
49+
50+
@Test
51+
void when_rpcV2CborProtocolIsUsed_correctMetricIsAdded() {
52+
ProtocolSmithyrpcv2AsyncClientBuilder clientBuilder = asyncClientBuilderForRpcV2Cbor();
53+
54+
assertThatThrownBy(() -> clientBuilder.build().operationWithNoInputOrOutput(r -> {}).join())
55+
.hasMessageContaining("stop");
56+
57+
String userAgent = assertAndGetUserAgentString();
58+
assertThat(userAgent).matches(METRIC_SEARCH_PATTERN.apply(BusinessMetricFeatureId.PROTOCOL_RPC_V2_CBOR.value()));
59+
}
60+
61+
@Test
62+
void when_nonRpcV2CborProtocolIsUsed_rpcV2CborMetricIsNotAdded() {
63+
ProtocolRestJsonAsyncClientBuilder clientBuilder = asyncClientBuilderForRestJson();
64+
65+
assertThatThrownBy(() -> clientBuilder.build().allTypes(r -> {}).join())
66+
.hasMessageContaining("stop");
67+
68+
String userAgent = assertAndGetUserAgentString();
69+
assertThat(userAgent).doesNotMatch(METRIC_SEARCH_PATTERN.apply(BusinessMetricFeatureId.PROTOCOL_RPC_V2_CBOR.value()));
70+
}
71+
72+
private String assertAndGetUserAgentString() {
73+
Map<String, List<String>> headers = interceptor.context.httpRequest().headers();
74+
assertThat(headers).containsKey(USER_AGENT_HEADER_NAME);
75+
return headers.get(USER_AGENT_HEADER_NAME).get(0);
76+
}
77+
78+
private ProtocolSmithyrpcv2AsyncClientBuilder asyncClientBuilderForRpcV2Cbor() {
79+
return ProtocolSmithyrpcv2AsyncClient.builder()
80+
.region(Region.US_WEST_2)
81+
.credentialsProvider(CREDENTIALS_PROVIDER)
82+
.overrideConfiguration(c -> c.addExecutionInterceptor(interceptor));
83+
}
84+
85+
private ProtocolRestJsonAsyncClientBuilder asyncClientBuilderForRestJson() {
86+
return ProtocolRestJsonAsyncClient.builder()
87+
.region(Region.US_WEST_2)
88+
.credentialsProvider(CREDENTIALS_PROVIDER)
89+
.overrideConfiguration(c -> c.addExecutionInterceptor(interceptor));
90+
}
91+
92+
public static class CapturingInterceptor implements ExecutionInterceptor {
93+
private Context.BeforeTransmission context;
94+
private ExecutionAttributes executionAttributes;
95+
96+
@Override
97+
public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) {
98+
this.context = context;
99+
this.executionAttributes = executionAttributes;
100+
throw new RuntimeException("stop");
101+
}
102+
103+
public ExecutionAttributes executionAttributes() {
104+
return executionAttributes;
105+
}
106+
}
107+
}

0 commit comments

Comments
 (0)