Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/next-release/feature-AWSSDKforJavav2-439f346.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"type": "feature",
"category": "AWS SDK for Java v2",
"contributor": "",
"description": "Optimized JSON marshalling performance for JSON RPC, REST JSON and RPCv2 Cbor protocols."
}
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,21 @@
whose NULL marshallers handle null validation. -->
<Match>
<Class name="software.amazon.awssdk.protocols.json.internal.marshall.JsonProtocolMarshaller"/>
<Method name="doMarshall"/>
<Or>
<Method name="doMarshall"/>
<Method name="marshallFieldViaRegistry"/>
</Or>
<Bug pattern="NP_LOAD_OF_KNOWN_NULL_VALUE"/>
</Match>

<!-- Intentional benign-race get-then-put on ConcurrentHashMap. SdkField instances are
static final, and the registry always returns the same marshaller for a given
(location, marshallingType) pair, so concurrent puts are idempotent. Using get()
instead of computeIfAbsent() avoids the latter's bucket-level synchronization
overhead on every call. -->
<Match>
<Class name="software.amazon.awssdk.protocols.json.internal.marshall.JsonProtocolMarshaller"/>
<Method name="marshallFieldViaRegistry"/>
<Bug pattern="AT_OPERATION_SEQUENCE_ON_CONCURRENT_ABSTRACTION"/>
</Match>
</FindBugsFilter>
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,22 @@
import static software.amazon.awssdk.http.Header.TRANSFER_ENCODING;

import java.io.ByteArrayInputStream;
import java.math.BigDecimal;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.Collections;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.core.SdkField;
import software.amazon.awssdk.core.SdkPojo;
import software.amazon.awssdk.core.document.Document;
import software.amazon.awssdk.core.protocol.MarshallLocation;
import software.amazon.awssdk.core.protocol.MarshallingKnownType;
import software.amazon.awssdk.core.protocol.MarshallingType;
import software.amazon.awssdk.core.traits.PayloadTrait;
import software.amazon.awssdk.core.traits.RequiredTrait;
Expand Down Expand Up @@ -61,6 +66,14 @@ public class JsonProtocolMarshaller implements ProtocolMarshaller<SdkHttpFullReq

private static final JsonMarshallerRegistry MARSHALLER_REGISTRY = createMarshallerRegistry();

// Caches the resolved marshaller for non-PAYLOAD fields, keyed by SdkField identity.
// SdkField instances are static final per generated model class, so identity-based lookup is correct.
// The cache is effectively bounded by the total number of non-payload SdkField instances across all
// loaded service models — each SdkField is inserted at most once, and no eviction is needed.
// ConcurrentHashMap is used for thread safety; the one-time put per SdkField is negligible.
private static final ConcurrentHashMap<SdkField<?>, JsonMarshaller<Object>> MARSHALLER_CACHE =
new ConcurrentHashMap<>();

private final URI endpoint;
private final StructuredJsonGenerator jsonGenerator;
private final SdkHttpFullRequest.Builder request;
Expand Down Expand Up @@ -214,17 +227,21 @@ void doMarshall(SdkPojo pojo) {
} else if (isExplicitPayloadMember(field)) {
marshallExplicitJsonPayload(field, val);
} else if (val != null) {
marshallField(field, val);
if (field.location() == MarshallLocation.PAYLOAD) {
// HOT PATH: switch-based dispatch, no registry, no interface dispatch
marshallPayloadField(field, val);
} else {
// WARM PATH: cached registry lookup + interface dispatch
marshallFieldViaRegistry(field, val);
}
} else if (field.location() != MarshallLocation.PAYLOAD) {
// Null payload fields that aren't required are no-op in the marshaller registry.
// We short circuit to avoid the registry lookup and dispatch overhead.
// Non payload locations (path, header, query) have null marshallers with
// different behavior, so they must still go through marshallField.
marshallField(field, val);
// Null non-payload: must go through registry (null marshallers vary by location)
marshallFieldViaRegistry(field, val);
} else if (field.containsTrait(RequiredTrait.class, TraitType.REQUIRED_TRAIT)) {
throw new IllegalArgumentException(
String.format("Parameter '%s' must not be null", field.locationName()));
}
// else: null payload field, not required → no-op
}
}

Expand Down Expand Up @@ -312,6 +329,111 @@ private SdkHttpFullRequest finishMarshalling() {
return request.build();
}

