|
25 | 25 | import com.squareup.javapoet.TypeName; |
26 | 26 | import com.squareup.javapoet.TypeVariableName; |
27 | 27 | import com.squareup.javapoet.WildcardTypeName; |
| 28 | +import java.util.ArrayList; |
| 29 | +import java.util.List; |
28 | 30 | import java.util.Optional; |
29 | 31 | import java.util.concurrent.CompletableFuture; |
| 32 | +import java.util.function.Consumer; |
30 | 33 | import java.util.function.Function; |
31 | 34 | import javax.lang.model.element.Modifier; |
| 35 | +import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; |
32 | 36 | import software.amazon.awssdk.awscore.eventstream.EventStreamAsyncResponseTransformer; |
33 | 37 | import software.amazon.awssdk.awscore.eventstream.EventStreamTaggedUnionPojoSupplier; |
34 | 38 | import software.amazon.awssdk.awscore.eventstream.RestEventStreamAsyncResponseTransformer; |
|
48 | 52 | import software.amazon.awssdk.codegen.poet.client.traits.RequestCompressionTrait; |
49 | 53 | import software.amazon.awssdk.codegen.poet.eventstream.EventStreamUtils; |
50 | 54 | import software.amazon.awssdk.codegen.poet.model.EventStreamSpecHelper; |
| 55 | +import software.amazon.awssdk.core.ApiName; |
51 | 56 | import software.amazon.awssdk.core.SdkPojoBuilder; |
52 | 57 | import software.amazon.awssdk.core.SdkResponse; |
53 | 58 | import software.amazon.awssdk.core.async.AsyncRequestBody; |
|
56 | 61 | import software.amazon.awssdk.core.client.handler.ClientExecutionParams; |
57 | 62 | import software.amazon.awssdk.core.http.HttpResponseHandler; |
58 | 63 | import software.amazon.awssdk.core.protocol.VoidSdkResponse; |
| 64 | +import software.amazon.awssdk.core.useragent.BusinessMetricFeatureId; |
59 | 65 | import software.amazon.awssdk.protocols.cbor.AwsCborProtocolFactory; |
60 | 66 | import software.amazon.awssdk.protocols.core.ExceptionMetadata; |
61 | 67 | import software.amazon.awssdk.protocols.json.AwsJsonProtocol; |
@@ -224,7 +230,10 @@ public CodeBlock executionHandler(OperationModel opModel) { |
224 | 230 | .add(discoveredEndpoint(opModel)) |
225 | 231 | .add(credentialType(opModel, model)) |
226 | 232 | .add(".withRequestConfiguration(clientConfiguration)") |
227 | | - .add(".withInput($L)\n", opModel.getInput().getVariableName()) |
| 233 | + .add(".withInput($L)\n", |
| 234 | + model.getMetadata().isRpcV2CborProtocol() ? |
| 235 | + "applyRpcV2CborUserAgent(" + opModel.getInput().getVariableName() + ")" : |
| 236 | + opModel.getInput().getVariableName()) |
228 | 237 | .add(".withMetricCollector(apiCallMetricCollector)") |
229 | 238 | .add(HttpChecksumRequiredTrait.putHttpChecksumAttribute(opModel)) |
230 | 239 | .add(HttpChecksumTrait.create(opModel)); |
@@ -320,7 +329,10 @@ public CodeBlock asyncExecutionHandler(IntermediateModel intermediateModel, Oper |
320 | 329 |
|
321 | 330 | builder.add(RequestCompressionTrait.create(opModel, model)) |
322 | 331 | .add(".withInput($L)$L)", |
323 | | - opModel.getInput().getVariableName(), asyncResponseTransformerVariable(isStreaming, isRestJson, opModel)) |
| 332 | + intermediateModel.getMetadata().isRpcV2CborProtocol() ? |
| 333 | + "applyRpcV2CborUserAgent(" + opModel.getInput().getVariableName() + ")" : |
| 334 | + opModel.getInput().getVariableName(), |
| 335 | + asyncResponseTransformerVariable(isStreaming, isRestJson, opModel)) |
324 | 336 | .add(opModel.getEndpointDiscovery() != null ? ");" : ";"); |
325 | 337 |
|
326 | 338 | if (opModel.hasStreamingOutput()) { |
@@ -568,4 +580,49 @@ private String protocolFactoryLiteral(IntermediateModel model, OperationModel op |
568 | 580 | private boolean isRestJson(IntermediateModel model) { |
569 | 581 | return model.getMetadata().getProtocol() == Protocol.REST_JSON; |
570 | 582 | } |
| 583 | + |
| 584 | + @Override |
| 585 | + public List<MethodSpec> additionalMethods() { |
| 586 | + List<MethodSpec> methods = new ArrayList<>(); |
| 587 | + |
| 588 | + applyRpcV2CborUserAgentMethod().ifPresent(methods::add); |
| 589 | + |
| 590 | + return methods; |
| 591 | + } |
| 592 | + |
| 593 | + private Optional<MethodSpec> applyRpcV2CborUserAgentMethod() { |
| 594 | + if (!model.getMetadata().isRpcV2CborProtocol()) { |
| 595 | + return Optional.empty(); |
| 596 | + } |
| 597 | + |
| 598 | + TypeVariableName typeVariableName = |
| 599 | + TypeVariableName.get("T", poetExtensions.getModelClass(model.getSdkRequestBaseClassName())); |
| 600 | + |
| 601 | + ParameterizedTypeName parameterizedTypeName = ParameterizedTypeName |
| 602 | + .get(ClassName.get(Consumer.class), ClassName.get(AwsRequestOverrideConfiguration.Builder.class)); |
| 603 | + |
| 604 | + CodeBlock codeBlock = CodeBlock.builder() |
| 605 | + .addStatement("$T userAgentApplier = b -> " |
| 606 | + + "b.addApiName($T.builder().name($S).version($S).build())", |
| 607 | + parameterizedTypeName, ApiName.class, |
| 608 | + "sdk-metrics", |
| 609 | + BusinessMetricFeatureId.PROTOCOL_RPC_V2_CBOR.value()) |
| 610 | + .addStatement("$T overrideConfiguration =\n" |
| 611 | + + " request.overrideConfiguration().map(c -> c.toBuilder()" |
| 612 | + + ".applyMutation(userAgentApplier).build())\n" |
| 613 | + + " .orElse((AwsRequestOverrideConfiguration.builder()" |
| 614 | + + ".applyMutation(userAgentApplier).build()))", |
| 615 | + AwsRequestOverrideConfiguration.class) |
| 616 | + .addStatement("return (T) request.toBuilder().overrideConfiguration(overrideConfiguration)" |
| 617 | + + ".build()") |
| 618 | + .build(); |
| 619 | + |
| 620 | + return Optional.of(MethodSpec.methodBuilder("applyRpcV2CborUserAgent") |
| 621 | + .addModifiers(Modifier.PRIVATE, Modifier.STATIC) |
| 622 | + .addParameter(typeVariableName, "request") |
| 623 | + .addTypeVariable(typeVariableName) |
| 624 | + .addCode(codeBlock) |
| 625 | + .returns(typeVariableName) |
| 626 | + .build()); |
| 627 | + } |
571 | 628 | } |
0 commit comments