diff --git a/aws/aws-sigv4/build.gradle.kts b/aws/aws-sigv4/build.gradle.kts index bc0cf141b2..e85a5034d9 100644 --- a/aws/aws-sigv4/build.gradle.kts +++ b/aws/aws-sigv4/build.gradle.kts @@ -2,7 +2,7 @@ import com.google.gradle.osdetector.OsDetector plugins { id("smithy-java.module-conventions") - alias(libs.plugins.jmh) + id("smithy-java.jmh-conventions") alias(libs.plugins.osdetector) } @@ -31,9 +31,6 @@ afterEvaluate { } jmh { - iterations = 3 warmupIterations = 2 - fork = 1 - // profilers.add("async:output=flamegraph") - // profilers.add("gc") + iterations = 3 } diff --git a/aws/client/aws-client-awsjson/build.gradle.kts b/aws/client/aws-client-awsjson/build.gradle.kts index 3024526ca1..b74b7a6fe9 100644 --- a/aws/client/aws-client-awsjson/build.gradle.kts +++ b/aws/client/aws-client-awsjson/build.gradle.kts @@ -18,6 +18,11 @@ dependencies { testImplementation(libs.smithy.aws.protocol.tests) } +protocolTestRuns { + run("native") { systemProperty("smithy-java.json-provider", "smithy") } + run("jackson") { systemProperty("smithy-java.json-provider", "jackson") } +} + val generator = "software.amazon.smithy.java.protocoltests.generators.ProtocolTestGenerator" addGenerateSrcsTask(generator, "awsJson1_0", "aws.protocoltests.json10#JsonRpc10") addGenerateSrcsTask(generator, "awsJson1_1", "aws.protocoltests.json#JsonProtocol") diff --git a/aws/client/aws-client-awsquery/build.gradle.kts b/aws/client/aws-client-awsquery/build.gradle.kts index e21c5af5f0..1ff2a0f70f 100644 --- a/aws/client/aws-client-awsquery/build.gradle.kts +++ b/aws/client/aws-client-awsquery/build.gradle.kts @@ -11,6 +11,7 @@ extra["moduleName"] = "software.amazon.smithy.java.aws.client.awsquery" dependencies { api(project(":client:client-http")) api(project(":codecs:xml-codec")) + api(project(":codecs:codec-commons", configuration = "shadow")) api(project(":io")) api(libs.smithy.aws.traits) @@ -18,6 +19,11 @@ dependencies { testImplementation(libs.smithy.aws.protocol.tests) } +protocolTestRuns { + run("native") { systemProperty("smithy-java.xml-provider", "smithy") } + run("stax") { } +} + val generator = "software.amazon.smithy.java.protocoltests.generators.ProtocolTestGenerator" addGenerateSrcsTask(generator, "awsQuery", "aws.protocoltests.query#AwsQuery") addGenerateSrcsTask(generator, "ec2Query", "aws.protocoltests.ec2#AwsEc2") diff --git a/aws/client/aws-client-awsquery/src/main/java/software/amazon/smithy/java/aws/client/awsquery/AwsQueryClientProtocol.java b/aws/client/aws-client-awsquery/src/main/java/software/amazon/smithy/java/aws/client/awsquery/AwsQueryClientProtocol.java index cc3102042f..f776304126 100644 --- a/aws/client/aws-client-awsquery/src/main/java/software/amazon/smithy/java/aws/client/awsquery/AwsQueryClientProtocol.java +++ b/aws/client/aws-client-awsquery/src/main/java/software/amazon/smithy/java/aws/client/awsquery/AwsQueryClientProtocol.java @@ -69,7 +69,7 @@ public HttpRequest SmithyUri endpoint ) { String operationName = operation.schema().id().getName(); - QueryFormSerializer serializer = new QueryFormSerializer( + QueryFormSerializer serializer = QueryFormSerializer.acquire( QueryFormSerializer.QueryVariant.AWS_QUERY, operationName, version); diff --git a/aws/client/aws-client-awsquery/src/main/java/software/amazon/smithy/java/aws/client/awsquery/AwsQuerySchemaExtensions.java b/aws/client/aws-client-awsquery/src/main/java/software/amazon/smithy/java/aws/client/awsquery/AwsQuerySchemaExtensions.java index a12f971141..0efde5c9de 100644 --- a/aws/client/aws-client-awsquery/src/main/java/software/amazon/smithy/java/aws/client/awsquery/AwsQuerySchemaExtensions.java +++ b/aws/client/aws-client-awsquery/src/main/java/software/amazon/smithy/java/aws/client/awsquery/AwsQuerySchemaExtensions.java @@ -5,17 +5,19 @@ package software.amazon.smithy.java.aws.client.awsquery; -import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import software.amazon.smithy.aws.traits.protocols.Ec2QueryNameTrait; import software.amazon.smithy.java.core.schema.Schema; import software.amazon.smithy.java.core.schema.SchemaExtensionKey; import software.amazon.smithy.java.core.schema.SchemaExtensionProvider; import software.amazon.smithy.java.core.schema.TraitKey; +import software.amazon.smithy.model.shapes.ShapeType; +import software.amazon.smithy.model.traits.TimestampFormatTrait; import software.amazon.smithy.utils.SmithyInternalApi; import software.amazon.smithy.utils.StringUtils; /** - * Pre-computes the URL-encoded member-name bytes for AWS Query and EC2 Query protocols once per {@link Schema}. + * Pre-computes URL-encoded member-name bytes and list/map metadata for AWS Query and EC2 Query protocols. */ @SmithyInternalApi public final class AwsQuerySchemaExtensions @@ -23,13 +25,34 @@ public final class AwsQuerySchemaExtensions public static final SchemaExtensionKey KEY = new SchemaExtensionKey<>(); + private static final byte[] MEMBER_BYTES = "member".getBytes(StandardCharsets.UTF_8); + private static final byte[] KEY_BYTES = "key".getBytes(StandardCharsets.UTF_8); + private static final byte[] VALUE_BYTES = "value".getBytes(StandardCharsets.UTF_8); + private static final byte[] ENTRY_BYTES = "entry".getBytes(StandardCharsets.UTF_8); + /** - * Pre-encoded member-name bytes for both query variants. + * Pre-computed query binding data for a schema. * - * @param awsQueryNameBytes Bytes to use as the awsQuery name. Never null. - * @param ec2QueryNameBytes Bytes to use as the ec2Query name. Never null. + * @param awsQueryNameBytes Bytes for awsQuery member name. Null for non-members. + * @param ec2QueryNameBytes Bytes for ec2Query member name. Null for non-members. + * @param listFlattened Whether this list member has xmlFlattened trait. + * @param listMemberNameBytes Pre-computed list member name bytes (for non-flattened lists). Null if flattened. + * @param mapFlattened Whether this map member has xmlFlattened trait. + * @param mapKeyNameBytes Pre-computed map key name bytes. + * @param mapValueNameBytes Pre-computed map value name bytes. + * @param mapEntryNameBytes Pre-computed map entry name bytes (null if flattened). + * @param timestampFormat Pre-resolved timestamp format for timestamp members. */ - public record QueryMemberBinding(byte[] awsQueryNameBytes, byte[] ec2QueryNameBytes) {} + public record QueryMemberBinding( + byte[] awsQueryNameBytes, + byte[] ec2QueryNameBytes, + boolean listFlattened, + byte[] listMemberNameBytes, + boolean mapFlattened, + byte[] mapKeyNameBytes, + byte[] mapValueNameBytes, + byte[] mapEntryNameBytes, + TimestampFormatTrait.Format timestampFormat) {} @Override public SchemaExtensionKey key() { @@ -42,9 +65,73 @@ public QueryMemberBinding provide(Schema schema) { return null; } + byte[] awsName = encodeName(resolveAwsQueryName(schema)); + byte[] ec2Name = encodeName(resolveEc2QueryName(schema)); + + // Pre-compute list metadata if target is a list + boolean listFlattened = false; + byte[] listMemberNameBytes = null; + Schema target = schema.memberTarget(); + if (target != null && target.type() == ShapeType.LIST) { + listFlattened = schema.hasTrait(TraitKey.XML_FLATTENED_TRAIT); + if (!listFlattened) { + Schema listMember = target.listMember(); + if (listMember != null) { + var xmlName = listMember.getTrait(TraitKey.XML_NAME_TRAIT); + listMemberNameBytes = xmlName != null + ? xmlName.getValue().getBytes(StandardCharsets.UTF_8) + : MEMBER_BYTES; + } else { + listMemberNameBytes = MEMBER_BYTES; + } + } + } + + // Pre-compute map metadata if target is a map + boolean mapFlattened = false; + byte[] mapKeyNameBytes = null; + byte[] mapValueNameBytes = null; + byte[] mapEntryNameBytes = null; + if (target != null && target.type() == ShapeType.MAP) { + mapFlattened = schema.hasTrait(TraitKey.XML_FLATTENED_TRAIT); + Schema keySchema = target.mapKeyMember(); + Schema valueSchema = target.mapValueMember(); + if (keySchema != null) { + var keyXmlName = keySchema.getTrait(TraitKey.XML_NAME_TRAIT); + mapKeyNameBytes = keyXmlName != null + ? keyXmlName.getValue().getBytes(StandardCharsets.UTF_8) + : KEY_BYTES; + } else { + mapKeyNameBytes = KEY_BYTES; + } + if (valueSchema != null) { + var valueXmlName = valueSchema.getTrait(TraitKey.XML_NAME_TRAIT); + mapValueNameBytes = valueXmlName != null + ? valueXmlName.getValue().getBytes(StandardCharsets.UTF_8) + : VALUE_BYTES; + } else { + mapValueNameBytes = VALUE_BYTES; + } + mapEntryNameBytes = mapFlattened ? null : ENTRY_BYTES; + } + + // Pre-resolve timestamp format + TimestampFormatTrait.Format timestampFormat = null; + var tsFmt = schema.getTrait(TraitKey.TIMESTAMP_FORMAT_TRAIT); + if (tsFmt != null) { + timestampFormat = tsFmt.getFormat(); + } + return new QueryMemberBinding( - encodeName(resolveAwsQueryName(schema)), - encodeName(resolveEc2QueryName(schema))); + awsName, + ec2Name, + listFlattened, + listMemberNameBytes, + mapFlattened, + mapKeyNameBytes, + mapValueNameBytes, + mapEntryNameBytes, + timestampFormat); } private static String resolveAwsQueryName(Schema schema) { @@ -71,7 +158,7 @@ static byte[] encodeName(String name) { boolean needsEncoding = false; for (int i = 0; i < len; i++) { char c = name.charAt(i); - if (!FormUrlEncodedSink.isUnreserved(c)) { + if (c >= 128 || !QueryFormSerializer.UNRESERVED[c]) { needsEncoding = true; break; } @@ -83,11 +170,57 @@ static byte[] encodeName(String name) { return result; } - FormUrlEncodedSink tmp = new FormUrlEncodedSink(len * 3); - tmp.writeUrlEncoded(name); - ByteBuffer bb = tmp.finish(); - byte[] result = new byte[bb.remaining()]; - bb.get(result); + // Member names that need encoding are rare (non-ASCII names). + // Use a simple byte array builder for this cold path. + // Max 12 bytes per char (4-byte UTF-8, each byte percent-encoded to 3 bytes) + byte[] buf = new byte[len * 12]; + int pos = 0; + for (int i = 0; i < len; i++) { + char c = name.charAt(i); + if (c < 128 && QueryFormSerializer.UNRESERVED[c]) { + buf[pos++] = (byte) c; + } else if (c < 0x80) { + int off = c * 3; + buf[pos++] = QueryFormSerializer.PERCENT_ENCODED[off]; + buf[pos++] = QueryFormSerializer.PERCENT_ENCODED[off + 1]; + buf[pos++] = QueryFormSerializer.PERCENT_ENCODED[off + 2]; + } else if (c < 0x800) { + int b0 = 0xC0 | (c >> 6); + int b1 = 0x80 | (c & 0x3F); + System.arraycopy(QueryFormSerializer.PERCENT_ENCODED, b0 * 3, buf, pos, 3); + pos += 3; + System.arraycopy(QueryFormSerializer.PERCENT_ENCODED, b1 * 3, buf, pos, 3); + pos += 3; + } else if (Character.isHighSurrogate(c) && i + 1 < len + && Character.isLowSurrogate(name.charAt(i + 1))) { + char low = name.charAt(++i); + int cp = Character.toCodePoint(c, low); + int b0 = 0xF0 | (cp >> 18); + int b1 = 0x80 | ((cp >> 12) & 0x3F); + int b2 = 0x80 | ((cp >> 6) & 0x3F); + int b3 = 0x80 | (cp & 0x3F); + System.arraycopy(QueryFormSerializer.PERCENT_ENCODED, b0 * 3, buf, pos, 3); + pos += 3; + System.arraycopy(QueryFormSerializer.PERCENT_ENCODED, b1 * 3, buf, pos, 3); + pos += 3; + System.arraycopy(QueryFormSerializer.PERCENT_ENCODED, b2 * 3, buf, pos, 3); + pos += 3; + System.arraycopy(QueryFormSerializer.PERCENT_ENCODED, b3 * 3, buf, pos, 3); + pos += 3; + } else { + int b0 = 0xE0 | (c >> 12); + int b1 = 0x80 | ((c >> 6) & 0x3F); + int b2 = 0x80 | (c & 0x3F); + System.arraycopy(QueryFormSerializer.PERCENT_ENCODED, b0 * 3, buf, pos, 3); + pos += 3; + System.arraycopy(QueryFormSerializer.PERCENT_ENCODED, b1 * 3, buf, pos, 3); + pos += 3; + System.arraycopy(QueryFormSerializer.PERCENT_ENCODED, b2 * 3, buf, pos, 3); + pos += 3; + } + } + byte[] result = new byte[pos]; + System.arraycopy(buf, 0, result, 0, pos); return result; } } diff --git a/aws/client/aws-client-awsquery/src/main/java/software/amazon/smithy/java/aws/client/awsquery/Ec2QueryClientProtocol.java b/aws/client/aws-client-awsquery/src/main/java/software/amazon/smithy/java/aws/client/awsquery/Ec2QueryClientProtocol.java index 40d193db4e..9b57d423da 100644 --- a/aws/client/aws-client-awsquery/src/main/java/software/amazon/smithy/java/aws/client/awsquery/Ec2QueryClientProtocol.java +++ b/aws/client/aws-client-awsquery/src/main/java/software/amazon/smithy/java/aws/client/awsquery/Ec2QueryClientProtocol.java @@ -69,7 +69,7 @@ public HttpRequest SmithyUri endpoint ) { String operationName = operation.schema().id().getName(); - QueryFormSerializer serializer = new QueryFormSerializer( + QueryFormSerializer serializer = QueryFormSerializer.acquire( QueryFormSerializer.QueryVariant.EC2_QUERY, operationName, version); diff --git a/aws/client/aws-client-awsquery/src/main/java/software/amazon/smithy/java/aws/client/awsquery/FormUrlEncodedSink.java b/aws/client/aws-client-awsquery/src/main/java/software/amazon/smithy/java/aws/client/awsquery/FormUrlEncodedSink.java deleted file mode 100644 index d6551b0786..0000000000 --- a/aws/client/aws-client-awsquery/src/main/java/software/amazon/smithy/java/aws/client/awsquery/FormUrlEncodedSink.java +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package software.amazon.smithy.java.aws.client.awsquery; - -import java.nio.ByteBuffer; -import java.util.Arrays; - -/** - * A byte buffer sink for building URL-encoded form data using RFC 3986 percent-encoding. - * - *

This uses RFC 3986 unreserved characters (A-Z, a-z, 0-9, '-', '.', '_', '~') which pass through - * unencoded, while all other characters are percent-encoded as UTF-8 bytes. This differs from the - * application/x-www-form-urlencoded spec which encodes space as '+', but AWS Query protocol expects - * RFC 3986 encoding. - */ -final class FormUrlEncodedSink { - private static final byte[] HEX = { - '0', - '1', - '2', - '3', - '4', - '5', - '6', - '7', - '8', - '9', - 'A', - 'B', - 'C', - 'D', - 'E', - 'F' - }; - - private byte[] bytes; - private int pos; - - FormUrlEncodedSink() { - this.bytes = new byte[256]; - this.pos = 0; - } - - FormUrlEncodedSink(int initialCapacity) { - this.bytes = new byte[initialCapacity]; - this.pos = 0; - } - - void writeByte(int b) { - ensureCapacity(1); - bytes[pos++] = (byte) b; - } - - void writeBytes(byte[] b, int off, int len) { - ensureCapacity(len); - System.arraycopy(b, off, bytes, pos, len); - pos += len; - } - - @SuppressWarnings("deprecation") - void writeAscii(String s) { - int len = s.length(); - ensureCapacity(len); - s.getBytes(0, len, bytes, pos); - pos += len; - } - - void writeUrlEncoded(String s) { - int len = s.length(); - ensureCapacity(len * 3); - for (int i = 0; i < len; i++) { - char c = s.charAt(i); - if (isUnreserved(c)) { - bytes[pos++] = (byte) c; - } else if (c < 0x80) { - writePercentEncoded(c); - } else if (c < 0x800) { - writePercentEncoded(0xC0 | (c >> 6)); - writePercentEncoded(0x80 | (c & 0x3F)); - } else if (Character.isHighSurrogate(c) && i + 1 < len && Character.isLowSurrogate(s.charAt(i + 1))) { - char low = s.charAt(++i); - int cp = Character.toCodePoint(c, low); - writePercentEncoded(0xF0 | (cp >> 18)); - writePercentEncoded(0x80 | ((cp >> 12) & 0x3F)); - writePercentEncoded(0x80 | ((cp >> 6) & 0x3F)); - writePercentEncoded(0x80 | (cp & 0x3F)); - } else { - writePercentEncoded(0xE0 | (c >> 12)); - writePercentEncoded(0x80 | ((c >> 6) & 0x3F)); - writePercentEncoded(0x80 | (c & 0x3F)); - } - } - } - - @SuppressWarnings("deprecation") - void writeInt(int value) { - String s = Integer.toString(value); - int len = s.length(); - ensureCapacity(len); - s.getBytes(0, len, bytes, pos); - pos += len; - } - - ByteBuffer finish() { - return ByteBuffer.wrap(bytes, 0, pos); - } - - static boolean isUnreserved(char c) { - return (c >= 'A' && c <= 'Z') - || (c >= 'a' && c <= 'z') - || (c >= '0' && c <= '9') - || c == '-' - || c == '.' - || c == '_' - || c == '~'; - } - - private void writePercentEncoded(int b) { - bytes[pos++] = '%'; - bytes[pos++] = HEX[(b >> 4) & 0xF]; - bytes[pos++] = HEX[b & 0xF]; - } - - private void ensureCapacity(int len) { - int required = pos + len; - if (required > bytes.length) { - bytes = Arrays.copyOf(bytes, Math.max(required, bytes.length + (bytes.length >> 1))); - } - } -} diff --git a/aws/client/aws-client-awsquery/src/main/java/software/amazon/smithy/java/aws/client/awsquery/QueryFormSerializer.java b/aws/client/aws-client-awsquery/src/main/java/software/amazon/smithy/java/aws/client/awsquery/QueryFormSerializer.java index 16e3eb44f5..176e905a59 100644 --- a/aws/client/aws-client-awsquery/src/main/java/software/amazon/smithy/java/aws/client/awsquery/QueryFormSerializer.java +++ b/aws/client/aws-client-awsquery/src/main/java/software/amazon/smithy/java/aws/client/awsquery/QueryFormSerializer.java @@ -12,6 +12,9 @@ import java.time.Instant; import java.util.Arrays; import java.util.function.BiConsumer; +import software.amazon.smithy.java.codecs.commons.NumberCodec; +import software.amazon.smithy.java.codecs.commons.StripedPool; +import software.amazon.smithy.java.codecs.commons.TimestampCodec; import software.amazon.smithy.java.core.schema.Schema; import software.amazon.smithy.java.core.schema.SerializableStruct; import software.amazon.smithy.java.core.schema.TraitKey; @@ -26,18 +29,13 @@ /** * Form-urlencoded serializer for both {@code awsQuery} and {@code ec2Query} protocols. * - *

The two protocols share the same wire format but differ in member name resolution, - * list serialization, and map support. The {@link QueryVariant} flag controls these differences. + *

Writes directly to an internal byte array buffer. Instances are pooled via a striped + * lock-free pool to avoid per-request allocation of both the serializer and its buffer. */ final class QueryFormSerializer implements ShapeSerializer { - /** - * Selects the protocol-specific serialization behavior. - */ enum QueryVariant { - /** Standard AWS Query protocol. */ AWS_QUERY, - /** EC2 Query protocol. */ EC2_QUERY } @@ -48,96 +46,295 @@ enum QueryVariant { private static final byte[] KEY = "key".getBytes(StandardCharsets.UTF_8); private static final byte[] VALUE = "value".getBytes(StandardCharsets.UTF_8); - private final FormUrlEncodedSink sink; - private final QueryVariant variant; - - private byte[][] prefixCache = new byte[8][]; + static final boolean[] UNRESERVED = new boolean[128]; + static final byte[] PERCENT_ENCODED = new byte[256 * 3]; + + static { + for (int c = 'A'; c <= 'Z'; c++) + UNRESERVED[c] = true; + for (int c = 'a'; c <= 'z'; c++) + UNRESERVED[c] = true; + for (int c = '0'; c <= '9'; c++) + UNRESERVED[c] = true; + UNRESERVED['-'] = true; + UNRESERVED['.'] = true; + UNRESERVED['_'] = true; + UNRESERVED['~'] = true; + + byte[] HEX = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'}; + for (int b = 0; b < 256; b++) { + int off = b * 3; + PERCENT_ENCODED[off] = '%'; + PERCENT_ENCODED[off + 1] = HEX[(b >> 4) & 0xF]; + PERCENT_ENCODED[off + 2] = HEX[b & 0xF]; + } + } + + private static final int DEFAULT_BUF_SIZE = 1024; + private static final int MAX_CACHEABLE_BUF = DEFAULT_BUF_SIZE * 4; + + record AcquireContext(QueryVariant variant, String action, String version) {} + + private static final StripedPool POOL = + new StripedPool<>() { + @Override + protected QueryFormSerializer create(AcquireContext ctx) { + return new QueryFormSerializer(); + } + + @Override + protected boolean canPool(QueryFormSerializer s) { + return true; + } + + @Override + protected void prepareForPool(QueryFormSerializer s) { + if (s.buf.length > MAX_CACHEABLE_BUF) { + s.buf = new byte[DEFAULT_BUF_SIZE]; + } + } + + @Override + protected boolean reset(QueryFormSerializer s, AcquireContext ctx) { + s.prefixLen = 0; + s.prefixDepth = 0; + s.pos = 0; + return true; + } + }; + + private byte[] buf; + private int pos; + private QueryVariant variant; + + private byte[] prefixBuf = new byte[128]; + private int prefixLen = 0; + private int[] prefixStack = new int[8]; private int prefixDepth = 0; private final ListItemSerializer listSerializer = new ListItemSerializer(); private final QueryMapSerializer mapSerializer = new QueryMapSerializer(); + private final MapValueSerializer mapValueSerializer = new MapValueSerializer(); - QueryFormSerializer(QueryVariant variant, String action, String version) { - this.variant = variant; - this.sink = new FormUrlEncodedSink(); - sink.writeBytes(ACTION_PREFIX, 0, ACTION_PREFIX.length); - sink.writeAscii(action); - sink.writeBytes(VERSION_PREFIX, 0, VERSION_PREFIX.length); - sink.writeAscii(version); + private QueryFormSerializer() { + this.buf = new byte[DEFAULT_BUF_SIZE]; } - ByteBuffer finish() { - return sink.finish(); + @SuppressWarnings("deprecation") + static QueryFormSerializer acquire(QueryVariant variant, String action, String version) { + QueryFormSerializer s = POOL.acquire(new AcquireContext(variant, action, version)); + s.variant = variant; + + int headerLen = ACTION_PREFIX.length + action.length() + VERSION_PREFIX.length + version.length(); + s.ensureCapacity(headerLen); + System.arraycopy(ACTION_PREFIX, 0, s.buf, 0, ACTION_PREFIX.length); + s.pos = ACTION_PREFIX.length; + action.getBytes(0, action.length(), s.buf, s.pos); + s.pos += action.length(); + System.arraycopy(VERSION_PREFIX, 0, s.buf, s.pos, VERSION_PREFIX.length); + s.pos += VERSION_PREFIX.length; + version.getBytes(0, version.length(), s.buf, s.pos); + s.pos += version.length(); + return s; } - private void writeParam(byte[] key, String value) { - sink.writeByte('&'); - writeCurrentPrefix(); - if (prefixDepth > 0) { - sink.writeByte('.'); - } - sink.writeBytes(key, 0, key.length); - sink.writeByte('='); - sink.writeUrlEncoded(value); + ByteBuffer finish() { + ByteBuffer result = ByteBuffer.wrap(buf, 0, pos); + POOL.release(this); + return result; } - private void writeCurrentPrefix() { - for (int i = 0; i < prefixDepth; i++) { - if (i > 0) { - sink.writeByte('.'); - } - sink.writeBytes(prefixCache[i], 0, prefixCache[i].length); + private void ensureCapacity(int needed) { + int required = pos + needed; + if (required > buf.length) { + buf = Arrays.copyOf(buf, Math.max(required, buf.length + (buf.length >> 1))); } } - private void pushPrefix(byte[] prefix) { - ensurePrefixCacheCapacity(); - prefixCache[prefixDepth++] = prefix; + private void pushPrefix(byte[] name) { + if (prefixDepth >= prefixStack.length) { + prefixStack = Arrays.copyOf(prefixStack, prefixStack.length * 2); + } + prefixStack[prefixDepth++] = prefixLen; + int needed = name.length + (prefixLen > 0 ? 1 : 0); + if (prefixLen + needed > prefixBuf.length) { + prefixBuf = Arrays.copyOf(prefixBuf, Math.max(prefixLen + needed, prefixBuf.length * 2)); + } + if (prefixLen > 0) { + prefixBuf[prefixLen++] = '.'; + } + System.arraycopy(name, 0, prefixBuf, prefixLen, name.length); + prefixLen += name.length; } - private void pushIndexedPrefix(byte[] base, int index) { - ensurePrefixCacheCapacity(); - prefixCache[prefixDepth++] = encodeIndexedPrefix(base, index); + private void pushPrefixWithIndex(byte[] name, int index) { + if (prefixDepth >= prefixStack.length) { + prefixStack = Arrays.copyOf(prefixStack, prefixStack.length * 2); + } + prefixStack[prefixDepth++] = prefixLen; + int indexLen = NumberCodec.digitCount(index); + int needed = (prefixLen > 0 ? 1 : 0) + name.length + 1 + indexLen; + if (prefixLen + needed > prefixBuf.length) { + prefixBuf = Arrays.copyOf(prefixBuf, Math.max(prefixLen + needed, prefixBuf.length * 2)); + } + if (prefixLen > 0) { + prefixBuf[prefixLen++] = '.'; + } + System.arraycopy(name, 0, prefixBuf, prefixLen, name.length); + prefixLen += name.length; + prefixBuf[prefixLen++] = '.'; + prefixLen = NumberCodec.writeInt(prefixBuf, prefixLen, index); } private void pushIndexPrefix(int index) { - ensurePrefixCacheCapacity(); - prefixCache[prefixDepth++] = encodeIndex(index); - } - - private void ensurePrefixCacheCapacity() { - if (prefixDepth >= prefixCache.length) { - prefixCache = Arrays.copyOf(prefixCache, prefixCache.length * 2); + if (prefixDepth >= prefixStack.length) { + prefixStack = Arrays.copyOf(prefixStack, prefixStack.length * 2); + } + prefixStack[prefixDepth++] = prefixLen; + int indexLen = NumberCodec.digitCount(index); + int needed = (prefixLen > 0 ? 1 : 0) + indexLen; + if (prefixLen + needed > prefixBuf.length) { + prefixBuf = Arrays.copyOf(prefixBuf, Math.max(prefixLen + needed, prefixBuf.length * 2)); } + if (prefixLen > 0) { + prefixBuf[prefixLen++] = '.'; + } + prefixLen = NumberCodec.writeInt(prefixBuf, prefixLen, index); } private void popPrefix() { - prefixDepth--; + prefixLen = prefixStack[--prefixDepth]; } - @SuppressWarnings("deprecation") - private byte[] encodeIndexedPrefix(byte[] base, int index) { - String indexStr = Integer.toString(index); - byte[] result = new byte[base.length + 1 + indexStr.length()]; - System.arraycopy(base, 0, result, 0, base.length); - result[base.length] = '.'; - indexStr.getBytes(0, indexStr.length(), result, base.length + 1); - return result; + private void writeUrlEncodedAsciiBytes(byte[] data, int dataLen) { + for (int i = 0; i < dataLen; i++) { + int b = data[i] & 0xFF; + if (b < 128 && UNRESERVED[b]) { + buf[pos++] = data[i]; + } else { + int off = b * 3; + buf[pos] = PERCENT_ENCODED[off]; + buf[pos + 1] = PERCENT_ENCODED[off + 1]; + buf[pos + 2] = PERCENT_ENCODED[off + 2]; + pos += 3; + } + } } - @SuppressWarnings("deprecation") - private byte[] encodeIndex(int index) { - String indexStr = Integer.toString(index); - byte[] result = new byte[indexStr.length()]; - indexStr.getBytes(0, indexStr.length(), result, 0); - return result; + private void writeUrlEncoded(String s) { + int len = s.length(); + for (int i = 0; i < len; i++) { + char c = s.charAt(i); + if (c < 0x80) { + if (UNRESERVED[c]) { + buf[pos++] = (byte) c; + } else { + int off = c * 3; + buf[pos] = PERCENT_ENCODED[off]; + buf[pos + 1] = PERCENT_ENCODED[off + 1]; + buf[pos + 2] = PERCENT_ENCODED[off + 2]; + pos += 3; + } + } else if (c < 0x800) { + int b0 = 0xC0 | (c >> 6); + int b1 = 0x80 | (c & 0x3F); + System.arraycopy(PERCENT_ENCODED, b0 * 3, buf, pos, 3); + pos += 3; + System.arraycopy(PERCENT_ENCODED, b1 * 3, buf, pos, 3); + pos += 3; + } else if (Character.isHighSurrogate(c) && i + 1 < len && Character.isLowSurrogate(s.charAt(i + 1))) { + char low = s.charAt(++i); + int cp = Character.toCodePoint(c, low); + int b0 = 0xF0 | (cp >> 18); + int b1 = 0x80 | ((cp >> 12) & 0x3F); + int b2 = 0x80 | ((cp >> 6) & 0x3F); + int b3 = 0x80 | (cp & 0x3F); + System.arraycopy(PERCENT_ENCODED, b0 * 3, buf, pos, 3); + pos += 3; + System.arraycopy(PERCENT_ENCODED, b1 * 3, buf, pos, 3); + pos += 3; + System.arraycopy(PERCENT_ENCODED, b2 * 3, buf, pos, 3); + pos += 3; + System.arraycopy(PERCENT_ENCODED, b3 * 3, buf, pos, 3); + pos += 3; + } else { + int b0 = 0xE0 | (c >> 12); + int b1 = 0x80 | ((c >> 6) & 0x3F); + int b2 = 0x80 | (c & 0x3F); + System.arraycopy(PERCENT_ENCODED, b0 * 3, buf, pos, 3); + pos += 3; + System.arraycopy(PERCENT_ENCODED, b1 * 3, buf, pos, 3); + pos += 3; + System.arraycopy(PERCENT_ENCODED, b2 * 3, buf, pos, 3); + pos += 3; + } + } } - // --- Member name resolution (protocol-specific) --- + /** + * Writes "&prefix.key=" to the buffer. Used by top-level member serialization. + */ + private void writeKeyPrefix(byte[] key, int maxValueSize) { + ensureCapacity(1 + prefixLen + 1 + key.length + 1 + maxValueSize); + buf[pos++] = '&'; + if (prefixLen > 0) { + System.arraycopy(prefixBuf, 0, buf, pos, prefixLen); + pos += prefixLen; + buf[pos++] = '.'; + } + System.arraycopy(key, 0, buf, pos, key.length); + pos += key.length; + buf[pos++] = '='; + } /** - * Read the pre-computed URL-encoded member-name bytes from the schema extension. + * Writes "&prefix=" to the buffer (no key/dot). Used by map value serialization. */ + private void writePrefixEquals(int maxValueSize) { + ensureCapacity(1 + prefixLen + 1 + maxValueSize); + buf[pos++] = '&'; + if (prefixLen > 0) { + System.arraycopy(prefixBuf, 0, buf, pos, prefixLen); + pos += prefixLen; + } + buf[pos++] = '='; + } + + private void writeParam(byte[] key, String value) { + writeKeyPrefix(key, value.length() * 3); + writeUrlEncoded(value); + } + + private void writeParamBoolean(byte[] key, boolean value) { + writeKeyPrefix(key, 5); + pos = NumberCodec.writeBoolean(buf, pos, value); + } + + private void writeParamInt(byte[] key, int value) { + writeKeyPrefix(key, 11); + pos = NumberCodec.writeInt(buf, pos, value); + } + + private void writeParamLong(byte[] key, long value) { + writeKeyPrefix(key, 20); + pos = NumberCodec.writeLong(buf, pos, value); + } + + private void writeParamDouble(byte[] key, double value) { + writeKeyPrefix(key, 25); + pos = NumberCodec.writeDouble(buf, pos, value); + } + + private void writeParamFloat(byte[] key, float value) { + writeKeyPrefix(key, 15); + pos = NumberCodec.writeFloat(buf, pos, value); + } + + private void writeParamTimestampDirect(byte[] key, Instant value) { + writeKeyPrefix(key, 30); + pos = TimestampCodec.writeIso8601(buf, pos, value); + } + private byte[] getMemberNameBytes(Schema schema) { var ext = schema.getExtension(AwsQuerySchemaExtensions.KEY); if (ext == null) { @@ -146,12 +343,9 @@ private byte[] getMemberNameBytes(Schema schema) { return variant == QueryVariant.AWS_QUERY ? ext.awsQueryNameBytes() : ext.ec2QueryNameBytes(); } - // --- Struct --- - @Override public void writeStruct(Schema schema, SerializableStruct struct) { if (schema.isMember()) { - // Member schemas always have a non-null QueryMemberBinding (see provider). pushPrefix(getMemberNameBytes(schema)); struct.serializeMembers(this); popPrefix(); @@ -160,8 +354,6 @@ public void writeStruct(Schema schema, SerializableStruct struct) { } } - // --- List (protocol-specific) --- - @Override public void writeList(Schema schema, T listState, int size, BiConsumer consumer) { if (variant == QueryVariant.EC2_QUERY) { @@ -177,9 +369,6 @@ private void writeAwsQueryList( int size, BiConsumer consumer ) { - boolean flattened = schema.hasTrait(TraitKey.XML_FLATTENED_TRAIT); - Schema memberSchema = schema.listMember(); - if (schema.isMember()) { pushPrefix(getMemberNameBytes(schema)); } @@ -192,12 +381,24 @@ private void writeAwsQueryList( return; } + var ext = schema.getExtension(AwsQuerySchemaExtensions.KEY); + boolean flattened; byte[] memberNameBytes; - if (flattened) { + if (ext != null && ext.listMemberNameBytes() != null) { + flattened = ext.listFlattened(); + memberNameBytes = ext.listMemberNameBytes(); + } else if (ext != null && ext.listFlattened()) { + flattened = true; memberNameBytes = null; } else { - var xmlName = memberSchema.getTrait(TraitKey.XML_NAME_TRAIT); - memberNameBytes = xmlName != null ? xmlName.getValue().getBytes(StandardCharsets.UTF_8) : MEMBER; + flattened = schema.hasTrait(TraitKey.XML_FLATTENED_TRAIT); + if (flattened) { + memberNameBytes = null; + } else { + Schema memberSchema = schema.listMember(); + var xmlName = memberSchema.getTrait(TraitKey.XML_NAME_TRAIT); + memberNameBytes = xmlName != null ? xmlName.getValue().getBytes(StandardCharsets.UTF_8) : MEMBER; + } } listSerializer.reset(memberNameBytes, flattened); @@ -209,7 +410,6 @@ private void writeAwsQueryList( } private void writeEc2List(Schema schema, T listState, int size, BiConsumer consumer) { - // EC2 Query lists are always flattened - no .member. segment if (schema.isMember()) { pushPrefix(getMemberNameBytes(schema)); } @@ -230,9 +430,13 @@ private void writeEc2List(Schema schema, T listState, int size, BiConsumer 0) { + System.arraycopy(prefixBuf, 0, buf, pos, prefixLen); + pos += prefixLen; + } + buf[pos++] = '='; } private final class ListItemSerializer implements ShapeSerializer { @@ -250,10 +454,34 @@ private void pushIndexedMemberPrefix() { if (flattened) { pushIndexPrefix(index); } else { - pushIndexedPrefix(memberNameBytes, index); + pushPrefixWithIndex(memberNameBytes, index); } } + /** + * Writes "&prefix.memberName.index=" (or "&prefix.index=" if flattened) to the buffer. + */ + private void writeIndexedKeyPrefix(int maxValueSize) { + int indexLen = NumberCodec.digitCount(index); + int memberPartLen = flattened ? indexLen : (memberNameBytes.length + 1 + indexLen); + ensureCapacity(1 + prefixLen + (prefixLen > 0 ? 1 : 0) + memberPartLen + 1 + maxValueSize); + buf[pos++] = '&'; + if (prefixLen > 0) { + System.arraycopy(prefixBuf, 0, buf, pos, prefixLen); + pos += prefixLen; + buf[pos++] = '.'; + } + if (flattened) { + pos = NumberCodec.writeInt(buf, pos, index); + } else { + System.arraycopy(memberNameBytes, 0, buf, pos, memberNameBytes.length); + pos += memberNameBytes.length; + buf[pos++] = '.'; + pos = NumberCodec.writeInt(buf, pos, index); + } + buf[pos++] = '='; + } + @Override public void writeStruct(Schema schema, SerializableStruct struct) { pushIndexedMemberPrefix(); @@ -280,75 +508,108 @@ public void writeMap(Schema schema, T mapState, int size, BiConsumer 0 ? "Infinity" : "-Infinity"); + if (!Float.isFinite(value)) { + writeIndexedKeyPrefix(9); + pos = NumberCodec.writeNonFiniteFloat(buf, pos, value); } else { - writeIndexedParam(Float.toString(value)); + writeIndexedKeyPrefix(15); + pos = NumberCodec.writeFloat(buf, pos, value); } + index++; } @Override public void writeDouble(Schema schema, double value) { - if (Double.isNaN(value)) { - writeIndexedParam("NaN"); - } else if (Double.isInfinite(value)) { - writeIndexedParam(value > 0 ? "Infinity" : "-Infinity"); + if (!Double.isFinite(value)) { + writeIndexedKeyPrefix(9); + pos = NumberCodec.writeNonFiniteDouble(buf, pos, value); } else { - writeIndexedParam(Double.toString(value)); + writeIndexedKeyPrefix(25); + pos = NumberCodec.writeDouble(buf, pos, value); } + index++; } @Override public void writeBigInteger(Schema schema, BigInteger value) { - writeIndexedParam(value.toString()); + writeIndexedKeyPrefix(64); + pos = NumberCodec.writeBigInteger(buf, pos, value); + index++; } @Override public void writeBigDecimal(Schema schema, BigDecimal value) { - writeIndexedParam(value.toPlainString()); + writeIndexedKeyPrefix(NumberCodec.maxBigDecimalLength(value)); + pos = NumberCodec.writeBigDecimal(buf, pos, value); + index++; } @Override public void writeString(Schema schema, String value) { - writeIndexedParam(value); + writeIndexedKeyPrefix(value.length() * 3); + writeUrlEncoded(value); + index++; } @Override public void writeBlob(Schema schema, ByteBuffer value) { - writeIndexedParam(ByteBufferUtils.base64Encode(value)); + byte[] encoded = ByteBufferUtils.base64EncodeToBytes(value); + writeIndexedKeyPrefix(encoded.length * 3); + writeUrlEncodedAsciiBytes(encoded, encoded.length); + index++; } @Override public void writeTimestamp(Schema schema, Instant value) { - TimestampFormatter formatter = TimestampFormatter.of(schema, TimestampFormatTrait.Format.DATE_TIME); - writeIndexedParam(formatter.writeString(value)); + TimestampFormatTrait.Format fmt = resolveTimestampFormat(schema); + if (fmt == TimestampFormatTrait.Format.DATE_TIME) { + writeIndexedKeyPrefix(30); + pos = TimestampCodec.writeIso8601(buf, pos, value); + index++; + } else if (fmt == TimestampFormatTrait.Format.EPOCH_SECONDS) { + writeIndexedKeyPrefix(30); + pos = TimestampCodec.writeEpochSeconds(buf, pos, value.getEpochSecond(), value.getNano()); + index++; + } else if (fmt == TimestampFormatTrait.Format.HTTP_DATE) { + writeIndexedKeyPrefix(90); + writeHttpDateUrlEncoded(value); + index++; + } else { + TimestampFormatter formatter = TimestampFormatter.of(schema, TimestampFormatTrait.Format.DATE_TIME); + String formatted = formatter.writeString(value); + writeIndexedKeyPrefix(formatted.length() * 3); + writeUrlEncoded(formatted); + index++; + } } @Override @@ -361,27 +622,13 @@ public void writeNull(Schema schema) { index++; } - private void writeIndexedParam(String value) { - sink.writeByte('&'); - writeCurrentPrefix(); - if (prefixDepth > 0) { - sink.writeByte('.'); - } - if (flattened) { - sink.writeInt(index); - } else { - sink.writeBytes(memberNameBytes, 0, memberNameBytes.length); - sink.writeByte('.'); - sink.writeInt(index); - } - sink.writeByte('='); - sink.writeUrlEncoded(value); + private void writeIndexedParamInt(int value) { + writeIndexedKeyPrefix(11); + pos = NumberCodec.writeInt(buf, pos, value); index++; } } - // --- Map (awsQuery only) --- - @Override public void writeMap(Schema schema, T mapState, int size, BiConsumer consumer) { if (variant == QueryVariant.EC2_QUERY) { @@ -436,7 +683,7 @@ public void writeEntry( if (flattened) { pushIndexPrefix(index); } else { - pushIndexedPrefix(entryNameBytes, index); + pushPrefixWithIndex(entryNameBytes, index); } writeParam(keyNameBytes, key); @@ -450,8 +697,6 @@ public void writeEntry( } } - private final MapValueSerializer mapValueSerializer = new MapValueSerializer(); - private final class MapValueSerializer implements ShapeSerializer { @Override public void writeStruct(Schema schema, SerializableStruct struct) { @@ -500,75 +745,96 @@ public void writeMap(Schema schema, T mapState, int size, BiConsumer 0 ? "Infinity" : "-Infinity"); + if (!Float.isFinite(value)) { + writePrefixEquals(9); + pos = NumberCodec.writeNonFiniteFloat(buf, pos, value); } else { - writeValueParam(Float.toString(value)); + writePrefixEquals(15); + pos = NumberCodec.writeFloat(buf, pos, value); } } @Override public void writeDouble(Schema schema, double value) { - if (Double.isNaN(value)) { - writeValueParam("NaN"); - } else if (Double.isInfinite(value)) { - writeValueParam(value > 0 ? "Infinity" : "-Infinity"); + if (!Double.isFinite(value)) { + writePrefixEquals(9); + pos = NumberCodec.writeNonFiniteDouble(buf, pos, value); } else { - writeValueParam(Double.toString(value)); + writePrefixEquals(25); + pos = NumberCodec.writeDouble(buf, pos, value); } } @Override public void writeBigInteger(Schema schema, BigInteger value) { - writeValueParam(value.toString()); + writePrefixEquals(64); + pos = NumberCodec.writeBigInteger(buf, pos, value); } @Override public void writeBigDecimal(Schema schema, BigDecimal value) { - writeValueParam(value.toPlainString()); + writePrefixEquals(NumberCodec.maxBigDecimalLength(value)); + pos = NumberCodec.writeBigDecimal(buf, pos, value); } @Override public void writeString(Schema schema, String value) { - writeValueParam(value); + writePrefixEquals(value.length() * 3); + writeUrlEncoded(value); } @Override public void writeBlob(Schema schema, ByteBuffer value) { - writeValueParam(ByteBufferUtils.base64Encode(value)); + byte[] encoded = ByteBufferUtils.base64EncodeToBytes(value); + writePrefixEquals(encoded.length * 3); + writeUrlEncodedAsciiBytes(encoded, encoded.length); } @Override public void writeTimestamp(Schema schema, Instant value) { - TimestampFormatter formatter = TimestampFormatter.of(schema, TimestampFormatTrait.Format.DATE_TIME); - writeValueParam(formatter.writeString(value)); + TimestampFormatTrait.Format fmt = resolveTimestampFormat(schema); + if (fmt == TimestampFormatTrait.Format.DATE_TIME) { + writePrefixEquals(30); + pos = TimestampCodec.writeIso8601(buf, pos, value); + } else if (fmt == TimestampFormatTrait.Format.EPOCH_SECONDS) { + writePrefixEquals(30); + pos = TimestampCodec.writeEpochSeconds(buf, pos, value.getEpochSecond(), value.getNano()); + } else if (fmt == TimestampFormatTrait.Format.HTTP_DATE) { + writePrefixEquals(90); + writeHttpDateUrlEncoded(value); + } else { + TimestampFormatter formatter = TimestampFormatter.of(schema, TimestampFormatTrait.Format.DATE_TIME); + String formatted = formatter.writeString(value); + writePrefixEquals(formatted.length() * 3); + writeUrlEncoded(formatted); + } } @Override @@ -579,73 +845,71 @@ public void writeDocument(Schema schema, Document value) { @Override public void writeNull(Schema schema) {} - private void writeValueParam(String value) { - sink.writeByte('&'); - writeCurrentPrefix(); - sink.writeByte('='); - sink.writeUrlEncoded(value); + private void writeValueParamInt(int value) { + writePrefixEquals(11); + pos = NumberCodec.writeInt(buf, pos, value); } } - // --- Scalar writes --- - @Override public void writeBoolean(Schema schema, boolean value) { - writeParam(getMemberNameBytes(schema), value ? "true" : "false"); + writeParamBoolean(getMemberNameBytes(schema), value); } @Override public void writeByte(Schema schema, byte value) { - writeParam(getMemberNameBytes(schema), Byte.toString(value)); + writeParamInt(getMemberNameBytes(schema), value); } @Override public void writeShort(Schema schema, short value) { - writeParam(getMemberNameBytes(schema), Short.toString(value)); + writeParamInt(getMemberNameBytes(schema), value); } @Override public void writeInteger(Schema schema, int value) { - writeParam(getMemberNameBytes(schema), Integer.toString(value)); + writeParamInt(getMemberNameBytes(schema), value); } @Override public void writeLong(Schema schema, long value) { - writeParam(getMemberNameBytes(schema), Long.toString(value)); + writeParamLong(getMemberNameBytes(schema), value); } @Override public void writeFloat(Schema schema, float value) { - byte[] memberNameBytes = getMemberNameBytes(schema); - if (Float.isNaN(value)) { - writeParam(memberNameBytes, "NaN"); - } else if (Float.isInfinite(value)) { - writeParam(memberNameBytes, value > 0 ? "Infinity" : "-Infinity"); + byte[] key = getMemberNameBytes(schema); + if (!Float.isFinite(value)) { + writeKeyPrefix(key, 9); + pos = NumberCodec.writeNonFiniteFloat(buf, pos, value); } else { - writeParam(memberNameBytes, Float.toString(value)); + writeParamFloat(key, value); } } @Override public void writeDouble(Schema schema, double value) { - byte[] memberNameBytes = getMemberNameBytes(schema); - if (Double.isNaN(value)) { - writeParam(memberNameBytes, "NaN"); - } else if (Double.isInfinite(value)) { - writeParam(memberNameBytes, value > 0 ? "Infinity" : "-Infinity"); + byte[] key = getMemberNameBytes(schema); + if (!Double.isFinite(value)) { + writeKeyPrefix(key, 9); + pos = NumberCodec.writeNonFiniteDouble(buf, pos, value); } else { - writeParam(memberNameBytes, Double.toString(value)); + writeParamDouble(key, value); } } @Override public void writeBigInteger(Schema schema, BigInteger value) { - writeParam(getMemberNameBytes(schema), value.toString()); + byte[] key = getMemberNameBytes(schema); + writeKeyPrefix(key, 64); + pos = NumberCodec.writeBigInteger(buf, pos, value); } @Override public void writeBigDecimal(Schema schema, BigDecimal value) { - writeParam(getMemberNameBytes(schema), value.toPlainString()); + byte[] key = getMemberNameBytes(schema); + writeKeyPrefix(key, NumberCodec.maxBigDecimalLength(value)); + pos = NumberCodec.writeBigDecimal(buf, pos, value); } @Override @@ -655,13 +919,28 @@ public void writeString(Schema schema, String value) { @Override public void writeBlob(Schema schema, ByteBuffer value) { - writeParam(getMemberNameBytes(schema), ByteBufferUtils.base64Encode(value)); + byte[] key = getMemberNameBytes(schema); + byte[] encoded = ByteBufferUtils.base64EncodeToBytes(value); + writeKeyPrefix(key, encoded.length * 3); + writeUrlEncodedAsciiBytes(encoded, encoded.length); } @Override public void writeTimestamp(Schema schema, Instant value) { - TimestampFormatter formatter = TimestampFormatter.of(schema, TimestampFormatTrait.Format.DATE_TIME); - writeParam(getMemberNameBytes(schema), formatter.writeString(value)); + TimestampFormatTrait.Format fmt = resolveTimestampFormat(schema); + byte[] key = getMemberNameBytes(schema); + if (fmt == TimestampFormatTrait.Format.DATE_TIME) { + writeParamTimestampDirect(key, value); + } else if (fmt == TimestampFormatTrait.Format.EPOCH_SECONDS) { + writeKeyPrefix(key, 30); + pos = TimestampCodec.writeEpochSeconds(buf, pos, value.getEpochSecond(), value.getNano()); + } else if (fmt == TimestampFormatTrait.Format.HTTP_DATE) { + writeKeyPrefix(key, 90); + writeHttpDateUrlEncoded(value); + } else { + TimestampFormatter formatter = TimestampFormatter.of(schema, TimestampFormatTrait.Format.DATE_TIME); + writeParam(key, formatter.writeString(value)); + } } @Override @@ -671,4 +950,20 @@ public void writeDocument(Schema schema, Document value) { @Override public void writeNull(Schema schema) {} + + private final byte[] httpDateTmp = new byte[40]; + + private void writeHttpDateUrlEncoded(Instant value) { + int len = TimestampCodec.writeHttpDate(httpDateTmp, 0, value); + writeUrlEncodedAsciiBytes(httpDateTmp, len); + } + + private static TimestampFormatTrait.Format resolveTimestampFormat(Schema schema) { + var ext = schema.getExtension(AwsQuerySchemaExtensions.KEY); + if (ext != null && ext.timestampFormat() != null) { + return ext.timestampFormat(); + } + var trait = schema.getTrait(TraitKey.TIMESTAMP_FORMAT_TRAIT); + return trait != null ? trait.getFormat() : TimestampFormatTrait.Format.DATE_TIME; + } } diff --git a/aws/client/aws-client-awsquery/src/test/java/software/amazon/smithy/java/aws/client/awsquery/FormUrlEncodedSinkTest.java b/aws/client/aws-client-awsquery/src/test/java/software/amazon/smithy/java/aws/client/awsquery/FormUrlEncodedSinkTest.java deleted file mode 100644 index 09cfc3e33d..0000000000 --- a/aws/client/aws-client-awsquery/src/test/java/software/amazon/smithy/java/aws/client/awsquery/FormUrlEncodedSinkTest.java +++ /dev/null @@ -1,288 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -package software.amazon.smithy.java.aws.client.awsquery; - -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.equalTo; -import static software.amazon.smithy.java.io.ByteBufferUtils.getUTF8String; - -import java.nio.ByteBuffer; -import java.nio.charset.StandardCharsets; -import java.util.stream.Stream; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; -import org.junit.jupiter.params.provider.ValueSource; - -class FormUrlEncodedSinkTest { - - @ParameterizedTest - @ValueSource(strings = { - "ABCDEFGHIJKLMNOPQRSTUVWXYZ", - "abcdefghijklmnopqrstuvwxyz", - "0123456789", - "-._~", - "Hello", - "test123", - "a-b.c_d~e" - }) - void unreservedCharactersPassThrough(String input) { - var sink = new FormUrlEncodedSink(); - sink.writeUrlEncoded(input); - assertThat(getUTF8String(sink.finish()), equalTo(input)); - } - - @ParameterizedTest - @MethodSource("reservedCharactersProvider") - void reservedCharactersArePercentEncoded(String input, String expected) { - var sink = new FormUrlEncodedSink(); - sink.writeUrlEncoded(input); - assertThat(getUTF8String(sink.finish()), equalTo(expected)); - } - - static Stream reservedCharactersProvider() { - return Stream.of( - Arguments.of(" ", "%20"), - Arguments.of("!", "%21"), - Arguments.of("#", "%23"), - Arguments.of("$", "%24"), - Arguments.of("%", "%25"), - Arguments.of("&", "%26"), - Arguments.of("'", "%27"), - Arguments.of("(", "%28"), - Arguments.of(")", "%29"), - Arguments.of("*", "%2A"), - Arguments.of("+", "%2B"), - Arguments.of(",", "%2C"), - Arguments.of("/", "%2F"), - Arguments.of(":", "%3A"), - Arguments.of(";", "%3B"), - Arguments.of("=", "%3D"), - Arguments.of("?", "%3F"), - Arguments.of("@", "%40"), - Arguments.of("[", "%5B"), - Arguments.of("]", "%5D"), - Arguments.of("hello world", "hello%20world"), - Arguments.of("a=b&c=d", "a%3Db%26c%3Dd"), - Arguments.of("foo/bar", "foo%2Fbar")); - } - - @ParameterizedTest - @MethodSource("utf8TwoByteProvider") - void twoByteUtf8CharactersAreEncoded(String input, String expected) { - var sink = new FormUrlEncodedSink(); - sink.writeUrlEncoded(input); - assertThat(getUTF8String(sink.finish()), equalTo(expected)); - } - - static Stream utf8TwoByteProvider() { - return Stream.of( - Arguments.of("é", "%C3%A9"), - Arguments.of("ñ", "%C3%B1"), - Arguments.of("ü", "%C3%BC"), - Arguments.of("café", "caf%C3%A9"), - Arguments.of("©", "%C2%A9")); - } - - @ParameterizedTest - @MethodSource("utf8ThreeByteProvider") - void threeByteUtf8CharactersAreEncoded(String input, String expected) { - var sink = new FormUrlEncodedSink(); - sink.writeUrlEncoded(input); - assertThat(getUTF8String(sink.finish()), equalTo(expected)); - } - - static Stream utf8ThreeByteProvider() { - return Stream.of( - Arguments.of("€", "%E2%82%AC"), - Arguments.of("中", "%E4%B8%AD"), - Arguments.of("日本", "%E6%97%A5%E6%9C%AC"), - Arguments.of("☃", "%E2%98%83")); - } - - @ParameterizedTest - @MethodSource("utf8FourByteProvider") - void fourByteUtf8SurrogatePairsAreEncoded(String input, String expected) { - var sink = new FormUrlEncodedSink(); - sink.writeUrlEncoded(input); - assertThat(getUTF8String(sink.finish()), equalTo(expected)); - } - - static Stream utf8FourByteProvider() { - return Stream.of( - Arguments.of("🎉", "%F0%9F%8E%89"), - Arguments.of("😀", "%F0%9F%98%80"), - Arguments.of("𝄞", "%F0%9D%84%9E"), - Arguments.of("hello🎉world", "hello%F0%9F%8E%89world")); - } - - @Test - void writeUrlEncodedWithEmptyString() { - var sink = new FormUrlEncodedSink(); - sink.writeUrlEncoded(""); - assertThat(getUTF8String(sink.finish()), equalTo("")); - } - - @Test - void writeUrlEncodedWithMixedContent() { - var sink = new FormUrlEncodedSink(); - sink.writeUrlEncoded("Hello World! café 日本 🎉"); - assertThat(getUTF8String(sink.finish()), - equalTo("Hello%20World%21%20caf%C3%A9%20%E6%97%A5%E6%9C%AC%20%F0%9F%8E%89")); - } - - @ParameterizedTest - @ValueSource(ints = {0, 1, 9, 10, 99, 100, 999, 1000, 12345, 999999, Integer.MAX_VALUE}) - void writeIntPositiveValues(int value) { - var sink = new FormUrlEncodedSink(); - sink.writeInt(value); - assertThat(getUTF8String(sink.finish()), equalTo(Integer.toString(value))); - } - - @ParameterizedTest - @ValueSource(ints = {-1, -10, -999, Integer.MIN_VALUE}) - void writeIntNegativeValues(int value) { - var sink = new FormUrlEncodedSink(); - sink.writeInt(value); - assertThat(getUTF8String(sink.finish()), equalTo(Integer.toString(value))); - } - - @Test - void writeAsciiSimpleString() { - var sink = new FormUrlEncodedSink(); - sink.writeAscii("Action=GetUser"); - assertThat(getUTF8String(sink.finish()), equalTo("Action=GetUser")); - } - - @Test - void writeAsciiEmptyString() { - var sink = new FormUrlEncodedSink(); - sink.writeAscii(""); - assertThat(getUTF8String(sink.finish()), equalTo("")); - } - - @Test - void writeByteSingleByte() { - var sink = new FormUrlEncodedSink(); - sink.writeByte('&'); - assertThat(getUTF8String(sink.finish()), equalTo("&")); - } - - @Test - void writeByteMultipleBytes() { - var sink = new FormUrlEncodedSink(); - sink.writeByte('a'); - sink.writeByte('='); - sink.writeByte('b'); - assertThat(getUTF8String(sink.finish()), equalTo("a=b")); - } - - @Test - void writeBytesFromArray() { - var sink = new FormUrlEncodedSink(); - byte[] data = "Hello".getBytes(StandardCharsets.UTF_8); - sink.writeBytes(data, 0, data.length); - assertThat(getUTF8String(sink.finish()), equalTo("Hello")); - } - - @Test - void writeBytesWithOffset() { - var sink = new FormUrlEncodedSink(); - byte[] data = "xxHelloxx".getBytes(StandardCharsets.UTF_8); - sink.writeBytes(data, 2, 5); - assertThat(getUTF8String(sink.finish()), equalTo("Hello")); - } - - @Test - void combineMultipleWriteOperations() { - var sink = new FormUrlEncodedSink(); - sink.writeAscii("Action=Test"); - sink.writeByte('&'); - sink.writeAscii("Index="); - sink.writeInt(42); - sink.writeByte('&'); - sink.writeAscii("Name="); - sink.writeUrlEncoded("hello world"); - assertThat(getUTF8String(sink.finish()), equalTo("Action=Test&Index=42&Name=hello%20world")); - } - - @Test - void bufferGrowsBeyondInitialCapacity() { - var sink = new FormUrlEncodedSink(8); - sink.writeAscii("This is a much longer string that exceeds the initial capacity"); - assertThat(getUTF8String(sink.finish()), - equalTo("This is a much longer string that exceeds the initial capacity")); - } - - @Test - void bufferGrowsWithUrlEncodedContent() { - var sink = new FormUrlEncodedSink(10); - sink.writeUrlEncoded("Special chars: !@#$%^&*()"); - assertThat(getUTF8String(sink.finish()), - equalTo("Special%20chars%3A%20%21%40%23%24%25%5E%26%2A%28%29")); - } - - @Test - void finishReturnsByteBufferWithCorrectPosition() { - var sink = new FormUrlEncodedSink(); - sink.writeAscii("test"); - ByteBuffer result = sink.finish(); - assertThat(result.position(), equalTo(0)); - assertThat(result.remaining(), equalTo(4)); - } - - @Test - void hexEncodingUsesUppercase() { - var sink = new FormUrlEncodedSink(); - sink.writeUrlEncoded("ÿ"); - String result = getUTF8String(sink.finish()); - assertThat(result, equalTo("%C3%BF")); - assertThat(result.contains("a") || result.contains("b") - || result.contains("c") - || result.contains("d") - || result.contains("e") - || result.contains("f"), equalTo(false)); - } - - @Test - void unpairedHighSurrogateIsEncodedAsSingleCharacter() { - var sink = new FormUrlEncodedSink(); - // High surrogate \uD83C without a following low surrogate - sink.writeUrlEncoded("a\uD83Cb"); - String result = getUTF8String(sink.finish()); - // High surrogate encoded as 3-byte sequence, then 'b' passes through - assertThat(result, equalTo("a%ED%A0%BCb")); - } - - @Test - void highSurrogateFollowedByNonSurrogateEncodesEachSeparately() { - var sink = new FormUrlEncodedSink(); - // High surrogate \uD83C followed by regular char 'X' (not a low surrogate) - sink.writeUrlEncoded("\uD83CX"); - String result = getUTF8String(sink.finish()); - // High surrogate encoded as 3-byte, then X passes through - assertThat(result, equalTo("%ED%A0%BCX")); - } - - @Test - void highSurrogateAtEndOfStringIsEncoded() { - var sink = new FormUrlEncodedSink(); - // High surrogate at end with no following character - sink.writeUrlEncoded("test\uD83C"); - String result = getUTF8String(sink.finish()); - assertThat(result, equalTo("test%ED%A0%BC")); - } - - @Test - void lowSurrogateAloneIsEncoded() { - var sink = new FormUrlEncodedSink(); - // Lone low surrogate (no preceding high surrogate) - sink.writeUrlEncoded("a\uDE89b"); - String result = getUTF8String(sink.finish()); - assertThat(result, equalTo("a%ED%BA%89b")); - } -} diff --git a/aws/client/aws-client-awsquery/src/test/java/software/amazon/smithy/java/aws/client/awsquery/UrlEncodingTest.java b/aws/client/aws-client-awsquery/src/test/java/software/amazon/smithy/java/aws/client/awsquery/UrlEncodingTest.java new file mode 100644 index 0000000000..77c096793b --- /dev/null +++ b/aws/client/aws-client-awsquery/src/test/java/software/amazon/smithy/java/aws/client/awsquery/UrlEncodingTest.java @@ -0,0 +1,165 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.aws.client.awsquery; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; + +import java.nio.charset.StandardCharsets; +import java.util.stream.Stream; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; + +/** + * Tests URL encoding correctness of the encoding used by QueryFormSerializer. + * Uses {@link AwsQuerySchemaExtensions#encodeName} which exercises the same + * lookup tables and encoding logic. + */ +class UrlEncodingTest { + + private static String urlEncode(String input) { + return new String(AwsQuerySchemaExtensions.encodeName(input), StandardCharsets.UTF_8); + } + + @ParameterizedTest + @ValueSource(strings = { + "ABCDEFGHIJKLMNOPQRSTUVWXYZ", + "abcdefghijklmnopqrstuvwxyz", + "0123456789", + "-._~", + "Hello", + "test123", + "a-b.c_d~e" + }) + void unreservedCharactersPassThrough(String input) { + assertThat(urlEncode(input), equalTo(input)); + } + + @ParameterizedTest + @MethodSource("reservedCharactersProvider") + void reservedCharactersArePercentEncoded(String input, String expected) { + assertThat(urlEncode(input), equalTo(expected)); + } + + static Stream reservedCharactersProvider() { + return Stream.of( + Arguments.of(" ", "%20"), + Arguments.of("!", "%21"), + Arguments.of("#", "%23"), + Arguments.of("$", "%24"), + Arguments.of("%", "%25"), + Arguments.of("&", "%26"), + Arguments.of("'", "%27"), + Arguments.of("(", "%28"), + Arguments.of(")", "%29"), + Arguments.of("*", "%2A"), + Arguments.of("+", "%2B"), + Arguments.of(",", "%2C"), + Arguments.of("/", "%2F"), + Arguments.of(":", "%3A"), + Arguments.of(";", "%3B"), + Arguments.of("=", "%3D"), + Arguments.of("?", "%3F"), + Arguments.of("@", "%40"), + Arguments.of("[", "%5B"), + Arguments.of("]", "%5D"), + Arguments.of("hello world", "hello%20world"), + Arguments.of("a=b&c=d", "a%3Db%26c%3Dd"), + Arguments.of("foo/bar", "foo%2Fbar")); + } + + @ParameterizedTest + @MethodSource("utf8TwoByteProvider") + void twoByteUtf8CharactersAreEncoded(String input, String expected) { + assertThat(urlEncode(input), equalTo(expected)); + } + + static Stream utf8TwoByteProvider() { + return Stream.of( + Arguments.of("é", "%C3%A9"), + Arguments.of("ñ", "%C3%B1"), + Arguments.of("ü", "%C3%BC"), + Arguments.of("café", "caf%C3%A9"), + Arguments.of("©", "%C2%A9")); + } + + @ParameterizedTest + @MethodSource("utf8ThreeByteProvider") + void threeByteUtf8CharactersAreEncoded(String input, String expected) { + assertThat(urlEncode(input), equalTo(expected)); + } + + static Stream utf8ThreeByteProvider() { + return Stream.of( + Arguments.of("€", "%E2%82%AC"), + Arguments.of("中", "%E4%B8%AD"), + Arguments.of("日本", "%E6%97%A5%E6%9C%AC"), + Arguments.of("☃", "%E2%98%83")); + } + + @ParameterizedTest + @MethodSource("utf8FourByteProvider") + void fourByteUtf8SurrogatePairsAreEncoded(String input, String expected) { + assertThat(urlEncode(input), equalTo(expected)); + } + + static Stream utf8FourByteProvider() { + return Stream.of( + Arguments.of("🎉", "%F0%9F%8E%89"), + Arguments.of("😀", "%F0%9F%98%80"), + Arguments.of("𝄞", "%F0%9D%84%9E"), + Arguments.of("hello🎉world", "hello%F0%9F%8E%89world")); + } + + @Test + void writeUrlEncodedWithEmptyString() { + assertThat(urlEncode(""), equalTo("")); + } + + @Test + void writeUrlEncodedWithMixedContent() { + assertThat(urlEncode("Hello World! café 日本 🎉"), + equalTo("Hello%20World%21%20caf%C3%A9%20%E6%97%A5%E6%9C%AC%20%F0%9F%8E%89")); + } + + @Test + void hexEncodingUsesUppercase() { + String result = urlEncode("ÿ"); + assertThat(result, equalTo("%C3%BF")); + assertThat(result.contains("a") || result.contains("b") + || result.contains("c") + || result.contains("d") + || result.contains("e") + || result.contains("f"), equalTo(false)); + } + + @Test + void unpairedHighSurrogateIsEncodedAsSingleCharacter() { + String result = urlEncode("a\uD83Cb"); + assertThat(result, equalTo("a%ED%A0%BCb")); + } + + @Test + void highSurrogateFollowedByNonSurrogateEncodesEachSeparately() { + String result = urlEncode("\uD83CX"); + assertThat(result, equalTo("%ED%A0%BCX")); + } + + @Test + void highSurrogateAtEndOfStringIsEncoded() { + String result = urlEncode("test\uD83C"); + assertThat(result, equalTo("test%ED%A0%BC")); + } + + @Test + void lowSurrogateAloneIsEncoded() { + String result = urlEncode("a\uDE89b"); + assertThat(result, equalTo("a%ED%BA%89b")); + } +} diff --git a/aws/client/aws-client-restjson/build.gradle.kts b/aws/client/aws-client-restjson/build.gradle.kts index 4805b2c749..8e895d7248 100644 --- a/aws/client/aws-client-restjson/build.gradle.kts +++ b/aws/client/aws-client-restjson/build.gradle.kts @@ -19,5 +19,10 @@ dependencies { testImplementation(libs.smithy.aws.protocol.tests) } +protocolTestRuns { + run("native") { systemProperty("smithy-java.json-provider", "smithy") } + run("jackson") { systemProperty("smithy-java.json-provider", "jackson") } +} + val generator = "software.amazon.smithy.java.protocoltests.generators.ProtocolTestGenerator" addGenerateSrcsTask(generator, "restJson1", "aws.protocoltests.restjson#RestJson") diff --git a/aws/client/aws-client-restxml/build.gradle.kts b/aws/client/aws-client-restxml/build.gradle.kts index a8e249170c..156e83d183 100644 --- a/aws/client/aws-client-restxml/build.gradle.kts +++ b/aws/client/aws-client-restxml/build.gradle.kts @@ -19,6 +19,11 @@ dependencies { testImplementation(libs.smithy.aws.protocol.tests) } +protocolTestRuns { + run("native") { systemProperty("smithy-java.xml-provider", "smithy") } + run("stax") { } +} + val generator = "software.amazon.smithy.java.protocoltests.generators.ProtocolTestGenerator" addGenerateSrcsTask(generator, "restXml", "aws.protocoltests.restxml#RestXml") addGenerateSrcsTask(generator, "restXmlWithNamespace", "aws.protocoltests.restxml.xmlns#RestXmlWithNamespace") diff --git a/aws/client/aws-client-rulesengine/build.gradle.kts b/aws/client/aws-client-rulesengine/build.gradle.kts index e756024d66..e9320f9ce3 100644 --- a/aws/client/aws-client-rulesengine/build.gradle.kts +++ b/aws/client/aws-client-rulesengine/build.gradle.kts @@ -1,6 +1,6 @@ plugins { id("smithy-java.module-conventions") - id("me.champeau.jmh") version "0.7.3" + id("smithy-java.jmh-conventions") } description = "This module provides AWS-Specific client rules engine functionality" @@ -34,18 +34,7 @@ configurations["jmhImplementation"].extendsFrom(lambdaModel) jmh { warmupIterations = 2 iterations = 3 - fork = 1 - // Allow filtering for specific benchmarks, e.g. -Pjmh.includes=S3EndpointBenchmark - includes.addAll( - providers - .gradleProperty("jmh.includes") - .map { listOf(it) } - .orElse(emptyList()), - ) - // profilers.add("async:output=flamegraph;dir=build/jmh-profiler") profilers.add("async:output=collapsed;dir=build/jmh-profiler") - // profilers.add("gc") - duplicateClassesStrategy = DuplicatesStrategy.EXCLUDE // don't dump a bunch of warnings. } // Clean cached bytecode before running benchmarks so stale compilations aren't reused. diff --git a/aws/server/aws-server-restjson/build.gradle.kts b/aws/server/aws-server-restjson/build.gradle.kts index 6fc287a525..193a8d6044 100644 --- a/aws/server/aws-server-restjson/build.gradle.kts +++ b/aws/server/aws-server-restjson/build.gradle.kts @@ -25,5 +25,10 @@ dependencies { testImplementation(libs.smithy.aws.protocol.tests) } +protocolTestRuns { + run("native") { systemProperty("smithy-java.json-provider", "smithy") } + run("jackson") { systemProperty("smithy-java.json-provider", "jackson") } +} + val generator = "software.amazon.smithy.java.protocoltests.generators.ProtocolTestGenerator" addGenerateSrcsTask(generator, "restJson1", "aws.protocoltests.restjson#RestJson", "server") diff --git a/benchmarks/serde-benchmarks/build.gradle.kts b/benchmarks/serde-benchmarks/build.gradle.kts index 3415285ece..3c378ff7d3 100644 --- a/benchmarks/serde-benchmarks/build.gradle.kts +++ b/benchmarks/serde-benchmarks/build.gradle.kts @@ -1,7 +1,7 @@ plugins { id("smithy-java.java-conventions") - alias(libs.plugins.shadow) - alias(libs.plugins.jmh) + id("com.gradleup.shadow") + id("smithy-java.jmh-conventions") id("software.amazon.smithy.gradle.smithy-base") } @@ -150,40 +150,22 @@ tasks.named("compileJmhJava") { dependsOn("smithyBuild") } -// All JMH parameters are configured here (single source of truth). -// Per-class annotations (@Warmup, @Measurement, @Fork, etc.) are not used. -// -// Fast mode: -Pjmh.fast (1 warmup, 3 measurement, 1 fork, 5s each) -// Profilers: -Pjmh.profilers=gc,stack (comma-separated JMH profiler names) -// Filter: -Pjmh.includes=RpcV2CborSerializeBenchmark.serialize // Test case: -Pjmh.testCaseId=rpcv2Cbor_PutItemRequest_BinaryData_S val fast = providers.gradleProperty("jmh.fast").isPresent jmh { - benchmarkMode.addAll("sample") - timeUnit = "ns" - warmupIterations = if (fast) 1 else 5 - warmup = if (fast) "5s" else "2s" - iterations = if (fast) 3 else 10 - timeOnIteration = if (fast) "5s" else "5s" - fork = 1 + benchmarkMode.set(listOf("sample")) + if (!fast) { + warmupIterations = 5 + iterations = 10 + } + timeOnIteration = "5s" jvmArgs.addAll( "-Xms1g", "-Xmx1g", "-XX:+UseG1GC", "-XX:+AlwaysPreTouch", "-Dsmithy-java.json-provider=smithy", - ) - includes.addAll( - providers - .gradleProperty("jmh.includes") - .map { listOf(it) } - .orElse(emptyList()), - ) - profilers.addAll( - providers - .gradleProperty("jmh.profilers") - .map { it.split(",") } - .orElse(emptyList()), + "-Dsmithy-java.xml-provider=smithy", ) providers.gradleProperty("jmh.testCaseId").orNull?.let { id -> val prop = objects.listProperty(String::class.java) diff --git a/buildSrc/build.gradle.kts b/buildSrc/build.gradle.kts index 3a37000475..b5453efeb4 100644 --- a/buildSrc/build.gradle.kts +++ b/buildSrc/build.gradle.kts @@ -15,6 +15,9 @@ dependencies { implementation(libs.spotless) implementation(libs.smithy.gradle.base) implementation(libs.dependency.analysis) + implementation(libs.pitest.gradle.plugin) + implementation(libs.jmh.gradle.plugin) + implementation(libs.shadow.gradle.plugin) // https://github.com/gradle/gradle/issues/15383 implementation(files(libs.javaClass.superclass.protectionDomain.codeSource.location)) diff --git a/buildSrc/src/main/kotlin/smithy-java.java-conventions.gradle.kts b/buildSrc/src/main/kotlin/smithy-java.java-conventions.gradle.kts index 9989c4d6f0..7de5d8e536 100644 --- a/buildSrc/src/main/kotlin/smithy-java.java-conventions.gradle.kts +++ b/buildSrc/src/main/kotlin/smithy-java.java-conventions.gradle.kts @@ -8,6 +8,7 @@ plugins { id("com.github.spotbugs") id("com.diffplug.spotless") id("com.autonomousapps.dependency-analysis") + id("info.solidsoft.pitest") id("smithy-java.utilities") } @@ -48,6 +49,7 @@ dependencies { testImplementation(libs.assertj.core) compileOnly("com.github.spotbugs:spotbugs-annotations:${spotbugs.toolVersion.get()}") testCompileOnly("com.github.spotbugs:spotbugs-annotations:${spotbugs.toolVersion.get()}") + "pitest"(libs.pitest.junit5.plugin) } tasks.withType { @@ -121,6 +123,26 @@ tasks.named("spotbugsTest") { enabled = false } +pitest { + targetClasses.set(setOf("software.amazon.smithy.*")) + targetTests.set(setOf("software.amazon.smithy.*")) + excludedClasses.set(setOf("*.GeneratedVersionProvider")) + threads.set(Runtime.getRuntime().availableProcessors()) + outputFormats.set(setOf("HTML", "XML")) + timestampedReports.set(false) + mutationThreshold.set(0) + failWhenNoMutations.set(false) +} + +tasks.named("pitest") { + val reportDir = project.layout.buildDirectory.dir("reports/pitest").map { it.asFile.absolutePath } + doLast { + val dir = reportDir.get() + logger.lifecycle("Pitest HTML report: file://${dir}/index.html") + logger.lifecycle("Pitest XML report: file://${dir}/mutations.xml") + } +} + /* * Repositories * ================================ diff --git a/buildSrc/src/main/kotlin/smithy-java.jmh-conventions.gradle.kts b/buildSrc/src/main/kotlin/smithy-java.jmh-conventions.gradle.kts new file mode 100644 index 0000000000..6c3a5a1075 --- /dev/null +++ b/buildSrc/src/main/kotlin/smithy-java.jmh-conventions.gradle.kts @@ -0,0 +1,54 @@ +import org.gradle.accessors.dm.LibrariesForLibs + +plugins { + id("me.champeau.jmh") +} + +val Project.libs get() = the() + +// Standardized JMH property namespace (all "jmh." prefix): +// -Pjmh.fast Reduce iterations for quick local runs +// -Pjmh.includes= Filter which benchmarks to run (comma-separated) +// -Pjmh.profilers= ";;"-separated profiler specs (e.g. "gc;;async:output=flamegraph;dir=build") +// -Pjmh.warmupIterations=N Override warmup iteration count +// -Pjmh.iterations=N Override measurement iteration count +// -Pjmh.fork=N Override fork count +val fast = providers.gradleProperty("jmh.fast").isPresent + +jmh { + jmhVersion = libs.versions.jmhCore.get() + benchmarkMode.addAll("avgt") + timeUnit = "ns" + fork = providers.gradleProperty("jmh.fork").orElse("1").get().toInt() + + warmupIterations = + if (fast) 1 + else providers.gradleProperty("jmh.warmupIterations").orElse("3").get().toInt() + warmup = if (fast) "1s" else "2s" + iterations = + if (fast) 2 + else providers.gradleProperty("jmh.iterations").orElse("5").get().toInt() + timeOnIteration = if (fast) "3s" else "5s" + + includes.addAll( + providers.gradleProperty("jmh.includes") + .map { it.split(",").filter { s -> s.isNotBlank() } } + .orElse(emptyList()), + ) + + profilers.addAll( + providers.gradleProperty("jmh.profilers") + .map { it.split(";;").filter { s -> s.isNotBlank() } } + .orElse(emptyList()), + ) + + forceGC = true + + jvmArgs.addAll( + providers.gradleProperty("jmh.jvmArgs") + .map { it.split(" ").filter { s -> s.isNotBlank() } } + .orElse(emptyList()), + ) + + duplicateClassesStrategy = DuplicatesStrategy.EXCLUDE +} diff --git a/buildSrc/src/main/kotlin/smithy-java.protocol-testing-conventions.gradle.kts b/buildSrc/src/main/kotlin/smithy-java.protocol-testing-conventions.gradle.kts index 9398de9fad..640cb71a12 100644 --- a/buildSrc/src/main/kotlin/smithy-java.protocol-testing-conventions.gradle.kts +++ b/buildSrc/src/main/kotlin/smithy-java.protocol-testing-conventions.gradle.kts @@ -14,5 +14,36 @@ tasks.named("spotbugsIt") { enabled = false } -// Ensure integ tests are executed as part of test suite -tasks["test"].finalizedBy("integ") +// Extension to allow modules to register additional protocol test runs with custom configuration. +abstract class ProtocolTestRunsExtension { + internal val runs = mutableListOf>>() + + fun run(name: String, action: Action) { + runs.add(name to action) + } +} + +val protocolTestRuns = extensions.create("protocolTestRuns") + +afterEvaluate { + val itSourceSet = project.the()["it"] + val runs = protocolTestRuns.runs + + if (runs.isNotEmpty()) { + tasks.named("integ") { + enabled = false + } + for ((name, action) in runs) { + val task = tasks.register("integ-$name") { + useJUnitPlatform() + testClassesDirs = itSourceSet.output.classesDirs + classpath = itSourceSet.runtimeClasspath + action.execute(this) + } + tasks.named("integ") { dependsOn(task) } + } + } + + // Ensure integ tests are executed as part of test suite + tasks["test"].finalizedBy("integ") +} diff --git a/client/client-rpcv2-json/build.gradle.kts b/client/client-rpcv2-json/build.gradle.kts index cdaa4e1c74..a37ffd576d 100644 --- a/client/client-rpcv2-json/build.gradle.kts +++ b/client/client-rpcv2-json/build.gradle.kts @@ -19,5 +19,10 @@ dependencies { testImplementation(libs.smithy.protocol.tests) } +protocolTestRuns { + run("native") { systemProperty("smithy-java.json-provider", "smithy") } + run("jackson") { systemProperty("smithy-java.json-provider", "jackson") } +} + val generator = "software.amazon.smithy.java.protocoltests.generators.ProtocolTestGenerator" addGenerateSrcsTask(generator, "rpcv2Json", "smithy.protocoltests.rpcv2Json#RpcV2JsonProtocol") diff --git a/codecs/cbor-codec/build.gradle.kts b/codecs/cbor-codec/build.gradle.kts index e114df5b49..5365793a12 100644 --- a/codecs/cbor-codec/build.gradle.kts +++ b/codecs/cbor-codec/build.gradle.kts @@ -11,6 +11,7 @@ extra["moduleName"] = "software.amazon.smithy.java.cbor" dependencies { api(project(":core")) + implementation(project(":codecs:codec-commons", configuration = "shadow")) testFixturesImplementation(libs.assertj.core) testImplementation(project(":codecs:json-codec", configuration = "shadow")) } diff --git a/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborSerializer.java b/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborSerializer.java index c9a9170eb9..34367ee625 100644 --- a/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborSerializer.java +++ b/codecs/cbor-codec/src/main/java/software/amazon/smithy/java/cbor/CborSerializer.java @@ -36,8 +36,8 @@ import java.nio.ByteBuffer; import java.time.Instant; import java.util.Arrays; -import java.util.concurrent.atomic.AtomicReferenceArray; import java.util.function.BiConsumer; +import software.amazon.smithy.java.codecs.commons.StripedPool; import software.amazon.smithy.java.core.schema.Schema; import software.amazon.smithy.java.core.schema.SerializableStruct; import software.amazon.smithy.java.core.serde.MapSerializer; @@ -59,18 +59,7 @@ final class CborSerializer implements ShapeSerializer { private static final int DEFAULT_BUF_SIZE = 4096; private static final int MAX_CACHEABLE_BUF = DEFAULT_BUF_SIZE * 4; - private static final int POOL_SLOTS; - private static final int POOL_MASK; - private static final AtomicReferenceArray POOL; - private static final int MAX_PROBE = 3; - - static { - int processors = Runtime.getRuntime().availableProcessors(); - int raw = processors * 4; - POOL_SLOTS = Integer.highestOneBit(raw - 1) << 1; - POOL_MASK = POOL_SLOTS - 1; - POOL = new AtomicReferenceArray<>(POOL_SLOTS); - } + private static final StripedPool POOL = new CborStripedPool(); byte[] buf; int pos; @@ -101,50 +90,17 @@ final class CborSerializer implements ShapeSerializer { } static CborSerializer acquire() { - if (!Thread.currentThread().isVirtual()) { - int base = poolProbe(); - for (int i = 0; i < MAX_PROBE; i++) { - int idx = (base + i) & POOL_MASK; - CborSerializer s = POOL.getPlain(idx); - if (s != null && POOL.compareAndExchangeAcquire(idx, s, null) == s) { - s.pos = 0; - s.collectionMask = 0L; - s.collectionDepth = 0; - s.currentFieldNameTable = null; - return s; - } - } - } - return new CborSerializer(); + return POOL.acquire(null); } static void release(CborSerializer serializer, boolean exception) { - if (serializer.buf == null || serializer.sink != null || Thread.currentThread().isVirtual()) { - return; - } - if (serializer.buf.length > MAX_CACHEABLE_BUF) { - serializer.buf = new byte[DEFAULT_BUF_SIZE]; - } - int base = poolProbe(); - for (int i = 0; i < MAX_PROBE; i++) { - int idx = (base + i) & POOL_MASK; - if (POOL.getPlain(idx) == null - && POOL.compareAndExchangeRelease(idx, null, serializer) == null) { - return; - } - } - // Pool full, let GC collect + POOL.release(serializer); } ByteBuffer extractResult() { return ByteBuffer.wrap(Arrays.copyOf(buf, pos)); } - private static int poolProbe() { - long id = Thread.currentThread().threadId(); - return (int) (id ^ (id >>> 16)) & POOL_MASK; - } - private void ensureCapacity(int needed) { if (pos + needed > buf.length) { grow(needed); @@ -629,6 +585,34 @@ static byte[] encodeMemberName(String name) { return result; } + private static class CborStripedPool extends StripedPool { + @Override + protected CborSerializer create(Void ctx) { + return new CborSerializer(); + } + + @Override + protected boolean canPool(CborSerializer s) { + return s.buf != null && s.sink == null; + } + + @Override + protected void prepareForPool(CborSerializer s) { + if (s.buf.length > MAX_CACHEABLE_BUF) { + s.buf = new byte[DEFAULT_BUF_SIZE]; + } + } + + @Override + protected boolean reset(CborSerializer s, Void ctx) { + s.pos = 0; + s.collectionMask = 0L; + s.collectionDepth = 0; + s.currentFieldNameTable = null; + return true; + } + } + private final class CborStructSerializer implements ShapeSerializer { @Override diff --git a/codecs/codec-commons/build.gradle.kts b/codecs/codec-commons/build.gradle.kts new file mode 100644 index 0000000000..0ce2940d21 --- /dev/null +++ b/codecs/codec-commons/build.gradle.kts @@ -0,0 +1,60 @@ +plugins { + id("smithy-java.module-conventions") + id("smithy-java.fuzz-test") + id("com.gradleup.shadow") + id("smithy-java.jmh-conventions") +} + +pitest { + excludedClasses.add("software.amazon.smithy.java.codecs.commons.Schubfach*") +} + +description = "Shared utilities for Smithy codec implementations (number formatting, timestamps, base64)" + +extra["displayName"] = "Smithy :: Java :: Codec Commons" +extra["moduleName"] = "software.amazon.smithy.java.codecs.commons" + +dependencies { + api(libs.smithy.utils) + compileOnly(libs.fastdoubleparser) + testRuntimeOnly(libs.fastdoubleparser) +} + +tasks { + shadowJar { + archiveClassifier.set("") + mergeServiceFiles() + configurations = listOf(project.configurations.compileClasspath.get()) + dependencies { + include( + dependency( + libs.fastdoubleparser + .get() + .toString(), + ), + ) + relocate("ch.randelshofer", "software.amazon.smithy.java.codecs.commons.internal.shaded.ch.randelshofer") + } + } + jar { + finalizedBy(shadowJar) + } +} + +configurations { + shadow.get().extendsFrom(api.get()) +} + +configurePublishing { + customComponent = components["shadow"] as SoftwareComponent +} + +afterEvaluate { + val shadowComponent = components["shadow"] as AdhocComponentWithVariants + shadowComponent.addVariantsFromConfiguration(configurations.sourcesElements.get()) { + mapToMavenScope("runtime") + } + shadowComponent.addVariantsFromConfiguration(configurations.javadocElements.get()) { + mapToMavenScope("runtime") + } +} diff --git a/codecs/codec-commons/slow-unit-7be67b43995413bd38f081c6dfdc4e4167d1ac5a b/codecs/codec-commons/slow-unit-7be67b43995413bd38f081c6dfdc4e4167d1ac5a new file mode 100644 index 0000000000..7d396b191b --- /dev/null +++ b/codecs/codec-commons/slow-unit-7be67b43995413bd38f081c6dfdc4e4167d1ac5a @@ -0,0 +1 @@ +11e1003103 \ No newline at end of file diff --git a/codecs/codec-commons/src/fuzz/java/software/amazon/smithy/java/codecs/commons/NumberCodecFuzzTest.java b/codecs/codec-commons/src/fuzz/java/software/amazon/smithy/java/codecs/commons/NumberCodecFuzzTest.java new file mode 100644 index 0000000000..b654e3c38d --- /dev/null +++ b/codecs/codec-commons/src/fuzz/java/software/amazon/smithy/java/codecs/commons/NumberCodecFuzzTest.java @@ -0,0 +1,134 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.codecs.commons; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; + +import com.code_intelligence.jazzer.junit.FuzzTest; +import java.nio.charset.StandardCharsets; +import java.time.Duration; + +class NumberCodecFuzzTest { + + private static final Duration TIMEOUT = Duration.ofSeconds(5); + + @FuzzTest + void fuzzParseInt(byte[] input) { + String str = new String(input, StandardCharsets.US_ASCII); + Integer jdkResult = null; + Exception jdkError = null; + try { + jdkResult = Integer.parseInt(str); + } catch (Exception e) { + jdkError = e; + } + + Integer testResult = null; + Exception testError = null; + try { + testResult = assertTimeoutPreemptively(TIMEOUT, () -> NumberCodec.parseInt(input, 0, input.length)); + } catch (Exception e) { + testError = e; + } + + if (jdkError != null && testResult != null) { + throw new AssertionError( + "JDK failed but NumberCodec succeeded for: \"" + str + "\". JDK error: " + jdkError.getMessage() + + ", NumberCodec result: " + testResult, + jdkError); + } + if (jdkError == null && testError != null) { + throw new AssertionError( + "JDK succeeded but NumberCodec failed for: \"" + str + "\". JDK result: " + jdkResult + + ", NumberCodec error: " + testError.getMessage(), + testError); + } + if (jdkError != null) { + return; + } + assertEquals(jdkResult, testResult, "parseInt mismatch for: \"" + str + "\""); + } + + @FuzzTest + void fuzzParseLong(byte[] input) { + String str = new String(input, StandardCharsets.US_ASCII); + Long jdkResult = null; + Exception jdkError = null; + try { + jdkResult = Long.parseLong(str); + } catch (Exception e) { + jdkError = e; + } + + Long testResult = null; + Exception testError = null; + try { + testResult = assertTimeoutPreemptively(TIMEOUT, () -> NumberCodec.parseLong(input, 0, input.length)); + } catch (Exception e) { + testError = e; + } + + if (jdkError != null && testResult != null) { + throw new AssertionError( + "JDK failed but NumberCodec succeeded for: \"" + str + "\". JDK error: " + jdkError.getMessage() + + ", NumberCodec result: " + testResult, + jdkError); + } + if (jdkError == null && testError != null) { + throw new AssertionError( + "JDK succeeded but NumberCodec failed for: \"" + str + "\". JDK result: " + jdkResult + + ", NumberCodec error: " + testError.getMessage(), + testError); + } + if (jdkError != null) { + return; + } + assertEquals(jdkResult, testResult, "parseLong mismatch for: \"" + str + "\""); + } + + @FuzzTest + void fuzzWriteInt(int value) { + String expected = Integer.toString(value); + byte[] buf = new byte[12]; + int endPos = assertTimeoutPreemptively(TIMEOUT, () -> NumberCodec.writeInt(buf, 0, value)); + String actual = new String(buf, 0, endPos, StandardCharsets.US_ASCII); + assertEquals(expected, actual, "writeInt mismatch for: " + value); + } + + @FuzzTest + void fuzzWriteLong(long value) { + String expected = Long.toString(value); + byte[] buf = new byte[21]; + int endPos = assertTimeoutPreemptively(TIMEOUT, () -> NumberCodec.writeLong(buf, 0, value)); + String actual = new String(buf, 0, endPos, StandardCharsets.US_ASCII); + assertEquals(expected, actual, "writeLong mismatch for: " + value); + } + + @FuzzTest + void fuzzWriteDouble(double value) { + if (Double.isNaN(value) || Double.isInfinite(value)) { + return; + } + byte[] buf = new byte[32]; + int endPos = assertTimeoutPreemptively(TIMEOUT, () -> NumberCodec.writeDouble(buf, 0, value)); + String actual = new String(buf, 0, endPos, StandardCharsets.US_ASCII); + double roundTrip = Double.parseDouble(actual); + assertEquals(value, roundTrip, "writeDouble round-trip mismatch for: " + value); + } + + @FuzzTest + void fuzzWriteFloat(float value) { + if (Float.isNaN(value) || Float.isInfinite(value)) { + return; + } + byte[] buf = new byte[20]; + int endPos = assertTimeoutPreemptively(TIMEOUT, () -> NumberCodec.writeFloat(buf, 0, value)); + String actual = new String(buf, 0, endPos, StandardCharsets.US_ASCII); + float roundTrip = Float.parseFloat(actual); + assertEquals(value, roundTrip, "writeFloat round-trip mismatch for: " + value); + } +} diff --git a/codecs/codec-commons/src/fuzz/java/software/amazon/smithy/java/codecs/commons/TimestampCodecFuzzTest.java b/codecs/codec-commons/src/fuzz/java/software/amazon/smithy/java/codecs/commons/TimestampCodecFuzzTest.java new file mode 100644 index 0000000000..7f49277d44 --- /dev/null +++ b/codecs/codec-commons/src/fuzz/java/software/amazon/smithy/java/codecs/commons/TimestampCodecFuzzTest.java @@ -0,0 +1,195 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.codecs.commons; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively; + +import com.code_intelligence.jazzer.junit.FuzzTest; +import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.Instant; +import java.time.ZoneOffset; +import java.time.format.DateTimeFormatter; +import java.time.format.DateTimeParseException; + +class TimestampCodecFuzzTest { + + private static final Duration TIMEOUT = Duration.ofSeconds(5); + private static final DateTimeFormatter HTTP_FORMATTER = + DateTimeFormatter.ofPattern("EEE, dd MMM yyyy HH:mm:ss 'GMT'").withZone(ZoneOffset.UTC); + + @FuzzTest + void fuzzParseIso8601(byte[] input) { + Instant smithyResult = assertTimeoutPreemptively(TIMEOUT, () -> { + try { + return TimestampCodec.parseIso8601(input, 0, input.length); + } catch (Exception e) { + throw new AssertionError("parseIso8601 threw on input length=" + input.length, e); + } + }); + + String str = new String(input, StandardCharsets.US_ASCII); + Instant jdkResult = null; + try { + jdkResult = Instant.parse(str); + } catch (DateTimeParseException e) { + // JDK couldn't parse it + } + + if (smithyResult != null && jdkResult != null) { + assertEquals(jdkResult, smithyResult, "parseIso8601 mismatch for: " + str); + } + if (smithyResult != null && jdkResult == null) { + throw new AssertionError("parseIso8601 accepted input that JDK rejected: " + str); + } + } + + @FuzzTest + void fuzzWriteIso8601RoundTrip(long epochSecond, int nano) { + nano = Math.abs(nano % 1_000_000_000); + if (epochSecond < -62167219200L || epochSecond > 253402300799L) { + return; + } + Instant instant = Instant.ofEpochSecond(epochSecond, nano); + + byte[] buf = new byte[64]; + int endPos = assertTimeoutPreemptively(TIMEOUT, () -> TimestampCodec.writeIso8601(buf, 0, instant)); + String written = new String(buf, 0, endPos, StandardCharsets.US_ASCII); + + Instant parsed = Instant.parse(written); + assertEquals(instant, parsed, "writeIso8601 round-trip mismatch for epoch=" + epochSecond + " nano=" + nano); + + Instant selfParsed = TimestampCodec.parseIso8601(buf, 0, endPos); + assertNotNull(selfParsed, "parseIso8601 returned null for our own output: " + written); + assertEquals(instant, selfParsed, "self round-trip mismatch for: " + written); + } + + @FuzzTest + void fuzzWriteHttpDateRoundTrip(long epochSecond) { + // Start at year 0001 -- HTTP-date has no year 0 concept and JDK formatter disagrees with + // proleptic Gregorian for year <= 0 + if (epochSecond < -62135596800L || epochSecond > 253402300799L) { + return; + } + Instant instant = Instant.ofEpochSecond(epochSecond); + + byte[] buf = new byte[64]; + int endPos = assertTimeoutPreemptively(TIMEOUT, () -> TimestampCodec.writeHttpDate(buf, 0, instant)); + String written = new String(buf, 0, endPos, StandardCharsets.US_ASCII); + + String jdkFormatted = HTTP_FORMATTER.format(instant); + assertEquals(jdkFormatted, written, "writeHttpDate mismatch for epoch=" + epochSecond); + + Instant selfParsed = TimestampCodec.parseHttpDate(buf, 0, endPos); + assertNotNull(selfParsed, "parseHttpDate returned null for our own output: " + written); + assertEquals(instant, selfParsed, "HTTP-date self round-trip mismatch for: " + written); + } + + @FuzzTest + void fuzzWriteEpochSecondsRoundTrip(long epochSecond, int nano) { + nano = Math.abs(nano % 1_000_000_000); + // Constrain to Instant's valid range so we can verify the full round-trip + if (epochSecond < Instant.MIN.getEpochSecond() || epochSecond > Instant.MAX.getEpochSecond()) { + return; + } + Instant instant = Instant.ofEpochSecond(epochSecond, nano); + final long es = instant.getEpochSecond(); + final int n = instant.getNano(); + + byte[] buf = new byte[64]; + int endPos = assertTimeoutPreemptively(TIMEOUT, () -> TimestampCodec.writeEpochSeconds(buf, 0, es, n)); + String written = new String(buf, 0, endPos, StandardCharsets.US_ASCII); + + Instant selfParsed = TimestampCodec.parseEpochSeconds(buf, 0, endPos); + assertNotNull(selfParsed, "parseEpochSeconds returned null for our own output: " + written); + assertEquals(instant, selfParsed, "epoch-seconds self round-trip mismatch for: " + written); + } + + @FuzzTest + void fuzzParseEpochSeconds(byte[] input) { + Instant smithyResult = assertTimeoutPreemptively(TIMEOUT, () -> { + try { + return TimestampCodec.parseEpochSeconds(input, 0, input.length); + } catch (Exception e) { + throw new AssertionError("parseEpochSeconds threw on input length=" + input.length, e); + } + }); + + // Differential: parse with JDK approach + String str = new String(input, StandardCharsets.US_ASCII); + Instant jdkResult = jdkParseEpochSeconds(str); + + if (smithyResult != null && jdkResult != null) { + assertEquals(jdkResult, smithyResult, "parseEpochSeconds mismatch for: " + str); + } + if (smithyResult != null && jdkResult == null) { + throw new AssertionError("parseEpochSeconds accepted input that JDK rejected: " + str); + } + if (smithyResult == null && jdkResult != null) { + // Our parser is stricter than BigDecimal -- acceptable differences: + boolean hasLeadingPlus = !str.isEmpty() && str.charAt(0) == '+'; + boolean hasWhitespace = !str.equals(str.trim()); + boolean missingIntegerPart = str.startsWith(".") || str.startsWith("-."); + if (!hasLeadingPlus && !hasWhitespace && !missingIntegerPart) { + throw new AssertionError("parseEpochSeconds rejected input that JDK accepted: " + str); + } + } + + // If we parsed it, verify round-trip + if (smithyResult != null) { + byte[] buf = new byte[64]; + int endPos = TimestampCodec.writeEpochSeconds(buf, + 0, + smithyResult.getEpochSecond(), + smithyResult.getNano()); + Instant reparsed = TimestampCodec.parseEpochSeconds(buf, 0, endPos); + assertNotNull(reparsed, "Re-parse failed for smithy output"); + assertEquals(smithyResult, reparsed, "epoch-seconds re-parse mismatch"); + } + } + + private static Instant jdkParseEpochSeconds(String str) { + try { + BigDecimal epoch = new BigDecimal(str); + if (Math.max(epoch.scale(), 0) > 9) + return null; + BigDecimal[] parts = epoch.divideAndRemainder(java.math.BigDecimal.ONE); + long seconds = parts[0].longValueExact(); + long nanos = parts[1].movePointRight(9).longValueExact(); + return Instant.ofEpochSecond(seconds, nanos); + } catch (Exception e) { + return null; + } + } + + @FuzzTest + void fuzzParseHttpDate(byte[] input) { + Instant smithyResult = assertTimeoutPreemptively(TIMEOUT, () -> { + try { + return TimestampCodec.parseHttpDate(input, 0, input.length); + } catch (Exception e) { + throw new AssertionError("parseHttpDate threw on input length=" + input.length, e); + } + }); + String str = new String(input, StandardCharsets.US_ASCII); + Instant jdkResult = null; + try { + jdkResult = HTTP_FORMATTER.parse(str, Instant::from); + } catch (Exception e) { + // JDK couldn't parse it + } + + if (smithyResult != null && jdkResult != null) { + assertEquals(jdkResult, smithyResult, "parseHttpDate mismatch for: " + str); + } + if (smithyResult != null && jdkResult == null) { + throw new AssertionError("parseHttpDate accepted input that JDK rejected: " + str); + } + } +} diff --git a/codecs/codec-commons/src/jmh/java/software/amazon/smithy/java/codecs/commons/NumberCodecBenchmark.java b/codecs/codec-commons/src/jmh/java/software/amazon/smithy/java/codecs/commons/NumberCodecBenchmark.java new file mode 100644 index 0000000000..9b66307295 --- /dev/null +++ b/codecs/codec-commons/src/jmh/java/software/amazon/smithy/java/codecs/commons/NumberCodecBenchmark.java @@ -0,0 +1,102 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.codecs.commons; + +import java.nio.charset.StandardCharsets; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.infra.Blackhole; + +@State(Scope.Benchmark) +public class NumberCodecBenchmark { + + private byte[] intBuf; + private byte[] longBuf; + private byte[] doubleBuf; + private byte[] writeBuf; + private static final int INT_VALUE = 1234567; + private static final long LONG_VALUE = 123456789012345L; + private static final double DOUBLE_VALUE = 3.141592653589793; + + @Setup(Level.Trial) + public void setup() { + intBuf = Integer.toString(INT_VALUE).getBytes(StandardCharsets.US_ASCII); + longBuf = Long.toString(LONG_VALUE).getBytes(StandardCharsets.US_ASCII); + doubleBuf = Double.toString(DOUBLE_VALUE).getBytes(StandardCharsets.US_ASCII); + writeBuf = new byte[32]; + } + + // --- parseInt: input is byte[], produce int --- + + @Benchmark + public void jdkParseInt(Blackhole bh) { + bh.consume(Integer.parseInt(new String(intBuf, 0, intBuf.length, StandardCharsets.US_ASCII))); + } + + @Benchmark + public void smithyParseInt(Blackhole bh) { + bh.consume(NumberCodec.parseInt(intBuf, 0, intBuf.length)); + } + + @Benchmark + public void jdkParseLong(Blackhole bh) { + bh.consume(Long.parseLong(new String(longBuf, 0, longBuf.length, StandardCharsets.US_ASCII))); + } + + @Benchmark + public void smithyParseLong(Blackhole bh) { + bh.consume(NumberCodec.parseLong(longBuf, 0, longBuf.length)); + } + + @Benchmark + public void jdkParseDouble(Blackhole bh) { + bh.consume(Double.parseDouble(new String(doubleBuf, 0, doubleBuf.length, StandardCharsets.US_ASCII))); + } + + @Benchmark + public void smithyParseDouble(Blackhole bh) { + bh.consume(NumberCodec.parseDouble(doubleBuf, 0, doubleBuf.length)); + } + + @Benchmark + public void jdkWriteInt(Blackhole bh) { + byte[] bytes = Integer.toString(INT_VALUE).getBytes(StandardCharsets.US_ASCII); + System.arraycopy(bytes, 0, writeBuf, 0, bytes.length); + bh.consume(writeBuf); + } + + @Benchmark + public void smithyWriteInt(Blackhole bh) { + bh.consume(NumberCodec.writeInt(writeBuf, 0, INT_VALUE)); + } + + @Benchmark + public void jdkWriteLong(Blackhole bh) { + byte[] bytes = Long.toString(LONG_VALUE).getBytes(StandardCharsets.US_ASCII); + System.arraycopy(bytes, 0, writeBuf, 0, bytes.length); + bh.consume(writeBuf); + } + + @Benchmark + public void smithyWriteLong(Blackhole bh) { + bh.consume(NumberCodec.writeLong(writeBuf, 0, LONG_VALUE)); + } + + @Benchmark + public void jdkWriteDouble(Blackhole bh) { + byte[] bytes = Double.toString(DOUBLE_VALUE).getBytes(StandardCharsets.US_ASCII); + System.arraycopy(bytes, 0, writeBuf, 0, bytes.length); + bh.consume(writeBuf); + } + + @Benchmark + public void smithyWriteDouble(Blackhole bh) { + bh.consume(NumberCodec.writeDouble(writeBuf, 0, DOUBLE_VALUE)); + } +} diff --git a/codecs/codec-commons/src/jmh/java/software/amazon/smithy/java/codecs/commons/StripedPoolBenchmark.java b/codecs/codec-commons/src/jmh/java/software/amazon/smithy/java/codecs/commons/StripedPoolBenchmark.java new file mode 100644 index 0000000000..a05551c8bc --- /dev/null +++ b/codecs/codec-commons/src/jmh/java/software/amazon/smithy/java/codecs/commons/StripedPoolBenchmark.java @@ -0,0 +1,100 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.codecs.commons; + +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Group; +import org.openjdk.jmh.annotations.GroupThreads; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.infra.Blackhole; + +@State(Scope.Thread) +public class StripedPoolBenchmark { + + private static final int BUF_SIZE = 1024; + private static final int MAX_CACHEABLE_BUF = BUF_SIZE * 4; + + static final class FakeSerializer { + byte[] buf; + int pos; + + FakeSerializer() { + this.buf = new byte[BUF_SIZE]; + } + } + + private static final StripedPool STRIPED_POOL = new StripedPool<>() { + @Override + protected FakeSerializer create(Void ctx) { + return new FakeSerializer(); + } + + @Override + protected boolean canPool(FakeSerializer s) { + return true; + } + + @Override + protected void prepareForPool(FakeSerializer s) { + if (s.buf.length > MAX_CACHEABLE_BUF) { + s.buf = new byte[BUF_SIZE]; + } + } + + @Override + protected boolean reset(FakeSerializer s, Void ctx) { + s.pos = 0; + return true; + } + }; + + @Benchmark + public void stripedPoolAcquireRelease(Blackhole bh) { + FakeSerializer s = STRIPED_POOL.acquire(null); + bh.consume(s.pos); + STRIPED_POOL.release(s); + } + + @Benchmark + public void newAllocEachTime(Blackhole bh) { + FakeSerializer s = new FakeSerializer(); + bh.consume(s.pos); + } + + @Benchmark + @Fork(jvmArgsAppend = "-Djmh.executor=VIRTUAL") + public void stripedPoolVirtualThread(Blackhole bh) { + FakeSerializer s = STRIPED_POOL.acquire(null); + bh.consume(s.pos); + STRIPED_POOL.release(s); + } + + @Benchmark + @Fork(jvmArgsAppend = "-Djmh.executor=VIRTUAL") + public void newAllocVirtualThread(Blackhole bh) { + FakeSerializer s = new FakeSerializer(); + bh.consume(s.pos); + } + + @Benchmark + @Group("stripedContended") + @GroupThreads(8) + public void stripedPoolContended(Blackhole bh) { + FakeSerializer s = STRIPED_POOL.acquire(null); + bh.consume(s.pos); + STRIPED_POOL.release(s); + } + + @Benchmark + @Group("newAllocContended") + @GroupThreads(8) + public void newAllocContended(Blackhole bh) { + FakeSerializer s = new FakeSerializer(); + bh.consume(s.pos); + } +} diff --git a/codecs/codec-commons/src/jmh/java/software/amazon/smithy/java/codecs/commons/TimestampCodecBenchmark.java b/codecs/codec-commons/src/jmh/java/software/amazon/smithy/java/codecs/commons/TimestampCodecBenchmark.java new file mode 100644 index 0000000000..b9e4a334ee --- /dev/null +++ b/codecs/codec-commons/src/jmh/java/software/amazon/smithy/java/codecs/commons/TimestampCodecBenchmark.java @@ -0,0 +1,154 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.codecs.commons; + +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.time.ZoneOffset; +import java.time.format.DateTimeFormatter; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.infra.Blackhole; + +@State(Scope.Benchmark) +public class TimestampCodecBenchmark { + + private static final Instant TEST_INSTANT = Instant.parse("2025-06-15T14:30:45.123456789Z"); + private static final Instant TEST_INSTANT_WHOLE = Instant.ofEpochSecond(1750000000L); + + private static final DateTimeFormatter ISO_FORMATTER = DateTimeFormatter.ISO_INSTANT; + private static final DateTimeFormatter HTTP_FORMATTER = + DateTimeFormatter.ofPattern("EEE, dd MMM yyyy HH:mm:ss 'GMT'").withZone(ZoneOffset.UTC); + + private byte[] iso8601Buf; + private byte[] httpDateBuf; + private byte[] epochSecondsBuf; + private byte[] epochSecondsFracBuf; + private byte[] writeBuf; + + @Setup(Level.Trial) + public void setup() { + iso8601Buf = "2025-06-15T14:30:45.123456789Z".getBytes(StandardCharsets.US_ASCII); + httpDateBuf = "Sun, 15 Jun 2025 14:30:45 GMT".getBytes(StandardCharsets.US_ASCII); + epochSecondsBuf = "1750000000".getBytes(StandardCharsets.US_ASCII); + epochSecondsFracBuf = "1750000045.123456789".getBytes(StandardCharsets.US_ASCII); + writeBuf = new byte[64]; + } + + // --- parseIso8601 --- + + @Benchmark + public void jdkParseIso8601(Blackhole bh) { + String s = new String(iso8601Buf, StandardCharsets.US_ASCII); + bh.consume(Instant.parse(s)); + } + + @Benchmark + public void smithyParseIso8601(Blackhole bh) { + bh.consume(TimestampCodec.parseIso8601(iso8601Buf, 0, iso8601Buf.length)); + } + + @Benchmark + public void jdkParseHttpDate(Blackhole bh) { + String s = new String(httpDateBuf, StandardCharsets.US_ASCII); + bh.consume(HTTP_FORMATTER.parse(s, Instant::from)); + } + + @Benchmark + public void smithyParseHttpDate(Blackhole bh) { + bh.consume(TimestampCodec.parseHttpDate(httpDateBuf, 0, httpDateBuf.length)); + } + + @Benchmark + public void jdkParseEpochSeconds(Blackhole bh) { + String s = new String(epochSecondsBuf, StandardCharsets.US_ASCII); + bh.consume(Instant.ofEpochSecond(Long.parseLong(s))); + } + + @Benchmark + public void smithyParseEpochSeconds(Blackhole bh) { + bh.consume(TimestampCodec.parseEpochSeconds(epochSecondsBuf, 0, epochSecondsBuf.length)); + } + + @Benchmark + public void jdkParseEpochSecondsFrac(Blackhole bh) { + String s = new String(epochSecondsFracBuf, StandardCharsets.US_ASCII); + int dot = s.indexOf('.'); + long secs = Long.parseLong(s.substring(0, dot)); + String frac = s.substring(dot + 1); + int nano = Integer.parseInt(frac) * (int) Math.pow(10, 9 - frac.length()); + bh.consume(Instant.ofEpochSecond(secs, nano)); + } + + @Benchmark + public void smithyParseEpochSecondsFrac(Blackhole bh) { + bh.consume(TimestampCodec.parseEpochSeconds(epochSecondsFracBuf, 0, epochSecondsFracBuf.length)); + } + + // --- writeIso8601 --- + + @Benchmark + public void jdkWriteIso8601(Blackhole bh) { + byte[] bytes = ISO_FORMATTER.format(TEST_INSTANT).getBytes(StandardCharsets.US_ASCII); + System.arraycopy(bytes, 0, writeBuf, 0, bytes.length); + bh.consume(writeBuf); + } + + @Benchmark + public void smithyWriteIso8601(Blackhole bh) { + bh.consume(TimestampCodec.writeIso8601(writeBuf, 0, TEST_INSTANT)); + } + + @Benchmark + public void jdkWriteHttpDate(Blackhole bh) { + byte[] bytes = HTTP_FORMATTER.format(TEST_INSTANT).getBytes(StandardCharsets.US_ASCII); + System.arraycopy(bytes, 0, writeBuf, 0, bytes.length); + bh.consume(writeBuf); + } + + @Benchmark + public void smithyWriteHttpDate(Blackhole bh) { + bh.consume(TimestampCodec.writeHttpDate(writeBuf, 0, TEST_INSTANT)); + } + + @Benchmark + public void jdkWriteEpochSeconds(Blackhole bh) { + Instant inst = TEST_INSTANT; + String s; + if (inst.getNano() == 0) { + s = Long.toString(inst.getEpochSecond()); + } else { + s = inst.getEpochSecond() + "." + String.format("%09d", inst.getNano()).replaceAll("0+$", ""); + } + byte[] bytes = s.getBytes(StandardCharsets.US_ASCII); + System.arraycopy(bytes, 0, writeBuf, 0, bytes.length); + bh.consume(writeBuf); + } + + @Benchmark + public void smithyWriteEpochSeconds(Blackhole bh) { + bh.consume( + TimestampCodec.writeEpochSeconds(writeBuf, 0, TEST_INSTANT.getEpochSecond(), TEST_INSTANT.getNano())); + } + + @Benchmark + public void jdkWriteEpochSecondsWhole(Blackhole bh) { + byte[] bytes = Long.toString(TEST_INSTANT_WHOLE.getEpochSecond()).getBytes(StandardCharsets.US_ASCII); + System.arraycopy(bytes, 0, writeBuf, 0, bytes.length); + bh.consume(writeBuf); + } + + @Benchmark + public void smithyWriteEpochSecondsWhole(Blackhole bh) { + bh.consume(TimestampCodec.writeEpochSeconds(writeBuf, + 0, + TEST_INSTANT_WHOLE.getEpochSecond(), + TEST_INSTANT_WHOLE.getNano())); + } +} diff --git a/codecs/codec-commons/src/main/java/software/amazon/smithy/java/codecs/commons/NumberCodec.java b/codecs/codec-commons/src/main/java/software/amazon/smithy/java/codecs/commons/NumberCodec.java new file mode 100644 index 0000000000..414bf5bc87 --- /dev/null +++ b/codecs/codec-commons/src/main/java/software/amazon/smithy/java/codecs/commons/NumberCodec.java @@ -0,0 +1,606 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.codecs.commons; + +import ch.randelshofer.fastdoubleparser.JavaDoubleParser; +import ch.randelshofer.fastdoubleparser.JavaFloatParser; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.VarHandle; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteOrder; +import java.nio.charset.StandardCharsets; +import software.amazon.smithy.utils.SmithyInternalApi; + +/** + * Low-level utilities for reading/writing numeric primitives directly from/to byte arrays. + */ +@SmithyInternalApi +public final class NumberCodec { + + private NumberCodec() {} + + private static final VarHandle INT_HANDLE = + MethodHandles.byteArrayViewVarHandle(int[].class, ByteOrder.BIG_ENDIAN); + private static final VarHandle LONG_HANDLE = + MethodHandles.byteArrayViewVarHandle(long[].class, ByteOrder.BIG_ENDIAN); + private static final byte[] INT_MAX_DIGITS = "2147483647".getBytes(StandardCharsets.UTF_8); + private static final byte[] INT_MIN_DIGITS = "2147483648".getBytes(StandardCharsets.UTF_8); + + private static final byte[] TRUE = {'t', 'r', 'u', 'e'}; + private static final byte[] FALSE = {'f', 'a', 'l', 's', 'e'}; + + private static final byte[] NAN = {'N', 'a', 'N'}; + private static final byte[] INF = {'I', 'n', 'f', 'i', 'n', 'i', 't', 'y'}; + private static final byte[] NEG_INF = {'-', 'I', 'n', 'f', 'i', 'n', 'i', 't', 'y'}; + private static final byte[] QUOTED_NAN = {'"', 'N', 'a', 'N', '"'}; + private static final byte[] QUOTED_INF = {'"', 'I', 'n', 'f', 'i', 'n', 'i', 't', 'y', '"'}; + private static final byte[] QUOTED_NEG_INF = {'"', '-', 'I', 'n', 'f', 'i', 'n', 'i', 't', 'y', '"'}; + private static final long[] POWERS_OF_10 = { + 1L, + 10L, + 100L, + 1000L, + 10000L, + 100000L, + 1000000L, + 10000000L, + 100000000L, + 1000000000L, + 10000000000L, + 100000000000L, + 1000000000000L, + 10000000000000L, + 100000000000000L, + 1000000000000000L, + 10000000000000000L, + 100000000000000000L, + 1000000000000000000L + }; + + /** + * Parse an int from a byte array span. Handles optional leading minus sign. + */ + public static int parseInt(byte[] buf, int start, int len) { + if (len == 0) { + throw new NumberFormatException("Empty input"); + } + + boolean negative = false; + int i = start; + int end = start + len; + + if (buf[i] == '-') { + negative = true; + i++; + if (i == end) { + throw new NumberFormatException("Just a minus sign"); + } + } else if (buf[i] == '+') { + i++; + if (i == end) { + throw new NumberFormatException("Just a plus sign"); + } + } + + while (i < end - 1 && buf[i] == '0') { + i++; + } + + int digitLen = end - i; + if (digitLen > 10) { + throw new NumberFormatException("Integer overflow"); + } + + if (digitLen == 10) { + byte[] limit = negative ? INT_MIN_DIGITS : INT_MAX_DIGITS; + for (int j = 0; j < 10; j++) { + int cmp = (buf[i + j] & 0xFF) - (limit[j] & 0xFF); + if (cmp < 0) { + break; + } + if (cmp > 0) { + throw new NumberFormatException("Integer overflow"); + } + } + } + + int result = 0; + + while (i + 4 <= end) { + int v = parse4Digits(buf, i); + if (v < 0) { + break; + } + result = result * 10000 - v; + i += 4; + } + + while (i < end) { + int d = buf[i++] - '0'; + if (d < 0 || d > 9) { + throw new NumberFormatException("Invalid digit at position " + (i - 1 - start)); + } + result = result * 10 - d; + } + + return negative ? result : -result; + } + + /** + * Parse a long from a byte array span. Handles optional leading minus sign. + */ + public static long parseLong(byte[] buf, int start, int len) { + if (len == 0) { + throw new NumberFormatException("Empty input"); + } + + boolean negative = false; + int i = start; + int end = start + len; + + if (buf[i] == '-') { + negative = true; + i++; + if (i == end) { + throw new NumberFormatException("Just a minus sign"); + } + } else if (buf[i] == '+') { + i++; + if (i == end) { + throw new NumberFormatException("Just a plus sign"); + } + } + + while (i < end - 1 && buf[i] == '0') { + i++; + } + + int digitLen = end - i; + if (digitLen > 19) { + throw new NumberFormatException("Long overflow"); + } + + long result = 0; + + while (i + 8 <= end) { + long v = parse8Digits(buf, i); + if (v < 0) { + break; + } + result = result * 100000000L - v; + i += 8; + } + + while (i + 4 <= end) { + int v = parse4Digits(buf, i); + if (v < 0) { + break; + } + result = result * 10000 - v; + i += 4; + } + + while (i < end) { + int d = buf[i++] - '0'; + if (d < 0 || d > 9) { + throw new NumberFormatException("Invalid digit at position " + (i - 1 - start)); + } + long next = result * 10 - d; + if (next > result) { + throw new NumberFormatException("Long overflow"); + } + result = next; + } + + if (!negative) { + result = -result; + if (result < 0) { + throw new NumberFormatException("Long overflow"); + } + } + return result; + } + + /** + * Parses 4 ASCII digit bytes at buf[i..i+4) into a value 0..9999. + * Returns -1 if any byte is not a digit. + */ + private static int parse4Digits(byte[] buf, int i) { + int word = (int) INT_HANDLE.get(buf, i); + int sub = word - 0x30303030; + // (sub + 0x76767676) & 0x80808080 catches bytes > 9; + // sub & 0x80808080 catches bytes < 0 (original byte < '0') + if ((((sub + 0x76767676) | sub) & 0x80808080) != 0) { + return -1; + } + int d0 = (sub >>> 24) & 0xFF; + int d1 = (sub >>> 16) & 0xFF; + int d2 = (sub >>> 8) & 0xFF; + int d3 = sub & 0xFF; + return d0 * 1000 + d1 * 100 + d2 * 10 + d3; + } + + /** + * Parses 8 ASCII digit bytes at buf[i..i+8) into a value 0..99999999. + * Returns -1 if any byte is not a digit. + */ + private static long parse8Digits(byte[] buf, int i) { + long word = (long) LONG_HANDLE.get(buf, i); + long sub = word - 0x3030303030303030L; + if ((((sub + 0x7676767676767676L) | sub) & 0x8080808080808080L) != 0) { + return -1; + } + // SWAR: combine pairs -> quads -> final 8-digit value + long lo = sub & 0x000F000F000F000FL; + long hi = (sub >>> 8) & 0x000F000F000F000FL; + long pairs = hi * 10 + lo; + long lo2 = pairs & 0x0000007F0000007FL; + long hi2 = (pairs >>> 16) & 0x0000007F0000007FL; + long quads = hi2 * 100 + lo2; + int upper = (int) (quads >>> 32); + int lower = (int) quads; + return (long) upper * 10000 + lower; + } + + /** + * Parse a double from a byte array span using FastDoubleParser (no String allocation). + */ + public static double parseDouble(byte[] buf, int offset, int length) { + return JavaDoubleParser.parseDouble(buf, offset, length); + } + + /** + * Parse a float from a byte array span using FastDoubleParser (no String allocation). + */ + public static float parseFloat(byte[] buf, int offset, int length) { + return JavaFloatParser.parseFloat(buf, offset, length); + } + + public static int writeBoolean(byte[] buf, int pos, boolean value) { + byte[] bytes = value ? TRUE : FALSE; + System.arraycopy(bytes, 0, buf, pos, bytes.length); + return pos + bytes.length; + } + + private static final byte[] DIGIT_PAIRS = new byte[200]; + + static { + for (int i = 0; i < 100; i++) { + DIGIT_PAIRS[i * 2] = (byte) ('0' + i / 10); + DIGIT_PAIRS[i * 2 + 1] = (byte) ('0' + i % 10); + } + } + + private static final BigInteger TEN_TO_18 = BigInteger.valueOf(1_000_000_000_000_000_000L); + + public static int writeInt(byte[] buf, int pos, int value) { + if (value == 0) { + buf[pos] = '0'; + return pos + 1; + } + + if (value < 0) { + buf[pos++] = '-'; + if (value == Integer.MIN_VALUE) { + return writePositiveLong(buf, pos, 2147483648L); + } + value = -value; + } + + return writePositiveInt(buf, pos, value); + } + + public static int writeLong(byte[] buf, int pos, long value) { + if (value == 0) { + buf[pos] = '0'; + return pos + 1; + } + + if (value < 0) { + buf[pos++] = '-'; + if (value == Long.MIN_VALUE) { + byte[] minBytes = "9223372036854775808".getBytes(StandardCharsets.US_ASCII); + System.arraycopy(minBytes, 0, buf, pos, minBytes.length); + return pos + minBytes.length; + } + value = -value; + } + + return writePositiveLong(buf, pos, value); + } + + public static int writeBigInteger(byte[] buf, int pos, BigInteger value) { + if (value.signum() < 0) { + buf[pos++] = '-'; + value = value.negate(); + } + + if (value.compareTo(TEN_TO_18) < 0) { + return writePositiveLong(buf, pos, value.longValue()); + } + + BigInteger[] qr = value.divideAndRemainder(TEN_TO_18); + BigInteger high = qr[0]; + long low = qr[1].longValue(); + + if (high.compareTo(TEN_TO_18) < 0) { + pos = writePositiveLong(buf, pos, high.longValue()); + return writePaddedLong18(buf, pos, low); + } + + BigInteger[] qr2 = high.divideAndRemainder(TEN_TO_18); + long mid = qr2[1].longValue(); + + if (qr2[0].compareTo(TEN_TO_18) < 0) { + pos = writePositiveLong(buf, pos, qr2[0].longValue()); + pos = writePaddedLong18(buf, pos, mid); + return writePaddedLong18(buf, pos, low); + } + + String s = value.toString(); + return writeAsciiString(buf, pos, s); + } + + public static int writeDouble(byte[] buf, int pos, double value) { + long longValue = (long) value; + if (value == (double) longValue) { + return writeLong(buf, pos, longValue); + } + return Schubfach.writeDouble(buf, pos, value); + } + + public static int writeFloat(byte[] buf, int pos, float value) { + int intValue = (int) value; + if (value == (float) intValue) { + return writeInt(buf, pos, intValue); + } + return Schubfach.writeFloat(buf, pos, value); + } + + public static int writeNonFiniteFloat(byte[] buf, int pos, float value) { + byte[] bytes; + if (Float.isNaN(value)) { + bytes = NAN; + } else { + bytes = value > 0 ? INF : NEG_INF; + } + System.arraycopy(bytes, 0, buf, pos, bytes.length); + return pos + bytes.length; + } + + public static int writeNonFiniteDouble(byte[] buf, int pos, double value) { + byte[] bytes; + if (Double.isNaN(value)) { + bytes = NAN; + } else { + bytes = value > 0 ? INF : NEG_INF; + } + System.arraycopy(bytes, 0, buf, pos, bytes.length); + return pos + bytes.length; + } + + public static int writeNonFiniteFloatQuoted(byte[] buf, int pos, float value) { + byte[] bytes; + if (Float.isNaN(value)) { + bytes = QUOTED_NAN; + } else { + bytes = value > 0 ? QUOTED_INF : QUOTED_NEG_INF; + } + System.arraycopy(bytes, 0, buf, pos, bytes.length); + return pos + bytes.length; + } + + public static int writeNonFiniteDoubleQuoted(byte[] buf, int pos, double value) { + byte[] bytes; + if (Double.isNaN(value)) { + bytes = QUOTED_NAN; + } else { + bytes = value > 0 ? QUOTED_INF : QUOTED_NEG_INF; + } + System.arraycopy(bytes, 0, buf, pos, bytes.length); + return pos + bytes.length; + } + + public static int writeFloatFull(byte[] buf, int pos, float value) { + if (Float.isFinite(value)) { + return writeFloat(buf, pos, value); + } + return writeNonFiniteFloat(buf, pos, value); + } + + public static int writeDoubleFull(byte[] buf, int pos, double value) { + if (Double.isFinite(value)) { + return writeDouble(buf, pos, value); + } + return writeNonFiniteDouble(buf, pos, value); + } + + public static int writeFloatFullQuoted(byte[] buf, int pos, float value) { + if (Float.isFinite(value)) { + return writeFloat(buf, pos, value); + } + return writeNonFiniteFloatQuoted(buf, pos, value); + } + + public static int writeDoubleFullQuoted(byte[] buf, int pos, double value) { + if (Double.isFinite(value)) { + return writeDouble(buf, pos, value); + } + return writeNonFiniteDoubleQuoted(buf, pos, value); + } + + public static int writeBigDecimal(byte[] buf, int pos, BigDecimal value) { + int scale = value.scale(); + if (value.unscaledValue().bitLength() < 64) { + if (scale == 0) { + return writeLong(buf, pos, value.longValueExact()); + } + if (scale > 0 && scale < POWERS_OF_10.length) { + long unscaled = value.unscaledValue().longValue(); + if (unscaled < 0) { + buf[pos++] = '-'; + unscaled = -unscaled; + } + long divisor = POWERS_OF_10[scale]; + long intPart = unscaled / divisor; + long fracPart = unscaled - intPart * divisor; + pos = writeLong(buf, pos, intPart); + buf[pos++] = '.'; + for (int i = scale - 1; i >= 0; i--) { + long p10 = POWERS_OF_10[i]; + int d = (int) (fracPart / p10); + buf[pos++] = (byte) ('0' + d); + fracPart -= d * p10; + } + return pos; + } + } + return writeAsciiString(buf, pos, value.toPlainString()); + } + + public static int maxBigDecimalLength(BigDecimal value) { + int scale = value.scale(); + int bitLength = value.unscaledValue().bitLength(); + if (bitLength < 64 && scale >= 0 && scale < POWERS_OF_10.length) { + return 1 + 20 + 1 + scale; + } + int unscaledDigits = (int) (bitLength * 0.302) + 2; + if (scale <= 0) { + return 1 + unscaledDigits + (-scale); + } + return 1 + unscaledDigits + 1 + scale; + } + + @SuppressWarnings("deprecation") + private static int writeAsciiString(byte[] buf, int pos, String s) { + int len = s.length(); + s.getBytes(0, len, buf, pos); + return pos + len; + } + + private static int writePositiveInt(byte[] buf, int pos, int value) { + int digits = digitCount(value); + int end = pos + digits; + int p = end; + + while (value >= 10000 && p - 4 >= pos) { + int q = value / 10000; + int r = value - q * 10000; + value = q; + p -= 4; + write4Digits(buf, p, r); + } + + while (value >= 100) { + int q = value / 100; + int r = (value - q * 100) * 2; + value = q; + buf[--p] = DIGIT_PAIRS[r + 1]; + buf[--p] = DIGIT_PAIRS[r]; + } + + if (value >= 10) { + int r = value * 2; + buf[--p] = DIGIT_PAIRS[r + 1]; + buf[--p] = DIGIT_PAIRS[r]; + } else { + buf[--p] = (byte) ('0' + value); + } + + return end; + } + + private static final int DIGITS_10_8 = 100_000_000; + + private static int writePositiveLong(byte[] buf, int pos, long value) { + if (value <= Integer.MAX_VALUE) { + return writePositiveInt(buf, pos, (int) value); + } + + // Split into 32-bit chunks to avoid repeated 64-bit division + if (value < (long) DIGITS_10_8 * DIGITS_10_8) { + int lo = (int) (value % DIGITS_10_8); + int hi = (int) (value / DIGITS_10_8); + pos = writePositiveInt(buf, pos, hi); + return writePaddedInt8(buf, pos, lo); + } else { + long tmp = value / DIGITS_10_8; + int lo = (int) (value - tmp * DIGITS_10_8); + int mid = (int) (tmp % DIGITS_10_8); + int hi = (int) (tmp / DIGITS_10_8); + pos = writePositiveInt(buf, pos, hi); + pos = writePaddedInt8(buf, pos, mid); + return writePaddedInt8(buf, pos, lo); + } + } + + /** + * Writes exactly 8 digits (zero-padded) using 32-bit arithmetic only. + */ + private static int writePaddedInt8(byte[] buf, int pos, int value) { + int hi4 = value / 10000; + int lo4 = value - hi4 * 10000; + write4Digits(buf, pos, hi4); + write4Digits(buf, pos + 4, lo4); + return pos + 8; + } + + /** + * Writes 4 decimal digits (0000-9999) as a single int store (big-endian). + */ + private static void write4Digits(byte[] buf, int pos, int value) { + int q = (value * 5243) >>> 19; // value / 100 + int r = value - q * 100; + int d01 = DIGIT_PAIRS[q * 2] << 24 | (DIGIT_PAIRS[q * 2 + 1] & 0xFF) << 16; + int d23 = (DIGIT_PAIRS[r * 2] & 0xFF) << 8 | (DIGIT_PAIRS[r * 2 + 1] & 0xFF); + INT_HANDLE.set(buf, pos, d01 | d23); + } + + private static int writePaddedLong18(byte[] buf, int pos, long value) { + int end = pos + 18; + int p = end; + for (int i = 0; i < 9; i++) { + int r = (int) (value % 100) * 2; + value /= 100; + buf[--p] = DIGIT_PAIRS[r + 1]; + buf[--p] = DIGIT_PAIRS[r]; + } + return end; + } + + public static int digitCount(int value) { + if (value < 10) { + return 1; + } + if (value < 100) { + return 2; + } + if (value < 1000) { + return 3; + } + if (value < 10000) { + return 4; + } + if (value < 100000) { + return 5; + } + if (value < 1000000) { + return 6; + } + if (value < 10000000) { + return 7; + } + if (value < 100000000) { + return 8; + } + if (value < 1000000000) { + return 9; + } + return 10; + } + +} diff --git a/codecs/json-codec/src/main/java/software/amazon/smithy/java/json/smithy/Schubfach.java b/codecs/codec-commons/src/main/java/software/amazon/smithy/java/codecs/commons/Schubfach.java similarity index 75% rename from codecs/json-codec/src/main/java/software/amazon/smithy/java/json/smithy/Schubfach.java rename to codecs/codec-commons/src/main/java/software/amazon/smithy/java/codecs/commons/Schubfach.java index 678cf809a3..2b3951e1a9 100644 --- a/codecs/json-codec/src/main/java/software/amazon/smithy/java/json/smithy/Schubfach.java +++ b/codecs/codec-commons/src/main/java/software/amazon/smithy/java/codecs/commons/Schubfach.java @@ -3,18 +3,17 @@ * SPDX-License-Identifier: Apache-2.0 */ -package software.amazon.smithy.java.json.smithy; +package software.amazon.smithy.java.codecs.commons; import static java.lang.Double.doubleToRawLongBits; import static java.lang.Float.floatToRawIntBits; import static java.lang.Integer.numberOfLeadingZeros; -import static software.amazon.smithy.java.json.smithy.Schubfach.DoubleToDecimal.NON_SPECIAL; -import static software.amazon.smithy.java.json.smithy.Schubfach.MathUtils.flog10pow2; -import static software.amazon.smithy.java.json.smithy.Schubfach.MathUtils.flog10threeQuartersPow2; -import static software.amazon.smithy.java.json.smithy.Schubfach.MathUtils.flog2pow10; -import static software.amazon.smithy.java.json.smithy.Schubfach.MathUtils.g0; -import static software.amazon.smithy.java.json.smithy.Schubfach.MathUtils.g1; -import static software.amazon.smithy.java.json.smithy.Schubfach.MathUtils.multiplyHigh; +import static java.lang.Math.multiplyHigh; +import static software.amazon.smithy.java.codecs.commons.Schubfach.MathUtils.flog10pow2; +import static software.amazon.smithy.java.codecs.commons.Schubfach.MathUtils.flog10threeQuartersPow2; +import static software.amazon.smithy.java.codecs.commons.Schubfach.MathUtils.flog2pow10; +import static software.amazon.smithy.java.codecs.commons.Schubfach.MathUtils.g0; +import static software.amazon.smithy.java.codecs.commons.Schubfach.MathUtils.g1; final class Schubfach { @@ -22,74 +21,70 @@ private Schubfach() { } - /** Creates a reusable DoubleToDecimal instance for hot-path serialization. */ - static DoubleToDecimal createDoubleToDecimal() { - return new DoubleToDecimal(); - } - - /** Creates a reusable FloatToDecimal instance for hot-path serialization. */ - static FloatToDecimal createFloatToDecimal() { - return new FloatToDecimal(); - } - - public static int writeDouble(byte[] buf, int pos, double v) { - return writeDouble(buf, pos, v, new DoubleToDecimal()); - } - - /** - * Writes the decimal representation of {@code v} using a reusable {@link DoubleToDecimal} - * instance to avoid per-call allocation on hot paths. - */ - public static int writeDouble(byte[] buf, int pos, double v, DoubleToDecimal dtd) { - dtd.bytes = buf; - dtd.index = pos - 1; // pre-increment convention: first ++index writes to buf[pos] - return switch (dtd.toDecimal(v)) { - case DoubleToDecimal.NON_SPECIAL -> dtd.index + 1; - case DoubleToDecimal.PLUS_ZERO -> { - buf[pos] = '0'; - buf[pos + 1] = '.'; - buf[pos + 2] = '0'; - yield pos + 3; + static int writeDouble(byte[] buf, int pos, double v) { + long bits = doubleToRawLongBits(v); + long t = bits & DoubleToDecimal.T_MASK; + int bq = (int) (bits >>> DoubleToDecimal.P - 1) & DoubleToDecimal.BQ_MASK; + if (bq < DoubleToDecimal.BQ_MASK) { + if (bits < 0) { + buf[pos++] = '-'; } - case DoubleToDecimal.MINUS_ZERO -> { - buf[pos] = '-'; - buf[pos + 1] = '0'; - buf[pos + 2] = '.'; - buf[pos + 3] = '0'; - yield pos + 4; + if (bq != 0) { + int mq = -DoubleToDecimal.Q_MIN + 1 - bq; + long c = DoubleToDecimal.C_MIN | t; + if (0 < mq & mq < DoubleToDecimal.P) { + long f = c >> mq; + if (f << mq == c) { + return DoubleToDecimal.toChars(buf, pos, f, 0); + } + } + return DoubleToDecimal.toDecimal(buf, pos, -mq, c, 0); } - default -> throw new AssertionError("Infinity/NaN should be handled by caller"); - }; - } - - public static int writeFloat(byte[] buf, int pos, float v) { - return writeFloat(buf, pos, v, new FloatToDecimal()); + if (t != 0) { + return t < DoubleToDecimal.C_TINY + ? DoubleToDecimal.toDecimal(buf, pos, DoubleToDecimal.Q_MIN, 10 * t, -1) + : DoubleToDecimal.toDecimal(buf, pos, DoubleToDecimal.Q_MIN, t, 0); + } + // +0.0 or -0.0 (sign already written if negative) + buf[pos++] = '0'; + buf[pos++] = '.'; + buf[pos++] = '0'; + return pos; + } + throw new AssertionError("Infinity/NaN should be handled by caller"); } - /** - * Writes the decimal representation of {@code v} using a reusable {@link FloatToDecimal} - * instance to avoid per-call allocation on hot paths. - */ - public static int writeFloat(byte[] buf, int pos, float v, FloatToDecimal ftd) { - ftd.bytes = buf; - ftd.index = pos - 1; - return switch (ftd.toDecimal(v)) { - case FloatToDecimal.NON_SPECIAL -> ftd.index + 1; - case FloatToDecimal.PLUS_ZERO -> { - buf[pos] = '0'; - buf[pos + 1] = '.'; - buf[pos + 2] = '0'; - yield pos + 3; + static int writeFloat(byte[] buf, int pos, float v) { + int bits = floatToRawIntBits(v); + int t = bits & FloatToDecimal.T_MASK; + int bq = (bits >>> FloatToDecimal.P - 1) & FloatToDecimal.BQ_MASK; + if (bq < FloatToDecimal.BQ_MASK) { + if (bits < 0) { + buf[pos++] = '-'; + } + if (bq != 0) { + int mq = -FloatToDecimal.Q_MIN + 1 - bq; + int c = FloatToDecimal.C_MIN | t; + if (0 < mq & mq < FloatToDecimal.P) { + int f = c >> mq; + if (f << mq == c) { + return FloatToDecimal.toChars(buf, pos, f, 0); + } + } + return FloatToDecimal.toDecimal(buf, pos, -mq, c, 0); } - case FloatToDecimal.MINUS_ZERO -> { - buf[pos] = '-'; - buf[pos + 1] = '0'; - buf[pos + 2] = '.'; - buf[pos + 3] = '0'; - yield pos + 4; + if (t != 0) { + return t < FloatToDecimal.C_TINY + ? FloatToDecimal.toDecimal(buf, pos, FloatToDecimal.Q_MIN, 10 * t, -1) + : FloatToDecimal.toDecimal(buf, pos, FloatToDecimal.Q_MIN, t, 0); } - default -> throw new AssertionError("Infinity/NaN should be handled by caller"); - }; + // +0.0 or -0.0 (sign already written if negative) + buf[pos++] = '0'; + buf[pos++] = '.'; + buf[pos++] = '0'; + return pos; + } + throw new AssertionError("Infinity/NaN should be handled by caller"); } /** @@ -118,7 +113,7 @@ public static int writeFloat(byte[] buf, int pos, float v, FloatToDecimal ftd) { * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN * THE SOFTWARE. */ - public static final class FloatToDecimal { + static final class FloatToDecimal { /* For full details about this code see the following references: @@ -145,33 +140,20 @@ public static final class FloatToDecimal { // Minimum value of the exponent: -(2^(W-1)) - P + 3. static final int Q_MIN = (-1 << W - 1) - P + 3; - // Maximum value of the exponent: 2^(W-1) - P. - static final int Q_MAX = (1 << W - 1) - P; - - // 10^(E_MIN - 1) <= MIN_VALUE < 10^E_MIN - static final int E_MIN = -44; - - // 10^(E_MAX - 1) <= MAX_VALUE < 10^E_MAX - static final int E_MAX = 39; - // Threshold to detect tiny values, as in section 8.1.1 of [1] static final int C_TINY = 8; - // The minimum and maximum k, as in section 8 of [1] - static final int K_MIN = -45; - static final int K_MAX = 31; - // H is as in section 8 of [1]. static final int H = 9; // Minimum value of the significand of a normal value: 2^(P-1). - private static final int C_MIN = 1 << P - 1; + static final int C_MIN = 1 << P - 1; // Mask to extract the biased exponent. - private static final int BQ_MASK = (1 << W) - 1; + static final int BQ_MASK = (1 << W) - 1; // Mask to extract the fraction bits. - private static final int T_MASK = (1 << P - 1) - 1; + static final int T_MASK = (1 << P - 1) - 1; // Used in rop(). private static final long MASK_32 = (1L << 32) - 1; @@ -179,218 +161,9 @@ public static final class FloatToDecimal { // Used for left-to-tight digit extraction. private static final int MASK_28 = (1 << 28) - 1; - private static final int NON_SPECIAL = 0; - private static final int PLUS_ZERO = 1; - private static final int MINUS_ZERO = 2; - private static final int PLUS_INF = 3; - private static final int MINUS_INF = 4; - private static final int NAN = 5; - - /* - Room for the longer of the forms - -ddddd.dddd H + 2 characters - -0.00ddddddddd H + 5 characters - -d.ddddddddE-ee H + 6 characters - where there are H digits d - */ - public final int MAX_CHARS = H + 6; - - // Writes directly to the caller's buffer — set by writeFloat before calling toDecimal. - private byte[] bytes; - - // Index into bytes. Set to (pos - 1) by writeFloat; toDecimal uses pre-increment (++index). - private int index; - private FloatToDecimal() {} - /** - * Returns a string rendering of the {@code float} argument. - * - *

The characters of the result are all drawn from the ASCII set. - *

    - *
  • Any NaN, whether quiet or signaling, is rendered as - * {@code "NaN"}, regardless of the sign bit. - *
  • The infinities +∞ and -∞ are rendered as - * {@code "Infinity"} and {@code "-Infinity"}, respectively. - *
  • The positive and negative zeroes are rendered as - * {@code "0.0"} and {@code "-0.0"}, respectively. - *
  • A finite negative {@code v} is rendered as the sign - * '{@code -}' followed by the rendering of the magnitude -{@code v}. - *
  • A finite positive {@code v} is rendered in two stages: - *
      - *
    • Selection of a decimal: A well-defined - * decimal dv is selected - * to represent {@code v}. - *
    • Formatting as a string: The decimal - * dv is formatted as a string, - * either in plain or in computerized scientific notation, - * depending on its value. - *
    - *
- * - *

A decimal is a number of the form - * d×10i - * for some (unique) integers d > 0 and i such that - * d is not a multiple of 10. - * These integers are the significand and - * the exponent, respectively, of the decimal. - * The length of the decimal is the (unique) - * integer n meeting - * 10n-1d < 10n. - * - *

The decimal dv - * for a finite positive {@code v} is defined as follows: - *

    - *
  • Let R be the set of all decimals that round to {@code v} - * according to the usual round-to-closest rule of - * IEEE 754 floating-point arithmetic. - *
  • Let m be the minimal length over all decimals in R. - *
  • When m ≥ 2, let T be the set of all decimals - * in R with length m. - * Otherwise, let T be the set of all decimals - * in R with length 1 or 2. - *
  • Define dv as - * the decimal in T that is closest to {@code v}. - * Or if there are two such decimals in T, - * select the one with the even significand (there is exactly one). - *
- * - *

The (uniquely) selected decimal dv - * is then formatted. - * - *

Let d, i and n be the significand, exponent and - * length of dv, respectively. - * Further, let e = n + i - 1 and let - * d1dn - * be the usual decimal expansion of the significand. - * Note that d1 ≠ 0 ≠ dn. - *

    - *
  • Case -3 ≤ e < 0: - * dv is formatted as - * 0.00d1dn, - * where there are exactly -(n + i) zeroes between - * the decimal point and d1. - * For example, 123 × 10-4 is formatted as - * {@code 0.0123}. - *
  • Case 0 ≤ e < 7: - *
      - *
    • Subcase i ≥ 0: - * dv is formatted as - * d1dn00.0, - * where there are exactly i zeroes - * between dn and the decimal point. - * For example, 123 × 102 is formatted as - * {@code 12300.0}. - *
    • Subcase i < 0: - * dv is formatted as - * d1dn+i.dn+i+1dn. - * There are exactly -i digits to the right of - * the decimal point. - * For example, 123 × 10-1 is formatted as - * {@code 12.3}. - *
    - *
  • Case e < -3 or e ≥ 7: - * computerized scientific notation is used to format - * dv. - * Here e is formatted as by {@link Integer#toString(int)}. - *
      - *
    • Subcase n = 1: - * dv is formatted as - * d1.0Ee. - * For example, 1 × 1023 is formatted as - * {@code 1.0E23}. - *
    • Subcase n > 1: - * dv is formatted as - * d1.d2dnEe. - * For example, 123 × 10-21 is formatted as - * {@code 1.23E-19}. - *
    - *
- * - * @param v the {@code float} to be rendered. - * @return a string rendering of the argument. - */ - public static String toString(float v) { - return new FloatToDecimal().toDecimalString(v); - } - - private String toDecimalString(float v) { - switch (toDecimal(v)) { - case NON_SPECIAL: - return charsToString(); - case PLUS_ZERO: - return "0.0"; - case MINUS_ZERO: - return "-0.0"; - case PLUS_INF: - return "Infinity"; - case MINUS_INF: - return "-Infinity"; - default: - return "NaN"; - } - } - - /* - Returns - PLUS_ZERO iff v is 0.0 - MINUS_ZERO iff v is -0.0 - PLUS_INF iff v is POSITIVE_INFINITY - MINUS_INF iff v is NEGATIVE_INFINITY - NAN iff v is NaN - */ - private int toDecimal(float v) { - /* - For full details see references [2] and [1]. - - For finite v != 0, determine integers c and q such that - |v| = c 2^q and - Q_MIN <= q <= Q_MAX and - either 2^(P-1) <= c < 2^P (normal) - or 0 < c < 2^(P-1) and q = Q_MIN (subnormal) - */ - int bits = floatToRawIntBits(v); - int t = bits & T_MASK; - int bq = (bits >>> P - 1) & BQ_MASK; - if (bq < BQ_MASK) { - // index is set by writeFloat (to pos - 1) before this method is called. - if (bits < 0) { - append('-'); - } - if (bq != 0) { - // normal value. Here mq = -q - int mq = -Q_MIN + 1 - bq; - int c = C_MIN | t; - // The fast path discussed in section 8.2 of [1]. - if (0 < mq & mq < P) { - int f = c >> mq; - if (f << mq == c) { - return toChars(f, 0); - } - } - return toDecimal(-mq, c, 0); - } - if (t != 0) { - // subnormal value - return t < C_TINY - ? toDecimal(Q_MIN, 10 * t, -1) - : toDecimal(Q_MIN, t, 0); - } - return bits == 0 ? PLUS_ZERO : MINUS_ZERO; - } - if (t != 0) { - return NAN; - } - return bits > 0 ? PLUS_INF : MINUS_INF; - } - - private int toDecimal(int q, int c, int dk) { + static int toDecimal(byte[] buf, int pos, int q, int c, int dk) { /* The skeleton corresponds to figure 4 of [1]. The efficient computations are those summarized in figure 7. @@ -410,7 +183,7 @@ private int toDecimal(int q, int c, int dk) { rop: r_o' "r-o-prime" */ int out = c & 0x1; - long cb = c << 2; + long cb = (long) c << 2; long cbr = cb + 2; long cbl; int k; @@ -424,7 +197,7 @@ private int toDecimal(int q, int c, int dk) { cbl = cb - 2; k = flog10pow2(q); } else { - // irregular spacing0 + // irregular spacing cbl = cb - 1; k = flog10threeQuartersPow2(q); } @@ -454,7 +227,7 @@ private int toDecimal(int q, int c, int dk) { boolean upin = vbl + out <= sp10 << 2; boolean wpin = (tp10 << 2) + out <= vbr; if (upin != wpin) { - return toChars(upin ? sp10 : tp10, k); + return toChars(buf, pos, upin ? sp10 : tp10, k); } } @@ -469,14 +242,14 @@ private int toDecimal(int q, int c, int dk) { boolean win = (t << 2) + out <= vbr; if (uin != win) { // Exactly one of u or w lies in Rv. - return toChars(uin ? s : t, k + dk); + return toChars(buf, pos, uin ? s : t, k + dk); } /* Both u and w lie in Rv: determine the one closest to v. See section 9.4 of [1]. */ int cmp = vb - (s + t << 1); - return toChars(cmp < 0 || cmp == 0 && (s & 0x1) == 0 ? s : t, k + dk); + return toChars(buf, pos, cmp < 0 || cmp == 0 && (s & 0x1) == 0 ? s : t, k + dk); } /* @@ -492,7 +265,7 @@ private static int rop(long g, long cp) { /* Formats the decimal f 10^e. */ - private int toChars(int f, int e) { + static int toChars(byte[] buf, int pos, int f, int e) { /* For details not discussed here see section 10 of [1]. @@ -510,7 +283,7 @@ private int toChars(int f, int e) { 10^(H-1) <= f < 10^H fp 10^ep = f 10^(e-H) = 0.f 10^e */ - f *= MathUtils.pow10(H - len); + f *= (int) MathUtils.pow10(H - len); e += len; /* @@ -527,63 +300,60 @@ private int toChars(int f, int e) { int l = f - 100_000_000 * h; if (0 < e && e <= 7) { - return toChars1(h, l, e); + return toChars1(buf, pos, h, l, e); } if (-3 < e && e <= 0) { - return toChars2(h, l, e); + return toChars2(buf, pos, h, l, e); } - return toChars3(h, l, e); + return toChars3(buf, pos, h, l, e); } - private int toChars1(int h, int l, int e) { + private static int toChars1(byte[] buf, int pos, int h, int l, int e) { /* 0 < e <= 7: plain format without leading zeroes. Left-to-right digits extraction: algorithm 1 in [3], with b = 10, k = 8, n = 28. */ - appendDigit(h); + buf[pos++] = (byte) ('0' + h); int y = y(l); int t; int i = 1; for (; i < e; ++i) { t = 10 * y; - appendDigit(t >>> 28); + buf[pos++] = (byte) ('0' + (t >>> 28)); y = t & MASK_28; } - append('.'); + buf[pos++] = '.'; for (; i <= 8; ++i) { t = 10 * y; - appendDigit(t >>> 28); + buf[pos++] = (byte) ('0' + (t >>> 28)); y = t & MASK_28; } - removeTrailingZeroes(); - return NON_SPECIAL; + return removeTrailingZeroes(buf, pos); } - private int toChars2(int h, int l, int e) { + private static int toChars2(byte[] buf, int pos, int h, int l, int e) { // -3 < e <= 0: plain format with leading zeroes. - appendDigit(0); - append('.'); + buf[pos++] = '0'; + buf[pos++] = '.'; for (; e < 0; ++e) { - appendDigit(0); + buf[pos++] = '0'; } - appendDigit(h); - append8Digits(l); - removeTrailingZeroes(); - return NON_SPECIAL; + buf[pos++] = (byte) ('0' + h); + pos = append8Digits(buf, pos, l); + return removeTrailingZeroes(buf, pos); } - private int toChars3(int h, int l, int e) { + private static int toChars3(byte[] buf, int pos, int h, int l, int e) { // -3 >= e | e > 7: computerized scientific notation - appendDigit(h); - append('.'); - append8Digits(l); - removeTrailingZeroes(); - exponent(e - 1); - return NON_SPECIAL; + buf[pos++] = (byte) ('0' + h); + buf[pos++] = '.'; + pos = append8Digits(buf, pos, l); + pos = removeTrailingZeroes(buf, pos); + return exponent(buf, pos, e - 1); } - private void append8Digits(int m) { + private static int append8Digits(byte[] buf, int pos, int m) { /* Left-to-right digits extraction: algorithm 1 in [3], with b = 10, k = 8, n = 28. @@ -591,22 +361,24 @@ private void append8Digits(int m) { int y = y(m); for (int i = 0; i < 8; ++i) { int t = 10 * y; - appendDigit(t >>> 28); + buf[pos++] = (byte) ('0' + (t >>> 28)); y = t & MASK_28; } + return pos; } - private void removeTrailingZeroes() { - while (bytes[index] == '0') { - --index; + private static int removeTrailingZeroes(byte[] buf, int pos) { + while (buf[pos - 1] == '0') { + --pos; } // ... but do not remove the one directly to the right of '.' - if (bytes[index] == '.') { - ++index; + if (buf[pos - 1] == '.') { + ++pos; } + return pos; } - private int y(int a) { + private static int y(int a) { /* Algorithm 1 in [3] needs computation of floor((a + 1) 2^n / b^k) - 1 @@ -620,37 +392,24 @@ private int y(int a) { 193_428_131_138_340_668L) >>> 20) - 1; } - private void exponent(int e) { - append('E'); + private static int exponent(byte[] buf, int pos, int e) { + buf[pos++] = 'E'; if (e < 0) { - append('-'); + buf[pos++] = '-'; e = -e; } if (e < 10) { - appendDigit(e); - return; + buf[pos++] = (byte) ('0' + e); + return pos; } /* For n = 2, m = 1 the table in section 10 of [1] shows floor(e / 10) = floor(103 e / 2^10) */ int d = e * 103 >>> 10; - appendDigit(d); - appendDigit(e - 10 * d); - } - - private void append(int c) { - bytes[++index] = (byte) c; - } - - private void appendDigit(int d) { - bytes[++index] = (byte) ('0' + d); - } - - // Using the deprecated constructor enhances performance. - @SuppressWarnings("deprecation") - private String charsToString() { - return new String(bytes, 0, 0, index + 1); + buf[pos++] = (byte) ('0' + d); + buf[pos++] = (byte) ('0' + e - 10 * d); + return pos; } } @@ -837,18 +596,6 @@ static long g0(int k) { return g[k - K_MIN << 1 | 1]; } - //a Java port of https://github.com/plokhotnyuk/jsoniter-scala/blob/c70a293ac802dc2eb44165471d76d7df2d4657b6/jsoniter-scala-core/native/src/main/scala/com/github/plokhotnyuk/jsoniter_scala/core/JsonWriter.scala#L2027 - static long multiplyHigh(long x, long y) { - // Karatsuba technique for two positive ints - long x2 = x & 0xFFFFFFFFL; - long y2 = y & 0xFFFFFFFFL; - long b = x2 * y2; - long x1 = x >>> 32; - long y1 = y >>> 32; - long a = x1 * y1; - return (((b >>> 32) + (x1 + x2) * (y1 + y2) - b - a) >>> 32) + a; - } - /* The precomputed values for g1(int) and g0(int). The first entry must be for an exponent of K_MIN or less. @@ -2144,33 +1891,20 @@ static final class DoubleToDecimal { // Minimum value of the exponent: -(2^(W-1)) - P + 3. static final int Q_MIN = (-1 << W - 1) - P + 3; - // Maximum value of the exponent: 2^(W-1) - P. - static final int Q_MAX = (1 << W - 1) - P; - - // 10^(E_MIN - 1) <= MIN_VALUE < 10^E_MIN - static final int E_MIN = -323; - - // 10^(E_MAX - 1) <= MAX_VALUE < 10^E_MAX - static final int E_MAX = 309; - // Threshold to detect tiny values, as in section 8.1.1 of [1] static final long C_TINY = 3; - // The minimum and maximum k, as in section 8 of [1] - static final int K_MIN = -324; - static final int K_MAX = 292; - // H is as in section 8 of [1]. static final int H = 17; // Minimum value of the significand of a normal value: 2^(P-1). - private static final long C_MIN = 1L << P - 1; + static final long C_MIN = 1L << P - 1; // Mask to extract the biased exponent. - private static final int BQ_MASK = (1 << W) - 1; + static final int BQ_MASK = (1 << W) - 1; // Mask to extract the fraction bits. - private static final long T_MASK = (1L << P - 1) - 1; + static final long T_MASK = (1L << P - 1) - 1; // Used in rop(). private static final long MASK_63 = (1L << 63) - 1; @@ -2178,218 +1912,9 @@ static final class DoubleToDecimal { // Used for left-to-tight digit extraction. private static final int MASK_28 = (1 << 28) - 1; - static final int NON_SPECIAL = 0; - static final int PLUS_ZERO = 1; - static final int MINUS_ZERO = 2; - static final int PLUS_INF = 3; - static final int MINUS_INF = 4; - static final int NAN = 5; - - /* - Room for the longer of the forms - -ddddd.dddddddddddd H + 2 characters - -0.00ddddddddddddddddd H + 5 characters - -d.ddddddddddddddddE-eee H + 7 characters - where there are H digits d - */ - public final int MAX_CHARS = H + 7; - - // Writes directly to the caller's buffer — set by writeDouble before calling toDecimal. - private byte[] bytes; - - // Index into bytes. Set to (pos - 1) by writeDouble; toDecimal uses pre-increment (++index). - private int index; - private DoubleToDecimal() {} - /** - * Returns a string rendering of the {@code double} argument. - * - *

The characters of the result are all drawn from the ASCII set. - *

    - *
  • Any NaN, whether quiet or signaling, is rendered as - * {@code "NaN"}, regardless of the sign bit. - *
  • The infinities +∞ and -∞ are rendered as - * {@code "Infinity"} and {@code "-Infinity"}, respectively. - *
  • The positive and negative zeroes are rendered as - * {@code "0.0"} and {@code "-0.0"}, respectively. - *
  • A finite negative {@code v} is rendered as the sign - * '{@code -}' followed by the rendering of the magnitude -{@code v}. - *
  • A finite positive {@code v} is rendered in two stages: - *
      - *
    • Selection of a decimal: A well-defined - * decimal dv is selected - * to represent {@code v}. - *
    • Formatting as a string: The decimal - * dv is formatted as a string, - * either in plain or in computerized scientific notation, - * depending on its value. - *
    - *
- * - *

A decimal is a number of the form - * d×10i - * for some (unique) integers d > 0 and i such that - * d is not a multiple of 10. - * These integers are the significand and - * the exponent, respectively, of the decimal. - * The length of the decimal is the (unique) - * integer n meeting - * 10n-1d < 10n. - * - *

The decimal dv - * for a finite positive {@code v} is defined as follows: - *

    - *
  • Let R be the set of all decimals that round to {@code v} - * according to the usual round-to-closest rule of - * IEEE 754 floating-point arithmetic. - *
  • Let m be the minimal length over all decimals in R. - *
  • When m ≥ 2, let T be the set of all decimals - * in R with length m. - * Otherwise, let T be the set of all decimals - * in R with length 1 or 2. - *
  • Define dv as - * the decimal in T that is closest to {@code v}. - * Or if there are two such decimals in T, - * select the one with the even significand (there is exactly one). - *
- * - *

The (uniquely) selected decimal dv - * is then formatted. - * - *

Let d, i and n be the significand, exponent and - * length of dv, respectively. - * Further, let e = n + i - 1 and let - * d1dn - * be the usual decimal expansion of the significand. - * Note that d1 ≠ 0 ≠ dn. - *

    - *
  • Case -3 ≤ e < 0: - * dv is formatted as - * 0.00d1dn, - * where there are exactly -(n + i) zeroes between - * the decimal point and d1. - * For example, 123 × 10-4 is formatted as - * {@code 0.0123}. - *
  • Case 0 ≤ e < 7: - *
      - *
    • Subcase i ≥ 0: - * dv is formatted as - * d1dn00.0, - * where there are exactly i zeroes - * between dn and the decimal point. - * For example, 123 × 102 is formatted as - * {@code 12300.0}. - *
    • Subcase i < 0: - * dv is formatted as - * d1dn+i.dn+i+1dn. - * There are exactly -i digits to the right of - * the decimal point. - * For example, 123 × 10-1 is formatted as - * {@code 12.3}. - *
    - *
  • Case e < -3 or e ≥ 7: - * computerized scientific notation is used to format - * dv. - * Here e is formatted as by {@link Integer#toString(int)}. - *
      - *
    • Subcase n = 1: - * dv is formatted as - * d1.0Ee. - * For example, 1 × 1023 is formatted as - * {@code 1.0E23}. - *
    • Subcase n > 1: - * dv is formatted as - * d1.d2dnEe. - * For example, 123 × 10-21 is formatted as - * {@code 1.23E-19}. - *
    - *
- * - * @param v the {@code double} to be rendered. - * @return a string rendering of the argument. - */ - public static String toString(double v) { - return new DoubleToDecimal().toDecimalString(v); - } - - private String toDecimalString(double v) { - switch (toDecimal(v)) { - case NON_SPECIAL: - return charsToString(); - case PLUS_ZERO: - return "0.0"; - case MINUS_ZERO: - return "-0.0"; - case PLUS_INF: - return "Infinity"; - case MINUS_INF: - return "-Infinity"; - default: - return "NaN"; - } - } - - /* - Returns - PLUS_ZERO iff v is 0.0 - MINUS_ZERO iff v is -0.0 - PLUS_INF iff v is POSITIVE_INFINITY - MINUS_INF iff v is NEGATIVE_INFINITY - NAN iff v is NaN - */ - private int toDecimal(double v) { - /* - For full details see references [2] and [1]. - - For finite v != 0, determine integers c and q such that - |v| = c 2^q and - Q_MIN <= q <= Q_MAX and - either 2^(P-1) <= c < 2^P (normal) - or 0 < c < 2^(P-1) and q = Q_MIN (subnormal) - */ - long bits = doubleToRawLongBits(v); - long t = bits & T_MASK; - int bq = (int) (bits >>> P - 1) & BQ_MASK; - if (bq < BQ_MASK) { - // index is set by writeDouble (to pos - 1) before this method is called. - if (bits < 0) { - append('-'); - } - if (bq != 0) { - // normal value. Here mq = -q - int mq = -Q_MIN + 1 - bq; - long c = C_MIN | t; - // The fast path discussed in section 8.2 of [1]. - if (0 < mq & mq < P) { - long f = c >> mq; - if (f << mq == c) { - return toChars(f, 0); - } - } - return toDecimal(-mq, c, 0); - } - if (t != 0) { - // subnormal value - return t < C_TINY - ? toDecimal(Q_MIN, 10 * t, -1) - : toDecimal(Q_MIN, t, 0); - } - return bits == 0 ? PLUS_ZERO : MINUS_ZERO; - } - if (t != 0) { - return NAN; - } - return bits > 0 ? PLUS_INF : MINUS_INF; - } - - private int toDecimal(int q, long c, int dk) { + static int toDecimal(byte[] buf, int pos, int q, long c, int dk) { /* The skeleton corresponds to figure 4 of [1]. The efficient computations are those summarized in figure 7. @@ -2454,7 +1979,7 @@ private int toDecimal(int q, long c, int dk) { boolean upin = vbl + out <= sp10 << 2; boolean wpin = (tp10 << 2) + out <= vbr; if (upin != wpin) { - return toChars(upin ? sp10 : tp10, k); + return toChars(buf, pos, upin ? sp10 : tp10, k); } } @@ -2469,14 +1994,14 @@ private int toDecimal(int q, long c, int dk) { boolean win = (t << 2) + out <= vbr; if (uin != win) { // Exactly one of u or w lies in Rv. - return toChars(uin ? s : t, k + dk); + return toChars(buf, pos, uin ? s : t, k + dk); } /* Both u and w lie in Rv: determine the one closest to v. See section 9.4 of [1]. */ long cmp = vb - (s + t << 1); - return toChars(cmp < 0 || cmp == 0 && (s & 0x1) == 0 ? s : t, k + dk); + return toChars(buf, pos, cmp < 0 || cmp == 0 && (s & 0x1) == 0 ? s : t, k + dk); } /* @@ -2495,7 +2020,7 @@ private static long rop(long g1, long g0, long cp) { /* Formats the decimal f 10^e. */ - private int toChars(long f, int e) { + static int toChars(byte[] buf, int pos, long f, int e) { /* For details not discussed here see section 10 of [1]. @@ -2536,70 +2061,70 @@ private int toChars(long f, int e) { int m = (int) (hm - 100_000_000 * h); if (0 < e && e <= 7) { - return toChars1(h, m, l, e); + return toChars1(buf, pos, h, m, l, e); } if (-3 < e && e <= 0) { - return toChars2(h, m, l, e); + return toChars2(buf, pos, h, m, l, e); } - return toChars3(h, m, l, e); + return toChars3(buf, pos, h, m, l, e); } - private int toChars1(int h, int m, int l, int e) { + private static int toChars1(byte[] buf, int pos, int h, int m, int l, int e) { /* 0 < e <= 7: plain format without leading zeroes. Left-to-right digits extraction: algorithm 1 in [3], with b = 10, k = 8, n = 28. */ - appendDigit(h); + buf[pos++] = (byte) ('0' + h); int y = y(m); int t; int i = 1; for (; i < e; ++i) { t = 10 * y; - appendDigit(t >>> 28); + buf[pos++] = (byte) ('0' + (t >>> 28)); y = t & MASK_28; } - append('.'); + buf[pos++] = '.'; for (; i <= 8; ++i) { t = 10 * y; - appendDigit(t >>> 28); + buf[pos++] = (byte) ('0' + (t >>> 28)); y = t & MASK_28; } - lowDigits(l); - return NON_SPECIAL; + pos = lowDigits(buf, pos, l); + return pos; } - private int toChars2(int h, int m, int l, int e) { + private static int toChars2(byte[] buf, int pos, int h, int m, int l, int e) { // -3 < e <= 0: plain format with leading zeroes. - appendDigit(0); - append('.'); + buf[pos++] = '0'; + buf[pos++] = '.'; for (; e < 0; ++e) { - appendDigit(0); + buf[pos++] = '0'; } - appendDigit(h); - append8Digits(m); - lowDigits(l); - return NON_SPECIAL; + buf[pos++] = (byte) ('0' + h); + pos = append8Digits(buf, pos, m); + pos = lowDigits(buf, pos, l); + return pos; } - private int toChars3(int h, int m, int l, int e) { + private static int toChars3(byte[] buf, int pos, int h, int m, int l, int e) { // -3 >= e | e > 7: computerized scientific notation - appendDigit(h); - append('.'); - append8Digits(m); - lowDigits(l); - exponent(e - 1); - return NON_SPECIAL; + buf[pos++] = (byte) ('0' + h); + buf[pos++] = '.'; + pos = append8Digits(buf, pos, m); + pos = lowDigits(buf, pos, l); + pos = exponent(buf, pos, e - 1); + return pos; } - private void lowDigits(int l) { + private static int lowDigits(byte[] buf, int pos, int l) { if (l != 0) { - append8Digits(l); + pos = append8Digits(buf, pos, l); } - removeTrailingZeroes(); + return removeTrailingZeroes(buf, pos); } - private void append8Digits(int m) { + private static int append8Digits(byte[] buf, int pos, int m) { /* Left-to-right digits extraction: algorithm 1 in [3], with b = 10, k = 8, n = 28. @@ -2607,22 +2132,24 @@ private void append8Digits(int m) { int y = y(m); for (int i = 0; i < 8; ++i) { int t = 10 * y; - appendDigit(t >>> 28); + buf[pos++] = (byte) ('0' + (t >>> 28)); y = t & MASK_28; } + return pos; } - private void removeTrailingZeroes() { - while (bytes[index] == '0') { - --index; + private static int removeTrailingZeroes(byte[] buf, int pos) { + while (buf[pos - 1] == '0') { + --pos; } // ... but do not remove the one directly to the right of '.' - if (bytes[index] == '.') { - ++index; + if (buf[pos - 1] == '.') { + ++pos; } + return pos; } - private int y(int a) { + private static int y(int a) { /* Algorithm 1 in [3] needs computation of floor((a + 1) 2^n / b^k) - 1 @@ -2636,15 +2163,15 @@ private int y(int a) { 193_428_131_138_340_668L) >>> 20) - 1; } - private void exponent(int e) { - append('E'); + private static int exponent(byte[] buf, int pos, int e) { + buf[pos++] = 'E'; if (e < 0) { - append('-'); + buf[pos++] = '-'; e = -e; } if (e < 10) { - appendDigit(e); - return; + buf[pos++] = (byte) ('0' + e); + return pos; } int d; if (e >= 100) { @@ -2653,7 +2180,7 @@ private void exponent(int e) { floor(e / 100) = floor(1_311 e / 2^17) */ d = e * 1_311 >>> 17; - appendDigit(d); + buf[pos++] = (byte) ('0' + d); e -= 100 * d; } /* @@ -2661,22 +2188,9 @@ private void exponent(int e) { floor(e / 10) = floor(103 e / 2^10) */ d = e * 103 >>> 10; - appendDigit(d); - appendDigit(e - 10 * d); - } - - private void append(int c) { - bytes[++index] = (byte) c; - } - - private void appendDigit(int d) { - bytes[++index] = (byte) ('0' + d); - } - - // Using the deprecated constructor enhances performance. - @SuppressWarnings("deprecation") - private String charsToString() { - return new String(bytes, 0, 0, index + 1); + buf[pos++] = (byte) ('0' + d); + buf[pos++] = (byte) ('0' + e - 10 * d); + return pos; } } diff --git a/codecs/codec-commons/src/main/java/software/amazon/smithy/java/codecs/commons/StripedPool.java b/codecs/codec-commons/src/main/java/software/amazon/smithy/java/codecs/commons/StripedPool.java new file mode 100644 index 0000000000..32d6c1240f --- /dev/null +++ b/codecs/codec-commons/src/main/java/software/amazon/smithy/java/codecs/commons/StripedPool.java @@ -0,0 +1,112 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.codecs.commons; + +import java.util.concurrent.atomic.AtomicReferenceArray; +import software.amazon.smithy.utils.SmithyInternalApi; + +/** + * A striped lock-free object pool for reusing serializer/deserializer instances across requests. + * + *

Each concrete pool should be declared as a {@code private static final} anonymous subclass. + * The JIT devirtualizes callback methods via monomorphic inline caching at each callsite. + * + *

Pool slots are probed starting from a Fibonacci-hashed thread-ID with cache-line-sized + * strides to avoid false sharing under contention. Both platform and virtual threads use the + * pool as actual concurrency is bounded by carrier thread count, and the pool is sized at + * 4x available processors to absorb probe collisions. + * + * @param the pooled type + * @param context type passed to {@link #acquire} and forwarded to {@link #create}/{@link #reset} + */ +@SmithyInternalApi +public abstract class StripedPool { + + private static final int N_PROCS = Runtime.getRuntime().availableProcessors(); + + // 64-byte cache line / 4-byte compressed oop = 16 refs per line + private static final int STRIDE_SHIFT = 4; + private static final int MAX_PROBE = 4; + private final int slotMask; + private final AtomicReferenceArray pool; + + protected StripedPool() { + this(4); + } + + protected StripedPool(int multiplier) { + int raw = N_PROCS * multiplier; + int logicalSlots = Integer.highestOneBit(raw - 1) << 1; + this.slotMask = logicalSlots - 1; + this.pool = new AtomicReferenceArray<>(logicalSlots << STRIDE_SHIFT); + } + + /** + * Create a fresh instance when the pool is empty. + */ + protected abstract T create(C context); + + /** + * Always-run cleanup on release, before the canPool check. + * Use for clearing references that must not outlive the request (e.g., OutputStream sink). + */ + protected void cleanup(T item) {} + + /** + * Whether this item is eligible for pooling. Called after {@link #cleanup}. + * Return false to discard (e.g., internal buffer is null). + */ + protected abstract boolean canPool(T item); + + /** + * Pre-pool preparation, called only when {@link #canPool} returned true. + * Use for downsizing oversized buffers to bound memory. + */ + protected void prepareForPool(T item) {} + + /** + * Reinitialize a pooled item for reuse. Return true to accept the item. + * Return false if the item is unsuitable for this particular acquire (e.g., settings mismatch). + * A rejected item is put back into its pool slot and probing continues. + */ + protected abstract boolean reset(T item, C context); + + public final T acquire(C context) { + AtomicReferenceArray p = this.pool; + int mask = this.slotMask; + long id = Thread.currentThread().threadId(); + int base = (int) (id * 0x9E3779B97F4A7C15L >>> 32) & mask; + for (int i = 0; i < MAX_PROBE; i++) { + int idx = ((base + i) & mask) << STRIDE_SHIFT; + T s = p.getPlain(idx); + if (s != null && p.compareAndExchangeAcquire(idx, s, null) == s) { + if (reset(s, context)) { + return s; + } + p.weakCompareAndSetRelease(idx, null, s); + } + } + return create(context); + } + + public final void release(T item) { + cleanup(item); + if (!canPool(item)) { + return; + } + prepareForPool(item); + AtomicReferenceArray p = this.pool; + int mask = this.slotMask; + long id = Thread.currentThread().threadId(); + int base = (int) (id * 0x9E3779B97F4A7C15L >>> 32) & mask; + for (int i = 0; i < MAX_PROBE; i++) { + int idx = ((base + i) & mask) << STRIDE_SHIFT; + if (p.getPlain(idx) == null && p.weakCompareAndSetRelease(idx, null, item)) { + return; + } + } + } +} diff --git a/codecs/codec-commons/src/main/java/software/amazon/smithy/java/codecs/commons/TimestampCodec.java b/codecs/codec-commons/src/main/java/software/amazon/smithy/java/codecs/commons/TimestampCodec.java new file mode 100644 index 0000000000..baf47d50cb --- /dev/null +++ b/codecs/codec-commons/src/main/java/software/amazon/smithy/java/codecs/commons/TimestampCodec.java @@ -0,0 +1,772 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.codecs.commons; + +import java.time.Instant; +import software.amazon.smithy.utils.SmithyInternalApi; + +/** + * Low-level utilities for reading/writing timestamp values directly from/to byte arrays. + */ +@SmithyInternalApi +public final class TimestampCodec { + + private TimestampCodec() {} + + private static final byte[] DIGIT_PAIRS = new byte[200]; + + static { + for (int i = 0; i < 100; i++) { + DIGIT_PAIRS[i * 2] = (byte) ('0' + i / 10); + DIGIT_PAIRS[i * 2 + 1] = (byte) ('0' + i % 10); + } + } + + private static final byte[][] DAY_NAMES = { + null, + {'M', 'o', 'n'}, + {'T', 'u', 'e'}, + {'W', 'e', 'd'}, + {'T', 'h', 'u'}, + {'F', 'r', 'i'}, + {'S', 'a', 't'}, + {'S', 'u', 'n'}, + }; + + private static final int[] NANO_SCALE = { + 0, + 100_000_000, + 10_000_000, + 1_000_000, + 100_000, + 10_000, + 1_000, + 100, + 10, + 1 + }; + + private static final byte[][] MONTH_NAMES = { + null, + {'J', 'a', 'n'}, + {'F', 'e', 'b'}, + {'M', 'a', 'r'}, + {'A', 'p', 'r'}, + {'M', 'a', 'y'}, + {'J', 'u', 'n'}, + {'J', 'u', 'l'}, + {'A', 'u', 'g'}, + {'S', 'e', 'p'}, + {'O', 'c', 't'}, + {'N', 'o', 'v'}, + {'D', 'e', 'c'}, + }; + + /** + * Parses an ISO-8601 timestamp (YYYY-MM-DDThh:mm:ss[.nnn]Z) directly from a byte array. + * Returns null if the format doesn't match (caller should fall back to DateTimeFormatter). + */ + public static Instant parseIso8601(byte[] buf, int pos, int end) { + if (pos + 20 > end) { + return null; + } + + int d0 = digit(buf[pos]), d1 = digit(buf[pos + 1]), d2 = digit(buf[pos + 2]), d3 = digit(buf[pos + 3]); + if ((d0 | d1 | d2 | d3) < 0) { + return null; + } + int year = d0 * 1000 + d1 * 100 + d2 * 10 + d3; + pos += 4; + if (buf[pos++] != '-') { + return null; + } + + int m0 = digit(buf[pos]), m1 = digit(buf[pos + 1]); + if ((m0 | m1) < 0) { + return null; + } + int month = m0 * 10 + m1; + pos += 2; + if (month < 1 || month > 12) { + return null; + } + if (buf[pos++] != '-') { + return null; + } + + int dy0 = digit(buf[pos]), dy1 = digit(buf[pos + 1]); + if ((dy0 | dy1) < 0) { + return null; + } + int day = dy0 * 10 + dy1; + pos += 2; + if (day < 1 || day > maxDayOfMonth(year, month)) { + return null; + } + + if (buf[pos] != 'T') { + return null; + } + pos++; + + int h0 = digit(buf[pos]), h1 = digit(buf[pos + 1]); + if ((h0 | h1) < 0) { + return null; + } + int hour = h0 * 10 + h1; + pos += 2; + if (buf[pos++] != ':') { + return null; + } + + int mn0 = digit(buf[pos]), mn1 = digit(buf[pos + 1]); + if ((mn0 | mn1) < 0) { + return null; + } + int minute = mn0 * 10 + mn1; + pos += 2; + if (buf[pos++] != ':') { + return null; + } + + int s0 = digit(buf[pos]), s1 = digit(buf[pos + 1]); + if ((s0 | s1) < 0) { + return null; + } + int second = s0 * 10 + s1; + pos += 2; + if (hour > 23 || minute > 59 || second > 59) { + return null; + } + + int nano = 0; + if (pos < end && buf[pos] == '.') { + pos++; + int fracLen = 0; + while (pos < end) { + int d = buf[pos] - '0'; + if (d < 0 || d > 9) { + break; + } + if (fracLen < 9) { + nano = nano * 10 + d; + } + fracLen++; + pos++; + } + if (fracLen == 0 || fracLen > 9) { + return null; + } + nano *= NANO_SCALE[fracLen]; + } + + if (pos >= end || buf[pos] != 'Z') { + return null; + } + pos++; + if (pos != end) { + return null; + } + + long epochDay = computeEpochDay(year, month, day); + long epochSecond = epochDay * 86400 + hour * 3600 + minute * 60 + second; + return Instant.ofEpochSecond(epochSecond, nano); + } + + /** + * Parses an HTTP-date (e.g. "Mon, 01 Jan 2024 12:00:00 GMT") directly from a byte array. + * Returns null if the format doesn't match. + */ + public static Instant parseHttpDate(byte[] buf, int pos, int end) { + if (pos + 29 > end) { + return null; + } + + pos += 3; + if (buf[pos] != ',' || buf[pos + 1] != ' ') { + return null; + } + pos += 2; + + int dy0 = digit(buf[pos]), dy1 = digit(buf[pos + 1]); + if ((dy0 | dy1) < 0) { + return null; + } + int day = dy0 * 10 + dy1; + pos += 2; + if (day < 1) { + return null; + } + if (buf[pos++] != ' ') { + return null; + } + + int month = parseMonthName(buf[pos], buf[pos + 1], buf[pos + 2]); + if (month == -1) { + return null; + } + pos += 3; + if (buf[pos++] != ' ') { + return null; + } + + int y0 = digit(buf[pos]), y1 = digit(buf[pos + 1]), y2 = digit(buf[pos + 2]), y3 = digit(buf[pos + 3]); + if ((y0 | y1 | y2 | y3) < 0) { + return null; + } + int year = y0 * 1000 + y1 * 100 + y2 * 10 + y3; + pos += 4; + if (day > maxDayOfMonth(year, month)) { + return null; + } + if (buf[pos++] != ' ') { + return null; + } + + int h0 = digit(buf[pos]), h1 = digit(buf[pos + 1]); + if ((h0 | h1) < 0) { + return null; + } + int hour = h0 * 10 + h1; + pos += 2; + if (buf[pos++] != ':') { + return null; + } + + int mn0 = digit(buf[pos]), mn1 = digit(buf[pos + 1]); + if ((mn0 | mn1) < 0) { + return null; + } + int minute = mn0 * 10 + mn1; + pos += 2; + if (buf[pos++] != ':') { + return null; + } + + int s0 = digit(buf[pos]), s1 = digit(buf[pos + 1]); + if ((s0 | s1) < 0) { + return null; + } + int second = s0 * 10 + s1; + pos += 2; + + if (hour > 23 || minute > 59 || second > 59) { + return null; + } + + if (buf[pos] != ' ') { + return null; + } + pos++; + if (pos + 3 > end || buf[pos] != 'G' || buf[pos + 1] != 'M' || buf[pos + 2] != 'T') { + return null; + } + + long epochDay = computeEpochDay(year, month, day); + long epochSecond = epochDay * 86400 + hour * 3600 + minute * 60 + second; + return Instant.ofEpochSecond(epochSecond); + } + + /** + * Parses an epoch-seconds timestamp from a byte array. + * Handles both integer (e.g. "1234567890") and fractional (e.g. "1234567890.123") forms. + * Returns null if the format doesn't match. + */ + public static Instant parseEpochSeconds(byte[] buf, int pos, int end) { + if (pos >= end) { + return null; + } + + boolean negative = false; + if (buf[pos] == '-') { + negative = true; + pos++; + if (pos >= end) { + return null; + } + } + + long seconds = 0; + int intStart = pos; + while (pos < end && buf[pos] >= '0' && buf[pos] <= '9') { + long prev = seconds; + seconds = seconds * 10 + (buf[pos] - '0'); + if (seconds < prev) { + return null; // overflow + } + pos++; + } + if (pos == intStart) { + return null; + } + + int nano = 0; + boolean hasExponent = false; + if (pos < end && buf[pos] == '.') { + pos++; + int fracStart = pos; + while (pos < end && buf[pos] >= '0' && buf[pos] <= '9') { + pos++; + } + int fracLen = pos - fracStart; + + hasExponent = pos < end && (buf[pos] == 'e' || buf[pos] == 'E'); + if (!hasExponent) { + if (fracLen > 9) { + return null; + } + for (int i = fracStart; i < pos; i++) { + nano = nano * 10 + (buf[i] - '0'); + } + nano *= NANO_SCALE[fracLen]; + } + } else { + hasExponent = pos < end && (buf[pos] == 'e' || buf[pos] == 'E'); + } + + if (hasExponent) { + return parseEpochSecondsBigDecimal(buf, intStart - (negative ? 1 : 0), end); + } + + if (pos != end) { + return null; + } + + if (negative) { + if (nano == 0) { + if (-seconds < Instant.MIN.getEpochSecond()) { + return null; + } + return Instant.ofEpochSecond(-seconds); + } + long es = -seconds - 1; + if (es < Instant.MIN.getEpochSecond()) { + return null; + } + return Instant.ofEpochSecond(es, 1_000_000_000 - nano); + } + if (seconds > Instant.MAX.getEpochSecond()) { + return null; + } + return Instant.ofEpochSecond(seconds, nano); + } + + private static Instant parseEpochSecondsBigDecimal(byte[] buf, int start, int end) { + int pos = start; + boolean negative = pos < end && buf[pos] == '-'; + if (negative) { + pos++; + } + + long mantissa = 0; + int mantissaDigits = 0; + int pointOffset = -1; // digits after the decimal point + + while (pos < end && buf[pos] >= '0' && buf[pos] <= '9') { + if (mantissaDigits < 18) { + mantissa = mantissa * 10 + (buf[pos] - '0'); + } + mantissaDigits++; + pos++; + } + if (pos < end && buf[pos] == '.') { + pos++; + pointOffset = 0; + while (pos < end && buf[pos] >= '0' && buf[pos] <= '9') { + if (mantissaDigits < 18) { + mantissa = mantissa * 10 + (buf[pos] - '0'); + mantissaDigits++; + } + pointOffset++; + pos++; + } + } + + if (pos >= end || (buf[pos] != 'e' && buf[pos] != 'E')) { + return null; + } + pos++; + + boolean expNegative = false; + if (pos < end && (buf[pos] == '+' || buf[pos] == '-')) { + expNegative = buf[pos] == '-'; + pos++; + } + if (pos >= end) { + return null; + } + + int exponent = 0; + while (pos < end && buf[pos] >= '0' && buf[pos] <= '9') { + exponent = exponent * 10 + (buf[pos] - '0'); + if (exponent > 20) { + return null; // way outside long range + } + pos++; + } + if (pos != end) { + return null; + } + if (expNegative) { + exponent = -exponent; + } + + int shift = exponent - (pointOffset >= 0 ? pointOffset : 0); + + long seconds; + int nano = 0; + + if (shift >= 0) { + seconds = mantissa; + for (int i = 0; i < shift; i++) { + seconds *= 10; + if (seconds < 0) { + return null; // overflow + } + } + } else { + int fracDigits = -shift; + if (fracDigits > 9) { + return null; + } + long divisor = 1; + for (int i = 0; i < fracDigits; i++) { + divisor *= 10; + } + seconds = mantissa / divisor; + long fracPart = mantissa % divisor; + for (int i = fracDigits; i < 9; i++) { + fracPart *= 10; + } + nano = (int) fracPart; + } + + if (negative) { + if (nano == 0) { + seconds = -seconds; + if (seconds > 0) { + return null; // overflow + } + if (seconds < Instant.MIN.getEpochSecond()) { + return null; + } + return Instant.ofEpochSecond(seconds); + } + long es = -seconds - 1; + if (es < Instant.MIN.getEpochSecond()) { + return null; + } + return Instant.ofEpochSecond(es, 1_000_000_000 - nano); + } + if (seconds > Instant.MAX.getEpochSecond()) { + return null; + } + return Instant.ofEpochSecond(seconds, nano); + } + + private static int parseMonthName(byte c0, byte c1, byte c2) { + switch (c0) { + case 'J': + if (c1 == 'a' && c2 == 'n') { + return 1; + } + if (c1 == 'u' && c2 == 'n') { + return 6; + } + if (c1 == 'u' && c2 == 'l') { + return 7; + } + return -1; + case 'F': + if (c1 == 'e' && c2 == 'b') { + return 2; + } + return -1; + case 'M': + if (c1 == 'a' && c2 == 'r') { + return 3; + } + if (c1 == 'a' && c2 == 'y') { + return 5; + } + return -1; + case 'A': + if (c1 == 'p' && c2 == 'r') { + return 4; + } + if (c1 == 'u' && c2 == 'g') { + return 8; + } + return -1; + case 'S': + if (c1 == 'e' && c2 == 'p') { + return 9; + } + return -1; + case 'O': + if (c1 == 'c' && c2 == 't') { + return 10; + } + return -1; + case 'N': + if (c1 == 'o' && c2 == 'v') { + return 11; + } + return -1; + case 'D': + if (c1 == 'e' && c2 == 'c') { + return 12; + } + return -1; + default: + return -1; + } + } + + /** + * Writes an epoch-seconds timestamp directly. Writes "seconds" for whole seconds + * or "seconds.nanos" for fractional, with trailing zeros stripped. + */ + public static int writeEpochSeconds(byte[] buf, int pos, long epochSecond, int nano) { + if (nano == 0) { + return NumberCodec.writeLong(buf, pos, epochSecond); + } + + int fraction = nano; + if (epochSecond < 0) { + epochSecond += 1; + fraction = 1_000_000_000 - nano; + if (epochSecond == 0) { + buf[pos++] = '-'; + buf[pos++] = '0'; + } else { + pos = NumberCodec.writeLong(buf, pos, epochSecond); + } + } else { + pos = NumberCodec.writeLong(buf, pos, epochSecond); + } + + buf[pos++] = '.'; + + int hi = fraction / 1_000_000; + int mid = (fraction / 1_000) % 1_000; + int lo = fraction % 1_000; + buf[pos++] = (byte) ('0' + hi / 100); + buf[pos++] = (byte) ('0' + (hi / 10) % 10); + buf[pos++] = (byte) ('0' + hi % 10); + buf[pos++] = (byte) ('0' + mid / 100); + buf[pos++] = (byte) ('0' + (mid / 10) % 10); + buf[pos++] = (byte) ('0' + mid % 10); + buf[pos++] = (byte) ('0' + lo / 100); + buf[pos++] = (byte) ('0' + (lo / 10) % 10); + buf[pos++] = (byte) ('0' + lo % 10); + while (buf[pos - 1] == '0') { + pos--; + } + return pos; + } + + /** + * Writes an ISO-8601 timestamp directly to the byte buffer without quotes. + * Produces output like {@code 2025-01-15T10:30:00Z} or {@code 2025-01-15T10:30:00.123Z}. + */ + public static int writeIso8601(byte[] buf, int pos, Instant value) { + long epochSecond = value.getEpochSecond(); + int nano = value.getNano(); + + long epochDay; + int secondOfDay; + if (epochSecond >= 0) { + epochDay = epochSecond / 86400; + secondOfDay = (int) (epochSecond - epochDay * 86400); + } else { + epochDay = Math.floorDiv(epochSecond, 86400); + secondOfDay = (int) Math.floorMod(epochSecond, 86400); + } + int hour = secondOfDay / 3600; + int minute = (secondOfDay % 3600) / 60; + int second = secondOfDay % 60; + long z = epochDay + 719468; + long era = (z >= 0 ? z : z - 146096) / 146097; + long doe = z - era * 146097; + long yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365; + long doy = doe - (365 * yoe + yoe / 4 - yoe / 100); + long mp = (5 * doy + 2) / 153; + int day = (int) (doy - (153 * mp + 2) / 5 + 1); + int month = (int) (mp < 10 ? mp + 3 : mp - 9); + int year = (int) (yoe + era * 400 + (month <= 2 ? 1 : 0)); + + if (year >= 0 && year <= 9999) { + int hi = year / 100; + int lo = year - hi * 100; + buf[pos++] = DIGIT_PAIRS[hi * 2]; + buf[pos++] = DIGIT_PAIRS[hi * 2 + 1]; + buf[pos++] = DIGIT_PAIRS[lo * 2]; + buf[pos++] = DIGIT_PAIRS[lo * 2 + 1]; + } else { + String yearStr = String.format("%04d", year); + for (int i = 0; i < yearStr.length(); i++) { + buf[pos++] = (byte) yearStr.charAt(i); + } + } + + buf[pos++] = '-'; + buf[pos++] = DIGIT_PAIRS[month * 2]; + buf[pos++] = DIGIT_PAIRS[month * 2 + 1]; + buf[pos++] = '-'; + buf[pos++] = DIGIT_PAIRS[day * 2]; + buf[pos++] = DIGIT_PAIRS[day * 2 + 1]; + buf[pos++] = 'T'; + buf[pos++] = DIGIT_PAIRS[hour * 2]; + buf[pos++] = DIGIT_PAIRS[hour * 2 + 1]; + buf[pos++] = ':'; + buf[pos++] = DIGIT_PAIRS[minute * 2]; + buf[pos++] = DIGIT_PAIRS[minute * 2 + 1]; + buf[pos++] = ':'; + buf[pos++] = DIGIT_PAIRS[second * 2]; + buf[pos++] = DIGIT_PAIRS[second * 2 + 1]; + + if (nano != 0) { + buf[pos++] = '.'; + // Write all 9 digits in 3 groups of 3 (avoids repeated division loops) + int hi = nano / 1_000_000; + int mid = (nano / 1_000) % 1_000; + int lo = nano % 1_000; + buf[pos] = (byte) ('0' + hi / 100); + buf[pos + 1] = (byte) ('0' + (hi / 10) % 10); + buf[pos + 2] = (byte) ('0' + hi % 10); + buf[pos + 3] = (byte) ('0' + mid / 100); + buf[pos + 4] = (byte) ('0' + (mid / 10) % 10); + buf[pos + 5] = (byte) ('0' + mid % 10); + buf[pos + 6] = (byte) ('0' + lo / 100); + buf[pos + 7] = (byte) ('0' + (lo / 10) % 10); + buf[pos + 8] = (byte) ('0' + lo % 10); + pos += 9; + while (buf[pos - 1] == '0') { + pos--; + } + } + + buf[pos++] = 'Z'; + return pos; + } + + /** + * Writes an HTTP-date timestamp directly to the byte buffer without quotes. + * Produces output like {@code Sat, 01 Jan 2026 00:00:00 GMT}. + */ + public static int writeHttpDate(byte[] buf, int pos, Instant value) { + long epochSecond = value.getEpochSecond(); + + long epochDay; + int secondOfDay; + if (epochSecond >= 0) { + epochDay = epochSecond / 86400; + secondOfDay = (int) (epochSecond - epochDay * 86400); + } else { + epochDay = Math.floorDiv(epochSecond, 86400); + secondOfDay = (int) Math.floorMod(epochSecond, 86400); + } + int hour = secondOfDay / 3600; + int minute = (secondOfDay % 3600) / 60; + int second = secondOfDay % 60; + long z = epochDay + 719468; + long era = (z >= 0 ? z : z - 146096) / 146097; + long doe = z - era * 146097; + long yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365; + long doy = doe - (365 * yoe + yoe / 4 - yoe / 100); + long mp = (5 * doy + 2) / 153; + int day = (int) (doy - (153 * mp + 2) / 5 + 1); + int month = (int) (mp < 10 ? mp + 3 : mp - 9); + int year = (int) (yoe + era * 400 + (month <= 2 ? 1 : 0)); + + int dow = (int) Math.floorMod(epochDay + 3, 7) + 1; + + byte[] dayName = DAY_NAMES[dow]; + buf[pos++] = dayName[0]; + buf[pos++] = dayName[1]; + buf[pos++] = dayName[2]; + buf[pos++] = ','; + buf[pos++] = ' '; + + buf[pos++] = DIGIT_PAIRS[day * 2]; + buf[pos++] = DIGIT_PAIRS[day * 2 + 1]; + buf[pos++] = ' '; + + byte[] monthName = MONTH_NAMES[month]; + buf[pos++] = monthName[0]; + buf[pos++] = monthName[1]; + buf[pos++] = monthName[2]; + buf[pos++] = ' '; + + if (year >= 0 && year <= 9999) { + int hi = year / 100; + int lo = year - hi * 100; + buf[pos++] = DIGIT_PAIRS[hi * 2]; + buf[pos++] = DIGIT_PAIRS[hi * 2 + 1]; + buf[pos++] = DIGIT_PAIRS[lo * 2]; + buf[pos++] = DIGIT_PAIRS[lo * 2 + 1]; + } else { + String yearStr = String.format("%04d", year); + for (int i = 0; i < yearStr.length(); i++) { + buf[pos++] = (byte) yearStr.charAt(i); + } + } + + buf[pos++] = ' '; + buf[pos++] = DIGIT_PAIRS[hour * 2]; + buf[pos++] = DIGIT_PAIRS[hour * 2 + 1]; + buf[pos++] = ':'; + buf[pos++] = DIGIT_PAIRS[minute * 2]; + buf[pos++] = DIGIT_PAIRS[minute * 2 + 1]; + buf[pos++] = ':'; + buf[pos++] = DIGIT_PAIRS[second * 2]; + buf[pos++] = DIGIT_PAIRS[second * 2 + 1]; + buf[pos++] = ' '; + buf[pos++] = 'G'; + buf[pos++] = 'M'; + buf[pos++] = 'T'; + return pos; + } + + private static int maxDayOfMonth(int year, int month) { + switch (month) { + case 2: + return (year % 4 == 0 && (year % 100 != 0 || year % 400 == 0)) ? 29 : 28; + case 4: + case 6: + case 9: + case 11: + return 30; + default: + return 31; + } + } + + private static long computeEpochDay(int year, int month, int day) { + long y = year; + long m = month; + if (m <= 2) { + y--; + m += 9; + } else { + m -= 3; + } + long era = (y >= 0 ? y : y - 399) / 400; + long yoe = y - era * 400; + long doy = (153 * m + 2) / 5 + day - 1; + long doe = yoe * 365 + yoe / 4 - yoe / 100 + doy; + return era * 146097 + doe - 719468; + } + + private static int digit(byte b) { + int d = b - '0'; + if (d < 0 || d > 9) { + return -1; + } + return d; + } +} diff --git a/codecs/codec-commons/src/test/java/software/amazon/smithy/java/codecs/commons/NumberCodecTest.java b/codecs/codec-commons/src/test/java/software/amazon/smithy/java/codecs/commons/NumberCodecTest.java new file mode 100644 index 0000000000..7333ea7780 --- /dev/null +++ b/codecs/codec-commons/src/test/java/software/amazon/smithy/java/codecs/commons/NumberCodecTest.java @@ -0,0 +1,433 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.codecs.commons; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.math.BigInteger; +import java.nio.charset.StandardCharsets; +import java.util.stream.Stream; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; + +class NumberCodecTest { + + // --- parseInt --- + + @ParameterizedTest + @ValueSource(strings = { + "0", + "1", + "-1", + "9", + "10", + "99", + "100", + "999", + "1000", + "9999", + // Exercises SWAR parse4Digits for exactly 4 digits + "1234", + // 5 digits: SWAR parses 4 then falls through to single-digit loop + "12345", + // 8 digits: two rounds of SWAR + "12345678", + "-12345678", + // MAX and MIN boundaries + "2147483647", + "-2147483648", + // Just below MAX (ensures we don't over-reject) + "2147483646", + "-2147483647", + // Leading zeros must not trigger false overflow + "0042", + "00000000000000042", + "-0042", + "0000000000", + // Plus sign prefix + "+1", + "+2147483647", + "+0042" + }) + void parseInt(String input) { + int expected = Integer.parseInt(input); + byte[] buf = input.getBytes(StandardCharsets.US_ASCII); + assertEquals(expected, NumberCodec.parseInt(buf, 0, buf.length)); + } + + @ParameterizedTest + @ValueSource(strings = { + "", + "-", + "+", + "abc", + // Invalid digit in the middle (exercises SWAR fallback to single-digit validation) + "12a4", + // Invalid byte at position 0 in a 4-byte SWAR chunk + "x234", + // Invalid byte at position 3 in a 4-byte SWAR chunk + "123x", + // Overflow: one past MAX + "2147483648", + // Overflow: one past MIN magnitude + "-2147483649", + // Silent wraparound without length guard + "99999999999", + // 10-digit overflow that previously wrapped silently + "9005150711", + "3000000000", + "-9005150711" + }) + void parseIntRejects(String input) { + byte[] buf = input.getBytes(StandardCharsets.US_ASCII); + assertThrows(NumberFormatException.class, () -> NumberCodec.parseInt(buf, 0, buf.length)); + } + + // --- parseLong --- + + @ParameterizedTest + @ValueSource(strings = { + "0", + "-1", + "12345678", + // 8-digit SWAR boundary + "99999999", + "100000000", + // 9 digits: 8-digit SWAR + 1 remaining + "123456789", + // 12 digits: one 8-digit SWAR + one 4-digit SWAR + "123456789012", + "-123456789012", + // 16 digits: two rounds of 8-digit SWAR + "1234567890123456", + // MAX/MIN + "9223372036854775807", + "-9223372036854775808", + "9223372036854775806", + // Leading zeros + "000000000000000000001", + "-00042" + }) + void parseLong(String input) { + long expected = Long.parseLong(input); + byte[] buf = input.getBytes(StandardCharsets.US_ASCII); + assertEquals(expected, NumberCodec.parseLong(buf, 0, buf.length)); + } + + @ParameterizedTest + @ValueSource(strings = { + "", + "-", + // Invalid digit in single-digit loop after SWAR + "12345678x", + // Invalid byte at position 0 of 8-byte SWAR chunk + "x2345678", + // Invalid byte at position 7 of 8-byte SWAR chunk + "1234567x", + "9223372036854775808", + "-9223372036854775809", + "99999999999999999999" + }) + void parseLongRejects(String input) { + byte[] buf = input.getBytes(StandardCharsets.US_ASCII); + assertThrows(NumberFormatException.class, () -> NumberCodec.parseLong(buf, 0, buf.length)); + } + + // --- parseInt with offset --- + + static Stream parseWithOffsetCases() { + return Stream.of( + Arguments.of("xxx42yyy", 3, 2), + Arguments.of(" -7 ", 2, 2), + Arguments.of("ab1234cd", 2, 4)); + } + + @ParameterizedTest + @MethodSource("parseWithOffsetCases") + void parseIntWithOffset(String input, int start, int len) { + byte[] buf = input.getBytes(StandardCharsets.US_ASCII); + String slice = input.substring(start, start + len); + int expected = Integer.parseInt(slice); + assertEquals(expected, NumberCodec.parseInt(buf, start, len)); + } + + // --- writeInt --- + + @ParameterizedTest + @ValueSource(ints = { + 0, + 1, + -1, + 9, + 10, + 99, + 100, + 999, + 9999, + 10000, + 10001, + Integer.MAX_VALUE, + Integer.MIN_VALUE, + // Exercises digit-pair path for values in [100, 10000) + 123, + 1000, + 5678, + // Value with remainder exactly 0 after /10000 (tests write4Digits with r=0) + 20000, + // Exercises writePositiveInt 4-digit SWAR write path + 12345, + // Value >= 100000000 forces two rounds of SWAR 4-digit writes + 123456789, + 100000000, + 1000000000 + }) + void writeInt(int value) { + String expected = Integer.toString(value); + byte[] buf = new byte[12]; + int end = NumberCodec.writeInt(buf, 0, value); + assertEquals(expected, new String(buf, 0, end, StandardCharsets.US_ASCII)); + } + + // --- writeLong --- + // Values chosen to exercise every branch of digitCountLong (11-19 digits) + // and all three paths in writePositiveLong + + @ParameterizedTest + @ValueSource(longs = { + 0L, + -1L, + Long.MAX_VALUE, + Long.MIN_VALUE, + // Exactly Integer.MAX_VALUE as long (boundary: writePositiveLong delegates to writePositiveInt) + 2147483647L, + // Just above int range -> exercises long split path + 2147483648L, + // Exactly at 10^16 boundary (writePositiveLong two-chunk vs three-chunk split) + 10000000000000000L, + // 11 digits + 10000000000L, + // 12 digits + 100000000001L, + // 13 digits + 1000000000001L, + // 14 digits + 10000000000001L, + // 15 digits + 100000000000001L, + // 16 digits + 1000000000000001L, + // 17 digits -> three-chunk split in writePositiveLong + 10000000000000001L, + // 18 digits + 100000000000000001L, + // 19 digits + 1000000000000000001L + }) + void writeLong(long value) { + String expected = Long.toString(value); + byte[] buf = new byte[21]; + int end = NumberCodec.writeLong(buf, 0, value); + assertEquals(expected, new String(buf, 0, end, StandardCharsets.US_ASCII)); + } + + // --- writeBigInteger --- + + static Stream writeBigIntegerCases() { + return Stream.of( + BigInteger.ZERO, + BigInteger.ONE, + BigInteger.ONE.negate(), + BigInteger.valueOf(Long.MAX_VALUE), + // Exactly 10^18 (boundary: compareTo(TEN_TO_18) == 0) + new BigInteger("1000000000000000000"), + // Just above 10^18 (triggers two-chunk path: high < 10^18) + new BigInteger("1000000000000000001"), + // Exactly 10^36 (boundary: high == TEN_TO_18, triggers three-chunk) + new BigInteger("1000000000000000000000000000000000000"), + // Just above 10^36 (triggers three-chunk path: qr2[0] < 10^18) + new BigInteger("1000000000000000000000000000000000000001"), + // Exactly 10^54 (boundary: qr2[0] == TEN_TO_18, triggers toString fallback) + new BigInteger("1" + "0".repeat(54)), + // Way above 10^54 (deep in toString() fallback) + new BigInteger("1" + "0".repeat(55)), + new BigInteger("-99999999999999999999")) + .map(Arguments::of); + } + + @ParameterizedTest + @MethodSource("writeBigIntegerCases") + void writeBigInteger(BigInteger value) { + String expected = value.toString(); + byte[] buf = new byte[80]; + int end = NumberCodec.writeBigInteger(buf, 0, value); + assertEquals(expected, new String(buf, 0, end, StandardCharsets.US_ASCII)); + } + + // --- writeDouble --- + + @ParameterizedTest + @ValueSource(doubles = { + 0.0, + 1.0, + -1.0, + 42.0, + 1000000.0, + 3.14, + 1.0E-7, + 1.7976931348623157E308, + 5.0E-324, + -0.1 + }) + void writeDouble(double value) { + byte[] buf = new byte[32]; + int end = NumberCodec.writeDouble(buf, 0, value); + String actual = new String(buf, 0, end, StandardCharsets.US_ASCII); + assertEquals(value, Double.parseDouble(actual)); + } + + // --- writeFloat --- + + @ParameterizedTest + @ValueSource(floats = { + 0.0f, + 1.0f, + -1.0f, + 100.0f, + 3.14f, + Float.MAX_VALUE, + Float.MIN_VALUE + }) + void writeFloat(float value) { + byte[] buf = new byte[20]; + int end = NumberCodec.writeFloat(buf, 0, value); + String actual = new String(buf, 0, end, StandardCharsets.US_ASCII); + assertEquals(value, Float.parseFloat(actual)); + } + + // --- writeNonFiniteFloat --- + + static Stream nonFiniteFloatCases() { + return Stream.of( + Arguments.of(Float.NaN, "NaN"), + Arguments.of(Float.POSITIVE_INFINITY, "Infinity"), + Arguments.of(Float.NEGATIVE_INFINITY, "-Infinity")); + } + + @ParameterizedTest + @MethodSource("nonFiniteFloatCases") + void writeNonFiniteFloat(float value, String expected) { + byte[] buf = new byte[24]; + int end = NumberCodec.writeNonFiniteFloat(buf, 0, value); + assertEquals(expected, new String(buf, 0, end, StandardCharsets.US_ASCII)); + } + + @ParameterizedTest + @MethodSource("nonFiniteFloatCases") + void writeNonFiniteFloatQuoted(float value, String expected) { + byte[] buf = new byte[24]; + int end = NumberCodec.writeNonFiniteFloatQuoted(buf, 0, value); + assertEquals("\"" + expected + "\"", new String(buf, 0, end, StandardCharsets.US_ASCII)); + } + + // --- writeNonFiniteDouble --- + + static Stream nonFiniteDoubleCases() { + return Stream.of( + Arguments.of(Double.NaN, "NaN"), + Arguments.of(Double.POSITIVE_INFINITY, "Infinity"), + Arguments.of(Double.NEGATIVE_INFINITY, "-Infinity")); + } + + @ParameterizedTest + @MethodSource("nonFiniteDoubleCases") + void writeNonFiniteDouble(double value, String expected) { + byte[] buf = new byte[24]; + int end = NumberCodec.writeNonFiniteDouble(buf, 0, value); + assertEquals(expected, new String(buf, 0, end, StandardCharsets.US_ASCII)); + } + + @ParameterizedTest + @MethodSource("nonFiniteDoubleCases") + void writeNonFiniteDoubleQuoted(double value, String expected) { + byte[] buf = new byte[24]; + int end = NumberCodec.writeNonFiniteDoubleQuoted(buf, 0, value); + assertEquals("\"" + expected + "\"", new String(buf, 0, end, StandardCharsets.US_ASCII)); + } + + // --- writeFloatFull / writeDoubleFull --- + + @ParameterizedTest + @ValueSource(floats = {0.0f, 1.5f, -3.14f, Float.NaN, Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY}) + void writeFloatFull(float value) { + byte[] buf = new byte[24]; + int end = NumberCodec.writeFloatFull(buf, 0, value); + String actual = new String(buf, 0, end, StandardCharsets.US_ASCII); + if (Float.isNaN(value)) { + assertEquals("NaN", actual); + } else if (value == Float.POSITIVE_INFINITY) { + assertEquals("Infinity", actual); + } else if (value == Float.NEGATIVE_INFINITY) { + assertEquals("-Infinity", actual); + } else { + assertEquals(value, Float.parseFloat(actual)); + } + } + + @ParameterizedTest + @ValueSource(doubles = {0.0, 1.5, -3.14, Double.NaN, Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY}) + void writeDoubleFull(double value) { + byte[] buf = new byte[24]; + int end = NumberCodec.writeDoubleFull(buf, 0, value); + String actual = new String(buf, 0, end, StandardCharsets.US_ASCII); + if (Double.isNaN(value)) { + assertEquals("NaN", actual); + } else if (value == Double.POSITIVE_INFINITY) { + assertEquals("Infinity", actual); + } else if (value == Double.NEGATIVE_INFINITY) { + assertEquals("-Infinity", actual); + } else { + assertEquals(value, Double.parseDouble(actual)); + } + } + + @ParameterizedTest + @ValueSource(floats = {0.0f, 1.5f, Float.NaN, Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY}) + void writeFloatFullQuoted(float value) { + byte[] buf = new byte[24]; + int end = NumberCodec.writeFloatFullQuoted(buf, 0, value); + String actual = new String(buf, 0, end, StandardCharsets.US_ASCII); + if (Float.isFinite(value)) { + assertEquals(value, Float.parseFloat(actual)); + } else if (Float.isNaN(value)) { + assertEquals("\"NaN\"", actual); + } else if (value > 0) { + assertEquals("\"Infinity\"", actual); + } else { + assertEquals("\"-Infinity\"", actual); + } + } + + @ParameterizedTest + @ValueSource(doubles = {0.0, 1.5, Double.NaN, Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY}) + void writeDoubleFullQuoted(double value) { + byte[] buf = new byte[24]; + int end = NumberCodec.writeDoubleFullQuoted(buf, 0, value); + String actual = new String(buf, 0, end, StandardCharsets.US_ASCII); + if (Double.isFinite(value)) { + assertEquals(value, Double.parseDouble(actual)); + } else if (Double.isNaN(value)) { + assertEquals("\"NaN\"", actual); + } else if (value > 0) { + assertEquals("\"Infinity\"", actual); + } else { + assertEquals("\"-Infinity\"", actual); + } + } + +} diff --git a/codecs/codec-commons/src/test/java/software/amazon/smithy/java/codecs/commons/TimestampCodecTest.java b/codecs/codec-commons/src/test/java/software/amazon/smithy/java/codecs/commons/TimestampCodecTest.java new file mode 100644 index 0000000000..10aa164dd6 --- /dev/null +++ b/codecs/codec-commons/src/test/java/software/amazon/smithy/java/codecs/commons/TimestampCodecTest.java @@ -0,0 +1,480 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.codecs.commons; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; + +import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.time.ZoneOffset; +import java.time.format.DateTimeFormatter; +import java.util.stream.Stream; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; + +class TimestampCodecTest { + + private static final DateTimeFormatter HTTP_FORMATTER = + DateTimeFormatter.ofPattern("EEE, dd MMM yyyy HH:mm:ss 'GMT'").withZone(ZoneOffset.UTC); + + private static Instant parse8601(String s) { + byte[] buf = s.getBytes(StandardCharsets.US_ASCII); + return TimestampCodec.parseIso8601(buf, 0, buf.length); + } + + private static Instant parseHttp(String s) { + byte[] buf = s.getBytes(StandardCharsets.US_ASCII); + return TimestampCodec.parseHttpDate(buf, 0, buf.length); + } + + private static Instant parseEpoch(String s) { + byte[] buf = s.getBytes(StandardCharsets.US_ASCII); + return TimestampCodec.parseEpochSeconds(buf, 0, buf.length); + } + + private static Instant jdkParseEpochSeconds(String s) { + try { + BigDecimal epoch = new BigDecimal(s); + if (Math.max(epoch.scale(), 0) > 9) { + return null; + } + BigDecimal[] parts = epoch.divideAndRemainder(BigDecimal.ONE); + long seconds = parts[0].longValueExact(); + long nanos = parts[1].movePointRight(9).longValueExact(); + return Instant.ofEpochSecond(seconds, nanos); + } catch (Exception e) { + return null; + } + } + + @Nested + class ParseIso8601 { + + @ParameterizedTest + @ValueSource(strings = { + "2024-01-15T10:30:00Z", + "2024-01-15T10:30:00.1Z", + "2024-01-15T10:30:00.123Z", + "2024-01-15T10:30:00.123456789Z", + // Midnight and end of day + "2024-01-01T00:00:00Z", + "2024-12-31T23:59:59Z", + // Leap year Feb 29 + "2024-02-29T00:00:00Z", + // Non-leap year Feb 28 + "2023-02-28T12:00:00Z", + // Year boundaries + "0001-01-01T00:00:00Z", + "9999-12-31T23:59:59Z", + // Century-divisible non-leap year + "1900-02-28T00:00:00Z", + // Century-divisible leap year (400 rule) + "2000-02-29T00:00:00Z", + // Single fractional digit + "2024-01-01T00:00:00.5Z", + // 9 fractional digits (max allowed) + "2024-01-01T00:00:00.000000001Z", + // Various time boundaries that kill hour/min/sec conditional mutations + "2024-01-01T23:59:59Z", + "2024-01-01T00:00:01Z", + "2024-01-01T01:01:01Z", + // Month 12, day 31 (boundary for month > 12 check) + "2024-12-01T00:00:00Z", + // Month 01 (boundary for month < 1 check) + "2024-01-31T00:00:00Z", + // Day 01 (boundary for day < 1 check) + "2024-06-01T00:00:00Z", + // Before epoch with negative year math in computeEpochDay + "0100-01-01T00:00:00Z" + }) + void parsesValid(String input) { + Instant expected = Instant.parse(input); + assertEquals(expected, parse8601(input)); + } + + @ParameterizedTest + @ValueSource(strings = { + // Too short + "2024-01-15T10:30:0Z", + // Missing Z + "2024-01-15T10:30:00", + // Trailing garbage after Z + "2024-01-15T10:30:00Zx", + // Invalid month + "2024-13-15T10:30:00Z", + "2024-00-15T10:30:00Z", + // Invalid day + "2024-01-32T10:30:00Z", + "2024-01-00T10:30:00Z", + // Feb 29 in non-leap year + "2023-02-29T00:00:00Z", + // Feb 29 in century non-leap year + "1900-02-29T00:00:00Z", + // Invalid hour/min/sec + "2024-01-15T24:00:00Z", + "2024-01-15T10:60:00Z", + "2024-01-15T10:30:60Z", + // Bad separators + "2024/01/15T10:30:00Z", + "2024-01-15 10:30:00Z", + // Empty fractional + "2024-01-15T10:30:00.Z", + // Over 9 fractional digits + "2024-01-15T10:30:00.1234567890Z", + // Non-digit in date (year position 0) + "x024-01-15T10:30:00Z", + // Non-digit in month + "2024-x1-15T10:30:00Z", + // Non-digit in day + "2024-01-x5T10:30:00Z", + // Non-digit in hour + "2024-01-15Tx0:30:00Z", + // Non-digit in minute + "2024-01-15T10:x0:00Z", + // Non-digit in second + "2024-01-15T10:30:x0Z", + // Apr 31 (30-day month boundary) + "2024-04-31T00:00:00Z", + // Jun 31 + "2024-06-31T00:00:00Z" + }) + void rejectsInvalid(String input) { + assertNull(parse8601(input)); + } + } + + @Nested + class ParseHttpDate { + + @ParameterizedTest + @ValueSource(strings = { + "Mon, 01 Jan 2024 12:00:00 GMT", + "Thu, 29 Feb 2024 00:00:00 GMT", + "Sat, 01 Jan 2000 00:00:00 GMT", + "Sun, 31 Dec 2023 23:59:59 GMT", + // All month names + "Wed, 15 Jan 2025 00:00:00 GMT", + "Sat, 15 Feb 2025 00:00:00 GMT", + "Sat, 15 Mar 2025 00:00:00 GMT", + "Tue, 15 Apr 2025 00:00:00 GMT", + "Thu, 15 May 2025 00:00:00 GMT", + "Sun, 15 Jun 2025 00:00:00 GMT", + "Tue, 15 Jul 2025 00:00:00 GMT", + "Fri, 15 Aug 2025 00:00:00 GMT", + "Mon, 15 Sep 2025 00:00:00 GMT", + "Wed, 15 Oct 2025 00:00:00 GMT", + "Sat, 15 Nov 2025 00:00:00 GMT", + "Mon, 15 Dec 2025 00:00:00 GMT", + // Boundary time values + "Mon, 01 Jan 2024 00:00:00 GMT", + "Mon, 01 Jan 2024 23:59:59 GMT" + }) + void parsesValid(String input) { + Instant expected = HTTP_FORMATTER.parse(input, Instant::from); + assertEquals(expected, parseHttp(input)); + } + + @ParameterizedTest + @ValueSource(strings = { + // Too short + "Mon, 01 Jan 2024 12:00:00 GM", + // Bad month -- 'J' first char, 'a' second but wrong third (not 'n') + "Mon, 01 Jax 2024 12:00:00 GMT", + // Bad month -- 'J' first char, 'u' second but wrong third (not 'n' or 'l') + "Mon, 01 Jux 2024 12:00:00 GMT", + // Bad month -- 'J' first char, wrong second char entirely + "Mon, 01 Jxx 2024 12:00:00 GMT", + // Bad month -- 'F' first char but wrong combo + "Mon, 01 Fxx 2024 12:00:00 GMT", + // Bad month -- 'M' first char, 'a' second but wrong third (not 'r' or 'y') + "Mon, 01 Max 2024 12:00:00 GMT", + // Bad month -- 'M' first char, wrong second char + "Mon, 01 Mxx 2024 12:00:00 GMT", + // Bad month -- 'A' first char, 'p' second but wrong third + "Mon, 01 Apx 2024 12:00:00 GMT", + // Bad month -- 'A' first char, 'u' second but wrong third + "Mon, 01 Aux 2024 12:00:00 GMT", + // Bad month -- 'A' first char, wrong second char + "Mon, 01 Axx 2024 12:00:00 GMT", + // Bad month -- 'S' first char but wrong combo + "Mon, 01 Sxx 2024 12:00:00 GMT", + // Bad month -- 'O' first char but wrong combo + "Mon, 01 Oxx 2024 12:00:00 GMT", + // Bad month -- 'N' first char but wrong combo + "Mon, 01 Nxx 2024 12:00:00 GMT", + // Bad month -- 'D' first char but wrong combo + "Mon, 01 Dxx 2024 12:00:00 GMT", + // Bad month -- wrong first char entirely + "Mon, 01 Xxx 2024 12:00:00 GMT", + // Missing comma + "Mon 01 Jan 2024 12:00:00 GMT", + // Invalid day (Feb 29 non-leap) + "Wed, 29 Feb 2023 00:00:00 GMT", + // Invalid time + "Mon, 01 Jan 2024 25:00:00 GMT", + "Mon, 01 Jan 2024 12:60:00 GMT", + "Mon, 01 Jan 2024 12:00:60 GMT", + // Not GMT + "Mon, 01 Jan 2024 12:00:00 UTC", + // Day 00 + "Mon, 00 Jan 2024 12:00:00 GMT", + // Non-digit in day + "Mon, x1 Jan 2024 12:00:00 GMT", + // Non-digit in year + "Mon, 01 Jan x024 12:00:00 GMT", + // Non-digit in hour + "Mon, 01 Jan 2024 x2:00:00 GMT", + // Non-digit in minute + "Mon, 01 Jan 2024 12:x0:00 GMT", + // Non-digit in second + "Mon, 01 Jan 2024 12:00:x0 GMT" + }) + void rejectsInvalid(String input) { + assertNull(parseHttp(input)); + } + } + + @Nested + class ParseEpochSeconds { + + @ParameterizedTest + @ValueSource(strings = { + "0", + "1", + "1234567890", + // Negative + "-1", + "-62135596800", + // Fractional seconds + "1.5", + "1.001", + "1.000000001", + "0.1", + // Negative with fraction + "-1.5", + "-0.001", + // Scientific notation -- positive exponent + "1e0", + "1e3", + "1E3", + // Scientific notation -- negative exponent (fractional result) + "15e-1", + // Scientific notation -- with decimal point + "1.5e9", + "1.5e3", + // Scientific notation -- explicit + sign on exponent + "1e+3", + // Negative with scientific notation + "-1.5e3", + "-15e-1", + // Large mantissa with exponent + "123456789012345678e-8", + // Fractional with varying digit counts + "1.1e1", + "1.12e2", + "1.123e3", + // Exponent that exactly cancels fractional + "1.5e0", + // fracDigits exactly 9 (boundary for fracDigits > 9 rejection) + "1e-9", + // Shift=0 produces integer from sci notation with decimal + "1.5e1", + // Negative with sci notation and fractional nanos + "-1e0", + "-1.5e0", + // Exponent with explicit + sign and multi-digit exponent + "1e+10", + // Integer-only mantissa with negative exponent + "123e-2", + "1234567e-7" + }) + void parsesValid(String input) { + Instant expected = jdkParseEpochSeconds(input); + assertNotNull(expected, "JDK reference failed to parse: " + input); + assertEquals(expected, parseEpoch(input)); + } + + @ParameterizedTest + @ValueSource(strings = { + "", + "-", + "abc", + // Overflow + "99999999999999999999999", + // More than 9 fractional digits (non-exponent) + "1.0000000001", + // Exponent too large + "1e21", + // Trailing garbage + "123abc", + // Just exponent sign, no digits + "1e", + "1e+", + "1e-", + // Negative with nothing after + "-e1" + }) + void rejectsInvalid(String input) { + assertNull(parseEpoch(input)); + } + } + + @Nested + class WriteEpochSeconds { + + static Stream cases() { + return Stream.of( + Instant.EPOCH, + Instant.ofEpochSecond(1), + Instant.ofEpochSecond(-1), + Instant.ofEpochSecond(1234567890), + // With nanos -- various trailing-zero patterns + Instant.ofEpochSecond(1, 500000000), + Instant.ofEpochSecond(1, 100000000), + Instant.ofEpochSecond(1, 1), + Instant.ofEpochSecond(1, 123456789), + Instant.ofEpochSecond(1, 120000000), + Instant.ofEpochSecond(0, 100000000), + // Negative with nanos -- exercises the epochSecond+1 / 10^9-nano path + Instant.ofEpochSecond(-1, 500000000), + Instant.ofEpochSecond(-2, 500000000), + Instant.ofEpochSecond(-1, 999000000), + Instant.ofEpochSecond(-1, 1), + // Negative zero seconds with nano (writes "-0.xxx") + Instant.ofEpochSecond(-1, 999999999)) + .map(Arguments::of); + } + + @ParameterizedTest + @MethodSource("cases") + void writeThenParseMatchesOriginal(Instant instant) { + byte[] buf = new byte[64]; + int end = TimestampCodec.writeEpochSeconds(buf, + 0, + instant.getEpochSecond(), + instant.getNano()); + Instant reparsed = TimestampCodec.parseEpochSeconds(buf, 0, end); + assertEquals(instant, reparsed); + } + } + + @Nested + class WriteIso8601 { + + static Stream cases() { + return Stream.of( + Instant.parse("2024-01-15T10:30:00Z"), + Instant.parse("2024-01-15T10:30:00.1Z"), + Instant.parse("2024-01-15T10:30:00.123Z"), + Instant.parse("2024-01-15T10:30:00.123456789Z"), + Instant.EPOCH, + // Before epoch (exercises negative epochDay in civil date conversion) + Instant.parse("1969-12-31T23:59:59Z"), + // Trailing zeros stripped from fractional + Instant.parse("2024-01-01T00:00:00.120Z"), + // Year 0100 -- exercises early-year civil date conversion + Instant.parse("0100-03-01T12:00:00Z")) + .map(Arguments::of); + } + + @ParameterizedTest + @MethodSource("cases") + void writeMatchesJdkAndRoundTrips(Instant instant) { + byte[] buf = new byte[64]; + int end = TimestampCodec.writeIso8601(buf, 0, instant); + String written = new String(buf, 0, end, StandardCharsets.US_ASCII); + + assertEquals(instant, Instant.parse(written)); + + Instant selfParsed = TimestampCodec.parseIso8601(buf, 0, end); + if (selfParsed != null) { + assertEquals(instant, selfParsed); + } + } + } + + @Nested + class WriteHttpDate { + + static Stream cases() { + return Stream.of( + Instant.parse("2024-01-01T12:00:00Z"), + Instant.parse("2024-02-29T00:00:00Z"), + Instant.EPOCH, + Instant.parse("2023-12-31T23:59:59Z"), + // Before epoch (exercises negative epochDay in dow/civil date calc) + Instant.parse("1969-06-15T08:30:00Z")) + .map(Arguments::of); + } + + @ParameterizedTest + @MethodSource("cases") + void writeMatchesJdkFormatter(Instant instant) { + byte[] buf = new byte[64]; + int end = TimestampCodec.writeHttpDate(buf, 0, instant); + String written = new String(buf, 0, end, StandardCharsets.US_ASCII); + + String expected = HTTP_FORMATTER.format(instant); + assertEquals(expected, written); + + Instant selfParsed = TimestampCodec.parseHttpDate(buf, 0, end); + assertEquals(instant, selfParsed); + } + } + + @Nested + class RoundTrips { + + @Test + void iso8601RoundTrip() { + Instant original = Instant.parse("2024-06-15T14:30:45.123456789Z"); + byte[] buf = new byte[64]; + int end = TimestampCodec.writeIso8601(buf, 0, original); + Instant parsed = TimestampCodec.parseIso8601(buf, 0, end); + assertEquals(original, parsed); + } + + @Test + void httpDateRoundTrip() { + Instant original = Instant.parse("2024-06-15T14:30:45Z"); + byte[] buf = new byte[64]; + int end = TimestampCodec.writeHttpDate(buf, 0, original); + Instant parsed = TimestampCodec.parseHttpDate(buf, 0, end); + assertEquals(original, parsed); + } + + @Test + void epochSecondsRoundTrip() { + Instant original = Instant.ofEpochSecond(1718458245L, 123456789); + byte[] buf = new byte[64]; + int end = TimestampCodec.writeEpochSeconds(buf, + 0, + original.getEpochSecond(), + original.getNano()); + Instant parsed = TimestampCodec.parseEpochSeconds(buf, 0, end); + assertEquals(original, parsed); + } + + @Test + void negativeEpochSecondsRoundTrip() { + Instant original = Instant.ofEpochSecond(-100, 500000000); + byte[] buf = new byte[64]; + int end = TimestampCodec.writeEpochSeconds(buf, + 0, + original.getEpochSecond(), + original.getNano()); + Instant parsed = TimestampCodec.parseEpochSeconds(buf, 0, end); + assertEquals(original, parsed); + } + } +} diff --git a/codecs/json-codec/build.gradle.kts b/codecs/json-codec/build.gradle.kts index 274aa19819..548a0a922b 100644 --- a/codecs/json-codec/build.gradle.kts +++ b/codecs/json-codec/build.gradle.kts @@ -2,7 +2,7 @@ plugins { id("smithy-java.module-conventions") id("smithy-java.fuzz-test") id("software.amazon.smithy.gradle.smithy-base") - alias(libs.plugins.shadow) + id("com.gradleup.shadow") } description = "This module provides json functionality" @@ -12,10 +12,9 @@ extra["moduleName"] = "software.amazon.smithy.java.json" dependencies { api(project(":core")) + api(project(":codecs:codec-commons", configuration = "shadow")) compileOnly(libs.jackson.core) - compileOnly(libs.fastdoubleparser) testRuntimeOnly(libs.jackson.core) - testRuntimeOnly(libs.fastdoubleparser) smithyBuild(project(":codegen:codegen-plugin")) } @@ -32,15 +31,7 @@ tasks { .toString(), ), ) - include( - dependency( - libs.fastdoubleparser - .get() - .toString(), - ), - ) relocate("tools.jackson.core", "software.amazon.smithy.java.internal.shaded.tools.jackson.core") - relocate("ch.randelshofer", "software.amazon.smithy.java.internal.shaded.ch.randelshofer") } } jar { diff --git a/codecs/json-codec/src/main/java/software/amazon/smithy/java/json/smithy/JsonReadUtils.java b/codecs/json-codec/src/main/java/software/amazon/smithy/java/json/smithy/JsonReadUtils.java index 9f71944d1f..31e69ae047 100644 --- a/codecs/json-codec/src/main/java/software/amazon/smithy/java/json/smithy/JsonReadUtils.java +++ b/codecs/json-codec/src/main/java/software/amazon/smithy/java/json/smithy/JsonReadUtils.java @@ -5,14 +5,16 @@ package software.amazon.smithy.java.json.smithy; -import ch.randelshofer.fastdoubleparser.JavaDoubleParser; import java.lang.invoke.MethodHandles; import java.lang.invoke.VarHandle; +import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.charset.StandardCharsets; import java.time.Instant; import java.util.Arrays; import java.util.Base64; +import software.amazon.smithy.java.codecs.commons.NumberCodec; +import software.amazon.smithy.java.codecs.commons.TimestampCodec; import software.amazon.smithy.java.core.serde.SerializationException; /** @@ -25,6 +27,8 @@ final class JsonReadUtils { private JsonReadUtils() {} + private static final Base64.Decoder BASE64_DECODER = Base64.getDecoder(); + // VarHandle for reading 8 bytes at a time from byte arrays (SWAR technique) private static final VarHandle LONG_HANDLE = MethodHandles.byteArrayViewVarHandle(long[].class, ByteOrder.LITTLE_ENDIAN); @@ -109,7 +113,6 @@ static void parseLong(byte[] buf, int pos, int end, SmithyJsonDeserializer deser static void parseDouble(byte[] buf, int pos, int end, SmithyJsonDeserializer deser) { int start = pos; - // Optional minus sign if (pos < end && buf[pos] == '-') { pos++; } @@ -118,7 +121,6 @@ static void parseDouble(byte[] buf, int pos, int end, SmithyJsonDeserializer des throw new SerializationException("Unexpected end of input while parsing number"); } - // Integer part — no leading zeros byte first = buf[pos]; if (first < '0' || first > '9') { throw new SerializationException("Expected digit, found: " + describeChar(first)); @@ -135,7 +137,6 @@ static void parseDouble(byte[] buf, int pos, int end, SmithyJsonDeserializer des } } - // Optional fractional part if (pos < end && buf[pos] == '.') { pos++; if (pos >= end || buf[pos] < '0' || buf[pos] > '9') { @@ -146,7 +147,6 @@ static void parseDouble(byte[] buf, int pos, int end, SmithyJsonDeserializer des } } - // Optional exponent if (pos < end && (buf[pos] == 'e' || buf[pos] == 'E')) { pos++; if (pos < end && (buf[pos] == '+' || buf[pos] == '-')) { @@ -160,8 +160,7 @@ static void parseDouble(byte[] buf, int pos, int end, SmithyJsonDeserializer des } } - // Parse directly from byte array — no String allocation. - deser.parsedDouble = JavaDoubleParser.parseDouble(buf, start, pos - start); + deser.parsedDouble = NumberCodec.parseDouble(buf, start, pos - start); deser.parsedEndPos = pos; } @@ -207,13 +206,13 @@ static void parseString(byte[] buf, int pos, int end, SmithyJsonDeserializer des while (pos < end) { byte b = buf[pos]; if (b == '"') { - // No escapes found — fast path + // No escapes found -- fast path deser.parsedString = new String(buf, start, pos - start, StandardCharsets.UTF_8); deser.parsedEndPos = pos + 1; return; } if (b == '\\') { - // Has escapes — slow path + // Has escapes -- slow path parseStringWithEscapes(buf, start, pos, end, deser); return; } @@ -254,9 +253,7 @@ private static void parseStringWithEscapes( int end, SmithyJsonDeserializer deser ) { - // Build a StringBuilder from what we've read so far + escaped content StringBuilder sb = new StringBuilder(escapePos - start + 16); - // Append everything before the first escape as UTF-8 sb.append(new String(buf, start, escapePos - start, StandardCharsets.UTF_8)); int pos = escapePos; @@ -295,7 +292,6 @@ private static void parseStringWithEscapes( char c = parseHex4(buf, pos); pos += 4; if (Character.isHighSurrogate(c)) { - // Expect low surrogate escape if (pos + 6 > end || buf[pos] != '\\' || buf[pos + 1] != 'u') { throw new SerializationException("Expected low surrogate after high surrogate"); } @@ -319,7 +315,6 @@ private static void parseStringWithEscapes( "Invalid escape character: \\" + (char) escaped); } } else { - // Regular UTF-8 byte — decode if ((b & 0x80) == 0) { sb.append((char) b); pos++; @@ -337,7 +332,7 @@ private static char parseHex4(byte[] buf, int pos) { int value = 0; for (int i = 0; i < 4; i++) { byte b = buf[pos + i]; - if (b < 0 || b >= HEX_VALUES.length || HEX_VALUES[b] == -1) { + if (b < 0 || HEX_VALUES[b] == -1) { throw new SerializationException( "Invalid hex digit in \\u escape: " + (char) (b & 0xFF)); } @@ -357,7 +352,6 @@ private static int[] decodeUtf8Char(byte[] buf, int pos, int end) { if ((b & 0x80) == 0) { return new int[] {b, pos + 1}; } else if ((b & 0xE0) == 0xC0) { - // 2-byte if (pos + 2 > end) { throw new SerializationException("Truncated UTF-8 sequence"); } @@ -368,7 +362,6 @@ private static int[] decodeUtf8Char(byte[] buf, int pos, int end) { } return new int[] {cp, pos + 2}; } else if ((b & 0xF0) == 0xE0) { - // 3-byte if (pos + 3 > end) { throw new SerializationException("Truncated UTF-8 sequence"); } @@ -384,7 +377,6 @@ private static int[] decodeUtf8Char(byte[] buf, int pos, int end) { } return new int[] {cp, pos + 3}; } else if ((b & 0xF8) == 0xF0) { - // 4-byte if (pos + 4 > end) { throw new SerializationException("Truncated UTF-8 sequence"); } @@ -415,11 +407,9 @@ private static void validateContinuationByte(byte b) { * The fast check at the top covers the common case (next byte is not whitespace). */ static int skipWhitespace(byte[] buf, int pos, int end) { - // Fast check: most common case is next byte is not whitespace if (pos < end && buf[pos] > ' ') { return pos; } - // Scalar loop for remaining bytes while (pos < end) { byte b = buf[pos]; if (b != ' ' && b != '\n' && b != '\r' && b != '\t') { @@ -430,307 +420,66 @@ static int skipWhitespace(byte[] buf, int pos, int end) { return pos; } - // Month lookup: index by first two bytes of 3-letter month abbreviation - // Jan=1, Feb=2, ..., Dec=12. Used by parseHttpDate. - private static final int[] MONTH_LOOKUP = new int[128 * 128]; - - static { - // Populate month lookup: key = first_char * 128 + second_char - MONTH_LOOKUP['J' * 128 + 'a'] = 1; // Jan - MONTH_LOOKUP['F' * 128 + 'e'] = 2; // Feb - MONTH_LOOKUP['M' * 128 + 'a'] = 3; // Mar (also May — disambiguate with 3rd char) - MONTH_LOOKUP['A' * 128 + 'p'] = 4; // Apr - MONTH_LOOKUP['J' * 128 + 'u'] = 6; // Jun (also Jul — disambiguate with 3rd char) - MONTH_LOOKUP['A' * 128 + 'u'] = 8; // Aug - MONTH_LOOKUP['S' * 128 + 'e'] = 9; // Sep - MONTH_LOOKUP['O' * 128 + 'c'] = 10; // Oct - MONTH_LOOKUP['N' * 128 + 'o'] = 11; // Nov - MONTH_LOOKUP['D' * 128 + 'e'] = 12; // Dec - } - - // Full 3-letter month name validation table: third character for each month - private static final byte[] MONTH_THIRD_CHAR = { - 0, - 'n', - 'b', - 'r', - 'r', - 'y', - 'n', - 'l', - 'g', - 'p', - 't', - 'v', - 'c' - }; - // Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec - - /** - * Looks up month number (1-12) from a 3-letter abbreviation at buf[pos..pos+3). - * Validates all three characters: uses a two-char lookup table for the first two, - * then disambiguates Mar/May and Jun/Jul with the third character, and validates - * the third character for all other months. - */ - private static int lookupMonth(byte[] buf, int pos) { - int key = (buf[pos] & 0x7F) * 128 + (buf[pos + 1] & 0x7F); - int month = key < MONTH_LOOKUP.length ? MONTH_LOOKUP[key] : 0; - if (month == 3) { - // Mar or May — disambiguate on third char - if (buf[pos + 2] == 'y') { - return 5; // May - } - if (buf[pos + 2] != 'r') { - throw new SerializationException("Invalid month: " - + (char) buf[pos] + (char) buf[pos + 1] + (char) buf[pos + 2]); - } - return 3; // Mar - } - if (month == 6) { - // Jun or Jul — disambiguate on third char - if (buf[pos + 2] == 'l') { - return 7; // Jul - } - if (buf[pos + 2] != 'n') { - throw new SerializationException("Invalid month: " - + (char) buf[pos] + (char) buf[pos + 1] + (char) buf[pos + 2]); - } - return 6; // Jun - } - if (month == 0 || buf[pos + 2] != MONTH_THIRD_CHAR[month]) { - throw new SerializationException("Invalid month: " - + (char) buf[pos] + (char) buf[pos + 1] + (char) buf[pos + 2]); - } - return month; - } - - /** - * Computes epoch day (days since 1970-01-01) from year/month/day using the - * proleptic Gregorian calendar algorithm. Pure integer arithmetic — no java.time overhead. - */ - private static long computeEpochDay(int year, int month, int day) { - long y = year; - long m = month; - // Shift March=0..Feb=11 so Feb (the leap-day month) is at the end - if (m <= 2) { - y--; - m += 9; - } else { - m -= 3; - } - long era = (y >= 0 ? y : y - 399) / 400; - long yoe = y - era * 400; - long doy = (153 * m + 2) / 5 + day - 1; - long doe = yoe * 365 + yoe / 4 - yoe / 100 + doy; - return era * 146097 + doe - 719468; - } - - private static int digit(byte b) { - int d = b - '0'; - if (d < 0 || d > 9) { - throw new SerializationException("Expected digit, found: " + describeChar(b)); - } - return d; - } - /** * Parses an ISO-8601 timestamp directly from a JSON quoted string in the byte buffer. - * Expects pos to be at the opening quote. Avoids String allocation and DateTimeFormatter. - * - *

On success, stores result in deser.parsedEndPos (position after closing quote) - * and returns the Instant. On failure (non-standard format), returns null and - * deser state is unchanged caller should fall back to DateTimeFormatter. + * Expects pos to be at the opening quote. */ static Instant parseIso8601(byte[] buf, int pos, int end, SmithyJsonDeserializer deser) { - // Minimum: "YYYY-MM-DDThh:mm:ssZ" = 22 bytes (including quotes) if (pos >= end || buf[pos] != '"' || pos + 22 > end) { return null; } - pos++; // skip opening quote - - // Parse YYYY-MM-DD - int year = digit(buf[pos]) * 1000 + digit(buf[pos + 1]) * 100 - + digit(buf[pos + 2]) * 10 - + digit(buf[pos + 3]); - pos += 4; - if (buf[pos++] != '-') { - return null; - } - int month = digit(buf[pos]) * 10 + digit(buf[pos + 1]); - pos += 2; - if (month < 1 || month > 12) { - throw new SerializationException("Invalid ISO-8601 month: " + month); - } - if (buf[pos++] != '-') { - return null; - } - int day = digit(buf[pos]) * 10 + digit(buf[pos + 1]); - pos += 2; - if (day < 1 || day > 31) { - throw new SerializationException("Invalid ISO-8601 day: " + day); - } - - if (buf[pos] != 'T' && buf[pos] != 't') { - return null; + int contentStart = pos + 1; + int closeQuote = contentStart; + while (closeQuote < end && buf[closeQuote] != '"') { + closeQuote++; } - pos++; - - // Parse hh:mm:ss - int hour = digit(buf[pos]) * 10 + digit(buf[pos + 1]); - pos += 2; - if (buf[pos++] != ':') { + if (closeQuote >= end) { return null; } - int minute = digit(buf[pos]) * 10 + digit(buf[pos + 1]); - pos += 2; - if (buf[pos++] != ':') { - return null; - } - int second = digit(buf[pos]) * 10 + digit(buf[pos + 1]); - pos += 2; - if (hour > 23 || minute > 59 || second > 59) { - throw new SerializationException( - "Invalid ISO-8601 time: " + hour + ":" + minute + ":" + second); - } - - // Optional fractional seconds - int nano = 0; - if (pos < end && buf[pos] == '.') { - pos++; - int fracStart = pos; - while (pos < end && buf[pos] >= '0' && buf[pos] <= '9') { - pos++; - } - int fracLen = pos - fracStart; - if (fracLen == 0) { - return null; - } - // Parse up to 9 fractional digits, zero-padding on the right - for (int i = 0; i < 9; i++) { - nano *= 10; - if (i < fracLen) { - nano += buf[fracStart + i] - '0'; - } - } - } - - // Must end with 'Z' for UTC fast path - if (pos >= end || buf[pos] != 'Z') { - return null; // timezone offset — fall back to DateTimeFormatter - } - pos++; // skip Z - - // Expect closing quote - if (pos >= end || buf[pos] != '"') { + Instant result = TimestampCodec.parseIso8601(buf, contentStart, closeQuote); + if (result == null) { return null; } - pos++; // skip closing quote - - deser.parsedEndPos = pos; - long epochDay = computeEpochDay(year, month, day); - long epochSecond = epochDay * 86400 + hour * 3600 + minute * 60 + second; - return Instant.ofEpochSecond(epochSecond, nano); + deser.parsedEndPos = closeQuote + 1; + return result; } /** - * Parses an HTTP-date ("EEE, dd MMM yyyy HH:mm:ss GMT") directly from a JSON - * quoted string. Expects pos to be at the opening quote. - * - *

On success, stores deser.parsedEndPos (after closing quote) and returns Instant. - * On failure, returns null — caller should fall back to DateTimeFormatter. + * Parses an HTTP-date directly from a JSON quoted string in the byte buffer. + * Expects pos to be at the opening quote. */ static Instant parseHttpDate(byte[] buf, int pos, int end, SmithyJsonDeserializer deser) { - // Minimum: "Thu, 01 Jan 2026 00:00:00 GMT" = 31 bytes (including quotes) if (pos >= end || buf[pos] != '"' || pos + 31 > end) { return null; } - pos++; // skip opening quote - - // Skip day name — find comma - while (pos < end && buf[pos] != ',') { - pos++; + int contentStart = pos + 1; + int closeQuote = contentStart; + while (closeQuote < end && buf[closeQuote] != '"') { + closeQuote++; } - if (pos >= end) { + if (closeQuote >= end) { return null; } - pos++; // skip comma - if (pos >= end || buf[pos] != ' ') { - return null; - } - pos++; // skip space - - // Parse dd - int day = digit(buf[pos]) * 10 + digit(buf[pos + 1]); - pos += 2; - if (buf[pos++] != ' ') { + Instant result = TimestampCodec.parseHttpDate(buf, contentStart, closeQuote); + if (result == null) { return null; } - - // Parse MMM - int month = lookupMonth(buf, pos); - pos += 3; - if (buf[pos++] != ' ') { - return null; - } - - // Parse yyyy - int year = digit(buf[pos]) * 1000 + digit(buf[pos + 1]) * 100 - + digit(buf[pos + 2]) * 10 - + digit(buf[pos + 3]); - pos += 4; - if (buf[pos++] != ' ') { - return null; - } - - // Parse HH:mm:ss - int hour = digit(buf[pos]) * 10 + digit(buf[pos + 1]); - pos += 2; - if (buf[pos++] != ':') { - return null; - } - int minute = digit(buf[pos]) * 10 + digit(buf[pos + 1]); - pos += 2; - if (buf[pos++] != ':') { - return null; - } - int second = digit(buf[pos]) * 10 + digit(buf[pos + 1]); - pos += 2; - - // Expect " GMT" - if (pos + 4 > end || buf[pos] != ' ' - || buf[pos + 1] != 'G' - || buf[pos + 2] != 'M' - || buf[pos + 3] != 'T') { - return null; - } - pos += 4; // skip " GMT" - - // Expect closing quote - if (pos >= end || buf[pos] != '"') { - return null; - } - pos++; // skip closing quote - - deser.parsedEndPos = pos; - long epochDay = computeEpochDay(year, month, day); - long epochSecond = epochDay * 86400 + hour * 3600 + minute * 60 + second; - return Instant.ofEpochSecond(epochSecond); + deser.parsedEndPos = closeQuote + 1; + return result; } - private static final Base64.Decoder BASE64_DECODER = Base64.getDecoder(); - /** * Decodes a base64-encoded JSON string from the byte buffer, bypassing String allocation. - * Scans for the closing quote to find the base64 content boundaries, then uses the JDK - * Base64 decoder which is backed by @IntrinsicCandidate SIMD on HotSpot. + * Scans for the closing quote to find the base64 content boundaries, then decodes + * directly from the span via ByteBuffer.wrap (zero-copy input, JDK SIMD intrinsic). * *

Expects {@code pos} at the opening quote. Stores the position after the closing * quote in {@code deser.parsedEndPos}. * - * @return the decoded bytes + * @return the decoded ByteBuffer * @throws SerializationException on unterminated string or invalid base64 */ - static byte[] decodeBase64String(byte[] buf, int pos, int end, SmithyJsonDeserializer deser) { + static ByteBuffer decodeBase64String(byte[] buf, int pos, int end, SmithyJsonDeserializer deser) { if (pos >= end || buf[pos] != '"') { throw new SerializationException("Expected '\"', found: " + describePos(buf, pos, end)); } @@ -757,12 +506,11 @@ static byte[] decodeBase64String(byte[] buf, int pos, int end, SmithyJsonDeseria deser.parsedEndPos = pos + 1; // after closing quote if (contentStart == contentEnd) { - return new byte[0]; + return ByteBuffer.allocate(0); } - byte[] base64Bytes = Arrays.copyOfRange(buf, contentStart, contentEnd); try { - return BASE64_DECODER.decode(base64Bytes); + return BASE64_DECODER.decode(ByteBuffer.wrap(buf, contentStart, contentEnd - contentStart)); } catch (IllegalArgumentException e) { throw new SerializationException("Invalid base64 in blob value", e); } diff --git a/codecs/json-codec/src/main/java/software/amazon/smithy/java/json/smithy/JsonWriteUtils.java b/codecs/json-codec/src/main/java/software/amazon/smithy/java/json/smithy/JsonWriteUtils.java index ef463aa043..36b0637f58 100644 --- a/codecs/json-codec/src/main/java/software/amazon/smithy/java/json/smithy/JsonWriteUtils.java +++ b/codecs/json-codec/src/main/java/software/amazon/smithy/java/json/smithy/JsonWriteUtils.java @@ -5,11 +5,12 @@ package software.amazon.smithy.java.json.smithy; -import java.math.BigInteger; +import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.time.Instant; -import java.util.Arrays; -import java.util.Base64; +import software.amazon.smithy.java.codecs.commons.NumberCodec; +import software.amazon.smithy.java.codecs.commons.TimestampCodec; +import software.amazon.smithy.java.io.ByteBufferUtils; /** * Low-level utilities for writing JSON primitives directly to byte arrays. @@ -20,24 +21,7 @@ final class JsonWriteUtils { private JsonWriteUtils() {} - // Pre-computed byte arrays for JSON literals - static final byte[] TRUE_BYTES = {'t', 'r', 'u', 'e'}; - static final byte[] FALSE_BYTES = {'f', 'a', 'l', 's', 'e'}; static final byte[] NULL_BYTES = {'n', 'u', 'l', 'l'}; - static final byte[] NAN_BYTES = {'"', 'N', 'a', 'N', '"'}; - static final byte[] INF_BYTES = {'"', 'I', 'n', 'f', 'i', 'n', 'i', 't', 'y', '"'}; - static final byte[] NEG_INF_BYTES = {'"', '-', 'I', 'n', 'f', 'i', 'n', 'i', 't', 'y', '"'}; - - // Pre-computed digit pairs: DIGIT_PAIRS[i*2] and DIGIT_PAIRS[i*2+1] give the two ASCII - // digits for the number i (00-99). - private static final byte[] DIGIT_PAIRS = new byte[200]; - - static { - for (int i = 0; i < 100; i++) { - DIGIT_PAIRS[i * 2] = (byte) ('0' + i / 10); - DIGIT_PAIRS[i * 2 + 1] = (byte) ('0' + i % 10); - } - } private static final byte[] HEX = { '0', @@ -79,206 +63,12 @@ private JsonWriteUtils() {} ESCAPE_TABLE['\t'] = 't'; } - private static final Base64.Encoder BASE64_ENCODER = Base64.getEncoder(); - - /** - * Writes an integer value as JSON number bytes. Returns new position. - * - *

Handles Integer.MIN_VALUE correctly. - */ static int writeInt(byte[] buf, int pos, int value) { - if (value == 0) { - buf[pos] = '0'; - return pos + 1; - } - - if (value < 0) { - buf[pos++] = '-'; - if (value == Integer.MIN_VALUE) { - // -2147483648 — can't negate, write directly - return writePositiveLong(buf, pos, 2147483648L); - } - value = -value; - } - - return writePositiveInt(buf, pos, value); + return NumberCodec.writeInt(buf, pos, value); } - private static int writePositiveInt(byte[] buf, int pos, int value) { - int digits = digitCount(value); - int end = pos + digits; - int p = end; - - while (value >= 100) { - int q = value / 100; - int r = (value - q * 100) * 2; - value = q; - buf[--p] = DIGIT_PAIRS[r + 1]; - buf[--p] = DIGIT_PAIRS[r]; - } - - if (value >= 10) { - int r = value * 2; - buf[--p] = DIGIT_PAIRS[r + 1]; - buf[--p] = DIGIT_PAIRS[r]; - } else { - buf[--p] = (byte) ('0' + value); - } - - return end; - } - - /** - * Writes a long value as JSON number bytes. Returns new position. - */ static int writeLong(byte[] buf, int pos, long value) { - if (value == 0) { - buf[pos] = '0'; - return pos + 1; - } - - if (value < 0) { - buf[pos++] = '-'; - if (value == Long.MIN_VALUE) { - // -9223372036854775808 — can't negate - byte[] minBytes = "9223372036854775808".getBytes(StandardCharsets.US_ASCII); - System.arraycopy(minBytes, 0, buf, pos, minBytes.length); - return pos + minBytes.length; - } - value = -value; - } - - return writePositiveLong(buf, pos, value); - } - - private static int writePositiveLong(byte[] buf, int pos, long value) { - if (value <= Integer.MAX_VALUE) { - return writePositiveInt(buf, pos, (int) value); - } - - int digits = digitCountLong(value); - int end = pos + digits; - int p = end; - - while (value >= 100) { - long q = value / 100; - int r = (int) (value - q * 100) * 2; - value = q; - buf[--p] = DIGIT_PAIRS[r + 1]; - buf[--p] = DIGIT_PAIRS[r]; - } - - if (value >= 10) { - int r = (int) value * 2; - buf[--p] = DIGIT_PAIRS[r + 1]; - buf[--p] = DIGIT_PAIRS[r]; - } else { - buf[--p] = (byte) ('0' + value); - } - - return end; - } - - private static int digitCount(int value) { - if (value < 10) - return 1; - if (value < 100) - return 2; - if (value < 1000) - return 3; - if (value < 10000) - return 4; - if (value < 100000) - return 5; - if (value < 1000000) - return 6; - if (value < 10000000) - return 7; - if (value < 100000000) - return 8; - if (value < 1000000000) - return 9; - return 10; - } - - private static int digitCountLong(long value) { - if (value < 10000000000L) { - return digitCount((int) Math.min(value, Integer.MAX_VALUE)); - } - if (value < 100000000000L) - return 11; - if (value < 1000000000000L) - return 12; - if (value < 10000000000000L) - return 13; - if (value < 100000000000000L) - return 14; - if (value < 1000000000000000L) - return 15; - if (value < 10000000000000000L) - return 16; - if (value < 100000000000000000L) - return 17; - if (value < 1000000000000000000L) - return 18; - return 19; - } - - private static final BigInteger TEN_TO_18 = BigInteger.valueOf(1_000_000_000_000_000_000L); - - /** - * Writes a BigInteger directly to the byte buffer by splitting into 18-digit groups. - * Avoids BigInteger.toString() which does expensive recursive division and String allocation. - */ - static int writeBigInteger(byte[] buf, int pos, BigInteger value) { - if (value.signum() < 0) { - buf[pos++] = '-'; - value = value.negate(); - } - - // Split into groups of up to 18 decimal digits (each fits in a long) - if (value.compareTo(TEN_TO_18) < 0) { - return writePositiveLong(buf, pos, value.longValue()); - } - - BigInteger[] qr = value.divideAndRemainder(TEN_TO_18); - BigInteger high = qr[0]; - long low = qr[1].longValue(); - - if (high.compareTo(TEN_TO_18) < 0) { - // Two groups: high (no padding) + low (18-digit padded) - pos = writePositiveLong(buf, pos, high.longValue()); - return writePaddedLong18(buf, pos, low); - } - - // Three or more groups (handles numbers up to ~54 digits) - BigInteger[] qr2 = high.divideAndRemainder(TEN_TO_18); - long mid = qr2[1].longValue(); - - if (qr2[0].compareTo(TEN_TO_18) < 0) { - pos = writePositiveLong(buf, pos, qr2[0].longValue()); - pos = writePaddedLong18(buf, pos, mid); - return writePaddedLong18(buf, pos, low); - } - - // Extremely large: fall back to toString for safety - String s = value.toString(); - return writeAsciiString(buf, pos, s); - } - - /** - * Writes a long value zero-padded to exactly 18 digits. - */ - private static int writePaddedLong18(byte[] buf, int pos, long value) { - int end = pos + 18; - int p = end; - for (int i = 0; i < 9; i++) { - int r = (int) (value % 100) * 2; - value /= 100; - buf[--p] = DIGIT_PAIRS[r + 1]; - buf[--p] = DIGIT_PAIRS[r]; - } - return end; + return NumberCodec.writeLong(buf, pos, value); } /** @@ -322,28 +112,22 @@ private static int writeStringSlowPath(byte[] buf, int pos, String value, int st char c = value.charAt(i); if (c < 0x80) { - // ASCII range if (c >= 0x20 && !NEEDS_ESCAPE[c]) { buf[pos++] = (byte) c; } else if (ESCAPE_TABLE[c] != 0) { - // Two-character escape: \n, \t, \\, \", etc. buf[pos++] = '\\'; buf[pos++] = ESCAPE_TABLE[c]; } else { - // Unicode escape for control characters pos = writeUnicodeEscape(buf, pos, c); } } else if (c < 0x800) { - // 2-byte UTF-8 buf[pos++] = (byte) (0xC0 | (c >> 6)); buf[pos++] = (byte) (0x80 | (c & 0x3F)); } else if (!Character.isSurrogate(c)) { - // 3-byte UTF-8 (BMP, non-surrogate) buf[pos++] = (byte) (0xE0 | (c >> 12)); buf[pos++] = (byte) (0x80 | ((c >> 6) & 0x3F)); buf[pos++] = (byte) (0x80 | (c & 0x3F)); } else { - // Surrogate pair -> 4-byte UTF-8 if (Character.isHighSurrogate(c) && i + 1 < len) { char low = value.charAt(++i); if (Character.isLowSurrogate(low)) { @@ -353,12 +137,12 @@ private static int writeStringSlowPath(byte[] buf, int pos, String value, int st buf[pos++] = (byte) (0x80 | ((cp >> 6) & 0x3F)); buf[pos++] = (byte) (0x80 | (cp & 0x3F)); } else { - // Lone high surrogate followed by non-low — escape both + // Lone high surrogate followed by non-low -- escape both pos = writeUnicodeEscape(buf, pos, c); i--; // re-process the non-low char } } else { - // Lone surrogate — escape as unicode + // Lone surrogate -- escape as unicode pos = writeUnicodeEscape(buf, pos, c); } } @@ -381,374 +165,35 @@ private static int writeUnicodeEscape(byte[] buf, int pos, int c) { * Returns new position. */ static int writeDouble(byte[] buf, int pos, double value) { - // Avoid writing 1.0 when 1 suffices - long longValue = (long) value; - if (value == (double) longValue) { - return writeLong(buf, pos, longValue); - } - return Schubfach.writeDouble(buf, pos, value); - } - - /** - * Writes a double using a reusable Schubfach instance to avoid per-call allocation. - */ - static int writeDouble(byte[] buf, int pos, double value, Schubfach.DoubleToDecimal dtd) { - long longValue = (long) value; - if (value == (double) longValue) { - return writeLong(buf, pos, longValue); - } - return Schubfach.writeDouble(buf, pos, value, dtd); + return NumberCodec.writeDouble(buf, pos, value); } - /** - * Writes an epoch-seconds timestamp directly from an Instant using integer arithmetic. - * Writes "seconds" for whole seconds or "seconds.nanos" - * for fractional, with full nanosecond precision and trailing zeros stripped. - * - *

For negative epoch seconds with non-zero nanos, the Instant contract is: - * {@code Instant.ofEpochSecond(-1, 500_000_000)} = -0.5 seconds (not -1.5). - * The nano field is always non-negative and added to the epoch second. - */ static int writeEpochSeconds(byte[] buf, int pos, long epochSecond, int nano) { - if (nano == 0) { - return writeLong(buf, pos, epochSecond); - } - - int fraction = nano; - if (epochSecond < 0) { - // Instant(-1, 500_000_000) means -1 + 0.5 = -0.5 seconds. - // Adjust: seconds part becomes epochSecond+1, fraction becomes 1e9-nano. - epochSecond += 1; - fraction = 1_000_000_000 - nano; - if (epochSecond == 0) { - // Special case: -0.xxx (epochSecond was -1, adjusted to 0, but value is negative) - buf[pos++] = '-'; - buf[pos++] = '0'; - } else { - pos = writeLong(buf, pos, epochSecond); - } - } else { - pos = writeLong(buf, pos, epochSecond); - } - - buf[pos++] = '.'; - - // Write all 9 fractional digits (zero-padded), then strip trailing zeros. - int hi = fraction / 1_000_000; - int mid = (fraction / 1_000) % 1_000; - int lo = fraction % 1_000; - buf[pos++] = (byte) ('0' + hi / 100); - buf[pos++] = (byte) ('0' + (hi / 10) % 10); - buf[pos++] = (byte) ('0' + hi % 10); - buf[pos++] = (byte) ('0' + mid / 100); - buf[pos++] = (byte) ('0' + (mid / 10) % 10); - buf[pos++] = (byte) ('0' + mid % 10); - buf[pos++] = (byte) ('0' + lo / 100); - buf[pos++] = (byte) ('0' + (lo / 10) % 10); - buf[pos++] = (byte) ('0' + lo % 10); - // Strip trailing zeros - while (buf[pos - 1] == '0') { - pos--; - } - return pos; + return TimestampCodec.writeEpochSeconds(buf, pos, epochSecond, nano); } - /** - * Writes a float value as JSON. Handles integer-valued floats optimization. - * Returns new position. - */ static int writeFloat(byte[] buf, int pos, float value) { - int intValue = (int) value; - if (value == (float) intValue) { - return writeInt(buf, pos, intValue); - } - return Schubfach.writeFloat(buf, pos, value); - } - - /** - * Writes a float using a reusable Schubfach instance to avoid per-call allocation. - */ - static int writeFloat(byte[] buf, int pos, float value, Schubfach.FloatToDecimal ftd) { - int intValue = (int) value; - if (value == (float) intValue) { - return writeInt(buf, pos, intValue); - } - return Schubfach.writeFloat(buf, pos, value, ftd); + return NumberCodec.writeFloat(buf, pos, value); } - /** - * Writes an ISO-8601 timestamp directly to the byte buffer as a quoted JSON string. - * Produces output like {@code "2025-01-15T10:30:00Z"} or {@code "2025-01-15T10:30:00.123Z"} - * for timestamps with sub-second precision. - * - *

Uses pure integer arithmetic from epoch seconds to compute date/time components, - * avoiding the 4 object allocations from {@code Instant.atOffset(ZoneOffset.UTC)}. - */ static int writeIso8601Timestamp(byte[] buf, int pos, Instant value) { - long epochSecond = value.getEpochSecond(); - int nano = value.getNano(); - - // Compute time-of-day from epoch second - int secondOfDay = (int) Math.floorMod(epochSecond, 86400); - int hour = secondOfDay / 3600; - int minute = (secondOfDay % 3600) / 60; - int second = secondOfDay % 60; - - // Compute date from epoch day using civil calendar algorithm (inlined to avoid allocation) - long epochDay = Math.floorDiv(epochSecond, 86400); - long z = epochDay + 719468; - long era = (z >= 0 ? z : z - 146096) / 146097; - long doe = z - era * 146097; - long yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365; - long doy = doe - (365 * yoe + yoe / 4 - yoe / 100); - long mp = (5 * doy + 2) / 153; - int day = (int) (doy - (153 * mp + 2) / 5 + 1); - int month = (int) (mp < 10 ? mp + 3 : mp - 9); - int year = (int) (yoe + era * 400 + (month <= 2 ? 1 : 0)); - buf[pos++] = '"'; - - // Year (4 digits, with sign for years outside 0000-9999) - if (year >= 0 && year <= 9999) { - int hi = year / 100; - int lo = year - hi * 100; - buf[pos++] = DIGIT_PAIRS[hi * 2]; - buf[pos++] = DIGIT_PAIRS[hi * 2 + 1]; - buf[pos++] = DIGIT_PAIRS[lo * 2]; - buf[pos++] = DIGIT_PAIRS[lo * 2 + 1]; - } else { - // Fall back for years outside 0000-9999 - String yearStr = String.format("%04d", year); - for (int i = 0; i < yearStr.length(); i++) { - buf[pos++] = (byte) yearStr.charAt(i); - } - } - - buf[pos++] = '-'; - buf[pos++] = DIGIT_PAIRS[month * 2]; - buf[pos++] = DIGIT_PAIRS[month * 2 + 1]; - buf[pos++] = '-'; - buf[pos++] = DIGIT_PAIRS[day * 2]; - buf[pos++] = DIGIT_PAIRS[day * 2 + 1]; - buf[pos++] = 'T'; - buf[pos++] = DIGIT_PAIRS[hour * 2]; - buf[pos++] = DIGIT_PAIRS[hour * 2 + 1]; - buf[pos++] = ':'; - buf[pos++] = DIGIT_PAIRS[minute * 2]; - buf[pos++] = DIGIT_PAIRS[minute * 2 + 1]; - buf[pos++] = ':'; - buf[pos++] = DIGIT_PAIRS[second * 2]; - buf[pos++] = DIGIT_PAIRS[second * 2 + 1]; - - if (nano != 0) { - buf[pos++] = '.'; - // Write up to 9 fractional digits, stripping trailing zeros - int frac = nano; - int digits = 9; - while (frac % 10 == 0) { - frac /= 10; - digits--; - } - // Write digits left-to-right - int scale = 1; - for (int i = 1; i < digits; i++) { - scale *= 10; - } - while (scale > 0) { - buf[pos++] = (byte) ('0' + frac / scale); - frac %= scale; - scale /= 10; - } - } - - buf[pos++] = 'Z'; + pos = TimestampCodec.writeIso8601(buf, pos, value); buf[pos++] = '"'; return pos; } - // Day-of-week names as pre-computed byte arrays (Mon=1..Sun=7 per java.time DayOfWeek) - private static final byte[][] DAY_NAMES = { - null, // 0 unused - {'M', 'o', 'n'}, // 1 = Monday - {'T', 'u', 'e'}, // 2 = Tuesday - {'W', 'e', 'd'}, // 3 = Wednesday - {'T', 'h', 'u'}, // 4 = Thursday - {'F', 'r', 'i'}, // 5 = Friday - {'S', 'a', 't'}, // 6 = Saturday - {'S', 'u', 'n'}, // 7 = Sunday - }; - - // Month abbreviations as pre-computed byte arrays (Jan=1..Dec=12) - private static final byte[][] MONTH_NAMES = { - null, // 0 unused - {'J', 'a', 'n'}, - {'F', 'e', 'b'}, - {'M', 'a', 'r'}, - {'A', 'p', 'r'}, - {'M', 'a', 'y'}, - {'J', 'u', 'n'}, - {'J', 'u', 'l'}, - {'A', 'u', 'g'}, - {'S', 'e', 'p'}, - {'O', 'c', 't'}, - {'N', 'o', 'v'}, - {'D', 'e', 'c'}, - }; - - /** - * Writes an HTTP-date timestamp directly to the byte buffer as a quoted JSON string. - * Produces output like {@code "Sat, 01 Jan 2026 00:00:00 GMT"}. - * - *

Uses pure integer arithmetic from epoch seconds, avoiding the 4 object allocations - * from {@code Instant.atOffset(ZoneOffset.UTC)} and the heavy DateTimeFormatter machinery. - */ static int writeHttpDate(byte[] buf, int pos, Instant value) { - long epochSecond = value.getEpochSecond(); - - // Compute time-of-day - int secondOfDay = (int) Math.floorMod(epochSecond, 86400); - int hour = secondOfDay / 3600; - int minute = (secondOfDay % 3600) / 60; - int second = secondOfDay % 60; - - // Compute date from epoch day using civil calendar algorithm (inlined to avoid allocation) - long epochDay = Math.floorDiv(epochSecond, 86400); - long z = epochDay + 719468; - long era = (z >= 0 ? z : z - 146096) / 146097; - long doe = z - era * 146097; - long yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365; - long doy = doe - (365 * yoe + yoe / 4 - yoe / 100); - long mp = (5 * doy + 2) / 153; - int day = (int) (doy - (153 * mp + 2) / 5 + 1); - int month = (int) (mp < 10 ? mp + 3 : mp - 9); - int year = (int) (yoe + era * 400 + (month <= 2 ? 1 : 0)); - - // Day of week: epoch day 0 (1970-01-01) was Thursday (4). - // 1=Monday..7=Sunday per java.time convention. - int dow = (int) Math.floorMod(epochDay + 3, 7) + 1; - buf[pos++] = '"'; - - // Day name - byte[] dayName = DAY_NAMES[dow]; - buf[pos++] = dayName[0]; - buf[pos++] = dayName[1]; - buf[pos++] = dayName[2]; - buf[pos++] = ','; - buf[pos++] = ' '; - - // Day of month (2 digits) - buf[pos++] = DIGIT_PAIRS[day * 2]; - buf[pos++] = DIGIT_PAIRS[day * 2 + 1]; - buf[pos++] = ' '; - - // Month name - byte[] monthName = MONTH_NAMES[month]; - buf[pos++] = monthName[0]; - buf[pos++] = monthName[1]; - buf[pos++] = monthName[2]; - buf[pos++] = ' '; - - // Year (4 digits) - if (year >= 0 && year <= 9999) { - int hi = year / 100; - int lo = year - hi * 100; - buf[pos++] = DIGIT_PAIRS[hi * 2]; - buf[pos++] = DIGIT_PAIRS[hi * 2 + 1]; - buf[pos++] = DIGIT_PAIRS[lo * 2]; - buf[pos++] = DIGIT_PAIRS[lo * 2 + 1]; - } else { - String yearStr = String.format("%04d", year); - for (int i = 0; i < yearStr.length(); i++) { - buf[pos++] = (byte) yearStr.charAt(i); - } - } - - buf[pos++] = ' '; - buf[pos++] = DIGIT_PAIRS[hour * 2]; - buf[pos++] = DIGIT_PAIRS[hour * 2 + 1]; - buf[pos++] = ':'; - buf[pos++] = DIGIT_PAIRS[minute * 2]; - buf[pos++] = DIGIT_PAIRS[minute * 2 + 1]; - buf[pos++] = ':'; - buf[pos++] = DIGIT_PAIRS[second * 2]; - buf[pos++] = DIGIT_PAIRS[second * 2 + 1]; - buf[pos++] = ' '; - buf[pos++] = 'G'; - buf[pos++] = 'M'; - buf[pos++] = 'T'; + pos = TimestampCodec.writeHttpDate(buf, pos, value); buf[pos++] = '"'; return pos; } - // Pre-computed powers of 10 for BigDecimal fast path - private static final long[] POWERS_OF_10 = { - 1L, - 10L, - 100L, - 1000L, - 10000L, - 100000L, - 1000000L, - 10000000L, - 100000000L, - 1000000000L, - 10000000000L, - 100000000000L, - 1000000000000L, - 10000000000000L, - 100000000000000L, - 1000000000000000L, - 10000000000000000L, - 100000000000000000L, - 1000000000000000000L - }; - - /** - * Writes a BigDecimal with known unscaled long value and positive scale directly. - * E.g., unscaled=9999999999, scale=5 writes "99999.99999". - */ - static int writeBigDecimalFromLong(byte[] buf, int pos, long unscaled, int scale) { - if (unscaled < 0) { - buf[pos++] = '-'; - if (unscaled == Long.MIN_VALUE) { - // Extremely unlikely edge case — fall through won't work, but caller - // checks bitLength < 64 so this can't happen - throw new ArithmeticException("Cannot negate Long.MIN_VALUE"); - } - unscaled = -unscaled; - } - - if (scale < POWERS_OF_10.length) { - long divisor = POWERS_OF_10[scale]; - long intPart = unscaled / divisor; - long fracPart = unscaled - intPart * divisor; - - // Write integer part (or 0 if unscaled < divisor) - pos = writePositiveLong(buf, pos, intPart); - buf[pos++] = '.'; - - // Write fractional part with leading zeros - // e.g., scale=5 and fracPart=99 -> "00099" - for (int i = scale - 1; i >= 0; i--) { - long p10 = POWERS_OF_10[i]; - int d = (int) (fracPart / p10); - buf[pos++] = (byte) ('0' + d); - fracPart -= d * p10; - } - } else { - // Scale too large for our table — shouldn't happen for practical values - pos = writePositiveLong(buf, pos, unscaled); - } - - return pos; - } - /** * Writes an ASCII string directly to the buffer without quoting. * Used for number-to-string conversions (Double.toString, BigDecimal.toString, etc). */ - @SuppressWarnings("deprecation") static int writeAsciiString(byte[] buf, int pos, String s) { int len = s.length(); s.getBytes(0, len, buf, pos); @@ -760,11 +205,12 @@ static int writeAsciiString(byte[] buf, int pos, String s) { * Returns the new write position. */ static int writeBase64String(byte[] buf, int pos, byte[] data, int off, int len) { + return writeBase64String(buf, pos, ByteBuffer.wrap(data, off, len)); + } + + static int writeBase64String(byte[] buf, int pos, ByteBuffer data) { buf[pos++] = '"'; - // Use JDK Base64 encoder — produces standard base64 with +/ alphabet, no line breaks. - // This matches Jackson's MIME_NO_LINEFEEDS variant for JSON. - byte[] encoded = BASE64_ENCODER.encode( - off == 0 && len == data.length ? data : Arrays.copyOfRange(data, off, off + len)); + byte[] encoded = ByteBufferUtils.base64EncodeToBytes(data); System.arraycopy(encoded, 0, buf, pos, encoded.length); pos += encoded.length; buf[pos++] = '"'; @@ -776,7 +222,6 @@ static int writeBase64String(byte[] buf, int pos, byte[] data, int off, int len) * Used for buffer capacity estimation. */ static int maxQuotedStringBytes(String value) { - // Worst case: every char is a control char needing unicode escape (6 bytes) + 2 quotes return value.length() * 6 + 2; } @@ -784,7 +229,6 @@ static int maxQuotedStringBytes(String value) { * Returns the maximum number of bytes needed for a base64-encoded string. */ static int maxBase64Bytes(int dataLen) { - // Base64: 4 bytes per 3 input bytes, rounded up, plus 2 quotes return ((dataLen + 2) / 3) * 4 + 2; } @@ -795,8 +239,7 @@ static int maxBase64Bytes(int dataLen) { */ static byte[] precomputeFieldNameBytes(String fieldName) { byte[] nameUtf8 = fieldName.getBytes(StandardCharsets.UTF_8); - // "fieldName": - byte[] result = new byte[nameUtf8.length + 3]; // quote + name + quote + colon + byte[] result = new byte[nameUtf8.length + 3]; result[0] = '"'; System.arraycopy(nameUtf8, 0, result, 1, nameUtf8.length); result[nameUtf8.length + 1] = '"'; diff --git a/codecs/json-codec/src/main/java/software/amazon/smithy/java/json/smithy/SmithyJsonDeserializer.java b/codecs/json-codec/src/main/java/software/amazon/smithy/java/json/smithy/SmithyJsonDeserializer.java index da331dcbbe..9c8fb76c11 100644 --- a/codecs/json-codec/src/main/java/software/amazon/smithy/java/json/smithy/SmithyJsonDeserializer.java +++ b/codecs/json-codec/src/main/java/software/amazon/smithy/java/json/smithy/SmithyJsonDeserializer.java @@ -16,6 +16,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import software.amazon.smithy.java.codecs.commons.NumberCodec; import software.amazon.smithy.java.core.schema.Schema; import software.amazon.smithy.java.core.serde.SerializationException; import software.amazon.smithy.java.core.serde.ShapeDeserializer; @@ -71,8 +72,6 @@ public void close() { } } - // ---- Primitive readers ---- - @Override public boolean readBoolean(Schema schema) { skipWhitespace(); @@ -242,9 +241,9 @@ private String readStringValue() { @Override public ByteBuffer readBlob(Schema schema) { skipWhitespace(); - byte[] decoded = JsonReadUtils.decodeBase64String(buf, pos, end, this); + ByteBuffer decoded = JsonReadUtils.decodeBase64String(buf, pos, end, this); pos = parsedEndPos; - return ByteBuffer.wrap(decoded); + return decoded; } @Override @@ -292,9 +291,9 @@ public Instant readTimestamp(Schema schema) { throw new SerializationException("Epoch seconds out of range: " + parsedLong, e); } } - // No digits after dot — fall through to double parsing + // No digits after dot -- fall through to double parsing } else if (endPos >= end || (buf[endPos] != 'e' && buf[endPos] != 'E')) { - // Pure integer — no fractional part + // Pure integer -- no fractional part pos = endPos; try { return Instant.ofEpochSecond(parsedLong); @@ -302,7 +301,7 @@ public Instant readTimestamp(Schema schema) { throw new SerializationException("Epoch seconds out of range: " + parsedLong, e); } } - // Has exponent or unparseable fraction — fall through to double parsing + // Has exponent or unparseable fraction -- fall through to double parsing JsonReadUtils.parseDouble(buf, pos, end, this); pos = parsedEndPos; return format.readFromNumber(parsedDouble); @@ -354,8 +353,6 @@ private String describeCurrentToken() { }; } - // ---- Struct deserialization ---- - @Override public void readStruct(Schema schema, T state, StructMemberConsumer structMemberConsumer) { // Localize hot fields to registers. The JIT cannot promote instance fields across @@ -365,7 +362,11 @@ public void readStruct(Schema schema, T state, StructMemberConsumer struc final int localEnd = this.end; int p = JsonReadUtils.skipWhitespace(localBuf, this.pos, localEnd); - if (p >= localEnd || localBuf[p] != '{') { + if (p >= localEnd) { + this.pos = p; + return; + } + if (localBuf[p] != '{') { this.pos = p; throw new SerializationException( "Expected '{', found: " + JsonReadUtils.describePos(localBuf, p, localEnd)); @@ -379,14 +380,12 @@ public void readStruct(Schema schema, T state, StructMemberConsumer struc p = JsonReadUtils.skipWhitespace(localBuf, p, localEnd); - // Empty object if (p < localEnd && localBuf[p] == '}') { this.pos = p + 1; depth--; return; } - // Get the member lookup for this struct. Schema structSchema = schema.isMember() ? schema.memberTarget() : schema; var ext = structSchema.getExtension(SmithyJsonSchemaExtensions.KEY); SmithyMemberLookup lookup = null; @@ -419,7 +418,6 @@ public void readStruct(Schema schema, T state, StructMemberConsumer struc p = JsonReadUtils.skipWhitespace(localBuf, p, localEnd); - // Parse field name if (p >= localEnd || localBuf[p] != '"') { this.pos = p; throw new SerializationException( @@ -463,7 +461,6 @@ public void readStruct(Schema schema, T state, StructMemberConsumer struc p++; // skip closing quote } - // Skip colon p = JsonReadUtils.skipWhitespace(localBuf, p, localEnd); if (p >= localEnd || localBuf[p] != ':') { this.pos = p; @@ -498,7 +495,7 @@ public void readStruct(Schema schema, T state, StructMemberConsumer struc p = this.pos; } } else { - // Unknown field — validate field name bytes per RFC 8259 + // Unknown field -- validate field name bytes per RFC 8259 // (control chars, escape sequences). This is the cold path only. validateSkippedString(localBuf, nameStart, nameEnd); this.pos = p; @@ -520,8 +517,6 @@ public void readStruct(Schema schema, T state, StructMemberConsumer struc } } - // ---- List deserialization ---- - @Override public void readList(Schema schema, T state, ListMemberConsumer listMemberConsumer) { final byte[] buf = this.buf; @@ -542,7 +537,6 @@ public void readList(Schema schema, T state, ListMemberConsumer listMembe p = JsonReadUtils.skipWhitespace(buf, p, end); - // Empty array if (p < end && buf[p] == ']') { this.pos = p + 1; depth--; @@ -573,8 +567,6 @@ public void readList(Schema schema, T state, ListMemberConsumer listMembe } } - // ---- Map deserialization ---- - @Override public void readStringMap(Schema schema, T state, MapMemberConsumer mapMemberConsumer) { final byte[] buf = this.buf; @@ -595,7 +587,6 @@ public void readStringMap(Schema schema, T state, MapMemberConsumer void readStringMap(Schema schema, T state, MapMemberConsumer void readStringMap(Schema schema, T state, MapMemberConsumer= Integer.MIN_VALUE && lv <= Integer.MAX_VALUE) { number = (int) lv; } else { number = lv; } - } catch (NumberFormatException e) { - number = new BigInteger(numStr); + } else { + String numStr = new String(buf, pos, len, StandardCharsets.US_ASCII); + try { + long lv = Long.parseLong(numStr); + if (lv >= Integer.MIN_VALUE && lv <= Integer.MAX_VALUE) { + number = (int) lv; + } else { + number = lv; + } + } catch (NumberFormatException e) { + number = new BigInteger(numStr); + } } + pos = newPos; } return JsonDocuments.of(number, settings); } - // ---- Null handling ---- - @Override public boolean isNull() { skipWhitespace(); @@ -776,8 +778,6 @@ public T readNull() { return null; } - // ---- Validation for skipped content ---- - /** * Validates string bytes (between quotes) for RFC 8259 compliance without building a String. * Checks for unescaped control characters and valid escape sequences. @@ -812,8 +812,6 @@ private static void validateSkippedString(byte[] buf, int start, int end) { } } - // ---- Value skipping for unknown fields ---- - private void skipValue() { skipWhitespace(); if (pos >= end) { @@ -871,7 +869,7 @@ private void skipString() { case '"', '\\', '/', 'b', 'f', 'n', 'r', 't' -> { } case 'u' -> { - // \\uXXXX — skip 4 hex digits + // \\uXXXX -- skip 4 hex digits if (p + 4 >= end) { throw new SerializationException("Unterminated \\u escape"); } @@ -960,8 +958,6 @@ private void skipArray() { } } - // ---- Utility methods ---- - private void skipWhitespace() { pos = JsonReadUtils.skipWhitespace(buf, pos, end); } diff --git a/codecs/json-codec/src/main/java/software/amazon/smithy/java/json/smithy/SmithyJsonSerializer.java b/codecs/json-codec/src/main/java/software/amazon/smithy/java/json/smithy/SmithyJsonSerializer.java index cc578a5217..2cb316fa14 100644 --- a/codecs/json-codec/src/main/java/software/amazon/smithy/java/json/smithy/SmithyJsonSerializer.java +++ b/codecs/json-codec/src/main/java/software/amazon/smithy/java/json/smithy/SmithyJsonSerializer.java @@ -11,8 +11,9 @@ import java.nio.ByteBuffer; import java.time.Instant; import java.util.Arrays; -import java.util.concurrent.atomic.AtomicReferenceArray; import java.util.function.BiConsumer; +import software.amazon.smithy.java.codecs.commons.NumberCodec; +import software.amazon.smithy.java.codecs.commons.StripedPool; import software.amazon.smithy.java.core.schema.Schema; import software.amazon.smithy.java.core.schema.SerializableStruct; import software.amazon.smithy.java.core.serde.MapSerializer; @@ -21,7 +22,6 @@ import software.amazon.smithy.java.core.serde.SpecificShapeSerializer; import software.amazon.smithy.java.core.serde.TimestampFormatter; import software.amazon.smithy.java.core.serde.document.Document; -import software.amazon.smithy.java.io.ByteBufferUtils; import software.amazon.smithy.java.json.JsonFieldMapper; import software.amazon.smithy.java.json.JsonSettings; import software.amazon.smithy.model.shapes.ShapeType; @@ -38,19 +38,7 @@ final class SmithyJsonSerializer implements ShapeSerializer { private static final int DEFAULT_BUF_SIZE = 8192; private static final int MAX_CACHEABLE_BUF = DEFAULT_BUF_SIZE * 4; - // Striped serializer pool. - private static final int POOL_SLOTS; - private static final int POOL_MASK; - private static final AtomicReferenceArray POOL; - private static final int MAX_PROBE = 3; - - static { - int processors = Runtime.getRuntime().availableProcessors(); - int raw = processors * 4; - POOL_SLOTS = Integer.highestOneBit(raw - 1) << 1; - POOL_MASK = POOL_SLOTS - 1; - POOL = new AtomicReferenceArray<>(POOL_SLOTS); - } + private static final StripedPool POOL = new JsonStripedPool(); private byte[] buf; private int pos; @@ -66,10 +54,6 @@ final class SmithyJsonSerializer implements ShapeSerializer { // Resolved once per writeStruct call, then used for all member writes. private byte[][] currentFieldNameTable; - // Reusable Schubfach instances for double/float write - private final Schubfach.DoubleToDecimal doubleToDecimal = Schubfach.createDoubleToDecimal(); - private final Schubfach.FloatToDecimal floatToDecimal = Schubfach.createFloatToDecimal(); - private final ShapeSerializer structSerializer = new StructSerializer(); private final ShapeSerializer listElementSerializer = new ListElementSerializer(); private final MapSerializer mapSerializer = new SmithyMapSerializer(); @@ -97,83 +81,21 @@ final class SmithyJsonSerializer implements ShapeSerializer { this.depth = 0; } - /** - * Acquires a serializer from the pool, or creates a new one. - * The returned serializer is ready for use with a fresh buffer. - * - *

Uses getPlain to peek at slots cheaply (plain read, no ordering), then - * compareAndExchangeAcquire to atomically claim a non-null entry (acquire - * semantics ensure we see the serializer's fully-written state). This pays - * the atomic price only once per acquire - empty slots are skipped with a - * plain read instead of a full getAndSet. - */ static SmithyJsonSerializer acquire(JsonSettings settings) { - //TODO Have a different strat for VTs, - // we still some sort of pooling for VTs but the current strategy won't work. - if (!Thread.currentThread().isVirtual()) { - int base = poolProbe(); - for (int i = 0; i < MAX_PROBE; i++) { - int idx = (base + i) & POOL_MASK; - SmithyJsonSerializer s = POOL.getPlain(idx); - if (s != null && POOL.compareAndExchangeAcquire(idx, s, null) == s) { - if (s.settings.equals(settings)) { - s.pos = 0; - s.depth = 0; - s.currentFieldNameTable = null; - return s; - } - POOL.setRelease(idx, s); // wrong settings, put back - } - } - } - return new SmithyJsonSerializer(settings); + return POOL.acquire(settings); } - /** - * Returns a serializer to the pool for reuse. If the pool is full, the - * buffer is oversized, or we're on a virtual thread, the serializer is discarded. - * - *

Uses getPlain to peek for empty slots, then compareAndExchangeRelease to - * store the serializer with release semantics (ensures all serializer state is - * visible to the thread that later acquires it). - */ static void release(SmithyJsonSerializer serializer, boolean exception) { - if (serializer.buf == null || Thread.currentThread().isVirtual()) { - return; - } - // If an exception occurred, the needsComma array may be in an inconsistent state. - // Clear it before pooling so the next acquire gets a clean serializer. if (exception) { Arrays.fill(serializer.needsComma, false); } - // Downsize oversized buffers before pooling to bound memory - if (serializer.buf.length > MAX_CACHEABLE_BUF) { - serializer.buf = new byte[DEFAULT_BUF_SIZE]; - } - int base = poolProbe(); - for (int i = 0; i < MAX_PROBE; i++) { - int idx = (base + i) & POOL_MASK; - if (POOL.getPlain(idx) == null - && POOL.compareAndExchangeRelease(idx, null, serializer) == null) { - return; - } - } - // Pool full — let GC collect + POOL.release(serializer); } - /** - * Extracts the serialized JSON as a ByteBuffer without releasing the internal - * buffer. Used with {@link #acquire}/{@link #release} for pooled serializers. - */ ByteBuffer extractResult() { return ByteBuffer.wrap(Arrays.copyOf(buf, pos)); } - private static int poolProbe() { - long id = Thread.currentThread().threadId(); - return (int) (id ^ (id >>> 16)) & POOL_MASK; - } - private void ensureCapacity(int needed) { if (pos + needed > buf.length) { grow(needed); @@ -211,14 +133,10 @@ public void close() { } } - // ---- Value writers (no field name prefix) ---- - @Override public void writeBoolean(Schema schema, boolean value) { - byte[] bytes = value ? JsonWriteUtils.TRUE_BYTES : JsonWriteUtils.FALSE_BYTES; - ensureCapacity(bytes.length); - System.arraycopy(bytes, 0, buf, pos, bytes.length); - pos += bytes.length; + ensureCapacity(5); + pos = NumberCodec.writeBoolean(buf, pos, value); } @Override @@ -247,81 +165,38 @@ public void writeLong(Schema schema, long value) { @Override public void writeFloat(Schema schema, float value) { - if (Float.isFinite(value)) { - ensureCapacity(24); - pos = JsonWriteUtils.writeFloat(buf, pos, value, floatToDecimal); - } else if (Float.isNaN(value)) { - ensureCapacity(JsonWriteUtils.NAN_BYTES.length); - System.arraycopy(JsonWriteUtils.NAN_BYTES, 0, buf, pos, JsonWriteUtils.NAN_BYTES.length); - pos += JsonWriteUtils.NAN_BYTES.length; - } else { - byte[] bytes = value > 0 ? JsonWriteUtils.INF_BYTES : JsonWriteUtils.NEG_INF_BYTES; - ensureCapacity(bytes.length); - System.arraycopy(bytes, 0, buf, pos, bytes.length); - pos += bytes.length; - } + ensureCapacity(24); + pos = NumberCodec.writeFloatFullQuoted(buf, pos, value); } @Override public void writeDouble(Schema schema, double value) { - if (Double.isFinite(value)) { - ensureCapacity(24); - pos = JsonWriteUtils.writeDouble(buf, pos, value, doubleToDecimal); - } else if (Double.isNaN(value)) { - ensureCapacity(JsonWriteUtils.NAN_BYTES.length); - System.arraycopy(JsonWriteUtils.NAN_BYTES, 0, buf, pos, JsonWriteUtils.NAN_BYTES.length); - pos += JsonWriteUtils.NAN_BYTES.length; - } else { - byte[] bytes = value > 0 ? JsonWriteUtils.INF_BYTES : JsonWriteUtils.NEG_INF_BYTES; - ensureCapacity(bytes.length); - System.arraycopy(bytes, 0, buf, pos, bytes.length); - pos += bytes.length; - } + ensureCapacity(24); + pos = NumberCodec.writeDoubleFullQuoted(buf, pos, value); } @Override public void writeBigInteger(Schema schema, BigInteger value) { + int maxLen = value.bitLength() / 3 + 2; + ensureCapacity(maxLen + 2); if (settings.useStringForArbitraryPrecision()) { - String s = value.toString(); - ensureCapacity(JsonWriteUtils.maxQuotedStringBytes(s)); - pos = JsonWriteUtils.writeQuotedString(buf, pos, s); - return; - } - if (value.bitLength() < 64) { - ensureCapacity(20); - pos = JsonWriteUtils.writeLong(buf, pos, value.longValue()); - return; + buf[pos++] = '"'; + pos = NumberCodec.writeBigInteger(buf, pos, value); + buf[pos++] = '"'; + } else { + pos = NumberCodec.writeBigInteger(buf, pos, value); } - ensureCapacity(value.bitLength() / 3 + 2); - pos = JsonWriteUtils.writeBigInteger(buf, pos, value); } @Override public void writeBigDecimal(Schema schema, BigDecimal value) { - int scale = value.scale(); - if (value.unscaledValue().bitLength() < 64) { - if (scale == 0) { - ensureCapacity(20); - pos = JsonWriteUtils.writeLong(buf, pos, value.longValueExact()); - return; - } - if (scale > 0) { - // Fast path: write "integerPart.fractionalPart" directly. - // E.g., BigDecimal("99999.99999") -> unscaled=9999999999, scale=5 - long unscaled = value.unscaledValue().longValue(); - ensureCapacity(22 + scale); // sign + 20 digits + dot + scale digits - pos = JsonWriteUtils.writeBigDecimalFromLong(buf, pos, unscaled, scale); - return; - } - } - String s = value.toString(); - // Preempt the quotes wrapping, as a BigDecimal write will almost - // always have at least 1 additional character after it. - ensureCapacity(s.length() + 2); + ensureCapacity(NumberCodec.maxBigDecimalLength(value) + 2); if (settings.useStringForArbitraryPrecision()) { - pos = JsonWriteUtils.writeQuotedString(buf, pos, s); + buf[pos++] = '"'; + pos = NumberCodec.writeBigDecimal(buf, pos, value); + buf[pos++] = '"'; } else { - pos = JsonWriteUtils.writeAsciiString(buf, pos, s); + pos = NumberCodec.writeBigDecimal(buf, pos, value); } } @@ -339,18 +214,8 @@ public void writeBlob(Schema schema, byte[] value) { @Override public void writeBlob(Schema schema, ByteBuffer value) { - int len = value.remaining(); - byte[] data; - int off; - if (value.hasArray()) { - data = value.array(); - off = value.arrayOffset() + value.position(); - } else { - data = ByteBufferUtils.getBytes(value.duplicate()); - off = 0; - } - ensureCapacity(JsonWriteUtils.maxBase64Bytes(len)); - pos = JsonWriteUtils.writeBase64String(buf, pos, data, off, len); + ensureCapacity(JsonWriteUtils.maxBase64Bytes(value.remaining())); + pos = JsonWriteUtils.writeBase64String(buf, pos, value); } @Override @@ -453,8 +318,6 @@ public void writeNull(Schema schema) { pos += JsonWriteUtils.NULL_BYTES.length; } - // ---- Comma management ---- - private void writeCommaIfNeeded() { if (needsComma[depth]) { if (pos >= buf.length) { @@ -466,8 +329,6 @@ private void writeCommaIfNeeded() { } } - // ---- Field name writing ---- - /** * Resolves the pre-computed field name bytes for a schema member. */ @@ -508,7 +369,35 @@ private void writeFieldNameBytes(Schema schema) { writeFieldNameBytesUnchecked(nameBytes); } - // ---- Inner struct serializer: writes field name + value ---- + private static class JsonStripedPool extends StripedPool { + @Override + protected SmithyJsonSerializer create(JsonSettings settings) { + return new SmithyJsonSerializer(settings); + } + + @Override + protected boolean canPool(SmithyJsonSerializer s) { + return s.buf != null; + } + + @Override + protected void prepareForPool(SmithyJsonSerializer s) { + if (s.buf.length > MAX_CACHEABLE_BUF) { + s.buf = new byte[DEFAULT_BUF_SIZE]; + } + } + + @Override + protected boolean reset(SmithyJsonSerializer s, JsonSettings settings) { + if (!s.settings.equals(settings)) { + return false; + } + s.pos = 0; + s.depth = 0; + s.currentFieldNameTable = null; + return true; + } + } private final class StructSerializer implements ShapeSerializer { @@ -518,11 +407,9 @@ private final class StructSerializer implements ShapeSerializer { @Override public void writeBoolean(Schema schema, boolean value) { byte[] nameBytes = resolveFieldNameBytes(schema); - ensureCapacity(nameBytes.length + 1 + 5); // +5 for "false" + ensureCapacity(nameBytes.length + 1 + 5); writeFieldNameBytesUnchecked(nameBytes); - byte[] bytes = value ? JsonWriteUtils.TRUE_BYTES : JsonWriteUtils.FALSE_BYTES; - System.arraycopy(bytes, 0, buf, pos, bytes.length); - pos += bytes.length; + pos = NumberCodec.writeBoolean(buf, pos, value); } @Override @@ -562,16 +449,7 @@ public void writeFloat(Schema schema, float value) { byte[] nameBytes = resolveFieldNameBytes(schema); ensureCapacity(nameBytes.length + 1 + 24); writeFieldNameBytesUnchecked(nameBytes); - if (Float.isFinite(value)) { - pos = JsonWriteUtils.writeFloat(buf, pos, value, floatToDecimal); - } else if (Float.isNaN(value)) { - System.arraycopy(JsonWriteUtils.NAN_BYTES, 0, buf, pos, JsonWriteUtils.NAN_BYTES.length); - pos += JsonWriteUtils.NAN_BYTES.length; - } else { - byte[] bytes = value > 0 ? JsonWriteUtils.INF_BYTES : JsonWriteUtils.NEG_INF_BYTES; - System.arraycopy(bytes, 0, buf, pos, bytes.length); - pos += bytes.length; - } + pos = NumberCodec.writeFloatFullQuoted(buf, pos, value); } @Override @@ -579,16 +457,7 @@ public void writeDouble(Schema schema, double value) { byte[] nameBytes = resolveFieldNameBytes(schema); ensureCapacity(nameBytes.length + 1 + 24); writeFieldNameBytesUnchecked(nameBytes); - if (Double.isFinite(value)) { - pos = JsonWriteUtils.writeDouble(buf, pos, value, doubleToDecimal); - } else if (Double.isNaN(value)) { - System.arraycopy(JsonWriteUtils.NAN_BYTES, 0, buf, pos, JsonWriteUtils.NAN_BYTES.length); - pos += JsonWriteUtils.NAN_BYTES.length; - } else { - byte[] bytes = value > 0 ? JsonWriteUtils.INF_BYTES : JsonWriteUtils.NEG_INF_BYTES; - System.arraycopy(bytes, 0, buf, pos, bytes.length); - pos += bytes.length; - } + pos = NumberCodec.writeDoubleFullQuoted(buf, pos, value); } @Override @@ -659,8 +528,6 @@ public void writeDocument(Schema schema, Document value) { } } - // ---- List element serializer: handles comma separation between elements ---- - private final class ListElementSerializer implements ShapeSerializer { private void beforeElement() { writeCommaIfNeeded(); @@ -769,8 +636,6 @@ public void writeNull(Schema schema) { } } - // ---- Map serializer ---- - private final class SmithyMapSerializer implements MapSerializer { @Override public void writeEntry( @@ -787,8 +652,6 @@ public void writeEntry( } } - // ---- Document struct serializer (writes __type) ---- - private static final class SerializeDocumentContents extends SpecificShapeSerializer { private final SmithyJsonSerializer parent; diff --git a/codecs/json-codec/src/main/java/software/amazon/smithy/java/json/smithy/package-info.java b/codecs/json-codec/src/main/java/software/amazon/smithy/java/json/smithy/package-info.java deleted file mode 100644 index 5698a6f35e..0000000000 --- a/codecs/json-codec/src/main/java/software/amazon/smithy/java/json/smithy/package-info.java +++ /dev/null @@ -1,13 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -/** - * High-performance native JSON serialization/deserialization for smithy-java. - * - *

Writes and parses JSON bytes directly without Jackson on the hot path, - * exploiting Smithy schema knowledge for pre-computed field names and hash-based - * field matching. - */ -package software.amazon.smithy.java.json.smithy; diff --git a/codecs/json-codec/src/test/java/software/amazon/smithy/java/json/GeneratedModelSerdeTest.java b/codecs/json-codec/src/test/java/software/amazon/smithy/java/json/GeneratedModelSerdeTest.java index ee9a484773..8b3773e8c8 100644 --- a/codecs/json-codec/src/test/java/software/amazon/smithy/java/json/GeneratedModelSerdeTest.java +++ b/codecs/json-codec/src/test/java/software/amazon/smithy/java/json/GeneratedModelSerdeTest.java @@ -506,6 +506,13 @@ void nullOptionalMembersRoundtrip(JsonSerdeProvider ser, JsonSerdeProvider de) { assertThat(roundtrip(ser, de, original, NumericStruct.builder())).isEqualTo(original); } + @PerProvider + void emptyBodyDeserializesToDefaultStruct(JsonSerdeProvider provider) { + var codec = codec(provider); + var result = codec.deserializeShape(new byte[0], NumericStruct.builder()); + assertThat(result).isEqualTo(NumericStruct.builder().build()); + } + // --- Fields in non-schema order (exercises hash lookup slow path) --- @PerProvider diff --git a/codecs/xml-codec/build.gradle.kts b/codecs/xml-codec/build.gradle.kts index 4a88a21225..b9c8565055 100644 --- a/codecs/xml-codec/build.gradle.kts +++ b/codecs/xml-codec/build.gradle.kts @@ -1,6 +1,7 @@ plugins { id("smithy-java.module-conventions") id("smithy-java.fuzz-test") + id("software.amazon.smithy.gradle.smithy-base") } description = "This module provides XML functionality" @@ -10,4 +11,30 @@ extra["moduleName"] = "software.amazon.smithy.java.xml" dependencies { api(project(":core")) + api(project(":codecs:codec-commons", configuration = "shadow")) + smithyBuild(project(":codegen:codegen-plugin")) +} + +tasks.named("test") { + systemProperty("smithy-java.xml-provider", "smithy") +} + +afterEvaluate { + val typePath = smithy.getPluginProjectionPath(smithy.sourceProjection.get(), "java-codegen").get() + sourceSets.named("test") { + java { + srcDir("$typePath/java") + } + resources { + srcDir("$typePath/resources") + } + } +} + +tasks.named("compileTestJava") { + dependsOn("smithyBuild") +} + +tasks.named("processTestResources") { + dependsOn("smithyBuild") } diff --git a/codecs/xml-codec/model/test.smithy b/codecs/xml-codec/model/test.smithy new file mode 100644 index 0000000000..2e31d216a2 --- /dev/null +++ b/codecs/xml-codec/model/test.smithy @@ -0,0 +1,273 @@ +$version: "2" + +namespace smithy.java.xml.test + +/// A simple structure with scalar fields. +structure SimpleStruct { + @required + name: String + + @required + age: Integer + + active: Boolean + + score: Double + + @timestampFormat("date-time") + createdAt: Timestamp +} + +/// A complex structure that exercises many XML features. +structure ComplexStruct { + @required + id: String + + @required + count: Integer + + @required + enabled: PrimitiveBoolean = false + + @required + ratio: PrimitiveDouble = 0 + + @required + score: PrimitiveFloat = 0 + + @required + bigCount: PrimitiveLong = 0 + + optionalString: String + + optionalInt: Integer + + @timestampFormat("date-time") + createdAt: Timestamp + + payload: Blob + + tags: StringList + + intList: IntegerList + + metadata: StringMap + + intMap: IntegerMap + + @required + nested: NestedStruct + + optionalNested: NestedStruct + + structList: NestedStructList + + structMap: NestedStructMap + + color: Color + + colorList: ColorList + + bigIntValue: BigInteger + + bigDecValue: BigDecimal +} + +structure NestedStruct { + @required + field1: String + + @required + field2: Integer + + inner: InnerStruct +} + +structure InnerStruct { + value: String + numbers: IntegerList +} + +enum Color { + RED + GREEN + BLUE + YELLOW +} + +list StringList { + member: String +} + +list IntegerList { + member: Integer +} + +list NestedStructList { + member: NestedStruct +} + +list ColorList { + member: Color +} + +map StringMap { + key: String + value: String +} + +map IntegerMap { + key: String + value: Integer +} + +map NestedStructMap { + key: String + value: NestedStruct +} + +/// Structure focused on numeric boundary testing +structure NumericStruct { + byteVal: Byte + shortVal: Short + intVal: Integer + longVal: Long + floatVal: Float + doubleVal: Double + bigIntVal: BigInteger + bigDecVal: BigDecimal +} + +/// Structure focused on string edge cases +structure StringStruct { + @required + value: String +} + +/// Structure with all three timestamp formats +structure TimestampStruct { + epochSeconds: Timestamp + + @timestampFormat("date-time") + dateTime: Timestamp + + @timestampFormat("http-date") + httpDate: Timestamp +} + +/// Structure with xmlName trait on fields +@xmlName("CustomRoot") +structure XmlNameStruct { + @xmlName("ID") + id: String + + @xmlName("DisplayName") + displayName: String + + normalField: String +} + +/// Structure with xmlAttribute trait +structure XmlAttributeStruct { + @xmlAttribute + version: String + + @xmlAttribute + @xmlName("id") + identifier: String + + content: String +} + +/// Structure with xmlFlattened list +structure FlattenedListStruct { + @xmlFlattened + items: StringList + + @xmlFlattened + numbers: IntegerList + + normalList: StringList +} + +/// Structure with xmlFlattened map +structure FlattenedMapStruct { + @xmlFlattened + entries: StringMap + + normalMap: StringMap +} + +/// Structure with xmlNamespace +@xmlNamespace(uri: "https://example.com/test") +structure NamespacedStruct { + name: String + value: Integer +} + +/// Structure that nests itself for depth testing +structure RecursiveStruct { + value: String + child: RecursiveStruct +} + +/// Structure for blob testing +structure BlobStruct { + data: Blob +} + +list DoubleList { + member: Double +} + +list BigDecimalList { + member: BigDecimal +} + +list BooleanList { + member: Boolean +} + +list ByteList { + member: Byte +} + +list ShortList { + member: Short +} + +list LongList { + member: Long +} + +list FloatList { + member: Float +} + +list BigIntegerList { + member: BigInteger +} + +list BlobList { + member: Blob +} + +list TimestampList { + member: Timestamp +} + +/// Structure containing lists of all types +structure AllListsStruct { + booleans: BooleanList + bytes: ByteList + shorts: ShortList + ints: IntegerList + longs: LongList + floats: FloatList + doubles: DoubleList + bigInts: BigIntegerList + bigDecs: BigDecimalList + strings: StringList + blobs: BlobList + timestamps: TimestampList +} diff --git a/codecs/xml-codec/smithy-build.json b/codecs/xml-codec/smithy-build.json new file mode 100644 index 0000000000..2e0c0c96aa --- /dev/null +++ b/codecs/xml-codec/smithy-build.json @@ -0,0 +1,9 @@ +{ + "version": "1.0", + "plugins": { + "java-codegen": { + "namespace": "smithy.java.xml.test", + "modes": ["types"] + } + } +} diff --git a/codecs/xml-codec/src/fuzz/java/software/amazon/smithy/java/xml/DifferentialXmlFuzzTest.java b/codecs/xml-codec/src/fuzz/java/software/amazon/smithy/java/xml/DifferentialXmlFuzzTest.java new file mode 100644 index 0000000000..c1a49801db --- /dev/null +++ b/codecs/xml-codec/src/fuzz/java/software/amazon/smithy/java/xml/DifferentialXmlFuzzTest.java @@ -0,0 +1,91 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.xml; + +import java.nio.charset.StandardCharsets; +import software.amazon.smithy.java.core.schema.SerializableShape; +import software.amazon.smithy.java.core.serde.Codec; +import software.amazon.smithy.java.fuzz.DifferentialCodecFuzzTestBase; + +class DifferentialXmlFuzzTest extends DifferentialCodecFuzzTestBase { + + private static final XmlCodec STAX_CODEC = XmlCodec.builder().useNative(false).build(); + private static final XmlCodec NATIVE_CODEC = XmlCodec.builder().useNative(true).build(); + + @Override + protected Codec referenceCodec() { + return STAX_CODEC; + } + + @Override + protected Codec testCodec() { + return NATIVE_CODEC; + } + + @Override + protected boolean isAcceptableDivergence( + SerializableShape referenceShape, + SerializableShape testShape, + byte[] input + ) { + return false; + } + + @Override + protected boolean isAcceptableTestFailure(SerializableShape referenceShape, Exception testError, byte[] input) { + // When input starts with , , or , both codecs enter + // an error-response parsing path. If the content isn't actually an error response, + // the native parser may reject with stricter tag validation while StAX recovers + // via error correction (producing empty struct output). + if (input.length < 12) { + return false; + } + String prefix = new String(input, 0, Math.min(input.length, 20), StandardCharsets.UTF_8); + return prefix.startsWith("//) are handled differently. + String msg = getFullErrorMessage(referenceError); + if (msg.contains("XMLStreamException") + || (msg.contains("element type") && msg.contains("must be terminated"))) { + return true; + } + // Inputs that trigger error-response parsing may diverge in behavior + if (input.length >= 6) { + String prefix = new String(input, + 0, + Math.min(input.length, 20), + StandardCharsets.UTF_8); + if (prefix.startsWith(" void writeList(Schema schema, T listState, int size, BiConsumer consumer) { + delegate().writeList(schema, listState, size, consumer); + } + + @Override + public void writeMap(Schema schema, T mapState, int size, BiConsumer consumer) { + delegate().writeMap(schema, mapState, size, consumer); + } + + @Override + public void writeBoolean(Schema schema, boolean value) { + delegate().writeBoolean(schema, value); + } + + @Override + public void writeByte(Schema schema, byte value) { + delegate().writeByte(schema, value); + } + + @Override + public void writeShort(Schema schema, short value) { + delegate().writeShort(schema, value); + } + + @Override + public void writeInteger(Schema schema, int value) { + delegate().writeInteger(schema, value); + } + + @Override + public void writeLong(Schema schema, long value) { + delegate().writeLong(schema, value); + } + + @Override + public void writeFloat(Schema schema, float value) { + delegate().writeFloat(schema, value); + } + + @Override + public void writeDouble(Schema schema, double value) { + delegate().writeDouble(schema, value); + } + + @Override + public void writeBigInteger(Schema schema, BigInteger value) { + delegate().writeBigInteger(schema, value); + } + + @Override + public void writeBigDecimal(Schema schema, BigDecimal value) { + delegate().writeBigDecimal(schema, value); + } + + @Override + public void writeString(Schema schema, String value) { + delegate().writeString(schema, value); + } + + @Override + public void writeBlob(Schema schema, ByteBuffer value) { + delegate().writeBlob(schema, value); + } + + @Override + public void writeTimestamp(Schema schema, Instant value) { + delegate().writeTimestamp(schema, value); + } + + @Override + public void writeDocument(Schema schema, Document value) { + delegate().writeDocument(schema, value); + } + + @Override + public void writeNull(Schema schema) { + delegate().writeNull(schema); + } + + @Override + public void flush() { + if (delegate != null) { + delegate.flush(); + } + } + + @Override + public void close() { + if (delegate != null) { + delegate.close(); + } + } +} diff --git a/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/SmithyXmlDeserializer.java b/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/SmithyXmlDeserializer.java new file mode 100644 index 0000000000..2d9f0f3410 --- /dev/null +++ b/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/SmithyXmlDeserializer.java @@ -0,0 +1,1984 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.xml; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.time.format.DateTimeParseException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Base64; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import software.amazon.smithy.java.codecs.commons.NumberCodec; +import software.amazon.smithy.java.codecs.commons.TimestampCodec; +import software.amazon.smithy.java.core.schema.Schema; +import software.amazon.smithy.java.core.schema.TraitKey; +import software.amazon.smithy.java.core.serde.SerializationException; +import software.amazon.smithy.java.core.serde.ShapeDeserializer; +import software.amazon.smithy.java.core.serde.SpecificShapeDeserializer; +import software.amazon.smithy.java.core.serde.TimestampFormatter; +import software.amazon.smithy.java.core.serde.document.Document; +import software.amazon.smithy.model.traits.TimestampFormatTrait; + +/** + * High-performance XML deserializer that inlines all byte-level parsing directly. + * + *

This class combines the ShapeDeserializer contract with a custom byte-level XML parser, + * avoiding the overhead of javax.xml.stream and intermediate String allocations for numeric types. + */ +final class SmithyXmlDeserializer implements ShapeDeserializer, XmlErrorCodeParser { + + private static final Base64.Decoder BASE64_DECODER = Base64.getDecoder(); + private static final int START_ELEMENT = 1; + private static final int END_ELEMENT = 2; + private static final int CHARACTERS = 3; + private static final int EOF = -1; + + private static final byte[] ERROR_RESPONSE_BYTES = "ErrorResponse".getBytes(StandardCharsets.UTF_8); + private static final byte[] ERROR_BYTES = "Error".getBytes(StandardCharsets.UTF_8); + private static final byte[] RESPONSE_BYTES = "Response".getBytes(StandardCharsets.UTF_8); + + private static final boolean[] NAME_CHAR = new boolean[256]; + + static { + for (int i = 'a'; i <= 'z'; i++) { + NAME_CHAR[i] = true; + } + for (int i = 'A'; i <= 'Z'; i++) { + NAME_CHAR[i] = true; + } + for (int i = '0'; i <= '9'; i++) { + NAME_CHAR[i] = true; + } + NAME_CHAR['-'] = true; + NAME_CHAR['_'] = true; + NAME_CHAR['.'] = true; + NAME_CHAR[':'] = true; + } + + private final byte[] buf; + private int pos; + private final int limit; + private int nameStart; + private int nameLen; + private boolean selfClosing; + private int attrCount; + private int[] attrNameStarts; + private int[] attrNameLens; + private int[] attrValueStarts; + private int[] attrValueLens; + + private int endNameStart; + private int endNameLen; + + private int textSpanStart; + private int textSpanEnd; + private String textFallback; + + private final MemberDeserializer memberDeserializer = new MemberDeserializer(); + + private final int[] lookupHint = new int[1]; + + private int errorWrapperDepth; + + private final XmlInfo xmlInfo; + private final boolean isTopLevel; + private final List wrapperElements; + + SmithyXmlDeserializer( + byte[] buf, + int offset, + int length, + XmlInfo xmlInfo, + boolean isTopLevel, + List wrapperElements + ) { + this.buf = buf; + this.pos = offset; + this.limit = offset + length; + this.xmlInfo = xmlInfo; + this.isTopLevel = isTopLevel; + this.wrapperElements = wrapperElements; + skipProlog(); + } + + private int next() { + if (selfClosing) { + selfClosing = false; + endNameStart = nameStart; + endNameLen = nameLen; + return END_ELEMENT; + } + while (pos < limit) { + byte b = buf[pos]; + if (b == '<') { + pos++; + if (pos >= limit) { + return EOF; + } + b = buf[pos]; + if (b == '/') { + pos++; + parseEndElement(); + return END_ELEMENT; + } else if (b == '!') { + pos++; + if (pos + 1 < limit && buf[pos] == '-' && buf[pos + 1] == '-') { + pos += 2; + skipComment(); + continue; + } else if (pos + 6 < limit && buf[pos] == '[' + && buf[pos + 1] == 'C' + && buf[pos + 2] == 'D' + && buf[pos + 3] == 'A' + && buf[pos + 4] == 'T' + && buf[pos + 5] == 'A' + && buf[pos + 6] == '[') { + pos += 7; + while (pos + 2 < limit) { + if (buf[pos] == ']' && buf[pos + 1] == ']' && buf[pos + 2] == '>') { + pos += 3; + break; + } + pos++; + } + continue; + } + throw new SerializationException( + "Unsupported markup declaration at byte offset " + (pos - 1)); + } else if (b == '?') { + pos++; + skipPastString((byte) '?', (byte) '>'); + continue; + } else { + parseStartElement(); + return START_ELEMENT; + } + } else if (isWhitespace(b)) { + pos++; + continue; + } else { + while (pos < limit && buf[pos] != '<') { + pos++; + } + return CHARACTERS; + } + } + return EOF; + } + + private void skipProlog() { + while (pos < limit) { + skipWhitespace(); + if (pos >= limit) { + return; + } + if (buf[pos] != '<') { + throw new SerializationException("Content is not allowed in prolog"); + } + int peek = pos + 1; + if (peek >= limit) { + return; + } + byte b = buf[peek]; + if (b == '?') { + pos = peek + 1; + skipPastString((byte) '?', (byte) '>'); + } else if (b == '!') { + pos = peek + 1; + if (pos + 1 < limit && buf[pos] == '-' && buf[pos + 1] == '-') { + pos += 2; + skipComment(); + } else { + while (pos < limit && buf[pos] != '>') { + pos++; + } + if (pos < limit) { + pos++; // skip '>' + } + } + } else { + return; + } + } + } + + private void parseStartElement() { + int start = pos; + boolean hasColon = false; + int colonPos = 0; + while (pos < limit) { + byte b = buf[pos]; + if (!NAME_CHAR[b & 0xFF]) { + break; + } + if (b == ':') { + hasColon = true; + colonPos = pos; + } + pos++; + } + + if (hasColon) { + nameStart = colonPos + 1; + } else { + nameStart = start; + } + nameLen = pos - nameStart; + + attrCount = 0; + selfClosing = false; + + while (pos < limit) { + byte b = buf[pos]; + if (b == '>') { + pos++; + return; + } else if (b == '/') { + pos++; + if (pos < limit && buf[pos] == '>') { + pos++; + selfClosing = true; + return; + } + } else if (b <= ' ') { + pos++; + } else if (NAME_CHAR[b & 0xFF]) { + parseAttribute(); + } else { + throw new SerializationException( + "Malformed XML: unexpected character '" + (char) b + "' in element tag"); + } + } + } + + private void parseEndElement() { + int localStart = pos; + while (pos < limit) { + byte b = buf[pos]; + if (!NAME_CHAR[b & 0xFF]) { + break; + } + if (b == ':') { + localStart = pos + 1; + } + pos++; + } + endNameStart = localStart; + endNameLen = pos - localStart; + while (pos < limit && buf[pos] != '>') { + pos++; + } + if (pos >= limit) { + throw new SerializationException("Unexpected end of input in end tag"); + } + pos++; + } + + private void parseAttribute() { + int aNameStart = pos; + while (pos < limit && isNameChar(buf[pos])) { + pos++; + } + int aNameLen = pos - aNameStart; + + skipWhitespace(); + if (pos >= limit || buf[pos] != '=') { + throw new SerializationException("Malformed attribute: expected '=' after attribute name"); + } + pos++; + skipWhitespace(); + + int aValueStart = 0; + int aValueLen = 0; + if (pos >= limit || (buf[pos] != '"' && buf[pos] != '\'')) { + throw new SerializationException("Malformed attribute: expected quoted value"); + } + if (pos < limit && (buf[pos] == '"' || buf[pos] == '\'')) { + byte quote = buf[pos]; + pos++; + aValueStart = pos; + while (pos < limit && buf[pos] != quote) { + pos++; + } + aValueLen = pos - aValueStart; + if (pos < limit) { + pos++; // skip closing quote + } + } + + if (aNameLen >= 5 && buf[aNameStart] == 'x' + && buf[aNameStart + 1] == 'm' + && buf[aNameStart + 2] == 'l' + && buf[aNameStart + 3] == 'n' + && buf[aNameStart + 4] == 's') { + return; + } + + int localNameStart = aNameStart; + for (int i = aNameStart; i < aNameStart + aNameLen; i++) { + if (buf[i] == ':') { + localNameStart = i + 1; + break; + } + } + int localNameLen = (aNameStart + aNameLen) - localNameStart; + + ensureAttrCapacity(); + attrNameStarts[attrCount] = localNameStart; + attrNameLens[attrCount] = localNameLen; + attrValueStarts[attrCount] = aValueStart; + attrValueLens[attrCount] = aValueLen; + attrCount++; + } + + private void ensureAttrCapacity() { + if (attrNameStarts == null) { + attrNameStarts = new int[4]; + attrNameLens = new int[4]; + attrValueStarts = new int[4]; + attrValueLens = new int[4]; + } else if (attrCount >= attrNameStarts.length) { + int newCap = attrNameStarts.length * 2; + attrNameStarts = Arrays.copyOf(attrNameStarts, newCap); + attrNameLens = Arrays.copyOf(attrNameLens, newCap); + attrValueStarts = Arrays.copyOf(attrValueStarts, newCap); + attrValueLens = Arrays.copyOf(attrValueLens, newCap); + } + } + + private void skipElement() { + skipElement(nameStart, nameLen); + } + + private void skipElement(int expectedNameStart, int expectedNameLen) { + if (selfClosing) { + selfClosing = false; + return; + } + int depth = 1; + while (depth > 0 && pos < limit) { + int event = next(); + if (event == START_ELEMENT) { + depth++; + } else if (event == END_ELEMENT) { + depth--; + if (depth == 0) { + if (endNameLen != expectedNameLen + || !Arrays.equals(buf, + endNameStart, + endNameStart + endNameLen, + buf, + expectedNameStart, + expectedNameStart + expectedNameLen)) { + throw new SerializationException( + "Mismatched end tag: expected '' but found ''"); + } + } + } else if (event == EOF) { + throw new SerializationException( + "Unexpected end of input while looking for end tag ''"); + } + } + } + + private void consumeEndElement() { + if (pos + 1 < limit && buf[pos] == '<' && buf[pos + 1] == '/') { + pos += 2; + int expectedStart = nameStart; + int expectedLen = nameLen; + if (pos + expectedLen <= limit + && Arrays.equals(buf, pos, pos + expectedLen, buf, expectedStart, expectedStart + expectedLen)) { + pos += expectedLen; + while (pos < limit && buf[pos] != '>') { + pos++; + } + if (pos < limit) { + pos++; + } + return; + } + pos -= 2; + } + consumeEndElementSlow(); + } + + private void consumeEndElementSlow() { + int expectedStart = nameStart; + int expectedLen = nameLen; + while (true) { + int event = next(); + if (event == END_ELEMENT) { + if (endNameLen != expectedLen + || !Arrays.equals(buf, + endNameStart, + endNameStart + endNameLen, + buf, + expectedStart, + expectedStart + expectedLen)) { + throw new SerializationException( + "Mismatched end tag: expected '' but found ''"); + } + return; + } else if (event == START_ELEMENT) { + skipElement(); + } else if (event == CHARACTERS) { + continue; + } else { + throw new SerializationException( + "Unexpected end of input while looking for end tag ''"); + } + } + } + + private void consumeEndElement(byte[] expectedName) { + while (true) { + int event = next(); + if (event == END_ELEMENT) { + if (endNameLen != expectedName.length + || !Arrays.equals(buf, + endNameStart, + endNameStart + endNameLen, + expectedName, + 0, + expectedName.length)) { + throw new SerializationException( + "Mismatched end tag: expected '' but found ''"); + } + return; + } else if (event == CHARACTERS) { + continue; + } else if (event == EOF) { + throw new SerializationException( + "Unexpected end of input while looking for end tag ''"); + } else { + throw new SerializationException( + "Expected end element for '" + + new String(expectedName, StandardCharsets.UTF_8) + + "' but found other content"); + } + } + } + + private void validateContainerEndTag(int expectedStart, int expectedLen) { + if (endNameLen != expectedLen + || !Arrays.equals(buf, + endNameStart, + endNameStart + endNameLen, + buf, + expectedStart, + expectedStart + expectedLen)) { + throw new SerializationException( + "Mismatched end tag: expected '' but found ''"); + } + } + + private void skipComment() { + while (pos + 2 < limit) { + if (buf[pos] == '-' && buf[pos + 1] == '-' && buf[pos + 2] == '>') { + pos += 3; + return; + } + pos++; + } + pos = limit; + } + + private void skipWhitespace() { + while (pos < limit && isWhitespace(buf[pos])) { + pos++; + } + } + + private void skipPastString(byte b1, byte b2) { + while (pos + 1 < limit) { + if (buf[pos] == b1 && buf[pos + 1] == b2) { + pos += 2; + return; + } + pos++; + } + pos = limit; + } + + private boolean nextStartElement() { + if (selfClosing) { + selfClosing = false; + endNameStart = nameStart; + endNameLen = nameLen; + return false; + } + while (pos < limit) { + byte b = buf[pos]; + if (b <= ' ') { + pos++; + continue; + } + if (b != '<') { + while (pos < limit && buf[pos] != '<') { + pos++; + } + continue; + } + pos++; + if (pos >= limit) { + return false; + } + b = buf[pos]; + if (b == '/') { + pos++; + parseEndElement(); + return false; + } else if (b == '!' || b == '?') { + skipMarkup(b); + continue; + } else { + parseStartElement(); + return true; + } + } + return false; + } + + private void skipMarkup(byte type) { + if (type == '!') { + pos++; + if (pos + 1 < limit && buf[pos] == '-' && buf[pos + 1] == '-') { + pos += 2; + skipComment(); + } else if (pos + 6 < limit && buf[pos] == '[' + && buf[pos + 1] == 'C' + && buf[pos + 2] == 'D' + && buf[pos + 3] == 'A' + && buf[pos + 4] == 'T' + && buf[pos + 5] == 'A' + && buf[pos + 6] == '[') { + pos += 7; + while (pos + 2 < limit) { + if (buf[pos] == ']' && buf[pos + 1] == ']' && buf[pos + 2] == '>') { + pos += 3; + return; + } + pos++; + } + } else { + throw new SerializationException( + "Unsupported markup declaration at byte offset " + (pos - 1)); + } + } else { + pos++; + skipPastString((byte) '?', (byte) '>'); + } + } + + private static boolean isWhitespace(byte b) { + return b == ' ' || b == '\n' || b == '\r' || b == '\t'; + } + + private static boolean isNameChar(byte b) { + return NAME_CHAR[b & 0xFF]; + } + + /** + * Scans text content from current pos until the end-tag '= limit || buf[pos + 1] != '!') { + int len = pos - start; + if (len == 0) { + return ""; + } + boolean clean = true; + for (int i = start; i < pos; i++) { + int b = buf[i] & 0xFF; + if (b < 0x20 || b > 0x7E || b == '&' || b == ']') { + clean = false; + break; + } + } + if (clean) { + @SuppressWarnings("deprecation") + String s = new String(buf, 0, start, len); + return s; + } + return readTextContentSlow(start, len); + } + + StringBuilder sb = new StringBuilder(); + appendUnescaped(sb, start, pos - start); + + while (pos < limit && buf[pos] == '<' && pos + 1 < limit && buf[pos + 1] == '!') { + if (pos + 8 < limit && buf[pos + 2] == '[' + && buf[pos + 3] == 'C' + && buf[pos + 4] == 'D' + && buf[pos + 5] == 'A' + && buf[pos + 6] == 'T' + && buf[pos + 7] == 'A' + && buf[pos + 8] == '[') { + pos += 9; // skip ') { + sb.append(new String(buf, cdataStart, pos - cdataStart, StandardCharsets.UTF_8)); + pos += 3; // skip ]]> + break; + } + pos++; + } + } else { + break; + } + start = pos; + while (pos < limit && buf[pos] != '<') { + pos++; + } + if (pos > start) { + appendUnescaped(sb, start, pos - start); + } + } + + return sb.toString(); + } + + private String readTextContentSlow(int start, int len) { + int end = start + len; + boolean needsUnescape = false; + boolean needsCrNormalization = false; + for (int i = start; i < end; i++) { + byte b = buf[i]; + if (b == '&') { + needsUnescape = true; + break; + } + if (b == '\r') { + needsCrNormalization = true; + } else if ((b & 0xFF) < 0x20 && b != '\t' && b != '\n') { + throw new SerializationException( + "Invalid XML character U+" + String.format("%04X", b & 0xFF) + + " at byte offset " + (i - start)); + } else if ((b & 0xC0) == 0x80) { + if (i == start || (buf[i - 1] & 0x80) == 0) { + throw new SerializationException("Invalid UTF-8 encoding at byte offset " + (i - start)); + } + } + if (b == ']' && i + 2 < end && buf[i + 1] == ']' && buf[i + 2] == '>') { + throw new SerializationException("The sequence ']]>' is not allowed in text content"); + } + } + if (needsUnescape) { + return unescapeXml(start, len); + } + if (needsCrNormalization) { + return normalizeCr(start, len); + } + return new String(buf, start, len, StandardCharsets.UTF_8); + } + + /** + * Fused read: scans text content, constructs String, and verifies/consumes the end tag + * in a single pass. Avoids separate readTextContent + consumeEndElement method calls. + */ + private String readStringAndConsumeEndTag() { + if (selfClosing) { + selfClosing = false; + endNameStart = nameStart; + endNameLen = nameLen; + return ""; + } + int start = pos; + while (pos < limit && buf[pos] != '<') { + pos++; + } + int textEnd = pos; + int textLen = textEnd - start; + + int expectedStart = nameStart; + int expectedLen = nameLen; + if (pos + 1 < limit && buf[pos] == '<' && buf[pos + 1] == '/') { + int tagNamePos = pos + 2; + if (tagNamePos + expectedLen < limit + && Arrays.equals(buf, + tagNamePos, + tagNamePos + expectedLen, + buf, + expectedStart, + expectedStart + expectedLen)) { + int p = tagNamePos + expectedLen; + while (p < limit && buf[p] != '>') { + p++; + } + if (p < limit) { + pos = p + 1; + if (textLen == 0) { + return ""; + } + boolean clean = true; + for (int i = start; i < textEnd; i++) { + int b = buf[i] & 0xFF; + if (b < 0x20 || b > 0x7E || b == '&' || b == ']') { + clean = false; + break; + } + } + if (clean) { + @SuppressWarnings("deprecation") + String s = new String(buf, 0, start, textLen); + return s; + } + return readTextContentSlow(start, textLen); + } + } + } + pos = start; + String result = readTextContent(); + consumeEndElement(); + return result; + } + + private void appendUnescaped(StringBuilder sb, int start, int len) { + int end = start + len; + int i = start; + while (i < end) { + byte b = buf[i]; + if (b == '&') { + sb.append(unescapeXml(start, len)); + return; + } + i++; + } + sb.append(new String(buf, start, len, StandardCharsets.UTF_8)); + } + + private String unescapeXml(int start, int len) { + int end = start + len; + StringBuilder sb = new StringBuilder(len); + int i = start; + while (i < end) { + byte b = buf[i]; + if (b == '&') { + i++; + int entityStart = i; + while (i < end && buf[i] != ';') { + i++; + } + if (i >= end) { + throw new SerializationException("Unterminated entity reference"); + } + int entityLen = i - entityStart; + i++; // skip ';' + + if (entityLen == 2 && buf[entityStart] == 'l' && buf[entityStart + 1] == 't') { + sb.append('<'); + } else if (entityLen == 2 && buf[entityStart] == 'g' && buf[entityStart + 1] == 't') { + sb.append('>'); + } else if (entityLen == 3 && buf[entityStart] == 'a' + && buf[entityStart + 1] == 'm' + && buf[entityStart + 2] == 'p') { + sb.append('&'); + } else if (entityLen == 4 && buf[entityStart] == 'a' + && buf[entityStart + 1] == 'p' + && buf[entityStart + 2] == 'o' + && buf[entityStart + 3] == 's') { + sb.append('\''); + } else if (entityLen == 4 && buf[entityStart] == 'q' + && buf[entityStart + 1] == 'u' + && buf[entityStart + 2] == 'o' + && buf[entityStart + 3] == 't') { + sb.append('"'); + } else if (buf[entityStart] == '#') { + int codePoint; + if (entityLen > 1 && buf[entityStart + 1] == 'x') { + codePoint = 0; + for (int j = entityStart + 2; j < entityStart + entityLen; j++) { + codePoint = codePoint * 16 + hexDigit(buf[j]); + } + } else { + codePoint = 0; + for (int j = entityStart + 1; j < entityStart + entityLen; j++) { + codePoint = codePoint * 10 + (buf[j] - '0'); + } + } + if (!isValidXmlChar(codePoint)) { + throw new SerializationException( + "Invalid XML character reference &#" + + (entityLen > 1 && buf[entityStart + 1] == 'x' ? "x" : "") + + Integer.toString(codePoint, + entityLen > 1 && buf[entityStart + 1] == 'x' ? 16 : 10) + + ";"); + } + sb.appendCodePoint(codePoint); + } else { + throw new SerializationException( + "Undeclared entity reference '&" + + new String(buf, entityStart, entityLen, StandardCharsets.UTF_8) + ";'"); + } + } else { + if (b == '\r') { + sb.append('\n'); + i++; + if (i < end && buf[i] == '\n') { + i++; // \r\n -> \n + } + } else if ((b & 0x80) == 0) { + sb.append((char) b); + i++; + } else { + int remaining = end - i; + int seqLen; + if ((b & 0xE0) == 0xC0) { + seqLen = 2; + } else if ((b & 0xF0) == 0xE0) { + seqLen = 3; + } else { + seqLen = 4; + } + seqLen = Math.min(seqLen, remaining); + sb.append(new String(buf, i, seqLen, StandardCharsets.UTF_8)); + i += seqLen; + } + } + } + return sb.toString(); + } + + private String normalizeCr(int start, int len) { + int end = start + len; + StringBuilder sb = new StringBuilder(len); + int i = start; + while (i < end) { + byte b = buf[i]; + if (b == '\r') { + sb.append('\n'); + i++; + if (i < end && buf[i] == '\n') { + i++; // \r\n -> \n (skip the \n) + } + } else if ((b & 0x80) == 0) { + sb.append((char) b); + i++; + } else { + int seqLen; + if ((b & 0xE0) == 0xC0) { + seqLen = 2; + } else if ((b & 0xF0) == 0xE0) { + seqLen = 3; + } else { + seqLen = 4; + } + seqLen = Math.min(seqLen, end - i); + sb.append(new String(buf, i, seqLen, StandardCharsets.UTF_8)); + i += seqLen; + } + } + return sb.toString(); + } + + private static boolean isValidXmlChar(int cp) { + return cp == 0x9 || cp == 0xA + || cp == 0xD + || (cp >= 0x20 && cp <= 0xD7FF) + || (cp >= 0xE000 && cp <= 0xFFFD) + || (cp >= 0x10000 && cp <= 0x10FFFF); + } + + private static int hexDigit(byte b) { + if (b >= '0' && b <= '9') { + return b - '0'; + } + if (b >= 'a' && b <= 'f') { + return b - 'a' + 10; + } + if (b >= 'A' && b <= 'F') { + return b - 'A' + 10; + } + return 0; + } + + private boolean nameEquals(byte[] expected) { + if (nameLen != expected.length) { + return false; + } + return Arrays.equals(buf, nameStart, nameStart + nameLen, expected, 0, expected.length); + } + + private String elementName() { + return new String(buf, nameStart, nameLen, StandardCharsets.UTF_8); + } + + private String getAttributeValue(String localName) { + byte[] nameBytes = localName.getBytes(StandardCharsets.UTF_8); + return getAttributeValueByBytes(nameBytes); + } + + private String getAttributeValueByBytes(byte[] nameBytes) { + for (int i = 0; i < attrCount; i++) { + if (attrNameLens[i] == nameBytes.length + && Arrays.equals(buf, + attrNameStarts[i], + attrNameStarts[i] + attrNameLens[i], + nameBytes, + 0, + nameBytes.length)) { + int vStart = attrValueStarts[i]; + int vLen = attrValueLens[i]; + for (int j = vStart; j < vStart + vLen; j++) { + if (buf[j] == '&') { + return unescapeXml(vStart, vLen); + } + } + return new String(buf, vStart, vLen, StandardCharsets.UTF_8); + } + } + return null; + } + + private void enter(Schema schema) { + if (!isTopLevel) { + return; + } + + if (!wrapperElements.isEmpty()) { + skipWrapperElements(); + return; + } + + if (!nextStartElement()) { + throw new SerializationException("Expected start element but found end of document"); + } + + String expected; + var trait = schema.getTrait(TraitKey.XML_NAME_TRAIT); + if (trait != null) { + expected = trait.getValue(); + } else if (schema.isMember()) { + expected = schema.memberTarget().id().getName(); + } else { + expected = schema.id().getName(); + } + + if (nameEquals(ERROR_RESPONSE_BYTES)) { + if (!nextStartElement()) { + throw new SerializationException("Expected inside "); + } + errorWrapperDepth = 1; + return; + } else if (nameEquals(RESPONSE_BYTES)) { + nextStartElement(); + if (!nextStartElement()) { + throw new SerializationException("Expected inside "); + } + errorWrapperDepth = 2; + return; + } + + if (!nameEqualsString(expected)) { + if (!nameEquals(ERROR_BYTES)) { + throw new SerializationException( + "Expected XML element named '" + expected + "', found '" + elementName() + "'"); + } + } + } + + private boolean nameEqualsString(String expected) { + if (nameLen != expected.length()) { + return false; + } + for (int i = 0; i < nameLen; i++) { + if ((buf[nameStart + i] & 0xFF) != expected.charAt(i)) { + return false; + } + } + return true; + } + + private void skipWrapperElements() { + for (String wrapperName : wrapperElements) { + if (!nextStartElement()) { + return; + } + String name = elementName(); + if (!name.equals(wrapperName)) { + throw new SerializationException( + "Expected wrapper element '" + wrapperName + "', found '" + name + "'"); + } + } + } + + public String parseErrorCodeName() { + if (!nextStartElement()) { + throw new SerializationException( + "Expected element , , or for XML error response"); + } + String name = elementName(); + if (!name.equals("ErrorResponse") && !name.equals("Error") && !name.equals("Response")) { + throw new SerializationException( + "Expected element , , or for XML error response"); + } + + if (name.equals("ErrorResponse")) { + if (!nextStartElement()) { + throw new SerializationException("Expected element inside "); + } + if (!elementName().equals("Error")) { + throw new SerializationException("Expected element inside "); + } + } else if (name.equals("Response")) { + if (!nextStartElement()) { + throw new SerializationException("Expected element inside "); + } + if (!elementName().equals("Errors")) { + throw new SerializationException("Expected element inside "); + } + if (!nextStartElement()) { + throw new SerializationException("Expected element inside "); + } + if (!elementName().equals("Error")) { + throw new SerializationException("Expected element inside "); + } + } + + while (nextStartElement()) { + if (elementName().equals("Code")) { + String code = readTextContent(); + consumeEndElement(); + return code; + } + skipElement(); + } + throw new SerializationException("Expected element inside "); + } + + private void validateNoTrailingContent() { + if (!isTopLevel) { + return; + } + int depth = wrapperElements.size() + errorWrapperDepth; + if (depth > 0) { + while (depth > 0 && pos < limit) { + int event = next(); + if (event == END_ELEMENT) { + depth--; + } else if (event == START_ELEMENT) { + skipElement(); + } else if (event == EOF) { + break; + } + } + } + skipWhitespace(); + if (pos < limit) { + throw new SerializationException("Content is not allowed after root element closes"); + } + } + + @Override + public boolean readBoolean(Schema schema) { + enter(schema); + readTextSpan(); + boolean result = parseBooleanFromSpan(); + consumeEndElement(); + validateNoTrailingContent(); + return result; + } + + @Override + public byte readByte(Schema schema) { + enter(schema); + readTextSpan(); + int value = parseIntFromSpan(); + if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) { + throw new SerializationException("Value out of range for byte: " + value); + } + consumeEndElement(); + validateNoTrailingContent(); + return (byte) value; + } + + @Override + public short readShort(Schema schema) { + enter(schema); + readTextSpan(); + int value = parseIntFromSpan(); + if (value < Short.MIN_VALUE || value > Short.MAX_VALUE) { + throw new SerializationException("Value out of range for short: " + value); + } + consumeEndElement(); + validateNoTrailingContent(); + return (short) value; + } + + @Override + public int readInteger(Schema schema) { + enter(schema); + readTextSpan(); + int result = parseIntFromSpan(); + consumeEndElement(); + validateNoTrailingContent(); + return result; + } + + @Override + public long readLong(Schema schema) { + enter(schema); + readTextSpan(); + long result = parseLongFromSpan(); + consumeEndElement(); + validateNoTrailingContent(); + return result; + } + + @Override + public float readFloat(Schema schema) { + enter(schema); + readTextSpan(); + float result = parseFloatFromSpan(); + consumeEndElement(); + validateNoTrailingContent(); + return result; + } + + @Override + public double readDouble(Schema schema) { + enter(schema); + readTextSpan(); + double result = parseDoubleFromSpan(); + consumeEndElement(); + validateNoTrailingContent(); + return result; + } + + @Override + public BigInteger readBigInteger(Schema schema) { + enter(schema); + readTextSpan(); + BigInteger result = parseBigIntegerFromSpan(); + consumeEndElement(); + validateNoTrailingContent(); + return result; + } + + @Override + public BigDecimal readBigDecimal(Schema schema) { + enter(schema); + readTextSpan(); + BigDecimal result = parseBigDecimalFromSpan(); + consumeEndElement(); + validateNoTrailingContent(); + return result; + } + + @Override + public String readString(Schema schema) { + enter(schema); + String result = readTextContent(); + consumeEndElement(); + validateNoTrailingContent(); + return result; + } + + @Override + public ByteBuffer readBlob(Schema schema) { + enter(schema); + readTextSpan(); + ByteBuffer result = parseBlobFromSpan(); + consumeEndElement(); + validateNoTrailingContent(); + return result; + } + + @Override + public Instant readTimestamp(Schema schema) { + enter(schema); + readTextSpan(); + Instant result = parseTimestampFromSpan(schema); + consumeEndElement(); + validateNoTrailingContent(); + return result; + } + + @Override + public void readStruct(Schema schema, T state, StructMemberConsumer consumer) { + enter(schema); + readStructContent(schema, state, consumer); + validateNoTrailingContent(); + } + + @Override + public void readList(Schema schema, T state, ListMemberConsumer consumer) { + enter(schema); + readListContent(schema, state, consumer); + validateNoTrailingContent(); + } + + @Override + public void readStringMap(Schema schema, T state, MapMemberConsumer consumer) { + enter(schema); + readMapContent(schema, state, consumer); + validateNoTrailingContent(); + } + + @Override + public boolean isNull() { + int saved = pos; + while (saved < limit && buf[saved] != '<' && isWhitespace(buf[saved])) { + saved++; + } + return saved < limit && buf[saved] == '<'; + } + + @Override + public T readNull() { + return null; + } + + @Override + public Document readDocument() { + return null; + } + + private void readStructContent(Schema schema, T state, StructMemberConsumer consumer) { + int containerNameStart = nameStart; + int containerNameLen = nameLen; + var decoder = xmlInfo.getStructInfo(schema); + + if (!decoder.attributes.isEmpty() && attrCount > 0) { + var structExt = decoder.schema.getExtension(XmlSchemaExtensions.KEY); + byte[][] nameTable = (structExt instanceof XmlSchemaExtensions.StructExtension se) + ? se.nameTable() + : null; + + for (var entry : decoder.attributes.entrySet()) { + Schema attributeSchema = entry.getValue(); + int idx = attributeSchema.memberIndex(); + String attrValue; + if (nameTable != null && idx >= 0 && idx < nameTable.length && nameTable[idx] != null) { + attrValue = getAttributeValueByBytes(nameTable[idx]); + } else { + String attributeName = entry.getKey(); + int colonIdx = attributeName.indexOf(':'); + String lookupName = colonIdx >= 0 ? attributeName.substring(colonIdx + 1) : attributeName; + attrValue = getAttributeValue(lookupName); + } + if (attrValue != null) { + consumer.accept(state, attributeSchema, new AttributeDeserializer(attrValue)); + } + } + } + + XmlMemberLookup lookup = decoder.memberLookup; + lookupHint[0] = 0; + + if (decoder.hasFlattened) { + readStructContentWithFlattened(state, consumer, lookup, containerNameStart, containerNameLen); + } else if (lookup != null) { + while (nextStartElement()) { + Schema memberSchema = lookup.findMember(buf, nameStart, nameLen, lookupHint); + if (memberSchema != null) { + memberDeserializer.consumed = false; + consumer.accept(state, memberSchema, memberDeserializer); + if (!memberDeserializer.consumed) { + skipElement(); + } + } else { + consumer.unknownMember(state, elementName()); + skipElement(); + } + } + validateContainerEndTag(containerNameStart, containerNameLen); + } else { + while (nextStartElement()) { + skipElement(); + } + validateContainerEndTag(containerNameStart, containerNameLen); + } + } + + private void readStructContentWithFlattened( + T state, + StructMemberConsumer consumer, + XmlMemberLookup lookup, + int containerNameStart, + int containerNameLen + ) { + Map> flattenedSpans = new LinkedHashMap<>(); + while (nextStartElement()) { + Schema memberSchema = lookup.findMember(buf, nameStart, nameLen, lookupHint); + if (memberSchema != null) { + if (memberSchema.hasTrait(TraitKey.XML_FLATTENED_TRAIT)) { + int spanStart = nameStart - 1; + skipElement(); + int spanEnd = pos; + flattenedSpans.computeIfAbsent(memberSchema, k -> new ArrayList<>()) + .add(new int[] {spanStart, spanEnd}); + } else { + memberDeserializer.consumed = false; + consumer.accept(state, memberSchema, memberDeserializer); + if (!memberDeserializer.consumed) { + skipElement(); + } + } + } else { + consumer.unknownMember(state, elementName()); + skipElement(); + } + } + validateContainerEndTag(containerNameStart, containerNameLen); + for (var entry : flattenedSpans.entrySet()) { + consumer.accept(state, entry.getKey(), new FlattenedReplayDeserializer(entry.getValue())); + } + } + + private void readListContent(Schema schema, T state, ListMemberConsumer consumer) { + int containerNameStart = nameStart; + int containerNameLen = nameLen; + var ext = schema.getExtension(XmlSchemaExtensions.KEY); + byte[] expectedMemberName; + if (ext instanceof XmlSchemaExtensions.ListExtension le) { + expectedMemberName = le.memberNameBytes(); + } else { + var info = xmlInfo.getListInfo(schema); + expectedMemberName = info.memberName.getBytes(StandardCharsets.UTF_8); + } + + MemberDeserializer itemDeser = new MemberDeserializer(); + while (nextStartElement()) { + if (!nameEquals(expectedMemberName)) { + throw new SerializationException( + "Expected list item '" + new String(expectedMemberName, StandardCharsets.UTF_8) + + "' but found '" + elementName() + "'"); + } + itemDeser.consumed = false; + consumer.accept(state, itemDeser); + if (!itemDeser.consumed) { + skipElement(); + } + } + + validateContainerEndTag(containerNameStart, containerNameLen); + } + + private void readMapContent(Schema schema, T state, MapMemberConsumer consumer) { + int containerNameStart = nameStart; + int containerNameLen = nameLen; + byte[] entryNameBytes; + byte[] keyNameBytes; + byte[] valueNameBytes; + boolean flattened; + var ext = schema.getExtension(XmlSchemaExtensions.KEY); + if (ext instanceof XmlSchemaExtensions.MapExtension me) { + entryNameBytes = me.entryNameBytes(); + keyNameBytes = me.keyNameBytes(); + valueNameBytes = me.valueNameBytes(); + var decoder = xmlInfo.getMapInfo(schema); + flattened = decoder.flattened; + } else { + var decoder = xmlInfo.getMapInfo(schema); + entryNameBytes = decoder.entryName.getBytes(StandardCharsets.UTF_8); + keyNameBytes = decoder.keyName.getBytes(StandardCharsets.UTF_8); + valueNameBytes = decoder.valueName.getBytes(StandardCharsets.UTF_8); + flattened = decoder.flattened; + } + + while (nextStartElement()) { + if (!nameEquals(entryNameBytes)) { + if (!flattened) { + break; + } else { + throw new SerializationException("Unexpected element in map: " + elementName()); + } + } + + if (!nextStartElement()) { + throw new SerializationException("Expected map key, but map unexpectedly closed"); + } + if (!nameEquals(keyNameBytes)) { + throw new SerializationException("Expected map key but found '" + elementName() + "'"); + } + String key = readTextContent(); + consumeEndElement(); // consume + + if (!nextStartElement()) { + throw new SerializationException("Expected map value, but map unexpectedly closed"); + } + if (!nameEquals(valueNameBytes)) { + throw new SerializationException("Expected map value but found '" + elementName() + "'"); + } + MemberDeserializer valueDeser = new MemberDeserializer(); + consumer.accept(state, key, valueDeser); + if (!valueDeser.consumed) { + skipElement(); + } + consumeEndElement(entryNameBytes); + } + + validateContainerEndTag(containerNameStart, containerNameLen); + } + + private boolean parseBooleanFromSpan() { + if (spanHasCdataFallback()) { + return "true".equals(textFallback); + } + int len = textSpanEnd - textSpanStart; + if (len == 4 && buf[textSpanStart] == 't' + && buf[textSpanStart + 1] == 'r' + && buf[textSpanStart + 2] == 'u' + && buf[textSpanStart + 3] == 'e') { + return true; + } + if (len == 5 && buf[textSpanStart] == 'f' + && buf[textSpanStart + 1] == 'a' + && buf[textSpanStart + 2] == 'l' + && buf[textSpanStart + 3] == 's' + && buf[textSpanStart + 4] == 'e') { + return false; + } + throw new SerializationException( + "Expected boolean 'true' or 'false', found '" + + new String(buf, textSpanStart, len, StandardCharsets.UTF_8) + "'"); + } + + private static final byte[] INFINITY_BYTES = "Infinity".getBytes(StandardCharsets.UTF_8); + + private float parseFloatFromSpan() { + if (spanHasCdataFallback()) { + return Float.parseFloat(textFallback); + } + int len = textSpanEnd - textSpanStart; + if (len == 0) { + throw new SerializationException("Empty float value"); + } + if (len == 3 && buf[textSpanStart] == 'N' && buf[textSpanStart + 1] == 'a' && buf[textSpanStart + 2] == 'N') { + return Float.NaN; + } else if (len == 8 && Arrays.equals(buf, textSpanStart, textSpanEnd, INFINITY_BYTES, 0, 8)) { + return Float.POSITIVE_INFINITY; + } else if (len == 9 && buf[textSpanStart] == '-' + && Arrays.equals(buf, textSpanStart + 1, textSpanEnd, INFINITY_BYTES, 0, 8)) { + return Float.NEGATIVE_INFINITY; + } + return NumberCodec.parseFloat(buf, textSpanStart, len); + } + + private double parseDoubleFromSpan() { + if (spanHasCdataFallback()) { + return Double.parseDouble(textFallback); + } + int len = textSpanEnd - textSpanStart; + if (len == 0) { + throw new SerializationException("Empty double value"); + } + if (len == 3 && buf[textSpanStart] == 'N' && buf[textSpanStart + 1] == 'a' && buf[textSpanStart + 2] == 'N') { + return Double.NaN; + } else if (len == 8 && Arrays.equals(buf, textSpanStart, textSpanEnd, INFINITY_BYTES, 0, 8)) { + return Double.POSITIVE_INFINITY; + } else if (len == 9 && buf[textSpanStart] == '-' + && Arrays.equals(buf, textSpanStart + 1, textSpanEnd, INFINITY_BYTES, 0, 8)) { + return Double.NEGATIVE_INFINITY; + } + return NumberCodec.parseDouble(buf, textSpanStart, len); + } + + private Instant parseTimestampFromSpan(Schema schema) { + if (spanHasCdataFallback()) { + try { + return TimestampFormatter.of(schema, TimestampFormatTrait.Format.DATE_TIME) + .readFromString(textFallback, false); + } catch (TimestampFormatter.TimestampSyntaxError | DateTimeParseException e) { + throw new SerializationException("Failed to read timestamp: " + e.getMessage(), e); + } + } + try { + Instant result = TimestampCodec.parseIso8601(buf, textSpanStart, textSpanEnd); + if (result != null) { + return result; + } + String value = new String(buf, textSpanStart, textSpanEnd - textSpanStart, StandardCharsets.UTF_8); + return TimestampFormatter.of(schema, TimestampFormatTrait.Format.DATE_TIME).readFromString(value, false); + } catch (TimestampFormatter.TimestampSyntaxError | DateTimeParseException e) { + throw new SerializationException("Failed to read timestamp: " + e.getMessage(), e); + } + } + + private int parseIntFromSpan() { + if (spanHasCdataFallback()) { + return Integer.parseInt(textFallback.trim()); + } + return NumberCodec.parseInt(buf, textSpanStart, textSpanEnd - textSpanStart); + } + + private long parseLongFromSpan() { + if (spanHasCdataFallback()) { + return Long.parseLong(textFallback.trim()); + } + return NumberCodec.parseLong(buf, textSpanStart, textSpanEnd - textSpanStart); + } + + private BigInteger parseBigIntegerFromSpan() { + if (spanHasCdataFallback()) { + return new BigInteger(textFallback.trim()); + } + return new BigInteger(new String(buf, textSpanStart, textSpanEnd - textSpanStart, StandardCharsets.US_ASCII)); + } + + private BigDecimal parseBigDecimalFromSpan() { + if (spanHasCdataFallback()) { + return new BigDecimal(textFallback.trim()); + } + return new BigDecimal(new String(buf, textSpanStart, textSpanEnd - textSpanStart, StandardCharsets.US_ASCII)); + } + + private ByteBuffer parseBlobFromSpan() { + if (spanHasCdataFallback()) { + return ByteBuffer.wrap(BASE64_DECODER.decode(textFallback.trim())); + } + return BASE64_DECODER.decode(ByteBuffer.wrap(buf, textSpanStart, textSpanEnd - textSpanStart)); + } + + private final class MemberDeserializer implements ShapeDeserializer { + boolean consumed; + + @Override + public boolean readBoolean(Schema schema) { + consumed = true; + readTextSpan(); + boolean result = parseBooleanFromSpan(); + consumeEndElement(); + return result; + } + + @Override + public byte readByte(Schema schema) { + consumed = true; + readTextSpan(); + int value = parseIntFromSpan(); + if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) { + throw new SerializationException("Value out of range for byte: " + value); + } + consumeEndElement(); + return (byte) value; + } + + @Override + public short readShort(Schema schema) { + consumed = true; + readTextSpan(); + int value = parseIntFromSpan(); + if (value < Short.MIN_VALUE || value > Short.MAX_VALUE) { + throw new SerializationException("Value out of range for short: " + value); + } + consumeEndElement(); + return (short) value; + } + + @Override + public int readInteger(Schema schema) { + consumed = true; + readTextSpan(); + int result = parseIntFromSpan(); + consumeEndElement(); + return result; + } + + @Override + public long readLong(Schema schema) { + consumed = true; + readTextSpan(); + long result = parseLongFromSpan(); + consumeEndElement(); + return result; + } + + @Override + public float readFloat(Schema schema) { + consumed = true; + readTextSpan(); + float result = parseFloatFromSpan(); + consumeEndElement(); + return result; + } + + @Override + public double readDouble(Schema schema) { + consumed = true; + readTextSpan(); + double result = parseDoubleFromSpan(); + consumeEndElement(); + return result; + } + + @Override + public BigInteger readBigInteger(Schema schema) { + consumed = true; + readTextSpan(); + BigInteger result = parseBigIntegerFromSpan(); + consumeEndElement(); + return result; + } + + @Override + public BigDecimal readBigDecimal(Schema schema) { + consumed = true; + readTextSpan(); + BigDecimal result = parseBigDecimalFromSpan(); + consumeEndElement(); + return result; + } + + @Override + public String readString(Schema schema) { + consumed = true; + return readStringAndConsumeEndTag(); + } + + @Override + public ByteBuffer readBlob(Schema schema) { + consumed = true; + readTextSpan(); + ByteBuffer result = parseBlobFromSpan(); + consumeEndElement(); + return result; + } + + @Override + public Instant readTimestamp(Schema schema) { + consumed = true; + readTextSpan(); + Instant result = parseTimestampFromSpan(schema); + consumeEndElement(); + return result; + } + + @Override + public void readStruct(Schema schema, T state, StructMemberConsumer consumer) { + consumed = true; + readStructContent(schema, state, consumer); + } + + @Override + public void readList(Schema schema, T state, ListMemberConsumer consumer) { + consumed = true; + readListContent(schema, state, consumer); + } + + @Override + public void readStringMap(Schema schema, T state, MapMemberConsumer consumer) { + consumed = true; + readMapContent(schema, state, consumer); + } + + @Override + public boolean isNull() { + if (selfClosing) { + return true; + } + int saved = pos; + while (saved < limit && isWhitespace(buf[saved])) { + saved++; + } + return saved + 1 < limit && buf[saved] == '<' && buf[saved + 1] == '/'; + } + + @Override + public T readNull() { + consumed = true; + skipElement(); + return null; + } + + @Override + public Document readDocument() { + consumed = true; + skipElement(); + return null; + } + } + + private static final class AttributeDeserializer extends SpecificShapeDeserializer { + + private final String value; + + AttributeDeserializer(String value) { + this.value = value; + } + + @Override + public boolean readBoolean(Schema schema) { + return switch (value) { + case "true" -> true; + case "false" -> false; + default -> throw new SerializationException( + "Expected boolean 'true' or 'false', found '" + value + "'"); + }; + } + + @Override + public String readString(Schema schema) { + return value; + } + + @Override + public Instant readTimestamp(Schema schema) { + try { + return TimestampFormatter.of(schema, TimestampFormatTrait.Format.DATE_TIME) + .readFromString(value, false); + } catch (TimestampFormatter.TimestampSyntaxError e) { + throw new SerializationException("Failed to read timestamp: " + e.getMessage(), e); + } + } + + @Override + public byte readByte(Schema schema) { + return Byte.parseByte(value); + } + + @Override + public short readShort(Schema schema) { + return Short.parseShort(value); + } + + @Override + public int readInteger(Schema schema) { + return Integer.parseInt(value); + } + + @Override + public long readLong(Schema schema) { + return Long.parseLong(value); + } + + @Override + public float readFloat(Schema schema) { + return Float.parseFloat(value); + } + + @Override + public double readDouble(Schema schema) { + return Double.parseDouble(value); + } + + @Override + public BigDecimal readBigDecimal(Schema schema) { + return new BigDecimal(value); + } + + @Override + public BigInteger readBigInteger(Schema schema) { + return new BigInteger(value); + } + } + + private final class FlattenedReplayDeserializer implements ShapeDeserializer { + + private final List spans; + + FlattenedReplayDeserializer(List spans) { + this.spans = spans; + } + + @Override + public boolean readBoolean(Schema schema) { + throw new SerializationException("Flattened replay does not support scalar reads"); + } + + @Override + public ByteBuffer readBlob(Schema schema) { + throw new SerializationException("Flattened replay does not support scalar reads"); + } + + @Override + public byte readByte(Schema schema) { + throw new SerializationException("Flattened replay does not support scalar reads"); + } + + @Override + public short readShort(Schema schema) { + throw new SerializationException("Flattened replay does not support scalar reads"); + } + + @Override + public int readInteger(Schema schema) { + throw new SerializationException("Flattened replay does not support scalar reads"); + } + + @Override + public long readLong(Schema schema) { + throw new SerializationException("Flattened replay does not support scalar reads"); + } + + @Override + public float readFloat(Schema schema) { + throw new SerializationException("Flattened replay does not support scalar reads"); + } + + @Override + public double readDouble(Schema schema) { + throw new SerializationException("Flattened replay does not support scalar reads"); + } + + @Override + public BigInteger readBigInteger(Schema schema) { + throw new SerializationException("Flattened replay does not support scalar reads"); + } + + @Override + public BigDecimal readBigDecimal(Schema schema) { + throw new SerializationException("Flattened replay does not support scalar reads"); + } + + @Override + public String readString(Schema schema) { + throw new SerializationException("Flattened replay does not support scalar reads"); + } + + @Override + public Instant readTimestamp(Schema schema) { + throw new SerializationException("Flattened replay does not support scalar reads"); + } + + @Override + public void readStruct(Schema schema, T state, StructMemberConsumer consumer) { + throw new SerializationException("Flattened replay does not support struct reads"); + } + + @Override + public void readList(Schema schema, T state, ListMemberConsumer consumer) { + for (int[] span : spans) { + int spanStart = span[0]; + int spanLen = span[1] - span[0]; + SmithyXmlDeserializer sub = new SmithyXmlDeserializer( + buf, + spanStart, + spanLen, + xmlInfo, + false, + List.of()); + if (sub.nextStartElement()) { + MemberDeserializer memberDeser = sub.new MemberDeserializer(); + consumer.accept(state, memberDeser); + } + } + } + + @Override + public void readStringMap(Schema schema, T state, MapMemberConsumer consumer) { + var decoder = xmlInfo.getMapInfo(schema); + byte[] keyNameBytes = decoder.keyName.getBytes(StandardCharsets.UTF_8); + byte[] valueNameBytes = decoder.valueName.getBytes(StandardCharsets.UTF_8); + + for (int[] span : spans) { + int spanStart = span[0]; + int spanLen = span[1] - span[0]; + SmithyXmlDeserializer sub = new SmithyXmlDeserializer( + buf, + spanStart, + spanLen, + xmlInfo, + false, + List.of()); + if (sub.nextStartElement()) { + if (!sub.nextStartElement()) { + throw new SerializationException("Expected map key in flattened entry"); + } + if (!sub.nameEquals(keyNameBytes)) { + throw new SerializationException("Expected map key '" + decoder.keyName + "'"); + } + String key = sub.readTextContent(); + sub.skipElement(); + + if (!sub.nextStartElement()) { + throw new SerializationException("Expected map value in flattened entry"); + } + if (!sub.nameEquals(valueNameBytes)) { + throw new SerializationException("Expected map value '" + decoder.valueName + "'"); + } + MemberDeserializer valueDeser = sub.new MemberDeserializer(); + consumer.accept(state, key, valueDeser); + } + } + } + + @Override + public boolean isNull() { + return spans.isEmpty(); + } + + @Override + public Document readDocument() { + return null; + } + } +} diff --git a/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/SmithyXmlSerializer.java b/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/SmithyXmlSerializer.java new file mode 100644 index 0000000000..d2e4768a82 --- /dev/null +++ b/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/SmithyXmlSerializer.java @@ -0,0 +1,729 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.xml; + +import java.io.IOException; +import java.io.OutputStream; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.util.Arrays; +import java.util.function.BiConsumer; +import software.amazon.smithy.java.codecs.commons.NumberCodec; +import software.amazon.smithy.java.codecs.commons.StripedPool; +import software.amazon.smithy.java.core.schema.Schema; +import software.amazon.smithy.java.core.schema.SerializableStruct; +import software.amazon.smithy.java.core.schema.TraitKey; +import software.amazon.smithy.java.core.serde.InterceptingSerializer; +import software.amazon.smithy.java.core.serde.MapSerializer; +import software.amazon.smithy.java.core.serde.SerializationException; +import software.amazon.smithy.java.core.serde.ShapeSerializer; +import software.amazon.smithy.java.core.serde.SpecificShapeSerializer; +import software.amazon.smithy.java.core.serde.TimestampFormatter; +import software.amazon.smithy.java.core.serde.document.Document; +import software.amazon.smithy.java.io.ByteBufferUtils; +import software.amazon.smithy.model.traits.TimestampFormatTrait; +import software.amazon.smithy.model.traits.XmlNamespaceTrait; + +final class SmithyXmlSerializer extends InterceptingSerializer { + + private static final TimestampFormatTrait.Format DEFAULT_FORMAT = TimestampFormatTrait.Format.DATE_TIME; + private static final int DEFAULT_BUF_SIZE = 4096; + private static final int MAX_CACHEABLE_BUF = DEFAULT_BUF_SIZE * 4; + + record AcquireContext(XmlNamespaceTrait defaultNamespace, XmlInfo xmlInfo, OutputStream sink) {} + + private static final StripedPool POOL = new XmlStripedPool(); + + private byte[] buf; + private int pos; + private OutputStream sink; + private XmlNamespaceTrait defaultNamespace; + private XmlInfo xmlInfo; + + private final StructElementSerializer structElementSerializer = new StructElementSerializer(); + private final StructAttributeSerializer structAttributeSerializer = new StructAttributeSerializer(); + private final ValueSerializer valueSerializer = new ValueSerializer(); + + private byte[][] currentNameTable; + private boolean[] currentIsAttributeTable; + private boolean[] currentIsFlattenedTable; + + // When true, the opening tag is not yet closed with '>'. + // writeStruct will close it after writing attributes; scalar writes close it immediately. + private boolean pendingClose; + + private SmithyXmlSerializer(XmlNamespaceTrait defaultNamespace, XmlInfo xmlInfo) { + this.defaultNamespace = defaultNamespace; + this.xmlInfo = xmlInfo; + this.buf = new byte[DEFAULT_BUF_SIZE]; + this.pos = 0; + } + + static SmithyXmlSerializer acquire(XmlNamespaceTrait defaultNamespace, XmlInfo xmlInfo, OutputStream sink) { + return POOL.acquire(new AcquireContext(defaultNamespace, xmlInfo, sink)); + } + + static void release(SmithyXmlSerializer serializer) { + POOL.release(serializer); + } + + @Override + public void flush() { + try { + if (sink != null && pos > 0) { + sink.write(buf, 0, pos); + pos = 0; + sink.flush(); + } + } catch (IOException e) { + throw new SerializationException(e); + } + } + + @Override + public void close() { + try { + if (sink != null && pos > 0) { + sink.write(buf, 0, pos); + pos = 0; + } + } catch (IOException e) { + throw new SerializationException(e); + } finally { + release(this); + } + } + + private void ensureCapacity(int needed) { + if (pos + needed > buf.length) { + grow(needed); + } + } + + private void grow(int needed) { + buf = Arrays.copyOf(buf, Math.max(buf.length * 2, pos + needed)); + } + + @Override + protected ShapeSerializer before(Schema schema) { + byte[] nameBytes = resolveTopLevelName(schema); + byte[] nsBytes = resolveTopLevelNamespace(schema); + + int nsLen = nsBytes != null ? nsBytes.length : 0; + ensureCapacity(1 + nameBytes.length + nsLen); + buf[pos++] = '<'; + System.arraycopy(nameBytes, 0, buf, pos, nameBytes.length); + pos += nameBytes.length; + if (nsBytes != null) { + System.arraycopy(nsBytes, 0, buf, pos, nsBytes.length); + pos += nsBytes.length; + } + pendingClose = true; + return valueSerializer; + } + + @Override + protected void after(Schema schema) { + byte[] nameBytes = resolveTopLevelName(schema); + writeCloseTag(nameBytes); + } + + private byte[] resolveTopLevelName(Schema schema) { + var trait = schema.getTrait(TraitKey.XML_NAME_TRAIT); + if (trait != null) { + return trait.getValue().getBytes(java.nio.charset.StandardCharsets.UTF_8); + } + var ext = getStructExtension(schema); + if (ext != null) { + return ext.structNameBytes(); + } + if (schema.isMember()) { + return schema.memberTarget().id().getName().getBytes(java.nio.charset.StandardCharsets.UTF_8); + } + return schema.id().getName().getBytes(java.nio.charset.StandardCharsets.UTF_8); + } + + private byte[] resolveTopLevelNamespace(Schema schema) { + var ext = getStructExtension(schema); + if (ext != null && ext.namespaceBytes() != null) { + return ext.namespaceBytes(); + } + var ns = schema.getTrait(TraitKey.XML_NAMESPACE_TRAIT); + if (ns == null) { + ns = defaultNamespace; + } + if (ns != null) { + return buildNamespaceBytes(ns); + } + return null; + } + + private static XmlSchemaExtensions.StructExtension getStructExtension(Schema schema) { + Schema target = schema.isMember() ? schema.memberTarget() : schema; + var ext = target.getExtension(XmlSchemaExtensions.KEY); + if (ext instanceof XmlSchemaExtensions.StructExtension se) { + return se; + } + return null; + } + + private static byte[] buildNamespaceBytes(XmlNamespaceTrait ns) { + String prefix = ns.getPrefix().orElse(null); + String uri = ns.getUri(); + String escapedUri = uri.replace("&", "&") + .replace("\"", """) + .replace("<", "<") + .replace(">", ">"); + if (prefix == null || prefix.isEmpty()) { + return (" xmlns=\"" + escapedUri + "\"").getBytes(java.nio.charset.StandardCharsets.UTF_8); + } + return (" xmlns:" + prefix + "=\"" + escapedUri + "\"").getBytes(java.nio.charset.StandardCharsets.UTF_8); + } + + private void writeOpenTag(byte[] nameBytes) { + ensureCapacity(1 + nameBytes.length + 1); + buf[pos++] = '<'; + System.arraycopy(nameBytes, 0, buf, pos, nameBytes.length); + pos += nameBytes.length; + buf[pos++] = '>'; + } + + private void writeCloseTag(byte[] nameBytes) { + ensureCapacity(3 + nameBytes.length); + buf[pos++] = '<'; + buf[pos++] = '/'; + System.arraycopy(nameBytes, 0, buf, pos, nameBytes.length); + pos += nameBytes.length; + buf[pos++] = '>'; + } + + private byte[] resolveMemberNameBytes(Schema schema) { + byte[][] table = currentNameTable; + int idx = schema.memberIndex(); + if (table != null && idx >= 0 && idx < table.length && table[idx] != null) { + return table[idx]; + } + var ext = schema.getExtension(XmlSchemaExtensions.KEY); + if (ext instanceof XmlSchemaExtensions.MemberExtension me) { + return me.nameBytes(); + } + String name = schema.memberName(); + return name.getBytes(StandardCharsets.UTF_8); + } + + private void closePendingTag() { + if (pendingClose) { + ensureCapacity(1); + buf[pos++] = '>'; + pendingClose = false; + } + } + + private byte[] resolveMemberNamespaceBytes(Schema schema) { + var ext = schema.getExtension(XmlSchemaExtensions.KEY); + if (ext instanceof XmlSchemaExtensions.MemberExtension me) { + return me.namespaceBytes(); + } + var ns = schema.getDirectTrait(TraitKey.XML_NAMESPACE_TRAIT); + if (ns != null) { + return buildNamespaceBytes(ns); + } + return null; + } + + private static String formatTimestamp(Schema schema, Instant value) { + return TimestampFormatter.of( + schema.getTrait(TraitKey.TIMESTAMP_FORMAT_TRAIT), + DEFAULT_FORMAT).writeString(value); + } + + private static class XmlStripedPool extends StripedPool { + @Override + protected SmithyXmlSerializer create(AcquireContext ctx) { + var s = new SmithyXmlSerializer(ctx.defaultNamespace(), ctx.xmlInfo()); + s.sink = ctx.sink(); + return s; + } + + @Override + protected void cleanup(SmithyXmlSerializer s) { + s.sink = null; + } + + @Override + protected boolean canPool(SmithyXmlSerializer s) { + return s.buf != null; + } + + @Override + protected void prepareForPool(SmithyXmlSerializer s) { + if (s.buf.length > MAX_CACHEABLE_BUF) { + s.buf = new byte[DEFAULT_BUF_SIZE]; + } + } + + @Override + protected boolean reset(SmithyXmlSerializer s, AcquireContext ctx) { + s.pos = 0; + s.sink = ctx.sink(); + s.defaultNamespace = ctx.defaultNamespace(); + s.xmlInfo = ctx.xmlInfo(); + s.currentNameTable = null; + s.currentIsAttributeTable = null; + s.currentIsFlattenedTable = null; + return true; + } + } + + private final class StructElementSerializer extends InterceptingSerializer { + @Override + protected ShapeSerializer before(Schema schema) { + int idx = schema.memberIndex(); + if (currentIsAttributeTable != null && idx >= 0 + && idx < currentIsAttributeTable.length + && currentIsAttributeTable[idx]) { + return ShapeSerializer.nullSerializer(); + } + if (currentIsFlattenedTable != null && idx >= 0 + && idx < currentIsFlattenedTable.length + && currentIsFlattenedTable[idx]) { + return valueSerializer; + } + byte[] nameBytes = resolveMemberNameBytes(schema); + ensureCapacity(1 + nameBytes.length); + buf[pos++] = '<'; + System.arraycopy(nameBytes, 0, buf, pos, nameBytes.length); + pos += nameBytes.length; + byte[] nsBytes = resolveMemberNamespaceBytes(schema); + if (nsBytes != null) { + ensureCapacity(nsBytes.length); + System.arraycopy(nsBytes, 0, buf, pos, nsBytes.length); + pos += nsBytes.length; + } + pendingClose = true; + return valueSerializer; + } + + @Override + protected void after(Schema schema) { + int idx = schema.memberIndex(); + if (currentIsAttributeTable != null && idx >= 0 + && idx < currentIsAttributeTable.length + && currentIsAttributeTable[idx]) { + return; + } + if (currentIsFlattenedTable != null && idx >= 0 + && idx < currentIsFlattenedTable.length + && currentIsFlattenedTable[idx]) { + return; + } + byte[] nameBytes = resolveMemberNameBytes(schema); + writeCloseTag(nameBytes); + } + } + + private final class StructAttributeSerializer extends InterceptingSerializer { + @Override + protected ShapeSerializer before(Schema schema) { + int idx = schema.memberIndex(); + if (currentIsAttributeTable != null && idx >= 0 + && idx < currentIsAttributeTable.length + && currentIsAttributeTable[idx]) { + return new InlineAttributeSerializer(schema); + } + return ShapeSerializer.nullSerializer(); + } + } + + private final class InlineAttributeSerializer extends SpecificShapeSerializer { + private final byte[] nameBytes; + + InlineAttributeSerializer(Schema schema) { + this.nameBytes = resolveMemberNameBytes(schema); + } + + private void writeAttrPrefix() { + ensureCapacity(1 + nameBytes.length + 2); + buf[pos++] = ' '; + System.arraycopy(nameBytes, 0, buf, pos, nameBytes.length); + pos += nameBytes.length; + buf[pos++] = '='; + buf[pos++] = '"'; + } + + private void writeAttrSuffix() { + ensureCapacity(1); + buf[pos++] = '"'; + } + + @Override + public void writeString(Schema schema, String value) { + writeAttrPrefix(); + ensureCapacity(XmlWriteUtils.maxEscapedAttributeBytes(value)); + pos = XmlWriteUtils.writeEscapedAttribute(buf, pos, value); + writeAttrSuffix(); + } + + @Override + public void writeBoolean(Schema schema, boolean value) { + writeAttrPrefix(); + ensureCapacity(5); + pos = NumberCodec.writeBoolean(buf, pos, value); + writeAttrSuffix(); + } + + @Override + public void writeInteger(Schema schema, int value) { + writeAttrPrefix(); + ensureCapacity(11); + pos = NumberCodec.writeInt(buf, pos, value); + writeAttrSuffix(); + } + + @Override + public void writeLong(Schema schema, long value) { + writeAttrPrefix(); + ensureCapacity(20); + pos = NumberCodec.writeLong(buf, pos, value); + writeAttrSuffix(); + } + + @Override + public void writeFloat(Schema schema, float value) { + if (Float.isFinite(value)) { + writeAttrPrefix(); + ensureCapacity(24); + pos = NumberCodec.writeFloat(buf, pos, value); + writeAttrSuffix(); + } else { + writeAttrPrefix(); + ensureCapacity(9); + pos = NumberCodec.writeNonFiniteFloat(buf, pos, value); + writeAttrSuffix(); + } + } + + @Override + public void writeDouble(Schema schema, double value) { + if (Double.isFinite(value)) { + writeAttrPrefix(); + ensureCapacity(24); + pos = NumberCodec.writeDouble(buf, pos, value); + writeAttrSuffix(); + } else { + writeAttrPrefix(); + ensureCapacity(9); + pos = NumberCodec.writeNonFiniteDouble(buf, pos, value); + writeAttrSuffix(); + } + } + + @Override + public void writeTimestamp(Schema schema, Instant value) { + writeString(schema, formatTimestamp(schema, value)); + } + + @Override + public void writeNull(Schema schema) {} + } + + private final class ValueSerializer implements ShapeSerializer { + @Override + public void writeStruct(Schema schema, SerializableStruct struct) { + var ext = getStructExtension(schema); + + byte[][] savedNameTable = currentNameTable; + boolean[] savedIsAttrTable = currentIsAttributeTable; + boolean[] savedIsFlatTable = currentIsFlattenedTable; + + if (ext != null) { + currentNameTable = ext.nameTable(); + currentIsAttributeTable = ext.isAttributeTable(); + currentIsFlattenedTable = ext.isFlattenedTable(); + + if (ext.hasAttributes()) { + struct.serializeMembers(structAttributeSerializer); + } + } else { + currentNameTable = null; + currentIsAttributeTable = null; + currentIsFlattenedTable = null; + } + + closePendingTag(); + struct.serializeMembers(structElementSerializer); + + currentNameTable = savedNameTable; + currentIsAttributeTable = savedIsAttrTable; + currentIsFlattenedTable = savedIsFlatTable; + } + + @Override + public void writeList(Schema schema, T listState, int size, BiConsumer consumer) { + closePendingTag(); + var ext = schema.getExtension(XmlSchemaExtensions.KEY); + byte[] memberNameBytes; + if (ext instanceof XmlSchemaExtensions.ListExtension le) { + memberNameBytes = le.memberNameBytes(); + } else { + var info = xmlInfo.getListInfo(schema); + memberNameBytes = info.memberName.getBytes(java.nio.charset.StandardCharsets.UTF_8); + } + Schema listMember = schema.listMember(); + byte[] memberNsBytes = null; + var ns = listMember.getDirectTrait(TraitKey.XML_NAMESPACE_TRAIT); + if (ns != null) { + memberNsBytes = buildNamespaceBytes(ns); + } + consumer.accept(listState, new ListItemSerializer(memberNameBytes, memberNsBytes)); + } + + @Override + public void writeMap(Schema schema, T mapState, int size, BiConsumer consumer) { + closePendingTag(); + var ext = schema.getExtension(XmlSchemaExtensions.KEY); + byte[] entryBytes, keyBytes, valueBytes; + if (ext instanceof XmlSchemaExtensions.MapExtension(byte[] entryNameBytes, byte[] keyNameBytes, byte[] valueNameBytes)) { + entryBytes = entryNameBytes; + keyBytes = keyNameBytes; + valueBytes = valueNameBytes; + } else { + var info = xmlInfo.getMapInfo(schema); + entryBytes = info.entryName.getBytes(java.nio.charset.StandardCharsets.UTF_8); + keyBytes = info.keyName.getBytes(java.nio.charset.StandardCharsets.UTF_8); + valueBytes = info.valueName.getBytes(java.nio.charset.StandardCharsets.UTF_8); + } + byte[] keyNsBytes = null; + byte[] valueNsBytes = null; + Schema keyMember = schema.mapKeyMember(); + Schema valueMember = schema.mapValueMember(); + var keyNs = keyMember.getDirectTrait(TraitKey.XML_NAMESPACE_TRAIT); + if (keyNs != null) { + keyNsBytes = buildNamespaceBytes(keyNs); + } + var valueNs = valueMember.getDirectTrait(TraitKey.XML_NAMESPACE_TRAIT); + if (valueNs != null) { + valueNsBytes = buildNamespaceBytes(valueNs); + } + consumer.accept(mapState, + new XmlMapEntrySerializer(entryBytes, + keyBytes, + valueBytes, + keyNsBytes, + valueNsBytes)); + } + + @Override + public void writeBoolean(Schema schema, boolean value) { + closePendingTag(); + ensureCapacity(5); + pos = NumberCodec.writeBoolean(buf, pos, value); + } + + @Override + public void writeByte(Schema schema, byte value) { + closePendingTag(); + ensureCapacity(4); + pos = NumberCodec.writeInt(buf, pos, value); + } + + @Override + public void writeShort(Schema schema, short value) { + closePendingTag(); + ensureCapacity(6); + pos = NumberCodec.writeInt(buf, pos, value); + } + + @Override + public void writeInteger(Schema schema, int value) { + closePendingTag(); + ensureCapacity(11); + pos = NumberCodec.writeInt(buf, pos, value); + } + + @Override + public void writeLong(Schema schema, long value) { + closePendingTag(); + ensureCapacity(20); + pos = NumberCodec.writeLong(buf, pos, value); + } + + @Override + public void writeFloat(Schema schema, float value) { + closePendingTag(); + ensureCapacity(24); + pos = NumberCodec.writeFloatFull(buf, pos, value); + } + + @Override + public void writeDouble(Schema schema, double value) { + closePendingTag(); + ensureCapacity(24); + pos = NumberCodec.writeDoubleFull(buf, pos, value); + } + + @Override + public void writeBigInteger(Schema schema, BigInteger value) { + closePendingTag(); + ensureCapacity(value.bitLength() / 3 + 2); + pos = NumberCodec.writeBigInteger(buf, pos, value); + } + + @Override + public void writeBigDecimal(Schema schema, BigDecimal value) { + closePendingTag(); + String s = value.toString(); + ensureCapacity(s.length()); + pos = writeAsciiString(buf, pos, s); + } + + @SuppressWarnings("deprecation") + private static int writeAsciiString(byte[] buf, int pos, String s) { + int len = s.length(); + s.getBytes(0, len, buf, pos); + return pos + len; + } + + @Override + public void writeString(Schema schema, String value) { + closePendingTag(); + ensureCapacity(XmlWriteUtils.maxEscapedTextBytes(value)); + pos = XmlWriteUtils.writeEscapedText(buf, pos, value); + } + + @Override + public void writeBlob(Schema schema, ByteBuffer value) { + closePendingTag(); + byte[] encoded = ByteBufferUtils.base64EncodeToBytes(value); + ensureCapacity(encoded.length); + System.arraycopy(encoded, 0, buf, pos, encoded.length); + pos += encoded.length; + } + + @Override + public void writeTimestamp(Schema schema, Instant value) { + closePendingTag(); + writeTextContent(formatTimestamp(schema, value)); + } + + @Override + public void writeDocument(Schema schema, Document value) { + closePendingTag(); + } + + @Override + public void writeNull(Schema schema) { + closePendingTag(); + } + + private void writeTextContent(String value) { + ensureCapacity(XmlWriteUtils.maxEscapedTextBytes(value)); + pos = XmlWriteUtils.writeEscapedText(buf, pos, value); + } + } + + private final class ListItemSerializer extends InterceptingSerializer { + private final byte[] memberNameBytes; + private final byte[] nsBytes; + + ListItemSerializer(byte[] memberNameBytes, byte[] nsBytes) { + this.memberNameBytes = memberNameBytes; + this.nsBytes = nsBytes; + } + + @Override + protected ShapeSerializer before(Schema schema) { + int nsLen = nsBytes != null ? nsBytes.length : 0; + ensureCapacity(1 + memberNameBytes.length + nsLen); + buf[pos++] = '<'; + System.arraycopy(memberNameBytes, 0, buf, pos, memberNameBytes.length); + pos += memberNameBytes.length; + if (nsBytes != null) { + System.arraycopy(nsBytes, 0, buf, pos, nsBytes.length); + pos += nsBytes.length; + } + pendingClose = true; + return valueSerializer; + } + + @Override + protected void after(Schema schema) { + writeCloseTag(memberNameBytes); + } + } + + private final class XmlMapEntrySerializer implements MapSerializer { + private final byte[] entryBytes; + private final byte[] keyBytes; + private final byte[] valueBytes; + private final byte[] keyNsBytes; + private final byte[] valueNsBytes; + + XmlMapEntrySerializer( + byte[] entryBytes, + byte[] keyBytes, + byte[] valueBytes, + byte[] keyNsBytes, + byte[] valueNsBytes + ) { + this.entryBytes = entryBytes; + this.keyBytes = keyBytes; + this.valueBytes = valueBytes; + this.keyNsBytes = keyNsBytes; + this.valueNsBytes = valueNsBytes; + } + + @Override + public void writeEntry( + Schema keySchema, + String key, + T state, + BiConsumer valueSerializer + ) { + writeOpenTag(entryBytes); + + writeTagWithNsClosed(keyBytes, keyNsBytes); + ensureCapacity(XmlWriteUtils.maxEscapedTextBytes(key)); + pos = XmlWriteUtils.writeEscapedText(buf, pos, key); + writeCloseTag(keyBytes); + + writeTagWithNsOpen(valueBytes, valueNsBytes); + pendingClose = true; + valueSerializer.accept(state, SmithyXmlSerializer.this.valueSerializer); + writeCloseTag(valueBytes); + + writeCloseTag(entryBytes); + } + + private void writeTagWithNsClosed(byte[] nameBytes, byte[] nsBytes) { + int nsLen = nsBytes != null ? nsBytes.length : 0; + ensureCapacity(1 + nameBytes.length + nsLen + 1); + buf[pos++] = '<'; + System.arraycopy(nameBytes, 0, buf, pos, nameBytes.length); + pos += nameBytes.length; + if (nsBytes != null) { + System.arraycopy(nsBytes, 0, buf, pos, nsBytes.length); + pos += nsBytes.length; + } + buf[pos++] = '>'; + } + + private void writeTagWithNsOpen(byte[] nameBytes, byte[] nsBytes) { + int nsLen = nsBytes != null ? nsBytes.length : 0; + ensureCapacity(1 + nameBytes.length + nsLen); + buf[pos++] = '<'; + System.arraycopy(nameBytes, 0, buf, pos, nameBytes.length); + pos += nameBytes.length; + if (nsBytes != null) { + System.arraycopy(nsBytes, 0, buf, pos, nsBytes.length); + pos += nsBytes.length; + } + } + } +} diff --git a/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlCodec.java b/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlCodec.java index e22e45fccc..ad01e47b63 100644 --- a/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlCodec.java +++ b/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlCodec.java @@ -13,6 +13,7 @@ import javax.xml.stream.XMLOutputFactory; import javax.xml.stream.XMLStreamException; import software.amazon.smithy.java.core.serde.Codec; +import software.amazon.smithy.java.core.serde.SerializationException; import software.amazon.smithy.java.core.serde.ShapeDeserializer; import software.amazon.smithy.java.core.serde.ShapeSerializer; import software.amazon.smithy.java.io.ByteBufferUtils; @@ -25,22 +26,34 @@ */ public final class XmlCodec implements Codec { - private final XMLInputFactory xmlInputFactory; - private final XMLOutputFactory xmlOutputFactory; + private static final boolean USE_SMITHY_NATIVE = + "smithy".equals(System.getProperty("smithy-java.xml-provider")); + + private volatile XMLInputFactory xmlInputFactory; + private volatile XMLOutputFactory xmlOutputFactory; + private volatile XMLEventFactory eventFactory; private final XmlInfo xmlInfo = new XmlInfo(); - private final XMLEventFactory eventFactory = XMLEventFactory.newInstance(); private final List wrapperElements; private final XmlNamespaceTrait defaultNamespace; + private final boolean useNative; private XmlCodec(Builder builder) { + this.wrapperElements = builder.wrapperElements; + this.defaultNamespace = builder.defaultNamespace; + this.useNative = builder.useNative != null ? builder.useNative : USE_SMITHY_NATIVE; + if (!useNative) { + initStax(); + } + } + + private void initStax() { xmlInputFactory = XMLInputFactory.newInstance(); xmlInputFactory.setProperty(XMLInputFactory.SUPPORT_DTD, false); xmlInputFactory.setProperty("javax.xml.stream.isSupportingExternalEntities", false); xmlInputFactory.setProperty(XMLInputFactory.IS_REPLACING_ENTITY_REFERENCES, false); xmlInputFactory.setProperty(XMLInputFactory.IS_COALESCING, false); xmlOutputFactory = XMLOutputFactory.newInstance(); - this.wrapperElements = builder.wrapperElements; - this.defaultNamespace = builder.defaultNamespace; + eventFactory = XMLEventFactory.newInstance(); } /** @@ -54,6 +67,9 @@ public static Builder builder() { @Override public ShapeSerializer createSerializer(OutputStream sink) { + if (useNative) { + return new LazyXmlSerializer(defaultNamespace, xmlInfo, sink); + } try { return new XmlSerializer(xmlOutputFactory.createXMLStreamWriter(sink), xmlInfo, defaultNamespace); } catch (XMLStreamException e) { @@ -67,6 +83,20 @@ public ShapeDeserializer createDeserializer(ByteBuffer source) { return EmptyXmlDeserializer.INSTANCE; } + if (useNative) { + byte[] bytes; + int offset; + int length = source.remaining(); + if (source.hasArray()) { + bytes = source.array(); + offset = source.arrayOffset() + source.position(); + } else { + bytes = ByteBufferUtils.getBytes(source); + offset = 0; + } + return new SmithyXmlDeserializer(bytes, offset, length, xmlInfo, true, wrapperElements); + } + try { var reader = xmlInputFactory.createXMLStreamReader(ByteBufferUtils.byteBufferInputStream(source)); return XmlDeserializer.topLevel( @@ -75,7 +105,7 @@ public ShapeDeserializer createDeserializer(ByteBuffer source) { new XmlReader.StreamReader(reader, xmlInputFactory), wrapperElements); } catch (XMLStreamException e) { - throw new RuntimeException(e); + throw new SerializationException(e); } } @@ -85,6 +115,7 @@ public ShapeDeserializer createDeserializer(ByteBuffer source) { public static final class Builder { private List wrapperElements = List.of(); private XmlNamespaceTrait defaultNamespace; + private Boolean useNative; private Builder() {} @@ -118,6 +149,16 @@ public Builder defaultNamespace(XmlNamespaceTrait defaultNamespace) { return this; } + /** + * Override the native provider selection for testing. When set to true, the native + * (high-performance) implementation is used regardless of system property. When false, + * the StAX implementation is used. + */ + Builder useNative(boolean useNative) { + this.useNative = useNative; + return this; + } + /** * Create the codec and ensure all required settings are present. * diff --git a/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlDeserializer.java b/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlDeserializer.java index 86566717bf..9ef767446d 100644 --- a/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlDeserializer.java +++ b/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlDeserializer.java @@ -23,7 +23,7 @@ import software.amazon.smithy.java.core.serde.document.Document; import software.amazon.smithy.model.traits.TimestampFormatTrait; -final class XmlDeserializer implements ShapeDeserializer { +final class XmlDeserializer implements ShapeDeserializer, XmlErrorCodeParser { private final XmlInfo xmlInfo; private final XmlReader reader; @@ -173,7 +173,7 @@ private void skipToCodeElement(String name) throws XMLStreamException { } } - String parseErrorCodeName() { + public String parseErrorCodeName() { try { var element = reader.nextMemberElement(); if (element == null diff --git a/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlErrorCodeParser.java b/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlErrorCodeParser.java new file mode 100644 index 0000000000..ef338ace30 --- /dev/null +++ b/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlErrorCodeParser.java @@ -0,0 +1,13 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.xml; + +/** + * Interface for XML deserializers that can parse error code names from error response bodies. + */ +interface XmlErrorCodeParser { + String parseErrorCodeName(); +} diff --git a/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlInfo.java b/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlInfo.java index aebc96393a..ab12381729 100644 --- a/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlInfo.java +++ b/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlInfo.java @@ -61,6 +61,7 @@ static final class StructInfo { final Map attributes; final Map elements; final boolean hasFlattened; + final XmlMemberLookup memberLookup; private StructInfo(Schema schema) { if (schema.type() != ShapeType.STRUCTURE && schema.type() != ShapeType.UNION) { @@ -96,6 +97,7 @@ private StructInfo(Schema schema) { this.hasFlattened = hasFlattened; this.attributes = attributes == null ? Collections.emptyMap() : attributes; this.elements = elements == null ? Collections.emptyMap() : elements; + this.memberLookup = elements != null ? new XmlMemberLookup(this.elements) : null; } // If the shape has flattened members, then prepare a map to store buffered state. diff --git a/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlMemberLookup.java b/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlMemberLookup.java new file mode 100644 index 0000000000..206a02b673 --- /dev/null +++ b/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlMemberLookup.java @@ -0,0 +1,70 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.xml; + +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Map; +import software.amazon.smithy.java.core.schema.Schema; + +/** + * Fast element name lookup using byte-span comparison with speculative fast path. + * + *

When XML arrives in definition order (common for smithy-to-smithy communication), + * the speculative path matches on the first try. For out-of-order elements, falls back + * to linear scan with length pre-filter. + * + *

This class is thread-safe and immutable. Per-invocation speculation state is tracked + * externally via the int[] hint parameter. + */ +final class XmlMemberLookup { + + private final byte[][] memberNames; + private final int[] memberLens; + private final Schema[] memberSchemas; + + XmlMemberLookup(Map elements) { + int size = elements.size(); + memberNames = new byte[size][]; + memberLens = new int[size]; + memberSchemas = new Schema[size]; + int i = 0; + for (var entry : elements.entrySet()) { + byte[] nameBytes = entry.getKey().getBytes(StandardCharsets.UTF_8); + memberNames[i] = nameBytes; + memberLens[i] = nameBytes.length; + memberSchemas[i] = entry.getValue(); + i++; + } + } + + Schema findMember(byte[] buf, int nameStart, int nameLen, int[] hint) { + int len = memberNames.length; + if (len == 0) + return null; + + int idx = hint[0]; + if (idx < len && memberLens[idx] == nameLen) { + byte[] expected = memberNames[idx]; + if (Arrays.equals(buf, nameStart, nameStart + nameLen, expected, 0, nameLen)) { + hint[0] = idx + 1; + return memberSchemas[idx]; + } + } + + for (int i = 0; i < len; i++) { + if (i == idx) + continue; + if (memberLens[i] == nameLen + && Arrays.equals(buf, nameStart, nameStart + nameLen, memberNames[i], 0, nameLen)) { + hint[0] = i + 1; + return memberSchemas[i]; + } + } + + return null; + } +} diff --git a/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlSchemaExtensions.java b/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlSchemaExtensions.java new file mode 100644 index 0000000000..978ca455b3 --- /dev/null +++ b/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlSchemaExtensions.java @@ -0,0 +1,206 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.xml; + +import java.nio.charset.StandardCharsets; +import software.amazon.smithy.java.core.schema.Schema; +import software.amazon.smithy.java.core.schema.SchemaExtensionKey; +import software.amazon.smithy.java.core.schema.SchemaExtensionProvider; +import software.amazon.smithy.java.core.schema.TraitKey; +import software.amazon.smithy.model.shapes.ShapeType; +import software.amazon.smithy.utils.SmithyInternalApi; + +/** + * Pre-computes XML element name bytes and member metadata at schema load time. + */ +@SmithyInternalApi +public final class XmlSchemaExtensions implements SchemaExtensionProvider { + + public static final SchemaExtensionKey KEY = new SchemaExtensionKey<>(); + + @Override + public SchemaExtensionKey key() { + return KEY; + } + + @Override + public XmlExtension provide(Schema schema) { + var type = schema.type(); + if (schema.isMember()) { + return provideMember(schema); + } else if (type == ShapeType.STRUCTURE || type == ShapeType.UNION) { + return provideStruct(schema); + } else if (type == ShapeType.LIST) { + return provideList(schema); + } else if (type == ShapeType.MAP) { + return provideMap(schema); + } + return null; + } + + private XmlExtension provideMember(Schema schema) { + String name = resolveMemberXmlName(schema); + byte[] nameBytes = name.getBytes(StandardCharsets.UTF_8); + boolean isAttribute = schema.hasTrait(TraitKey.XML_ATTRIBUTE_TRAIT); + boolean isFlattened = schema.hasTrait(TraitKey.XML_FLATTENED_TRAIT); + byte[] namespaceBytes = null; + var ns = schema.getDirectTrait(TraitKey.XML_NAMESPACE_TRAIT); + if (ns != null) { + namespaceBytes = buildNamespaceBytes(ns.getPrefix().orElse(null), ns.getUri()); + } + return new MemberExtension(nameBytes, isAttribute, isFlattened, namespaceBytes); + } + + private XmlExtension provideStruct(Schema schema) { + var members = schema.members(); + int maxIndex = 0; + for (var member : members) { + maxIndex = Math.max(maxIndex, member.memberIndex()); + } + + byte[][] nameTable = new byte[maxIndex + 1][]; + byte[][] nsTable = new byte[maxIndex + 1][]; + boolean[] isAttributeTable = new boolean[maxIndex + 1]; + boolean[] isFlattenedTable = new boolean[maxIndex + 1]; + boolean hasAttributes = false; + boolean hasNamespaces = false; + + for (var member : members) { + int idx = member.memberIndex(); + String name = resolveMemberXmlName(member); + nameTable[idx] = name.getBytes(StandardCharsets.UTF_8); + isAttributeTable[idx] = member.hasTrait(TraitKey.XML_ATTRIBUTE_TRAIT); + isFlattenedTable[idx] = member.hasTrait(TraitKey.XML_FLATTENED_TRAIT); + if (isAttributeTable[idx]) { + hasAttributes = true; + } + var ns = member.getDirectTrait(TraitKey.XML_NAMESPACE_TRAIT); + if (ns != null) { + nsTable[idx] = buildNamespaceBytes(ns.getPrefix().orElse(null), ns.getUri()); + hasNamespaces = true; + } + } + + String structName; + var xmlNameTrait = schema.getTrait(TraitKey.XML_NAME_TRAIT); + if (xmlNameTrait != null) { + structName = xmlNameTrait.getValue(); + } else { + structName = schema.id().getName(); + } + byte[] structNameBytes = structName.getBytes(StandardCharsets.UTF_8); + + byte[] namespaceBytes = null; + var nsTrait = schema.getTrait(TraitKey.XML_NAMESPACE_TRAIT); + if (nsTrait != null) { + namespaceBytes = buildNamespaceBytes(nsTrait.getPrefix().orElse(null), nsTrait.getUri()); + } + + return new StructExtension(nameTable, + hasNamespaces ? nsTable : null, + isAttributeTable, + isFlattenedTable, + structNameBytes, + namespaceBytes, + hasAttributes); + } + + private XmlExtension provideList(Schema schema) { + boolean flattened = schema.hasTrait(TraitKey.XML_FLATTENED_TRAIT); + String memberName; + if (flattened) { + memberName = resolveXmlName(schema); + } else { + var member = schema.listMember(); + var memberXmlName = member.getTrait(TraitKey.XML_NAME_TRAIT); + if (memberXmlName != null) { + memberName = memberXmlName.getValue(); + } else { + memberName = "member"; + } + } + return new ListExtension(memberName.getBytes(StandardCharsets.UTF_8)); + } + + private XmlExtension provideMap(Schema schema) { + boolean flattened = schema.hasTrait(TraitKey.XML_FLATTENED_TRAIT); + String xmlName = resolveXmlName(schema); + String entryName = flattened ? xmlName : "entry"; + String keyName = resolveMapMemberName(schema.mapKeyMember()); + String valueName = resolveMapMemberName(schema.mapValueMember()); + return new MapExtension( + entryName.getBytes(StandardCharsets.UTF_8), + keyName.getBytes(StandardCharsets.UTF_8), + valueName.getBytes(StandardCharsets.UTF_8)); + } + + private static String resolveMemberXmlName(Schema schema) { + var trait = schema.getDirectTrait(TraitKey.XML_NAME_TRAIT); + if (trait != null) { + return trait.getValue(); + } else if (schema.isMember()) { + return schema.memberName(); + } else { + return schema.id().getName(); + } + } + + private static String resolveXmlName(Schema schema) { + var trait = schema.getDirectTrait(TraitKey.XML_NAME_TRAIT); + if (trait != null) { + return trait.getValue(); + } else if (schema.isMember()) { + return schema.memberName(); + } else { + return schema.id().getName(); + } + } + + private static String resolveMapMemberName(Schema schema) { + var trait = schema.getDirectTrait(TraitKey.XML_NAME_TRAIT); + if (trait != null) { + return trait.getValue(); + } + return schema.memberName(); + } + + private static byte[] buildNamespaceBytes(String prefix, String uri) { + String escapedUri = uri.replace("&", "&") + .replace("\"", """) + .replace("<", "<") + .replace(">", ">"); + if (prefix == null || prefix.isEmpty()) { + return (" xmlns=\"" + escapedUri + "\"").getBytes(StandardCharsets.UTF_8); + } else { + return (" xmlns:" + prefix + "=\"" + escapedUri + "\"").getBytes(StandardCharsets.UTF_8); + } + } + + public sealed interface XmlExtension permits MemberExtension, StructExtension, ListExtension, MapExtension {} + + public record MemberExtension( + byte[] nameBytes, + boolean isAttribute, + boolean isFlattened, + byte[] namespaceBytes) implements XmlExtension {} + + public record StructExtension( + byte[][] nameTable, + byte[][] memberNamespaceTable, + boolean[] isAttributeTable, + boolean[] isFlattenedTable, + byte[] structNameBytes, + byte[] namespaceBytes, + boolean hasAttributes) implements XmlExtension {} + + public record ListExtension( + byte[] memberNameBytes) implements XmlExtension {} + + public record MapExtension( + byte[] entryNameBytes, + byte[] keyNameBytes, + byte[] valueNameBytes) implements XmlExtension {} +} diff --git a/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlUtil.java b/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlUtil.java index 630560d6a0..262de7ab9d 100644 --- a/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlUtil.java +++ b/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlUtil.java @@ -18,8 +18,10 @@ public final class XmlUtil { * @return String value of the Code element if found */ public static String parseErrorCodeName(ShapeDeserializer deserializer) { - try (var xmlDeserializer = (XmlDeserializer) deserializer) { - return xmlDeserializer.parseErrorCodeName(); + try { + return ((XmlErrorCodeParser) deserializer).parseErrorCodeName(); + } finally { + deserializer.close(); } } } diff --git a/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlWriteUtils.java b/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlWriteUtils.java new file mode 100644 index 0000000000..ad07a8bc8c --- /dev/null +++ b/codecs/xml-codec/src/main/java/software/amazon/smithy/java/xml/XmlWriteUtils.java @@ -0,0 +1,124 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.xml; + +/** + * Low-level utilities for writing XML content directly to byte arrays with proper escaping. + */ +final class XmlWriteUtils { + + private XmlWriteUtils() {} + + private static final byte[] AMP_ESC = {'&', 'a', 'm', 'p', ';'}; + private static final byte[] LT_ESC = {'&', 'l', 't', ';'}; + private static final byte[] GT_ESC = {'&', 'g', 't', ';'}; + private static final byte[] QUOT_ESC = {'&', 'q', 'u', 'o', 't', ';'}; + private static final byte[] APOS_ESC = {'&', 'a', 'p', 'o', 's', ';'}; + + static int writeEscapedText(byte[] buf, int pos, String value) { + int len = value.length(); + for (int i = 0; i < len; i++) { + char c = value.charAt(i); + if (c < 0x80) { + if (c == '&') { + System.arraycopy(AMP_ESC, 0, buf, pos, 5); + pos += 5; + } else if (c == '<') { + System.arraycopy(LT_ESC, 0, buf, pos, 4); + pos += 4; + } else if (c == '>') { + System.arraycopy(GT_ESC, 0, buf, pos, 4); + pos += 4; + } else { + buf[pos++] = (byte) c; + } + } else if (c < 0x800) { + buf[pos++] = (byte) (0xC0 | (c >> 6)); + buf[pos++] = (byte) (0x80 | (c & 0x3F)); + } else if (!Character.isSurrogate(c)) { + buf[pos++] = (byte) (0xE0 | (c >> 12)); + buf[pos++] = (byte) (0x80 | ((c >> 6) & 0x3F)); + buf[pos++] = (byte) (0x80 | (c & 0x3F)); + } else { + if (Character.isHighSurrogate(c) && i + 1 < len) { + char low = value.charAt(++i); + if (Character.isLowSurrogate(low)) { + int cp = Character.toCodePoint(c, low); + buf[pos++] = (byte) (0xF0 | (cp >> 18)); + buf[pos++] = (byte) (0x80 | ((cp >> 12) & 0x3F)); + buf[pos++] = (byte) (0x80 | ((cp >> 6) & 0x3F)); + buf[pos++] = (byte) (0x80 | (cp & 0x3F)); + } else { + buf[pos++] = '?'; + i--; + } + } else { + buf[pos++] = '?'; + } + } + } + return pos; + } + + static int writeEscapedAttribute(byte[] buf, int pos, String value) { + int len = value.length(); + for (int i = 0; i < len; i++) { + char c = value.charAt(i); + if (c < 0x80) { + if (c == '&') { + System.arraycopy(AMP_ESC, 0, buf, pos, 5); + pos += 5; + } else if (c == '<') { + System.arraycopy(LT_ESC, 0, buf, pos, 4); + pos += 4; + } else if (c == '>') { + System.arraycopy(GT_ESC, 0, buf, pos, 4); + pos += 4; + } else if (c == '"') { + System.arraycopy(QUOT_ESC, 0, buf, pos, 6); + pos += 6; + } else if (c == '\'') { + System.arraycopy(APOS_ESC, 0, buf, pos, 6); + pos += 6; + } else { + buf[pos++] = (byte) c; + } + } else if (c < 0x800) { + buf[pos++] = (byte) (0xC0 | (c >> 6)); + buf[pos++] = (byte) (0x80 | (c & 0x3F)); + } else if (!Character.isSurrogate(c)) { + buf[pos++] = (byte) (0xE0 | (c >> 12)); + buf[pos++] = (byte) (0x80 | ((c >> 6) & 0x3F)); + buf[pos++] = (byte) (0x80 | (c & 0x3F)); + } else { + if (Character.isHighSurrogate(c) && i + 1 < len) { + char low = value.charAt(++i); + if (Character.isLowSurrogate(low)) { + int cp = Character.toCodePoint(c, low); + buf[pos++] = (byte) (0xF0 | (cp >> 18)); + buf[pos++] = (byte) (0x80 | ((cp >> 12) & 0x3F)); + buf[pos++] = (byte) (0x80 | ((cp >> 6) & 0x3F)); + buf[pos++] = (byte) (0x80 | (cp & 0x3F)); + } else { + buf[pos++] = '?'; + i--; + } + } else { + buf[pos++] = '?'; + } + } + } + return pos; + } + + static int maxEscapedTextBytes(String value) { + return value.length() * 5; + } + + static int maxEscapedAttributeBytes(String value) { + return value.length() * 6; + } +} diff --git a/codecs/xml-codec/src/main/resources/META-INF/services/software.amazon.smithy.java.core.schema.SchemaExtensionProvider b/codecs/xml-codec/src/main/resources/META-INF/services/software.amazon.smithy.java.core.schema.SchemaExtensionProvider new file mode 100644 index 0000000000..77881969f0 --- /dev/null +++ b/codecs/xml-codec/src/main/resources/META-INF/services/software.amazon.smithy.java.core.schema.SchemaExtensionProvider @@ -0,0 +1 @@ +software.amazon.smithy.java.xml.XmlSchemaExtensions diff --git a/codecs/xml-codec/src/test/java/software/amazon/smithy/java/xml/GeneratedModelSerdeTest.java b/codecs/xml-codec/src/test/java/software/amazon/smithy/java/xml/GeneratedModelSerdeTest.java new file mode 100644 index 0000000000..0ae228d7d1 --- /dev/null +++ b/codecs/xml-codec/src/test/java/software/amazon/smithy/java/xml/GeneratedModelSerdeTest.java @@ -0,0 +1,866 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.xml; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.time.LocalDate; +import java.time.ZoneOffset; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import smithy.java.xml.test.model.AllListsStruct; +import smithy.java.xml.test.model.BlobStruct; +import smithy.java.xml.test.model.Color; +import smithy.java.xml.test.model.ComplexStruct; +import smithy.java.xml.test.model.FlattenedListStruct; +import smithy.java.xml.test.model.FlattenedMapStruct; +import smithy.java.xml.test.model.InnerStruct; +import smithy.java.xml.test.model.NamespacedStruct; +import smithy.java.xml.test.model.NestedStruct; +import smithy.java.xml.test.model.NumericStruct; +import smithy.java.xml.test.model.RecursiveStruct; +import smithy.java.xml.test.model.SimpleStruct; +import smithy.java.xml.test.model.StringStruct; +import smithy.java.xml.test.model.TimestampStruct; +import smithy.java.xml.test.model.XmlAttributeStruct; +import smithy.java.xml.test.model.XmlNameStruct; +import software.amazon.smithy.java.core.schema.SerializableShape; +import software.amazon.smithy.java.core.schema.ShapeBuilder; +import software.amazon.smithy.java.core.serde.SerializationException; + +public class GeneratedModelSerdeTest extends ProviderTestBase { + + private static T roundtrip( + boolean useNative, + SerializableShape original, + ShapeBuilder builder + ) { + return roundtrip(useNative, useNative, original, builder); + } + + private static T roundtrip( + boolean ser, + boolean de, + SerializableShape original, + ShapeBuilder builder + ) { + try (var serCodec = XmlCodec.builder().useNative(ser).build(); + var deCodec = XmlCodec.builder().useNative(de).build()) { + ByteBuffer serialized = serCodec.serialize(original); + return deCodec.deserializeShape(serialized, builder); + } + } + + // --- Simple Struct --- + + @PerProvider + void simpleStructRoundtrip(boolean useNative) { + var original = SimpleStruct.builder() + .name("test") + .age(42) + .active(true) + .score(98.6) + .createdAt(Instant.parse("2025-01-15T10:30:00Z")) + .build(); + assertThat(roundtrip(useNative, original, SimpleStruct.builder())).isEqualTo(original); + } + + // --- Complex Struct --- + + private static ComplexStruct buildComplexStruct() { + var inner = InnerStruct.builder() + .value("inner-value") + .numbers(List.of(1, 2, 3, 4, 5)) + .build(); + var nested = NestedStruct.builder() + .field1("nested-field") + .field2(100) + .inner(inner) + .build(); + return ComplexStruct.builder() + .id("bench-001") + .count(999) + .enabled(true) + .ratio(1.618) + .score(2.718f) + .bigCount(1_000_000L) + .optionalString("optional-value") + .optionalInt(42) + .createdAt(Instant.parse("2025-06-01T12:00:00Z")) + .payload(ByteBuffer.wrap("binary-payload-data".getBytes(StandardCharsets.UTF_8))) + .tags(List.of("alpha", "beta", "gamma", "delta")) + .intList(List.of(10, 20, 30, 40, 50)) + .metadata(Map.of("key1", "value1", "key2", "value2")) + .intMap(Map.of("a", 1, "b", 2)) + .nested(nested) + .optionalNested(NestedStruct.builder() + .field1("opt-nested") + .field2(200) + .build()) + .structList(List.of(nested, nested)) + .structMap(Map.of("first", nested)) + .color(Color.GREEN) + .colorList(List.of(Color.RED, Color.BLUE, Color.YELLOW)) + .bigIntValue(new BigInteger("123456789012345678901234567890")) + .bigDecValue(new BigDecimal("99999.99999")) + .build(); + } + + @Test + void complexStructRoundtrip() { + var original = buildComplexStruct(); + assertThat(roundtrip(NATIVE, original, ComplexStruct.builder())).isEqualTo(original); + } + + // --- Numeric Boundary Roundtrips --- + + @ParameterizedTest + @MethodSource("crossProviders") + void numericMinValuesRoundtrip(boolean ser, boolean de) { + var original = NumericStruct.builder() + .byteVal(Byte.MIN_VALUE) + .shortVal(Short.MIN_VALUE) + .intVal(Integer.MIN_VALUE) + .longVal(Long.MIN_VALUE) + .floatVal(Float.MIN_VALUE) + .doubleVal(Double.MIN_VALUE) + .bigIntVal(new BigInteger("-99999999999999999999999999999")) + .bigDecVal(new BigDecimal("-99999.99999")) + .build(); + assertThat(roundtrip(ser, de, original, NumericStruct.builder())).isEqualTo(original); + } + + @ParameterizedTest + @MethodSource("crossProviders") + void numericMaxValuesRoundtrip(boolean ser, boolean de) { + var original = NumericStruct.builder() + .byteVal(Byte.MAX_VALUE) + .shortVal(Short.MAX_VALUE) + .intVal(Integer.MAX_VALUE) + .longVal(Long.MAX_VALUE) + .floatVal(Float.MAX_VALUE) + .doubleVal(Double.MAX_VALUE) + .bigIntVal(new BigInteger("99999999999999999999999999999")) + .bigDecVal(new BigDecimal("123456789.123456789")) + .build(); + assertThat(roundtrip(ser, de, original, NumericStruct.builder())).isEqualTo(original); + } + + @PerProvider + void numericZeroValues(boolean useNative) { + var original = NumericStruct.builder() + .byteVal((byte) 0) + .shortVal((short) 0) + .intVal(0) + .longVal(0L) + .floatVal(0.0f) + .doubleVal(0.0) + .bigIntVal(BigInteger.ZERO) + .bigDecVal(BigDecimal.ZERO) + .build(); + assertThat(roundtrip(useNative, original, NumericStruct.builder())).isEqualTo(original); + } + + @PerProvider + void bigDecimalVariousScales(boolean useNative) { + for (var bd : List.of( + new BigDecimal("42"), + new BigDecimal("1.005"), + new BigDecimal("1E+10"), + new BigDecimal("12345678901234567890.12345"))) { + var original = NumericStruct.builder().bigDecVal(bd).build(); + assertThat(roundtrip(useNative, original, NumericStruct.builder()) + .getBigDecVal()).isEqualByComparingTo(bd); + } + } + + @PerProvider + void bigIntegerSmallAndLarge(boolean useNative) { + for (var bi : List.of( + BigInteger.valueOf(42), + BigInteger.valueOf(-42), + new BigInteger("123456789012345678901234567890123456789"))) { + var original = NumericStruct.builder().bigIntVal(bi).build(); + assertThat(roundtrip(useNative, original, NumericStruct.builder())).isEqualTo(original); + } + } + + // --- String Edge Cases --- + + @ParameterizedTest + @MethodSource("crossProviders") + void stringEdgeCasesRoundtrip(boolean ser, boolean de) { + // Note: \r is excluded because XML spec Section 2.11 normalizes CR to LF + for (var s : List.of( + "", + "abcdefghijklmnopqrstuvwxyz0123456789", + "quote:\" backslash:\\ tab:\t newline:\n", + "é中ü", + "😀🎉", + "hello\tworld\né😀\"quoted\"")) { + var original = StringStruct.builder().value(s).build(); + assertThat(roundtrip(ser, de, original, StringStruct.builder())).isEqualTo(original); + } + } + + @ParameterizedTest + @MethodSource("crossProviders") + void xmlSpecialCharactersRoundtrip(boolean ser, boolean de) { + for (var s : List.of( + "value", + "a & b", + "1 < 2 > 0", + "'single' and \"double\" quotes", + "")) { + var original = StringStruct.builder().value(s).build(); + assertThat(roundtrip(ser, de, original, StringStruct.builder())).isEqualTo(original); + } + } + + // --- Timestamp Format Roundtrips --- + + @ParameterizedTest + @MethodSource("crossProviders") + void timestampAllFormatsRoundtrip(boolean ser, boolean de) { + var epoch = Instant.EPOCH; + var original = TimestampStruct.builder() + .epochSeconds(epoch) + .dateTime(epoch) + .httpDate(epoch) + .build(); + assertThat(roundtrip(ser, de, original, TimestampStruct.builder())).isEqualTo(original); + } + + @ParameterizedTest + @MethodSource("crossProviders") + void timestampWithNanoseconds(boolean ser, boolean de) { + var withNanos = Instant.parse("2025-01-15T10:30:00.123456789Z"); + var original = TimestampStruct.builder().dateTime(withNanos).build(); + assertThat(roundtrip(ser, de, original, TimestampStruct.builder())).isEqualTo(original); + } + + @PerProvider + void timestampPreEpoch(boolean useNative) { + var preEpoch = Instant.parse("1969-12-31T23:59:59Z"); + var original = TimestampStruct.builder() + .epochSeconds(preEpoch) + .dateTime(preEpoch) + .httpDate(preEpoch) + .build(); + assertThat(roundtrip(useNative, original, TimestampStruct.builder())).isEqualTo(original); + } + + @PerProvider + void timestampLeapYear(boolean useNative) { + var leapDay = LocalDate.of(2024, 2, 29).atStartOfDay(ZoneOffset.UTC).toInstant(); + var original = TimestampStruct.builder().dateTime(leapDay).build(); + assertThat(roundtrip(useNative, original, TimestampStruct.builder())).isEqualTo(original); + } + + @PerProvider + void timestampAllMonthsHttpDate(boolean useNative) { + for (int month = 1; month <= 12; month++) { + var instant = LocalDate.of(2025, month, 15).atStartOfDay(ZoneOffset.UTC).toInstant(); + var original = TimestampStruct.builder().httpDate(instant).build(); + assertThat(roundtrip(useNative, original, TimestampStruct.builder())).isEqualTo(original); + } + } + + // --- Collection Roundtrips --- + + @ParameterizedTest + @MethodSource("crossProviders") + void emptyCollections(boolean ser, boolean de) { + var original = ComplexStruct.builder() + .id("empty") + .count(0) + .nested(NestedStruct.builder().field1("f").field2(0).build()) + .tags(List.of()) + .intList(List.of()) + .metadata(Map.of()) + .intMap(Map.of()) + .build(); + assertThat(roundtrip(ser, de, original, ComplexStruct.builder())).isEqualTo(original); + } + + @Test + void nestedStructCollections() { + var nested1 = NestedStruct.builder().field1("first").field2(1).build(); + var nested2 = NestedStruct.builder().field1("second").field2(2).build(); + var original = ComplexStruct.builder() + .id("nested-collections") + .count(0) + .nested(nested1) + .structList(List.of(nested1, nested2)) + .structMap(Map.of("a", nested1, "b", nested2)) + .build(); + assertThat(roundtrip(NATIVE, original, ComplexStruct.builder())).isEqualTo(original); + } + + // --- Enum Roundtrips --- + + @ParameterizedTest + @MethodSource("crossProviders") + void enumAllVariantsRoundtrip(boolean ser, boolean de) { + for (var color : List.of(Color.RED, Color.GREEN, Color.BLUE, Color.YELLOW)) { + var original = ComplexStruct.builder() + .id("enum") + .count(0) + .nested(NestedStruct.builder().field1("f").field2(0).build()) + .color(color) + .build(); + assertThat(roundtrip(ser, de, original, ComplexStruct.builder())).isEqualTo(original); + } + } + + // --- XML-specific: xmlName Trait --- + + @PerProvider + void xmlNameTraitRoundtrip(boolean useNative) { + var original = XmlNameStruct.builder() + .id("test-id") + .displayName("Test Name") + .normalField("normal") + .build(); + assertThat(roundtrip(useNative, original, XmlNameStruct.builder())).isEqualTo(original); + } + + @PerProvider + void xmlNameTraitSerialization(boolean useNative) { + var original = XmlNameStruct.builder() + .id("test-id") + .displayName("Test Name") + .normalField("normal") + .build(); + + try (var codec = XmlCodec.builder().useNative(useNative).build()) { + String xml = codec.serializeToString(original); + assertThat(xml).contains(""); + assertThat(xml).contains(""); + assertThat(xml).contains(""); + assertThat(xml).contains(""); + } + } + + // --- XML-specific: xmlAttribute Trait --- + + @ParameterizedTest + @MethodSource("crossProviders") + void xmlAttributeRoundtrip(boolean ser, boolean de) { + var original = XmlAttributeStruct.builder() + .version("1.0") + .identifier("abc-123") + .content("hello world") + .build(); + assertThat(roundtrip(ser, de, original, XmlAttributeStruct.builder())).isEqualTo(original); + } + + @PerProvider + void xmlAttributeSerialization(boolean useNative) { + var original = XmlAttributeStruct.builder() + .version("1.0") + .identifier("abc-123") + .content("hello world") + .build(); + + try (var codec = XmlCodec.builder().useNative(useNative).build()) { + String xml = codec.serializeToString(original); + assertThat(xml).contains("version=\"1.0\""); + assertThat(xml).contains("id=\"abc-123\""); + assertThat(xml).contains("hello world"); + } + } + + // --- XML-specific: xmlFlattened Trait --- + + @ParameterizedTest + @MethodSource("crossProviders") + void flattenedListRoundtrip(boolean ser, boolean de) { + var original = FlattenedListStruct.builder() + .items(List.of("a", "b", "c")) + .numbers(List.of(1, 2, 3)) + .normalList(List.of("x", "y")) + .build(); + assertThat(roundtrip(ser, de, original, FlattenedListStruct.builder())).isEqualTo(original); + } + + @PerProvider + void flattenedListSerialization(boolean useNative) { + var original = FlattenedListStruct.builder() + .items(List.of("a", "b", "c")) + .normalList(List.of("x", "y")) + .build(); + + try (var codec = XmlCodec.builder().useNative(useNative).build()) { + String xml = codec.serializeToString(original); + assertThat(xml).contains("a"); + assertThat(xml).contains("b"); + assertThat(xml).contains("c"); + assertThat(xml).contains("xy"); + } + } + + @ParameterizedTest + @MethodSource("crossProviders") + void flattenedMapRoundtrip(boolean ser, boolean de) { + var original = FlattenedMapStruct.builder() + .entries(Map.of("k1", "v1", "k2", "v2")) + .normalMap(Map.of("a", "b")) + .build(); + assertThat(roundtrip(ser, de, original, FlattenedMapStruct.builder())).isEqualTo(original); + } + + // --- XML-specific: xmlNamespace Trait --- + + @PerProvider + void namespacedStructRoundtrip(boolean useNative) { + var original = NamespacedStruct.builder() + .name("test") + .value(42) + .build(); + assertThat(roundtrip(useNative, original, NamespacedStruct.builder())).isEqualTo(original); + } + + @PerProvider + void namespacedStructSerialization(boolean useNative) { + var original = NamespacedStruct.builder() + .name("test") + .value(42) + .build(); + + try (var codec = XmlCodec.builder().useNative(useNative).build()) { + String xml = codec.serializeToString(original); + assertThat(xml).contains("xmlns=\"https://example.com/test\""); + } + } + + // --- Recursive / Nested Depth --- + + @ParameterizedTest + @MethodSource("crossProviders") + void recursiveStructRoundtrip(boolean ser, boolean de) { + RecursiveStruct current = RecursiveStruct.builder().value("leaf").build(); + for (int i = 9; i >= 1; i--) { + current = RecursiveStruct.builder().value("level-" + i).child(current).build(); + } + assertThat(roundtrip(ser, de, current, RecursiveStruct.builder())).isEqualTo(current); + } + + // --- Blob Roundtrips --- + + @ParameterizedTest + @MethodSource("crossProviders") + void blobRoundtrip(boolean ser, boolean de) { + var empty = BlobStruct.builder().data(ByteBuffer.wrap(new byte[0])).build(); + assertThat(roundtrip(ser, de, empty, BlobStruct.builder())).isEqualTo(empty); + + var small = BlobStruct.builder() + .data(ByteBuffer.wrap("hello".getBytes(StandardCharsets.UTF_8))) + .build(); + assertThat(roundtrip(ser, de, small, BlobStruct.builder())).isEqualTo(small); + + byte[] largeData = new byte[1000]; + for (int i = 0; i < largeData.length; i++) { + largeData[i] = (byte) (i % 256); + } + var large = BlobStruct.builder().data(ByteBuffer.wrap(largeData)).build(); + assertThat(roundtrip(ser, de, large, BlobStruct.builder())).isEqualTo(large); + } + + // --- Null optional members --- + + @ParameterizedTest + @MethodSource("crossProviders") + void nullOptionalMembersRoundtrip(boolean ser, boolean de) { + var original = NumericStruct.builder().build(); + assertThat(roundtrip(ser, de, original, NumericStruct.builder())).isEqualTo(original); + } + + // --- All-types list roundtrip --- + + @ParameterizedTest + @MethodSource("crossProviders") + void allListTypesRoundtrip(boolean ser, boolean de) { + var original = AllListsStruct.builder() + .booleans(List.of(true, false, true)) + .bytes(List.of((byte) 1, (byte) -128, (byte) 127)) + .shorts(List.of((short) 1, (short) -32768, (short) 32767)) + .ints(List.of(1, Integer.MIN_VALUE, Integer.MAX_VALUE)) + .longs(List.of(1L, Long.MIN_VALUE, Long.MAX_VALUE)) + .floats(List.of(1.5f, Float.MIN_VALUE, Float.MAX_VALUE)) + .doubles(List.of(1.5, Double.MIN_VALUE, Double.MAX_VALUE)) + .bigInts(List.of(BigInteger.ZERO, new BigInteger("99999999999999999999"))) + .bigDecs(List.of(BigDecimal.ZERO, new BigDecimal("99999.99999"))) + .strings(List.of("hello", "world")) + .blobs(List.of( + ByteBuffer.wrap("a".getBytes(StandardCharsets.UTF_8)), + ByteBuffer.wrap("b".getBytes(StandardCharsets.UTF_8)))) + .timestamps(List.of(Instant.EPOCH, Instant.parse("2025-01-15T10:30:00Z"))) + .build(); + assertThat(roundtrip(ser, de, original, AllListsStruct.builder())).isEqualTo(original); + } + + // --- Malformed XML validation (parity between native and StAX) --- + + private static T deserialize( + boolean useNative, + String xml, + ShapeBuilder builder + ) { + try (var codec = XmlCodec.builder().useNative(useNative).build()) { + return codec.deserializeShape(xml, builder); + } + } + + private static T deserialize( + boolean useNative, + byte[] xml, + ShapeBuilder builder + ) { + try (var codec = XmlCodec.builder().useNative(useNative).build()) { + return codec.deserializeShape(ByteBuffer.wrap(xml), builder); + } + } + + @PerProvider + void invalidUtf8IsRejected(boolean useNative) { + byte[] input = new byte[] { + '<', + 'S', + 'i', + 'm', + 'p', + 'l', + 'e', + 'S', + 't', + 'r', + 'u', + 'c', + 't', + '>', + '<', + 'n', + 'a', + 'm', + 'e', + '>', + 'h', + 'i', + '<', + '/', + 'n', + (byte) 0x97, + 'm', + 'e', + '>', + '<', + '/', + 'S', + 'i', + 'm', + 'p', + 'l', + 'e', + 'S', + 't', + 'r', + 'u', + 'c', + 't', + '>' + }; + assertThatThrownBy(() -> deserialize(useNative, input, SimpleStruct.builder())) + .isInstanceOf(SerializationException.class); + } + + @PerProvider + void invalidUtf8ContinuationByteAlone(boolean useNative) { + byte[] input = new byte[] { + '<', + 'S', + 'i', + 'm', + 'p', + 'l', + 'e', + 'S', + 't', + 'r', + 'u', + 'c', + 't', + '>', + '<', + 'n', + 'a', + 'm', + 'e', + '>', + (byte) 0x80, + '<', + '/', + 'n', + 'a', + 'm', + 'e', + '>', + '<', + '/', + 'S', + 'i', + 'm', + 'p', + 'l', + 'e', + 'S', + 't', + 'r', + 'u', + 'c', + 't', + '>' + }; + assertThatThrownBy(() -> deserialize(useNative, input, SimpleStruct.builder())) + .isInstanceOf(SerializationException.class); + } + + @PerProvider + void nullByteInContentIsRejected(boolean useNative) { + byte[] input = "hi\0there".getBytes(); + assertThatThrownBy(() -> deserialize(useNative, input, SimpleStruct.builder())) + .isInstanceOf(SerializationException.class); + } + + @PerProvider + void numericCharRefToControlCharIsRejected(boolean useNative) { + String xml = "hithere"; + assertThatThrownBy(() -> deserialize(useNative, xml, SimpleStruct.builder())) + .isInstanceOf(SerializationException.class); + } + + @PerProvider + void mismatchedEndTagOnScalar(boolean useNative) { + String xml = "hello"; + assertThatThrownBy(() -> deserialize(useNative, xml, SimpleStruct.builder())) + .isInstanceOf(SerializationException.class); + } + + @PerProvider + void mismatchedOuterEndTag(boolean useNative) { + String xml = "hi"; + assertThatThrownBy(() -> deserialize(useNative, xml, SimpleStruct.builder())) + .isInstanceOf(SerializationException.class); + } + + @PerProvider + void mismatchedEndTagInList(boolean useNative) { + String xml = "11true" + + "1.01.01" + + "a1" + + "hello"; + assertThatThrownBy(() -> deserialize(useNative, xml, ComplexStruct.builder())) + .isInstanceOf(SerializationException.class); + } + + @PerProvider + void mismatchedEndTagInMapKey(boolean useNative) { + String xml = "11true" + + "1.01.01" + + "a1" + + "kv"; + assertThatThrownBy(() -> deserialize(useNative, xml, ComplexStruct.builder())) + .isInstanceOf(SerializationException.class); + } + + @Test + void emptyListMembersHandledConsistently() { + String xml = "11true" + + "1.01.01" + + "a1" + + "helloworld" + + ""; + + var staxResult = deserialize(STAX, xml, ComplexStruct.builder()); + var nativeResult = deserialize(NATIVE, xml, ComplexStruct.builder()); + assertThat(nativeResult.getTags()).isEqualTo(staxResult.getTags()); + } + + @PerProvider + void selfClosingElementDeserializesAsEmpty(boolean useNative) { + String xml = "5"; + var result = deserialize(useNative, xml, SimpleStruct.builder()); + assertThat(result.getName()).isEmpty(); + assertThat(result.getAge()).isEqualTo(5); + } + + @PerProvider + void selfClosingElementDeserializesAsEmptyInStruct(boolean useNative) { + String xml = "5"; + var result = deserialize(useNative, xml, SimpleStruct.builder()); + assertThat(result.getAge()).isEqualTo(5); + } + + @Test + void selfClosingElementsInListSkippedAsNull() { + // Self-closing and empty elements in lists are treated as null (skipped) by the generated code + String xml = "11true" + + "1.01.01" + + "a1" + + "hello" + + ""; + var staxResult = deserialize(STAX, xml, ComplexStruct.builder()); + var nativeResult = deserialize(NATIVE, xml, ComplexStruct.builder()); + assertThat(nativeResult.getTags()).isEqualTo(staxResult.getTags()); + assertThat(nativeResult.getTags()).containsExactly("hello"); + } + + @PerProvider + void trailingContentAfterRootIsRejected(boolean useNative) { + String xml = "hi1extra"; + assertThatThrownBy(() -> deserialize(useNative, xml, SimpleStruct.builder())) + .isInstanceOf(SerializationException.class); + } + + @PerProvider + void unknownElementWithMismatchedInnerTags(boolean useNative) { + String xml = ""; + assertThatThrownBy(() -> deserialize(useNative, xml, SimpleStruct.builder())) + .isInstanceOf(SerializationException.class); + } + + @PerProvider + void unclosedElementRejectsAtEof(boolean useNative) { + String xml = "text"; + assertThatThrownBy(() -> deserialize(useNative, xml, SimpleStruct.builder())) + .isInstanceOf(SerializationException.class); + } + + @PerProvider + void truncatedEndTagRejectsAtEof(boolean useNative) { + String xml = " deserialize(useNative, xml, SimpleStruct.builder())) + .isInstanceOf(SerializationException.class); + } + + @PerProvider + void unterminatedEntityReferenceIsRejected(boolean useNative) { + String xml = "hello&world"; + assertThatThrownBy(() -> deserialize(useNative, xml, SimpleStruct.builder())) + .isInstanceOf(SerializationException.class); + } + + @PerProvider + void undeclaredEntityReferenceIsRejected(boolean useNative) { + String xml = "&foo;"; + assertThatThrownBy(() -> deserialize(useNative, xml, SimpleStruct.builder())) + .isInstanceOf(SerializationException.class); + } + + @PerProvider + void cdataCloseInTextContentIsRejected(boolean useNative) { + String xml = "a]]>b"; + assertThatThrownBy(() -> deserialize(useNative, xml, SimpleStruct.builder())) + .isInstanceOf(SerializationException.class); + } + + @PerProvider + void invalidMarkupDeclarationIsRejected(boolean useNative) { + String xml = ""; + assertThatThrownBy(() -> deserialize(useNative, xml, SimpleStruct.builder())) + .isInstanceOf(SerializationException.class); + } + + @Test + void infinityParsedCorrectlyNotFalsePositive() { + String xml = "x1ISO-885"; + assertThatThrownBy(() -> deserialize(STAX, xml, SimpleStruct.builder())).isInstanceOf(Exception.class); + assertThatThrownBy(() -> deserialize(NATIVE, xml, SimpleStruct.builder())).isInstanceOf(Exception.class); + } + + @Test + void actualInfinityParsedCorrectly() { + String xml = "x1Infinity"; + var stax = deserialize(STAX, xml, SimpleStruct.builder()); + var nat = deserialize(NATIVE, xml, SimpleStruct.builder()); + assertThat(nat.getScore()).isEqualTo(stax.getScore()); + assertThat(nat.getScore()).isEqualTo(Double.POSITIVE_INFINITY); + } + + @Test + void carriageReturnNormalization() { + String xml = "a\rb\r\nc1"; + var staxResult = deserialize(STAX, xml, SimpleStruct.builder()); + var nativeResult = deserialize(NATIVE, xml, SimpleStruct.builder()); + assertThat(nativeResult.getName()).isEqualTo(staxResult.getName()); + assertThat(nativeResult.getName()).isEqualTo("a\nb\nc"); + } + + @Test + void crossCodecRoundtripWithListsAndMaps() { + String xml = "x5false" + + "2.51.5100" + + "n110" + + "ab" + + "k1v1" + + ""; + + var staxResult = deserialize(STAX, xml, ComplexStruct.builder()); + var nativeResult = deserialize(NATIVE, xml, ComplexStruct.builder()); + + assertThat(nativeResult.getId()).isEqualTo(staxResult.getId()); + assertThat(nativeResult.getTags()).isEqualTo(staxResult.getTags()); + assertThat(nativeResult.getMetadata()).isEqualTo(staxResult.getMetadata()); + } + + // --- Error code parsing --- + + @PerProvider + void parseErrorCodeFromErrorResponse(boolean useNative) { + String xml = "SenderInvalidGreeting" + + "Hi"; + var codec = XmlCodec.builder().useNative(useNative).build(); + var deser = codec.createDeserializer(ByteBuffer.wrap(xml.getBytes(StandardCharsets.UTF_8))); + String code = XmlUtil.parseErrorCodeName(deser); + assertThat(code).isEqualTo("InvalidGreeting"); + } + + @PerProvider + void parseErrorCodeFromBareError(boolean useNative) { + String xml = "ComplexErrorSomething"; + var codec = XmlCodec.builder().useNative(useNative).build(); + var deser = codec.createDeserializer(ByteBuffer.wrap(xml.getBytes(StandardCharsets.UTF_8))); + String code = XmlUtil.parseErrorCodeName(deser); + assertThat(code).isEqualTo("ComplexError"); + } + + @PerProvider + void parseErrorCodeFromEc2Response(boolean useNative) { + String xml = "AuthFailure" + + "Unauthorized"; + var codec = XmlCodec.builder().useNative(useNative).build(); + var deser = codec.createDeserializer(ByteBuffer.wrap(xml.getBytes(StandardCharsets.UTF_8))); + String code = XmlUtil.parseErrorCodeName(deser); + assertThat(code).isEqualTo("AuthFailure"); + } + + @Test + void errorStructDeserializationThroughErrorResponse() { + String xml = "test42"; + var result = deserialize(NATIVE, xml, SimpleStruct.builder()); + assertThat(result.getName()).isEqualTo("test"); + assertThat(result.getAge()).isEqualTo(42); + } +} diff --git a/codecs/xml-codec/src/test/java/software/amazon/smithy/java/xml/PerProvider.java b/codecs/xml-codec/src/test/java/software/amazon/smithy/java/xml/PerProvider.java new file mode 100644 index 0000000000..808679112d --- /dev/null +++ b/codecs/xml-codec/src/test/java/software/amazon/smithy/java/xml/PerProvider.java @@ -0,0 +1,19 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.xml; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +@ParameterizedTest +@MethodSource("providers") +@interface PerProvider {} diff --git a/codecs/xml-codec/src/test/java/software/amazon/smithy/java/xml/ProviderTestBase.java b/codecs/xml-codec/src/test/java/software/amazon/smithy/java/xml/ProviderTestBase.java new file mode 100644 index 0000000000..bfa0888bb8 --- /dev/null +++ b/codecs/xml-codec/src/test/java/software/amazon/smithy/java/xml/ProviderTestBase.java @@ -0,0 +1,34 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +package software.amazon.smithy.java.xml; + +import java.util.List; +import org.junit.jupiter.api.Named; +import org.junit.jupiter.params.provider.Arguments; + +abstract class ProviderTestBase { + + static final boolean STAX = false; + static final boolean NATIVE = true; + + static List providers() { + return List.of( + Arguments.of(Named.of("stax", STAX)), + Arguments.of(Named.of("native", NATIVE))); + } + + static List crossProviders() { + return List.of( + Arguments.of(Named.of("stax->stax", STAX), STAX), + Arguments.of(Named.of("native->native", NATIVE), NATIVE), + Arguments.of(Named.of("stax->native", STAX), NATIVE), + Arguments.of(Named.of("native->stax", NATIVE), STAX)); + } + + static XmlCodec codec(boolean useNative) { + return XmlCodec.builder().useNative(useNative).build(); + } +} diff --git a/codecs/xml-codec/src/test/java/software/amazon/smithy/java/xml/XmlCodecTest.java b/codecs/xml-codec/src/test/java/software/amazon/smithy/java/xml/XmlCodecTest.java index 623fc831d8..651a971650 100644 --- a/codecs/xml-codec/src/test/java/software/amazon/smithy/java/xml/XmlCodecTest.java +++ b/codecs/xml-codec/src/test/java/software/amazon/smithy/java/xml/XmlCodecTest.java @@ -8,11 +8,14 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.equalTo; +import static org.junit.jupiter.api.Assertions.assertEquals; import java.nio.ByteBuffer; import java.time.Instant; import java.util.ArrayList; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; import org.junit.jupiter.api.Test; import software.amazon.smithy.java.core.schema.PreludeSchemas; import software.amazon.smithy.java.core.schema.Schema; @@ -25,9 +28,10 @@ import software.amazon.smithy.model.traits.JsonNameTrait; import software.amazon.smithy.model.traits.TimestampFormatTrait; -public class XmlCodecTest { - @Test - public void deserializesXml() { +public class XmlCodecTest extends ProviderTestBase { + + @PerProvider + public void deserializesXml(boolean useNative) { var xml = """ 2006-03-01T00:00:00Z @@ -40,7 +44,7 @@ public void deserializesXml() { """; - try (var codec = XmlCodec.builder().build()) { + try (var codec = XmlCodec.builder().useNative(useNative).build()) { var pojo = codec.deserializeShape(xml, new TestPojo.Builder()); assertThat(pojo.name, equalTo("Hello")); assertThat(pojo.date, equalTo(Instant.parse("2006-03-01T00:00:00Z"))); @@ -48,10 +52,9 @@ public void deserializesXml() { } } - @Test - public void deserializesEmptyBody() { - try (var codec = XmlCodec.builder().build()) { - // Empty ByteBuffer: top-level readStruct should be a no-op + @PerProvider + public void deserializesEmptyBody(boolean useNative) { + try (var codec = XmlCodec.builder().useNative(useNative).build()) { var pojo = codec.deserializeShape(ByteBuffer.allocate(0), new TestPojo.Builder()); assertThat(pojo.name, equalTo(null)); assertThat(pojo.date, equalTo(null)); @@ -59,9 +62,9 @@ public void deserializesEmptyBody() { } } - @Test - public void serializesXml() { - try (var codec = XmlCodec.builder().build()) { + @PerProvider + public void serializesXml(boolean useNative) { + try (var codec = XmlCodec.builder().useNative(useNative).build()) { var builder = new TestPojo.Builder(); builder.name = "Hello"; builder.date = Instant.parse("2006-03-01T00:00:00Z"); @@ -76,6 +79,115 @@ public void serializesXml() { } } + @PerProvider + public void deserializesMapWithEmptyBlobValue(boolean useNative) { + var xml = "k1" + + "k2AQID"; + + try (var codec = XmlCodec.builder().useNative(useNative).build()) { + var pojo = codec.deserializeShape(xml, new MapPojo.Builder()); + assertThat(pojo.value.size(), equalTo(2)); + assertThat(pojo.value.get("k1"), equalTo(ByteBuffer.wrap(new byte[0]))); + assertThat(pojo.value.get("k2"), equalTo(ByteBuffer.wrap(new byte[] {1, 2, 3}))); + } + } + + @Test + public void nativeAndStaxProduceSameResultForBlobMap() { + var xml = "" + + "k1" + + "k2AQID" + + "k3BAUG" + + "k4BwgJ" + + ""; + + try (var stax = XmlCodec.builder().useNative(false).build(); + var native_ = XmlCodec.builder().useNative(true).build()) { + var staxResult = stax.deserializeShape(xml, new MapPojo.Builder()); + var nativeResult = native_.deserializeShape(xml, new MapPojo.Builder()); + assertEquals(staxResult.value.size(), + nativeResult.value.size(), + "Map sizes differ: stax=" + staxResult.value + " native=" + nativeResult.value); + for (var entry : staxResult.value.entrySet()) { + assertEquals(entry.getValue(), + nativeResult.value.get(entry.getKey()), + "Mismatch for key=" + entry.getKey()); + } + } + } + + private static final class MapPojo implements SerializableStruct { + + private static final ShapeId ID = ShapeId.from("smithy.example#MapPojo"); + + private static final Schema BLOB_MAP = Schema.mapBuilder(ShapeId.from("smithy.example#BlobMap")) + .putMember("key", PreludeSchemas.STRING) + .putMember("value", PreludeSchemas.BLOB) + .build(); + + private static final Schema SCHEMA = Schema.structureBuilder(ID) + .putMember("value", BLOB_MAP) + .build(); + + private static final Schema VALUE = SCHEMA.member("value"); + + private final Map value; + + MapPojo(Builder builder) { + this.value = builder.value; + } + + @Override + public Schema schema() { + return SCHEMA; + } + + @Override + public void serializeMembers(ShapeSerializer serializer) { + if (value != null) { + serializer.writeMap(VALUE, value, value.size(), (map, ms) -> { + for (var e : map.entrySet()) { + ms.writeEntry(BLOB_MAP.mapKeyMember(), + e.getKey(), + e.getValue(), + (v, s) -> s.writeBlob(BLOB_MAP.mapValueMember(), v)); + } + }); + } + } + + @Override + public T getMemberValue(Schema member) { + throw new UnsupportedOperationException(); + } + + private static final class Builder implements ShapeBuilder { + private Map value = new LinkedHashMap<>(); + + @Override + public Schema schema() { + return SCHEMA; + } + + @Override + public Builder deserialize(ShapeDeserializer decoder) { + decoder.readStruct(SCHEMA, this, (pojo, member, deser) -> { + if (member.memberName().equals("value")) { + deser.readStringMap(VALUE, pojo.value, (map, key, de) -> { + map.put(key, de.readBlob(BLOB_MAP.mapValueMember())); + }); + } + }); + return this; + } + + @Override + public MapPojo build() { + return new MapPojo(this); + } + } + } + private static final class TestPojo implements SerializableStruct { private static final ShapeId ID = ShapeId.from("smithy.example#Foo"); diff --git a/codegen/codegen-plugin/build.gradle.kts b/codegen/codegen-plugin/build.gradle.kts index 6e9ae18c73..5734c3177a 100644 --- a/codegen/codegen-plugin/build.gradle.kts +++ b/codegen/codegen-plugin/build.gradle.kts @@ -3,7 +3,7 @@ import java.io.Serializable plugins { id("smithy-java.codegen-plugin-conventions") id("smithy-java.publishing-conventions") - alias(libs.plugins.jmh) + id("smithy-java.jmh-conventions") } description = "Smithy Java code generation plugin" @@ -82,8 +82,6 @@ tasks.named("compileJmhJava") { dependsOn("compileItJava") } -jmh {} - // Ensure generate tasks that use it source set resources depend on base generateSources listOf("generateSourcesClient", "generateSourcesServer", "generateSourcesTypes").forEach { taskName -> tasks.named(taskName) { diff --git a/config/spotbugs/filter.xml b/config/spotbugs/filter.xml index 2e83aab98b..66fae4faef 100644 --- a/config/spotbugs/filter.xml +++ b/config/spotbugs/filter.xml @@ -119,10 +119,10 @@ - + - + diff --git a/context/build.gradle.kts b/context/build.gradle.kts index cbab530029..620298e96e 100644 --- a/context/build.gradle.kts +++ b/context/build.gradle.kts @@ -1,6 +1,6 @@ plugins { id("smithy-java.module-conventions") - id("me.champeau.jmh") version "0.7.3" + id("smithy-java.jmh-conventions") } description = "This module provides a typed identity based collection" @@ -9,10 +9,5 @@ extra["displayName"] = "Smithy :: Java :: Context" extra["moduleName"] = "software.amazon.smithy.java.context" jmh { - warmupIterations = 3 - iterations = 5 - fork = 1 profilers.add("async:output=flamegraph") - // profilers.add("gc") - duplicateClassesStrategy = DuplicatesStrategy.EXCLUDE // don't dump a bunch of warnings. } diff --git a/core/build.gradle.kts b/core/build.gradle.kts index edd404b1f7..345ea3138b 100644 --- a/core/build.gradle.kts +++ b/core/build.gradle.kts @@ -2,7 +2,7 @@ import org.apache.tools.ant.filters.ReplaceTokens plugins { id("smithy-java.module-conventions") - alias(libs.plugins.jmh) + id("smithy-java.jmh-conventions") } description = "This module provides the core functionality for Smithy java" @@ -18,15 +18,6 @@ dependencies { implementation(project(":logging")) } -jmh { - includes.addAll( - providers - .gradleProperty("jmh.includes") - .map { listOf(it) } - .orElse(emptyList()), - ) -} - // Run all tests with a different locale to ensure we are not doing anything locale specific. val localeTest = tasks.register("localeTest") { diff --git a/dynamic-schemas/build.gradle.kts b/dynamic-schemas/build.gradle.kts index 30bc5fad11..3f8d0474a3 100644 --- a/dynamic-schemas/build.gradle.kts +++ b/dynamic-schemas/build.gradle.kts @@ -1,6 +1,6 @@ plugins { id("smithy-java.module-conventions") - alias(libs.plugins.jmh) + id("smithy-java.jmh-conventions") } description = "This module provides a way to dynamically create Smithy Java schemas from a model" @@ -8,12 +8,6 @@ description = "This module provides a way to dynamically create Smithy Java sche extra["displayName"] = "Smithy :: Java :: Dynamic Schemas" extra["moduleName"] = "software.amazon.smithy.java.dynamicschemas" -jmh { - warmupIterations = 3 - iterations = 5 - fork = 1 -} - dependencies { api(project(":core")) diff --git a/examples/event-streaming-client/build.gradle.kts b/examples/event-streaming-client/build.gradle.kts index c32bdf7c10..dd91fa22d0 100644 --- a/examples/event-streaming-client/build.gradle.kts +++ b/examples/event-streaming-client/build.gradle.kts @@ -1,7 +1,7 @@ plugins { `java-library` id("software.amazon.smithy.gradle.smithy-base") - id("me.champeau.jmh") version "0.7.3" + id("smithy-java.jmh-conventions") } dependencies { @@ -49,10 +49,6 @@ tasks { jmh { warmupIterations = 2 - iterations = 5 - fork = 1 - // profilers.add("async:output=flamegraph") - // profilers.add("gc") } // Helps Intellij IDE's discover smithy models diff --git a/examples/mcp-server/build.gradle.kts b/examples/mcp-server/build.gradle.kts index 9b4f1d4c02..2c93cba474 100644 --- a/examples/mcp-server/build.gradle.kts +++ b/examples/mcp-server/build.gradle.kts @@ -3,7 +3,7 @@ import com.github.jengelman.gradle.plugins.shadow.transformers.AppendingTransfor plugins { `java-library` id("software.amazon.smithy.gradle.smithy-base") - id("com.gradleup.shadow").version("8.3.5") + id("com.gradleup.shadow") } dependencies { diff --git a/examples/restjson-client/build.gradle.kts b/examples/restjson-client/build.gradle.kts index ef87255bd9..4e1727ab2c 100644 --- a/examples/restjson-client/build.gradle.kts +++ b/examples/restjson-client/build.gradle.kts @@ -2,7 +2,7 @@ plugins { `java-library` application id("software.amazon.smithy.gradle.smithy-base") - id("me.champeau.jmh") version "0.7.3" + id("smithy-java.jmh-conventions") } dependencies { @@ -58,15 +58,6 @@ tasks { jmh { warmupIterations = 4 - iterations = 5 - fork = 1 - // Allow filtering for specific benchmarks - includes.addAll(providers.gradleProperty("jmh.includes") - .map { listOf(it) } - .orElse(emptyList())) - // profilers.add("async:output=flamegraph;direction=forward") - // profilers.add("async:output=collapsed;dir=build/jmh-profiler") - // profilers.add("gc") } repositories { diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index ca720144d8..8aafad3de0 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -3,6 +3,7 @@ junit5 = "6.1.0" hamcrest = "3.0" smithy = "1.71.0" jmh = "0.7.3" +jmhCore = "1.37" test-logger-plugin = "4.0.0" spotbugs = "6.0.22" spotless = "8.6.0" @@ -20,6 +21,8 @@ picocli = "4.7.7" graalvm-native = "1.1.1" shadow = "8.3.11" jazzer = "0.30.0" +pitest-junit5 = "1.2.3" +pitest-gradle = "1.19.0" json-schema-validator = "3.0.0" opentelemetry = "1.62.0" jspecify = "1.0.0" @@ -88,6 +91,10 @@ spotbugs = { module = "com.github.spotbugs.snom:spotbugs-gradle-plugin", version spotless = { module = "com.diffplug.spotless:spotless-plugin-gradle", version.ref = "spotless" } smithy-gradle-base = { module = "software.amazon.smithy.gradle:smithy-base", version.ref = "smithy-gradle-plugins" } dependency-analysis = { module = "com.autonomousapps:dependency-analysis-gradle-plugin", version.ref = "dep-analysis" } +pitest-gradle-plugin = { module = "info.solidsoft.gradle.pitest:gradle-pitest-plugin", version.ref = "pitest-gradle" } +pitest-junit5-plugin = { module = "org.pitest:pitest-junit5-plugin", version.ref = "pitest-junit5" } +jmh-gradle-plugin = { module = "me.champeau.jmh:jmh-gradle-plugin", version.ref = "jmh" } +shadow-gradle-plugin = { module = "com.gradleup.shadow:shadow-gradle-plugin", version.ref = "shadow" } [plugins] jmh = { id = "me.champeau.jmh", version.ref = "jmh" } diff --git a/http/http-api/build.gradle.kts b/http/http-api/build.gradle.kts index 02fbb74932..2768940e62 100644 --- a/http/http-api/build.gradle.kts +++ b/http/http-api/build.gradle.kts @@ -1,6 +1,6 @@ plugins { id("smithy-java.module-conventions") - id("me.champeau.jmh") version "0.7.3" + id("smithy-java.jmh-conventions") } description = "This module provides the Smithy Java HTTP API" @@ -11,16 +11,3 @@ extra["moduleName"] = "software.amazon.smithy.java.http.api" dependencies { api(project(":io")) } - -jmh { - warmupIterations = 3 - iterations = 5 - fork = 1 - duplicateClassesStrategy = DuplicatesStrategy.EXCLUDE - includes.addAll( - providers - .gradleProperty("jmh.includes") - .map { listOf(it) } - .orElse(emptyList()), - ) -} diff --git a/http/http-client/build.gradle.kts b/http/http-client/build.gradle.kts index f2dfe41adf..98b1f1b175 100644 --- a/http/http-client/build.gradle.kts +++ b/http/http-client/build.gradle.kts @@ -2,7 +2,7 @@ import java.net.Socket plugins { java - id("me.champeau.jmh") version "0.7.3" + id("smithy-java.jmh-conventions") } repositories { @@ -89,31 +89,17 @@ val stopBenchmarkServer by tasks.registering { } jmh { - val includesProp = project.findProperty("jmh.includes")?.toString() - val jvmArgsProp = project.findProperty("jmh.jvmArgsAppend")?.toString() - val profilersProp = project.findProperty("jmh.profilers")?.toString() - val defaultJvmArgs = listOf("-Djdk.httpclient.allowRestrictedHeaders=host") - - includes = if (includesProp != null) listOf(includesProp) else listOf(".*") - warmupIterations = 3 iterations = 3 - fork = 1 + includes.set( + providers.gradleProperty("jmh.includes") + .map { listOf(it) } + .orElse(listOf(".*")), + ) resultFormat = "CSV" resultsFile = project.file("build/reports/jmh/results.csv") - - if (jvmArgsProp != null) { - jvmArgsAppend = defaultJvmArgs + jvmArgsProp.split(Regex("\\s*;\\s*")).filter { it.isNotEmpty() } - } else { - jvmArgsAppend = defaultJvmArgs - } - if (profilersProp != null) { - val profilerSpecs = - if (profilersProp.contains(";;")) { - profilersProp.split(Regex("\\s*;;\\s*")).filter { it.isNotEmpty() } - } else { - listOf(profilersProp) - } - profilers.addAll(profilerSpecs) + jvmArgsAppend.addAll("-Djdk.httpclient.allowRestrictedHeaders=host") + providers.gradleProperty("jmh.jvmArgsAppend").orNull?.let { args -> + jvmArgsAppend.addAll(args.split(Regex("\\s*;\\s*")).filter { it.isNotEmpty() }) } } diff --git a/io/src/main/java/software/amazon/smithy/java/io/ByteBufferUtils.java b/io/src/main/java/software/amazon/smithy/java/io/ByteBufferUtils.java index 9b0be82e74..be482aefde 100644 --- a/io/src/main/java/software/amazon/smithy/java/io/ByteBufferUtils.java +++ b/io/src/main/java/software/amazon/smithy/java/io/ByteBufferUtils.java @@ -17,14 +17,12 @@ public final class ByteBufferUtils { private ByteBufferUtils() {} public static String base64Encode(ByteBuffer buffer) { - byte[] bytes; - if (isExact(buffer)) { - bytes = buffer.array(); - } else { - bytes = new byte[buffer.remaining()]; - buffer.asReadOnlyBuffer().get(bytes); - } - return Base64.getEncoder().encodeToString(bytes); + byte[] encoded = base64EncodeToBytes(buffer); + return new String(encoded, StandardCharsets.ISO_8859_1); + } + + public static byte[] base64EncodeToBytes(ByteBuffer buffer) { + return Base64.getEncoder().encode(buffer.duplicate()).array(); } public static String getUTF8String(ByteBuffer buffer) { diff --git a/logging/build.gradle.kts b/logging/build.gradle.kts index 6a0c0ee7f8..399d73fcce 100644 --- a/logging/build.gradle.kts +++ b/logging/build.gradle.kts @@ -2,7 +2,7 @@ import org.gradle.language.base.plugins.LifecycleBasePlugin.VERIFICATION_GROUP plugins { id("smithy-java.module-conventions") - alias(libs.plugins.jmh) + id("smithy-java.jmh-conventions") } description = "This module provides the Logging functionality for Smithy java" diff --git a/scripts/run-remote-benchmarks.sh b/scripts/run-remote-benchmarks.sh index 8eb579e23c..7ebbe4b03d 100755 --- a/scripts/run-remote-benchmarks.sh +++ b/scripts/run-remote-benchmarks.sh @@ -82,7 +82,7 @@ ssh "$SSH_HOST" "mkdir -p $REMOTE_DIR" scp -q "$JAR" "$SSH_HOST:$REMOTE_DIR/$JAR_NAME" # --- Build JMH CLI args --- -JVM_ARGS="-Xms1g -Xmx1g -XX:+UseG1GC -XX:+AlwaysPreTouch -Dsmithy-java.json-provider=smithy" +JVM_ARGS="-Xms1g -Xmx1g -XX:+UseG1GC -XX:+AlwaysPreTouch -Dsmithy-java.json-provider=smithy -Dsmithy-java.xml-provider=smithy" JMH_ARGS="-bm sample -tu ns -f 1 -rf json -rff $REMOTE_DIR/results.json" JMH_ARGS="$JMH_ARGS -jvmArgs \"$JVM_ARGS\"" diff --git a/server/server-rpcv2-json/build.gradle.kts b/server/server-rpcv2-json/build.gradle.kts index 0f2c02a543..b42db086fa 100644 --- a/server/server-rpcv2-json/build.gradle.kts +++ b/server/server-rpcv2-json/build.gradle.kts @@ -22,5 +22,10 @@ dependencies { testImplementation(libs.smithy.protocol.tests) } +protocolTestRuns { + run("native") { systemProperty("smithy-java.json-provider", "smithy") } + run("jackson") { systemProperty("smithy-java.json-provider", "jackson") } +} + val generator = "software.amazon.smithy.java.protocoltests.generators.ProtocolTestGenerator" addGenerateSrcsTask(generator, "rpcv2Json", "smithy.protocoltests.rpcv2Json#RpcV2JsonProtocol", "server") diff --git a/settings.gradle.kts b/settings.gradle.kts index ba398180bf..23dcd1572c 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -37,6 +37,7 @@ include(":retries-api") include(":retries") // Codecs +include(":codecs:codec-commons") include(":codecs:cbor-codec") include(":codecs:json-codec") include(":codecs:xml-codec")