/**
* Marshalls a PAYLOAD-location field using a switch on {@link MarshallingKnownType} instead of
* registry lookup and interface dispatch. Each case is a monomorphic call site that the JIT can inline.
*/
@SuppressWarnings("unchecked")
private void marshallPayloadField(SdkField<?> field, Object val) {
MarshallingKnownType knownType = field.marshallingType().getKnownType();
if (knownType == null) {
marshallFieldViaRegistry(field, val);
return;
}

StructuredJsonGenerator gen = marshallerContext.jsonGenerator();
String fieldName = field.locationName();

switch (knownType) {
case STRING:
gen.writeFieldName(fieldName);
gen.writeValue((String) val);
break;
case INTEGER:
gen.writeFieldName(fieldName);
gen.writeValue((int) (Integer) val);
break;
case LONG:
gen.writeFieldName(fieldName);
gen.writeValue((long) (Long) val);
break;
case SHORT:
gen.writeFieldName(fieldName);
gen.writeValue((short) (Short) val);
break;
case BYTE:
gen.writeFieldName(fieldName);
gen.writeValue((byte) (Byte) val);
break;
case FLOAT:
gen.writeFieldName(fieldName);
gen.writeValue((float) (Float) val);
break;
case DOUBLE:
gen.writeFieldName(fieldName);
gen.writeValue((double) (Double) val);
break;
case BIG_DECIMAL:
gen.writeFieldName(fieldName);
gen.writeValue((BigDecimal) val);
break;
case BOOLEAN:
gen.writeFieldName(fieldName);
gen.writeValue((boolean) (Boolean) val);
break;
case INSTANT:
// Delegate to existing INSTANT marshaller to preserve TimestampFormatTrait handling.
// Note: INSTANT marshaller writes the field name itself.
SimpleTypeJsonMarshaller.INSTANT.marshall((Instant) val, marshallerContext,
fieldName, (SdkField<Instant>) field);
break;
case SDK_BYTES:
gen.writeFieldName(fieldName);
gen.writeValue(((SdkBytes) val).asByteBuffer());
break;
case SDK_POJO:
SimpleTypeJsonMarshaller.SDK_POJO.marshall((SdkPojo) val, marshallerContext,
fieldName, (SdkField<SdkPojo>) field);
break;
case LIST:
SimpleTypeJsonMarshaller.LIST.marshall((List<?>) val, marshallerContext,
fieldName, (SdkField<List<?>>) field);
break;
case MAP:
SimpleTypeJsonMarshaller.MAP.marshall((Map<String, ?>) val, marshallerContext,
fieldName, (SdkField<Map<String, ?>>) field);
break;
case DOCUMENT:
SimpleTypeJsonMarshaller.DOCUMENT.marshall((Document) val, marshallerContext,
fieldName, (SdkField<Document>) field);
break;
default:
// Unknown type — fall back to registry lookup
marshallFieldViaRegistry(field, val);
break;
}
}

@SuppressWarnings("unchecked")
private void marshallFieldViaRegistry(SdkField<?> field, Object val) {
if (val == null) {
MARSHALLER_REGISTRY.getMarshaller(field.location(), field.marshallingType(), val)
.marshall(val, marshallerContext, field.locationName(), (SdkField<Object>) field);
return;
}
// Use get-before-put instead of computeIfAbsent. ConcurrentHashMap.get() is a single lock-free
// volatile read, whereas computeIfAbsent() has additional overhead even on cache hits (bucket-level
// synchronization bookkeeping). The benign-race on first access is safe: SdkField instances are
// static final, and the registry always returns the same marshaller for a given (location, type) pair,
// so concurrent puts are idempotent.
JsonMarshaller<Object> marshaller = MARSHALLER_CACHE.get(field);
if (marshaller == null) {
marshaller = MARSHALLER_REGISTRY.getMarshaller(field.location(), field.marshallingType(), val);
MARSHALLER_CACHE.put(field, marshaller);
}
marshaller.marshall(val, marshallerContext, field.locationName(), (SdkField<Object>) field);
}

private void marshallField(SdkField<?> field, Object val) {
MARSHALLER_REGISTRY.getMarshaller(field.location(), field.marshallingType(), val)
.marshall(val, marshallerContext, field.locationName(), (SdkField<Object>) field);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.awssdk.protocols.json.internal.marshall;

import static org.assertj.core.api.Assertions.assertThat;

import java.net.URI;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.core.SdkField;
import software.amazon.awssdk.core.SdkPojo;
import software.amazon.awssdk.core.protocol.MarshallLocation;
import software.amazon.awssdk.core.protocol.MarshallingType;
import software.amazon.awssdk.core.traits.LocationTrait;
import software.amazon.awssdk.http.SdkHttpFullRequest;
import software.amazon.awssdk.http.SdkHttpMethod;
import software.amazon.awssdk.protocols.core.OperationInfo;
import software.amazon.awssdk.protocols.core.ProtocolMarshaller;
import software.amazon.awssdk.protocols.json.AwsJsonProtocol;
import software.amazon.awssdk.protocols.json.AwsJsonProtocolMetadata;
import software.amazon.awssdk.protocols.json.internal.AwsStructuredPlainJsonFactory;

/**
* Tests that the cached non-payload marshalling path in
* {@link JsonProtocolMarshaller#marshallFieldViaRegistry} produces correct output
* and that the cache is populated after the first call.
*/
class CachedNonPayloadMarshallingTest {

private static final URI ENDPOINT = URI.create("http://localhost");
private static final String CONTENT_TYPE = "application/x-amz-json-1.0";
private static final OperationInfo OP_INFO = OperationInfo.builder()
.httpMethod(SdkHttpMethod.POST)
.hasImplicitPayloadMembers(true)
.build();
private static final AwsJsonProtocolMetadata METADATA =
AwsJsonProtocolMetadata.builder()
.protocol(AwsJsonProtocol.AWS_JSON)
.contentType(CONTENT_TYPE)
.build();

// ---- HEADER tests ----

@Test
void header_string_producesCorrectHeader() {
SdkField<String> field = headerField("x-custom-header", obj -> "headerValue");
SdkPojo pojo = new SimplePojo(field);

SdkHttpFullRequest result = createMarshaller().marshall(pojo);

assertThat(result.firstMatchingHeader("x-custom-header"))
.isPresent()
.hasValue("headerValue");
}

@Test
void header_string_secondCall_usesCachedMarshaller() {
// Use the SAME SdkField instance for both calls so the cache is shared
SdkField<String> field = headerField("x-custom-header", obj -> "headerValue");

// First call — populates the internal marshaller cache
SdkPojo pojo1 = new SimplePojo(field);
SdkHttpFullRequest result1 = createMarshaller().marshall(pojo1);

// Second call — should use cached marshaller
SdkPojo pojo2 = new SimplePojo(field);
SdkHttpFullRequest result2 = createMarshaller().marshall(pojo2);

// Both calls produce identical header output, confirming the cached path works
assertThat(result1.firstMatchingHeader("x-custom-header"))
.isPresent()
.hasValue("headerValue");
assertThat(result2.firstMatchingHeader("x-custom-header"))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This technically doesnt prove that the cache is used 🙃

I'd suggest either changing the test name, or making MARSHALLER_CACHE package private so we can assert something like cache.size() > 0

.isPresent()
.hasValue("headerValue");
}

// ---- QUERY_PARAM tests ----

@Test
void queryParam_string_producesCorrectQueryParam() {
SdkField<String> field = queryParamField("myParam", obj -> "paramValue");
SdkPojo pojo = new SimplePojo(field);

SdkHttpFullRequest result = createMarshaller().marshall(pojo);

assertThat(result.rawQueryParameters().get("myParam"))
.isNotNull()
.containsExactly("paramValue");
}

private static SdkField<String> headerField(String headerName,
java.util.function.Function<Object, String> getter) {
return SdkField.<String>builder(MarshallingType.STRING)
.memberName(headerName)
.getter(getter)
.setter((obj, val) -> { })
.traits(LocationTrait.builder()
.location(MarshallLocation.HEADER)
.locationName(headerName)
.build())
.build();
}

private static SdkField<String> queryParamField(String paramName,
java.util.function.Function<Object, String> getter) {
return SdkField.<String>builder(MarshallingType.STRING)
.memberName(paramName)
.getter(getter)
.setter((obj, val) -> { })
.traits(LocationTrait.builder()
.location(MarshallLocation.QUERY_PARAM)
.locationName(paramName)
.build())
.build();
}

private static ProtocolMarshaller<SdkHttpFullRequest> createMarshaller() {
return JsonProtocolMarshallerBuilder.create()
.endpoint(ENDPOINT)
.jsonGenerator(AwsStructuredPlainJsonFactory
.SDK_JSON_FACTORY.createWriter(CONTENT_TYPE))
.contentType(CONTENT_TYPE)
.operationInfo(OP_INFO)
.sendExplicitNullForPayload(false)
.protocolMetadata(METADATA)
.build();
}

private static final class SimplePojo implements SdkPojo {
private final List<SdkField<?>> fields;

SimplePojo(SdkField<?>... fields) {
this.fields = Arrays.asList(fields);
}

@Override
public List<SdkField<?>> sdkFields() {
return fields;
}

@Override
public boolean equalsBySdkFields(Object other) {
return other instanceof SimplePojo;
}

@Override
public Map<String, SdkField<?>> sdkFieldNameToField() {
return Collections.emptyMap();
}
}
}
Loading
Loading