diff --git a/compiler/fory_compiler/generators/cpp.py b/compiler/fory_compiler/generators/cpp.py index a5bc8b418c..f1b6f4de6b 100644 --- a/compiler/fory_compiler/generators/cpp.py +++ b/compiler/fory_compiler/generators/cpp.py @@ -66,6 +66,7 @@ class CppGenerator(BaseGenerator): PrimitiveKind.FLOAT64: "double", PrimitiveKind.STRING: "std::string", PrimitiveKind.BYTES: "std::vector", + PrimitiveKind.DECIMAL: "fory::serialization::Decimal", PrimitiveKind.DATE: "fory::serialization::Date", PrimitiveKind.TIMESTAMP: "fory::serialization::Timestamp", PrimitiveKind.ANY: "std::any", @@ -1779,6 +1780,8 @@ def collect_includes( includes.add("") elif field_type.kind == PrimitiveKind.BYTES: includes.add("") + elif field_type.kind == PrimitiveKind.DECIMAL: + includes.add('"fory/serialization/decimal_serializers.h"') elif field_type.kind in (PrimitiveKind.DATE, PrimitiveKind.TIMESTAMP): includes.add('"fory/serialization/temporal_serializers.h"') elif field_type.kind == PrimitiveKind.ANY: diff --git a/compiler/fory_compiler/generators/dart.py b/compiler/fory_compiler/generators/dart.py index 943660cae2..e13ea0d5a3 100644 --- a/compiler/fory_compiler/generators/dart.py +++ b/compiler/fory_compiler/generators/dart.py @@ -67,6 +67,7 @@ class DartGenerator(BaseGenerator): PrimitiveKind.BYTES: "Uint8List", PrimitiveKind.DATE: "LocalDate", PrimitiveKind.TIMESTAMP: "Timestamp", + PrimitiveKind.DECIMAL: "Decimal", PrimitiveKind.ANY: "Object?", } @@ -610,6 +611,7 @@ def _default_value_for_type( PrimitiveKind.BYTES: "Uint8List(0)", PrimitiveKind.DATE: "const LocalDate(1970, 1, 1)", PrimitiveKind.TIMESTAMP: "Timestamp(0, 0)", + PrimitiveKind.DECIMAL: "const Decimal.zero()", PrimitiveKind.ANY: "null", }[t.kind] if isinstance(t, ListType): diff --git a/compiler/fory_compiler/generators/go.py b/compiler/fory_compiler/generators/go.py index 0fd9d4b0c1..a8a392df49 100644 --- a/compiler/fory_compiler/generators/go.py +++ b/compiler/fory_compiler/generators/go.py @@ -197,6 +197,7 @@ def message_has_unions(self, message: Message) -> bool: PrimitiveKind.BYTES: "[]byte", PrimitiveKind.DATE: "fory.Date", PrimitiveKind.TIMESTAMP: "time.Time", + PrimitiveKind.DECIMAL: "fory.Decimal", PrimitiveKind.ANY: "any", } @@ -666,6 +667,7 @@ def get_union_case_type_id_expr( PrimitiveKind.BYTES: "fory.BINARY", PrimitiveKind.DATE: "fory.DATE", PrimitiveKind.TIMESTAMP: "fory.TIMESTAMP", + PrimitiveKind.DECIMAL: "fory.DECIMAL", PrimitiveKind.ANY: "fory.UNKNOWN", } return primitive_type_ids.get(kind, "fory.UNKNOWN") diff --git a/compiler/fory_compiler/generators/javascript.py b/compiler/fory_compiler/generators/javascript.py index 34aa5b8fb0..395302b8cb 100644 --- a/compiler/fory_compiler/generators/javascript.py +++ b/compiler/fory_compiler/generators/javascript.py @@ -147,7 +147,7 @@ class JavaScriptGenerator(BaseGenerator): PrimitiveKind.DATE: "Date", PrimitiveKind.TIMESTAMP: "Date", PrimitiveKind.DURATION: "number", - # DECIMAL is not supported by the JS runtime; rejected in _field_type_expr. + PrimitiveKind.DECIMAL: "Decimal", PrimitiveKind.ANY: "any", } @@ -177,7 +177,7 @@ class JavaScriptGenerator(BaseGenerator): PrimitiveKind.DATE: "Type.date()", PrimitiveKind.TIMESTAMP: "Type.timestamp()", PrimitiveKind.DURATION: "Type.duration()", - # DECIMAL is not yet supported by the JS runtime; omitted intentionally. + PrimitiveKind.DECIMAL: "Type.decimal()", PrimitiveKind.ANY: "Type.any()", } @@ -531,6 +531,9 @@ def generate_imports(self) -> List[str]: lines: List[str] = [] imported_regs = self._collect_imported_registrations() + if self._schema_uses_primitive_kind(PrimitiveKind.DECIMAL): + lines.append("import { Decimal } from '@apache-fory/core';") + # Collect all imported types used in this schema imported_types_by_module: Dict[str, Set[str]] = {} @@ -577,6 +580,39 @@ def generate_imports(self) -> List[str]: return lines + def _schema_uses_primitive_kind(self, primitive_kind: PrimitiveKind) -> bool: + def uses_field_type(field_type: FieldType) -> bool: + if isinstance(field_type, PrimitiveType): + return field_type.kind == primitive_kind + if isinstance(field_type, NamedType): + return field_type.name.lower() == primitive_kind.value + if isinstance(field_type, ListType): + return uses_field_type(field_type.element_type) + if isinstance(field_type, MapType): + return uses_field_type(field_type.key_type) or uses_field_type( + field_type.value_type + ) + return False + + def uses_message(message: Message) -> bool: + for field in message.fields: + if uses_field_type(field.field_type): + return True + for nested_union in message.nested_unions: + if any( + uses_field_type(field.field_type) for field in nested_union.fields + ): + return True + return any(uses_message(nested) for nested in message.nested_messages) + + if any( + uses_field_type(field.field_type) + for union in self.schema.unions + for field in union.fields + ): + return True + return any(uses_message(message) for message in self.schema.messages) + def generate(self) -> List[GeneratedFile]: """Generate JavaScript files for the schema.""" return [self.generate_file()] @@ -806,11 +842,6 @@ def _field_type_expr( """Return the Fory JS runtime ``Type.xxx()`` expression for a field type.""" parent_stack = parent_stack or [] if isinstance(field_type, PrimitiveType): - if field_type.kind == PrimitiveKind.DECIMAL: - raise ValueError( - "decimal is not supported by the JavaScript runtime. " - "Use a different type or wait for runtime support." - ) expr = self.PRIMITIVE_RUNTIME_MAP.get(field_type.kind) if expr is None: return "Type.any()" @@ -826,11 +857,6 @@ def _field_type_expr( return self.PRIMITIVE_RUNTIME_MAP[shorthand_map[lower]] for pk in PrimitiveKind: if pk.value == lower: - if pk == PrimitiveKind.DECIMAL: - raise ValueError( - "decimal is not supported by the JavaScript runtime. " - "Use a different type or wait for runtime support." - ) expr = self.PRIMITIVE_RUNTIME_MAP.get(pk) if expr is None: raise ValueError( diff --git a/compiler/fory_compiler/generators/python.py b/compiler/fory_compiler/generators/python.py index 8919462637..2065b8b5f0 100644 --- a/compiler/fory_compiler/generators/python.py +++ b/compiler/fory_compiler/generators/python.py @@ -68,6 +68,7 @@ class PythonGenerator(BaseGenerator): PrimitiveKind.BYTES: "bytes", PrimitiveKind.DATE: "datetime.date", PrimitiveKind.TIMESTAMP: "datetime.datetime", + PrimitiveKind.DECIMAL: "decimal.Decimal", PrimitiveKind.ANY: "Any", } @@ -138,6 +139,7 @@ class PythonGenerator(BaseGenerator): PrimitiveKind.BYTES: 'b""', PrimitiveKind.DATE: "None", PrimitiveKind.TIMESTAMP: "None", + PrimitiveKind.DECIMAL: 'decimal.Decimal("0")', PrimitiveKind.ANY: "None", } @@ -962,6 +964,8 @@ def collect_imports( if isinstance(field_type, PrimitiveType): if field_type.kind in (PrimitiveKind.DATE, PrimitiveKind.TIMESTAMP): imports.add("import datetime") + elif field_type.kind == PrimitiveKind.DECIMAL: + imports.add("import decimal") elif field_type.kind == PrimitiveKind.ANY: imports.add("from typing import Any") diff --git a/compiler/fory_compiler/generators/rust.py b/compiler/fory_compiler/generators/rust.py index f9c8ae852f..47bb087f44 100644 --- a/compiler/fory_compiler/generators/rust.py +++ b/compiler/fory_compiler/generators/rust.py @@ -67,6 +67,7 @@ class RustGenerator(BaseGenerator): PrimitiveKind.BYTES: "Vec", PrimitiveKind.DATE: "chrono::NaiveDate", PrimitiveKind.TIMESTAMP: "chrono::NaiveDateTime", + PrimitiveKind.DECIMAL: "fory::Decimal", PrimitiveKind.ANY: "Box", } diff --git a/compiler/fory_compiler/generators/swift.py b/compiler/fory_compiler/generators/swift.py index 5b1c715dd6..a5093f3229 100644 --- a/compiler/fory_compiler/generators/swift.py +++ b/compiler/fory_compiler/generators/swift.py @@ -65,8 +65,9 @@ class SwiftGenerator(BaseGenerator): PrimitiveKind.FLOAT64: "Double", PrimitiveKind.STRING: "String", PrimitiveKind.BYTES: "Data", - PrimitiveKind.DATE: "ForyDate", - PrimitiveKind.TIMESTAMP: "ForyTimestamp", + PrimitiveKind.DATE: "LocalDate", + PrimitiveKind.TIMESTAMP: "Date", + PrimitiveKind.DECIMAL: "Decimal", PrimitiveKind.ANY: "Any", } @@ -963,12 +964,13 @@ def generate_message_fields( encoding = self.field_encoding_argument(field) field_id = self.message_field_id_argument(field) - if field_id is not None and encoding is not None: - lines.append(f"{ind}@ForyField(id: {field_id}, encoding: {encoding})") - elif field_id is not None: - lines.append(f"{ind}@ForyField(id: {field_id})") - elif encoding is not None: - lines.append(f"{ind}@ForyField(encoding: {encoding})") + attr_parts: List[str] = [] + if field_id is not None: + attr_parts.append(f"id: {field_id}") + if encoding is not None: + attr_parts.append(f"encoding: {encoding}") + if attr_parts: + lines.append(f"{ind}@ForyField({', '.join(attr_parts)})") field_type = self.field_swift_type(field, lineage) weak_prefix = "weak " if self.is_weak_ref_field(field) else "" diff --git a/compiler/fory_compiler/tests/test_dart_generator.py b/compiler/fory_compiler/tests/test_dart_generator.py index c24b01e7d8..3311adb985 100644 --- a/compiler/fory_compiler/tests/test_dart_generator.py +++ b/compiler/fory_compiler/tests/test_dart_generator.py @@ -143,6 +143,26 @@ def test_dart_generator_uses_typed_lists_for_non_nullable_primitive_lists(): assert "factory ValueUnion.values(Uint32List value)" in file.content +def test_dart_generator_supports_decimal_fields_and_unions(): + file = generate_dart( + """ + package demo; + + message Money [id=100] { + decimal amount = 1; + } + + union ValueUnion [id=101] { + decimal amount = 1; + Money money = 2; + } + """ + ) + + assert "Decimal amount = const Decimal.zero();" in file.content + assert "factory ValueUnion.amount(Decimal value)" in file.content + + def test_dart_generator_emits_container_ref_annotations_for_builder_metadata(): file = generate_dart( """ diff --git a/compiler/fory_compiler/tests/test_generated_code.py b/compiler/fory_compiler/tests/test_generated_code.py index 4385e07ed1..0856320d2c 100644 --- a/compiler/fory_compiler/tests/test_generated_code.py +++ b/compiler/fory_compiler/tests/test_generated_code.py @@ -543,6 +543,31 @@ def test_java_repeated_float16_generation_uses_float16_list(): assert "private Float16List vals;" in java_output +def test_cpp_generator_supports_decimal_fields_and_unions(): + schema = parse_fdl( + dedent( + """ + package gen; + + message Money { + decimal amount = 1; + } + + union Value { + decimal amount = 1; + Money money = 2; + } + """ + ) + ) + + cpp_output = render_files(generate_files(schema, CppGenerator)) + assert '#include "fory/serialization/decimal_serializers.h"' in cpp_output + assert "const fory::serialization::Decimal& amount() const" in cpp_output + assert "std::variant value_" in cpp_output + assert "(fory::serialization::Decimal, amount, fory::F(1))" in cpp_output + + def test_java_enum_generation_uses_fory_enum_ids(): schema = parse_fdl( dedent( diff --git a/compiler/fory_compiler/tests/test_javascript_codegen.py b/compiler/fory_compiler/tests/test_javascript_codegen.py index 30b3563ca7..fa428d602d 100644 --- a/compiler/fory_compiler/tests/test_javascript_codegen.py +++ b/compiler/fory_compiler/tests/test_javascript_codegen.py @@ -219,6 +219,29 @@ def test_javascript_collection_types(): assert "config: Map;" in output +def test_javascript_decimal_generation_uses_runtime_decimal_type(): + source = dedent( + """ + package example; + + message Money [id=100] { + decimal amount = 1; + } + + union Value [id=101] { + decimal amount = 1; + Money money = 2; + } + """ + ) + output = generate_javascript(source) + + assert "import { Decimal } from '@apache-fory/core';" in output + assert "amount: Decimal;" in output + assert "{ case: ValueCase.AMOUNT; value: Decimal }" in output + assert "amount: Type.decimal()" in output + + def test_javascript_map_key_fallback_to_map(): """Test that map keys not valid for Record use Map instead.""" source = dedent( diff --git a/compiler/fory_compiler/tests/test_swift_generator.py b/compiler/fory_compiler/tests/test_swift_generator.py index 384d95ad28..fdcfb82057 100644 --- a/compiler/fory_compiler/tests/test_swift_generator.py +++ b/compiler/fory_compiler/tests/test_swift_generator.py @@ -85,6 +85,47 @@ def test_swift_generator_emits_tagged_union_case_ids(): assert "fory.register(Demo.Animal.self, id: 101)" in content +def test_swift_generator_supports_decimal_fields_and_unions(): + source = """ + package demo; + + message Money [id=100] { + decimal amount = 1; + } + + union Value [id=101] { + decimal amount = 1; + Money money = 2; + } + """ + content = generate_swift(source) + assert "public var amount: Decimal = Decimal.foryDefault()" in content + assert "case amount(Decimal)" in content + + +def test_swift_generator_maps_date_to_local_date(): + source = """ + package demo; + + message Temporal [id=100] { + date day = 1; + timestamp instant = 2; + } + + union Value [id=101] { + date day = 1; + timestamp instant = 2; + } + """ + content = generate_swift(source) + assert "@ForyField(id: 1)" in content + assert "@ForyField(id: 2)" in content + assert "public var day: LocalDate = LocalDate.foryDefault()" in content + assert "public var instant: Date = Date.foryDefault()" in content + assert "case day(LocalDate)" in content + assert "case instant(Date)" in content + + def test_swift_generator_uses_class_for_ref_targets_and_weak_fields(): source = """ package tree; diff --git a/cpp/fory/serialization/BUILD b/cpp/fory/serialization/BUILD index c3dfe5bd58..77f15a5d12 100644 --- a/cpp/fory/serialization/BUILD +++ b/cpp/fory/serialization/BUILD @@ -14,6 +14,7 @@ cc_library( "collection_serializer.h", "config.h", "context.h", + "decimal_serializers.h", "enum_serializer.h", "fory.h", "map_serializer.h", diff --git a/cpp/fory/serialization/decimal_serializers.h b/cpp/fory/serialization/decimal_serializers.h new file mode 100644 index 0000000000..64d99f49d6 --- /dev/null +++ b/cpp/fory/serialization/decimal_serializers.h @@ -0,0 +1,303 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +#include "fory/serialization/serializer.h" + +#include +#include +#include +#include +#include +#include + +namespace fory { +namespace serialization { + +inline void normalize_decimal_magnitude(std::vector &magnitude_le) { + while (!magnitude_le.empty() && magnitude_le.back() == 0) { + magnitude_le.pop_back(); + } +} + +/// Exact decimal value represented as `unscaled_value * 10^-scale`. +class Decimal { +public: + Decimal() : scale_(0), negative_(false) {} + + Decimal(int32_t scale, bool negative, std::vector magnitude_le) + : scale_(scale), negative_(negative), + magnitude_le_(std::move(magnitude_le)) { + normalize_decimal_magnitude(magnitude_le_); + if (magnitude_le_.empty()) { + negative_ = false; + } + } + + static Decimal from_int64(int64_t value, int32_t scale = 0) { + if (value == 0) { + return Decimal(scale, false, {}); + } + const bool negative = value < 0; + uint64_t magnitude = negative + ? (value == std::numeric_limits::min() + ? (uint64_t{1} << 63) + : static_cast(-value)) + : static_cast(value); + std::vector magnitude_le; + while (magnitude != 0) { + magnitude_le.push_back(static_cast(magnitude & 0xFF)); + magnitude >>= 8; + } + return Decimal(scale, negative, std::move(magnitude_le)); + } + + static Decimal from_bytes(int32_t scale, bool negative, + std::initializer_list magnitude_le) { + return Decimal(scale, negative, std::vector(magnitude_le)); + } + + int32_t scale() const { return scale_; } + + bool negative() const { return negative_; } + + const std::vector &magnitude_le() const { return magnitude_le_; } + + bool is_zero() const { return magnitude_le_.empty(); } + + bool operator==(const Decimal &other) const { + return scale_ == other.scale_ && negative_ == other.negative_ && + magnitude_le_ == other.magnitude_le_; + } + + bool operator!=(const Decimal &other) const { return !(*this == other); } + +private: + int32_t scale_; + bool negative_; + std::vector magnitude_le_; +}; + +inline uint64_t encode_decimal_zigzag64(int64_t value) { + if (value >= 0) { + return static_cast(value) << 1; + } + return (static_cast(~value) << 1) | 1ULL; +} + +inline int64_t decode_decimal_zigzag64(uint64_t value) { + if ((value & 1ULL) == 0) { + return static_cast(value >> 1); + } + return ~static_cast(value >> 1); +} + +inline bool +decimal_magnitude_to_uint64(const std::vector &magnitude_le, + uint64_t &value) { + if (magnitude_le.size() > sizeof(uint64_t)) { + return false; + } + value = 0; + for (size_t i = 0; i < magnitude_le.size(); ++i) { + value |= static_cast(magnitude_le[i]) << (i * 8); + } + return true; +} + +inline bool decimal_try_get_int64(const Decimal &decimal, int64_t &value) { + if (decimal.is_zero()) { + value = 0; + return true; + } + + uint64_t magnitude = 0; + if (!decimal_magnitude_to_uint64(decimal.magnitude_le(), magnitude)) { + return false; + } + + if (!decimal.negative()) { + if (magnitude > + static_cast(std::numeric_limits::max())) { + return false; + } + value = static_cast(magnitude); + return true; + } + + if (magnitude == (uint64_t{1} << 63)) { + value = std::numeric_limits::min(); + return true; + } + if (magnitude > static_cast(std::numeric_limits::max())) { + return false; + } + value = -static_cast(magnitude); + return true; +} + +inline bool can_use_small_decimal_encoding(const Decimal &decimal, + int64_t &small_value) { + if (!decimal_try_get_int64(decimal, small_value)) { + return false; + } + return encode_decimal_zigzag64(small_value) <= + static_cast(std::numeric_limits::max()); +} + +template <> struct Serializer { + static constexpr TypeId type_id = TypeId::DECIMAL; + + static inline void write_type_info(WriteContext &ctx) { + ctx.write_uint8(static_cast(type_id)); + } + + static inline void read_type_info(ReadContext &ctx) { + uint32_t actual = ctx.read_uint8(ctx.error()); + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return; + } + if (actual != static_cast(type_id)) { + ctx.set_error( + Error::type_mismatch(actual, static_cast(type_id))); + } + } + + static inline void write(const Decimal &value, WriteContext &ctx, + RefMode ref_mode, bool write_type, + bool has_generics = false) { + (void)has_generics; + write_not_null_ref_flag(ctx, ref_mode); + if (write_type) { + ctx.write_uint8(static_cast(type_id)); + } + write_data(value, ctx); + } + + static inline void write_data(const Decimal &value, WriteContext &ctx) { + ctx.write_var_int32(value.scale()); + int64_t small_value = 0; + if (can_use_small_decimal_encoding(value, small_value)) { + ctx.write_var_uint64(encode_decimal_zigzag64(small_value) << 1); + return; + } + + if (value.is_zero()) { + ctx.set_error( + Error::invalid_data("Zero must use the small decimal encoding")); + return; + } + if (value.magnitude_le().size() > + static_cast(std::numeric_limits::max())) { + ctx.set_error(Error::invalid_data( + "Decimal magnitude length exceeds uint32_t range")); + return; + } + + uint64_t meta = (static_cast(value.magnitude_le().size()) << 1) | + (value.negative() ? 1ULL : 0ULL); + ctx.write_var_uint64((meta << 1) | 1ULL); + ctx.write_bytes(value.magnitude_le().data(), + static_cast(value.magnitude_le().size())); + } + + static inline void write_data_generic(const Decimal &value, WriteContext &ctx, + bool has_generics) { + (void)has_generics; + write_data(value, ctx); + } + + static inline Decimal read(ReadContext &ctx, RefMode ref_mode, + bool read_type) { + bool has_value = read_null_only_flag(ctx, ref_mode); + if (ctx.has_error() || !has_value) { + return Decimal(); + } + if (read_type) { + uint32_t type_id_read = ctx.read_uint8(ctx.error()); + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return Decimal(); + } + if (type_id_read != static_cast(type_id)) { + ctx.set_error( + Error::type_mismatch(type_id_read, static_cast(type_id))); + return Decimal(); + } + } + return read_data(ctx); + } + + static inline Decimal read_data(ReadContext &ctx) { + int32_t scale = ctx.read_var_int32(ctx.error()); + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return Decimal(); + } + uint64_t header = ctx.read_var_uint64(ctx.error()); + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return Decimal(); + } + if ((header & 1ULL) == 0) { + return Decimal::from_int64(decode_decimal_zigzag64(header >> 1), scale); + } + + uint64_t meta = header >> 1; + uint64_t length64 = meta >> 1; + if (length64 == 0) { + ctx.set_error(Error::invalid_data("Invalid decimal magnitude length 0")); + return Decimal(); + } + if (length64 > ctx.config().max_binary_size) { + ctx.set_error(Error::invalid_data("Binary size exceeds max_binary_size")); + return Decimal(); + } + if (length64 > std::numeric_limits::max()) { + ctx.set_error(Error::invalid_data("Invalid decimal magnitude length " + + std::to_string(length64))); + return Decimal(); + } + + uint32_t length = static_cast(length64); + std::vector payload(length); + ctx.buffer().read_bytes(payload.data(), length, ctx.error()); + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return Decimal(); + } + if (payload.back() == 0) { + ctx.set_error(Error::invalid_data( + "Non-canonical decimal payload: trailing zero byte")); + return Decimal(); + } + + return Decimal(scale, (meta & 1ULL) != 0, std::move(payload)); + } + + static inline Decimal read_data_generic(ReadContext &ctx, bool has_generics) { + (void)has_generics; + return read_data(ctx); + } + + static inline Decimal read_with_type_info(ReadContext &ctx, RefMode ref_mode, + const TypeInfo &) { + return read(ctx, ref_mode, false); + } +}; + +} // namespace serialization +} // namespace fory diff --git a/cpp/fory/serialization/fory.h b/cpp/fory/serialization/fory.h index 25cd5ec255..fee0c71aa5 100644 --- a/cpp/fory/serialization/fory.h +++ b/cpp/fory/serialization/fory.h @@ -24,6 +24,7 @@ #include "fory/serialization/collection_serializer.h" #include "fory/serialization/config.h" #include "fory/serialization/context.h" +#include "fory/serialization/decimal_serializers.h" #include "fory/serialization/map_serializer.h" #include "fory/serialization/serializer.h" #include "fory/serialization/smart_ptr_serializers.h" diff --git a/cpp/fory/serialization/serialization_test.cc b/cpp/fory/serialization/serialization_test.cc index 479fcd038c..c5c51e3f77 100644 --- a/cpp/fory/serialization/serialization_test.cc +++ b/cpp/fory/serialization/serialization_test.cc @@ -212,6 +212,150 @@ TEST(SerializationTest, DurationRoundtrip) { } } +TEST(SerializationTest, DateExposesDaysSinceEpochAccessorAndRoundTrips) { + auto fory = Fory::builder().xlang(true).track_ref(false).build(); + Date original(-1); + + EXPECT_EQ(original.days_since_epoch(), -1); + + auto serialize_result = fory.serialize(original); + ASSERT_TRUE(serialize_result.ok()) + << "Serialization failed: " << serialize_result.error().to_string(); + + std::vector bytes = std::move(serialize_result).value(); + Buffer expected; + expected.write_uint8(0b10); + expected.write_int8(NOT_NULL_VALUE_FLAG); + expected.write_uint8(static_cast(TypeId::DATE)); + expected.write_var_int64(-1); + EXPECT_EQ(bytes, buffer_bytes(expected)); + auto deserialize_result = fory.deserialize(bytes.data(), bytes.size()); + ASSERT_TRUE(deserialize_result.ok()) + << "Deserialization failed: " << deserialize_result.error().to_string(); + EXPECT_EQ(deserialize_result.value(), original); + EXPECT_EQ(deserialize_result.value().days_since_epoch(), -1); +} + +TEST(SerializationTest, DateRejectsXlangDayCountsOutsideInt32Range) { + auto fory = Fory::builder().xlang(true).track_ref(false).build(); + WriteContext write_ctx(fory.config(), fory.type_resolver().clone()); + write_ctx.write_var_int64( + static_cast(std::numeric_limits::max()) + 1); + ASSERT_FALSE(write_ctx.has_error()) << write_ctx.error().to_string(); + + ReadContext read_ctx(fory.config(), fory.type_resolver().clone()); + read_ctx.attach(write_ctx.buffer()); + Date decoded = Serializer::read_data(read_ctx); + (void)decoded; + ASSERT_TRUE(read_ctx.has_error()); + EXPECT_NE(read_ctx.error().to_string().find("exceeds int32_t range"), + std::string::npos); +} + +TEST(SerializationTest, DateSkipConsumesVarInt64DayCount) { + auto fory = Fory::builder().xlang(true).track_ref(false).build(); + WriteContext write_ctx(fory.config(), fory.type_resolver().clone()); + Serializer::write_data(Date(18954), write_ctx); + ASSERT_FALSE(write_ctx.has_error()) << write_ctx.error().to_string(); + + ReadContext read_ctx(fory.config(), fory.type_resolver().clone()); + read_ctx.attach(write_ctx.buffer()); + skip_field_value(read_ctx, + FieldType(static_cast(TypeId::DATE), false), + RefMode::None); + ASSERT_FALSE(read_ctx.has_error()) << read_ctx.error().to_string(); + EXPECT_EQ(read_ctx.buffer().reader_index(), + write_ctx.buffer().writer_index()); +} + +TEST(SerializationTest, DecimalRoundTripsEdgeCases) { + auto fory = Fory::builder().xlang(true).track_ref(false).build(); + std::vector values = { + Decimal::from_int64(0, 0), + Decimal::from_int64(0, 3), + Decimal::from_int64(1, 0), + Decimal::from_int64(-1, 0), + Decimal::from_int64(12345, 2), + Decimal::from_int64(std::numeric_limits::max(), 0), + Decimal::from_int64(std::numeric_limits::min(), 0), + Decimal::from_int64(4611686018427387903LL, 0), + Decimal::from_int64(-4611686018427387904LL, 0), + Decimal::from_bytes(0, false, {0, 0, 0, 0, 0, 0, 0, 128}), + Decimal::from_bytes(0, true, {1, 0, 0, 0, 0, 0, 0, 128}), + Decimal::from_bytes(37, false, + {21, 129, 57, 174, 40, 163, 223, 170, 197, 254, 21, + 96, 165, 233, 224, 92}), + Decimal::from_bytes(-17, true, + {21, 129, 57, 174, 40, 163, 223, 170, 197, 254, 21, + 96, 165, 233, 224, 92}), + }; + + for (const Decimal &original : values) { + auto serialize_result = fory.serialize(original); + ASSERT_TRUE(serialize_result.ok()) + << "Serialization failed: " << serialize_result.error().to_string(); + + std::vector bytes = std::move(serialize_result).value(); + auto deserialize_result = + fory.deserialize(bytes.data(), bytes.size()); + ASSERT_TRUE(deserialize_result.ok()) + << "Deserialization failed: " << deserialize_result.error().to_string(); + EXPECT_EQ(deserialize_result.value(), original); + } +} + +TEST(SerializationTest, DecimalRejectsNonCanonicalBigPayloads) { + auto fory = Fory::builder().xlang(true).track_ref(false).build(); + + Buffer zero_big_encoding; + zero_big_encoding.write_uint8(0b10); + zero_big_encoding.write_int8(NOT_NULL_VALUE_FLAG); + zero_big_encoding.write_uint8(static_cast(TypeId::DECIMAL)); + zero_big_encoding.write_var_int32(0); + zero_big_encoding.write_var_uint64(1); + + auto zero_result = fory.deserialize( + zero_big_encoding.data(), zero_big_encoding.writer_index()); + ASSERT_FALSE(zero_result.ok()); + EXPECT_NE( + zero_result.error().to_string().find("Invalid decimal magnitude length"), + std::string::npos); + + Buffer trailing_zero_payload; + trailing_zero_payload.write_uint8(0b10); + trailing_zero_payload.write_int8(NOT_NULL_VALUE_FLAG); + trailing_zero_payload.write_uint8(static_cast(TypeId::DECIMAL)); + trailing_zero_payload.write_var_int32(0); + trailing_zero_payload.write_var_uint64(9); + trailing_zero_payload.write_bytes("\x01\x00", 2); + + auto trailing_zero_result = fory.deserialize( + trailing_zero_payload.data(), trailing_zero_payload.writer_index()); + ASSERT_FALSE(trailing_zero_result.ok()); + EXPECT_NE(trailing_zero_result.error().to_string().find("trailing zero byte"), + std::string::npos); +} + +TEST(SerializationTest, DecimalSkipConsumesScaleHeaderAndPayload) { + auto fory = Fory::builder().xlang(true).track_ref(false).build(); + WriteContext write_ctx(fory.config(), fory.type_resolver().clone()); + Serializer::write_data( + Decimal::from_bytes(37, false, + {21, 129, 57, 174, 40, 163, 223, 170, 197, 254, 21, + 96, 165, 233, 224, 92}), + write_ctx); + ASSERT_FALSE(write_ctx.has_error()) << write_ctx.error().to_string(); + + ReadContext read_ctx(fory.config(), fory.type_resolver().clone()); + read_ctx.attach(write_ctx.buffer()); + skip_field_value(read_ctx, + FieldType(static_cast(TypeId::DECIMAL), false), + RefMode::None); + ASSERT_FALSE(read_ctx.has_error()) << read_ctx.error().to_string(); + EXPECT_EQ(read_ctx.buffer().reader_index(), + write_ctx.buffer().writer_index()); +} + TEST(SerializationTest, DurationUsesSecondsAndNanosecondsPayload) { struct TestCase { Duration value; diff --git a/cpp/fory/serialization/skip.cc b/cpp/fory/serialization/skip.cc index 99c0b1beee..fdb06dd037 100644 --- a/cpp/fory/serialization/skip.cc +++ b/cpp/fory/serialization/skip.cc @@ -568,9 +568,41 @@ void skip_field_value(ReadContext &ctx, const FieldType &field_type, } case TypeId::DATE: { - // Date is stored as fixed 4-byte day count. - constexpr uint32_t k_bytes = static_cast(sizeof(int32_t)); - ctx.buffer().increase_reader_index(k_bytes, ctx.error()); + // Date is stored as a signed varint64 day count. + ctx.read_var_int64(ctx.error()); + return; + } + + case TypeId::DECIMAL: { + // Decimal is stored as signed varint32 scale + varuint64 header and + // optional little-endian magnitude payload. + ctx.read_var_int32(ctx.error()); + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return; + } + uint64_t header = ctx.read_var_uint64(ctx.error()); + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return; + } + if ((header & 1U) == 0) { + return; + } + uint64_t length64 = header >> 2; + if (length64 == 0) { + ctx.set_error(Error::invalid_data("Invalid decimal magnitude length 0")); + return; + } + if (length64 > ctx.config().max_binary_size) { + ctx.set_error(Error::invalid_data("Binary size exceeds max_binary_size")); + return; + } + if (length64 > std::numeric_limits::max()) { + ctx.set_error(Error::invalid_data("Invalid decimal magnitude length " + + std::to_string(length64))); + return; + } + ctx.buffer().increase_reader_index(static_cast(length64), + ctx.error()); return; } diff --git a/cpp/fory/serialization/temporal_serializers.h b/cpp/fory/serialization/temporal_serializers.h index ac167ff113..86737059ed 100644 --- a/cpp/fory/serialization/temporal_serializers.h +++ b/cpp/fory/serialization/temporal_serializers.h @@ -21,6 +21,7 @@ #include "fory/serialization/serializer.h" #include +#include namespace fory { namespace serialization { @@ -37,17 +38,21 @@ using Timestamp = std::chrono::time_point; /// Date: naive date without timezone as days since Unix epoch -struct Date { - int32_t days_since_epoch; // Days since Jan 1, 1970 UTC +class Date { +public: + Date() : days_since_epoch_(0) {} + explicit Date(int32_t days) : days_since_epoch_(days) {} - Date() : days_since_epoch(0) {} - explicit Date(int32_t days) : days_since_epoch(days) {} + int32_t days_since_epoch() const { return days_since_epoch_; } bool operator==(const Date &other) const { - return days_since_epoch == other.days_since_epoch; + return days_since_epoch_ == other.days_since_epoch_; } bool operator!=(const Date &other) const { return !(*this == other); } + +private: + int32_t days_since_epoch_; // Days since Jan 1, 1970 UTC }; // ============================================================================ @@ -229,7 +234,7 @@ template <> struct Serializer { // ============================================================================ /// Serializer for Date -/// Per xlang spec: serialized as int32 day count since Unix epoch +/// Per xlang spec: serialized as signed varint64 day count since Unix epoch template <> struct Serializer { static constexpr TypeId type_id = TypeId::DATE; @@ -259,7 +264,7 @@ template <> struct Serializer { } static inline void write_data(const Date &date, WriteContext &ctx) { - ctx.write_bytes(&date.days_since_epoch, sizeof(int32_t)); + ctx.write_var_int64(static_cast(date.days_since_epoch())); } static inline void write_data_generic(const Date &date, WriteContext &ctx, @@ -287,9 +292,17 @@ template <> struct Serializer { } static inline Date read_data(ReadContext &ctx) { - Date date; - ctx.read_bytes(&date.days_since_epoch, sizeof(int32_t), ctx.error()); - return date; + int64_t days = ctx.read_var_int64(ctx.error()); + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return Date(); + } + if (FORY_PREDICT_FALSE(days < std::numeric_limits::min() || + days > std::numeric_limits::max())) { + ctx.set_error(Error::invalid_data( + "Date day count " + std::to_string(days) + " exceeds int32_t range")); + return Date(); + } + return Date(static_cast(days)); } static inline Date read_with_type_info(ReadContext &ctx, RefMode ref_mode, diff --git a/cpp/fory/serialization/xlang_test_main.cc b/cpp/fory/serialization/xlang_test_main.cc index e58ec3e252..e8f4247de6 100644 --- a/cpp/fory/serialization/xlang_test_main.cc +++ b/cpp/fory/serialization/xlang_test_main.cc @@ -44,6 +44,7 @@ using ::fory::Buffer; using ::fory::Error; using ::fory::Result; using ::fory::serialization::Date; +using ::fory::serialization::Decimal; using ::fory::serialization::Fory; using ::fory::serialization::ForyBuilder; using ::fory::serialization::Serializer; @@ -939,6 +940,7 @@ void run_test_ref_compatible(const std::string &data_file); void run_test_collection_element_ref_override(const std::string &data_file); void run_test_circular_ref_schema_consistent(const std::string &data_file); void run_test_circular_ref_compatible(const std::string &data_file); +void run_test_decimal(const std::string &data_file); void run_test_unsigned_schema_consistent_simple(const std::string &data_file); void run_test_unsigned_schema_consistent(const std::string &data_file); void run_test_unsigned_schema_compatible(const std::string &data_file); @@ -987,6 +989,8 @@ int main(int argc, char **argv) { run_test_map(data_file); } else if (case_name == "test_integer") { run_test_integer(data_file); + } else if (case_name == "test_decimal") { + run_test_decimal(data_file); } else if (case_name == "test_item") { run_test_item(data_file); } else if (case_name == "test_color") { @@ -1653,6 +1657,44 @@ void run_test_integer(const std::string &data_file) { write_file(data_file, out); } +void run_test_decimal(const std::string &data_file) { + auto bytes = read_file(data_file); + auto fory = build_fory(true, true); + std::vector expected_values = { + Decimal::from_int64(0, 0), + Decimal::from_int64(0, 3), + Decimal::from_int64(1, 0), + Decimal::from_int64(-1, 0), + Decimal::from_int64(12345, 2), + Decimal::from_int64(std::numeric_limits::max(), 0), + Decimal::from_int64(std::numeric_limits::min(), 0), + Decimal::from_int64(4611686018427387903LL, 0), + Decimal::from_int64(-4611686018427387904LL, 0), + Decimal::from_bytes(0, false, {0, 0, 0, 0, 0, 0, 0, 128}), + Decimal::from_bytes(0, true, {1, 0, 0, 0, 0, 0, 0, 128}), + Decimal::from_bytes(37, false, + {21, 129, 57, 174, 40, 163, 223, 170, 197, 254, 21, + 96, 165, 233, 224, 92}), + Decimal::from_bytes(-17, true, + {21, 129, 57, 174, 40, 163, 223, 170, 197, 254, 21, + 96, 165, 233, 224, 92}), + }; + + Buffer buffer = make_buffer(bytes); + std::vector out; + for (size_t i = 0; i < expected_values.size(); ++i) { + Decimal actual = read_next(fory, buffer); + if (actual != expected_values[i]) { + fail("Decimal value mismatch at index " + std::to_string(i)); + } + append_serialized(fory, actual, out); + } + if (buffer.remaining_size() != 0) { + fail("Unexpected trailing bytes after decimal payload"); + } + write_file(data_file, out); +} + void run_test_item(const std::string &data_file) { auto bytes = read_file(data_file); auto fory = build_fory(true, true); diff --git a/csharp/src/Fory/DecimalSerializer.cs b/csharp/src/Fory/DecimalSerializer.cs new file mode 100644 index 0000000000..0e460db4b4 --- /dev/null +++ b/csharp/src/Fory/DecimalSerializer.cs @@ -0,0 +1,122 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +using System.Numerics; + +namespace Apache.Fory; + +public readonly record struct ForyDecimal(BigInteger UnscaledValue, int Scale); + +internal sealed class ForyDecimalSerializer : Serializer +{ + public override ForyDecimal DefaultValue => default; + + public override void WriteData(WriteContext context, in ForyDecimal value, bool hasGenerics) + { + _ = hasGenerics; + DecimalCodec.Write(context.Writer, value.Scale, value.UnscaledValue); + } + + public override ForyDecimal ReadData(ReadContext context) + { + (int scale, BigInteger unscaled) = DecimalCodec.Read(context.Reader); + return new ForyDecimal(unscaled, scale); + } +} + +internal static class DecimalCodec +{ + private static readonly BigInteger LongMin = long.MinValue; + private static readonly BigInteger LongMax = long.MaxValue; + + public static void Write(ByteWriter buffer, int scale, BigInteger unscaled) + { + buffer.WriteVarInt32(scale); + if (CanUseSmallEncoding(unscaled)) + { + long smallValue = (long)unscaled; + ulong zigzag = EncodeZigZag64(smallValue); + buffer.WriteVarUInt64(zigzag << 1); + return; + } + + BigInteger magnitude = BigInteger.Abs(unscaled); + if (magnitude.IsZero) + { + throw new InvalidDataException("zero must use the small decimal encoding"); + } + + byte[] payload = magnitude.ToByteArray(isUnsigned: true, isBigEndian: false); + ulong meta = ((ulong)payload.Length << 1) | (unscaled.Sign < 0 ? 1UL : 0UL); + ulong header = (meta << 1) | 1UL; + buffer.WriteVarUInt64(header); + buffer.WriteBytes(payload); + } + + public static (int Scale, BigInteger Unscaled) Read(ByteReader buffer) + { + int scale = buffer.ReadVarInt32(); + ulong header = buffer.ReadVarUInt64(); + if ((header & 1UL) == 0UL) + { + return (scale, new BigInteger(DecodeZigZag64(header >> 1))); + } + + ulong meta = header >> 1; + ulong lenLong = meta >> 1; + if (lenLong == 0 || lenLong > int.MaxValue) + { + throw new InvalidDataException($"invalid decimal magnitude length {lenLong}"); + } + + int length = checked((int)lenLong); + byte[] payload = buffer.ReadBytes(length); + if (payload[^1] == 0) + { + throw new InvalidDataException("non-canonical decimal payload: trailing zero byte"); + } + + BigInteger magnitude = new(payload, isUnsigned: true, isBigEndian: false); + if (magnitude.IsZero) + { + throw new InvalidDataException("big decimal encoding must not represent zero"); + } + + return (scale, (meta & 1UL) == 0UL ? magnitude : BigInteger.Negate(magnitude)); + } + + private static bool CanUseSmallEncoding(BigInteger value) + { + if (value < LongMin || value > LongMax) + { + return false; + } + + ulong zigzag = EncodeZigZag64((long)value); + return (zigzag & (1UL << 63)) == 0; + } + + private static ulong EncodeZigZag64(long value) + { + return unchecked((ulong)((value << 1) ^ (value >> 63))); + } + + private static long DecodeZigZag64(ulong value) + { + return unchecked((long)((value >> 1) ^ (ulong)-(long)(value & 1UL))); + } +} diff --git a/csharp/src/Fory/FieldSkipper.cs b/csharp/src/Fory/FieldSkipper.cs index a2549f7e1d..30353bb66c 100644 --- a/csharp/src/Fory/FieldSkipper.cs +++ b/csharp/src/Fory/FieldSkipper.cs @@ -72,6 +72,8 @@ public static void SkipFieldValue(ReadContext context, TypeMetaFieldType fieldTy return context.TypeResolver.GetSerializer().Read(context, refMode, false); case (uint)TypeId.String: return context.TypeResolver.GetSerializer().Read(context, refMode, false); + case (uint)TypeId.Decimal: + return context.TypeResolver.GetSerializer().Read(context, refMode, false); case (uint)TypeId.List: { if (fieldType.Generics.Count != 1 || fieldType.Generics[0].TypeId != (uint)TypeId.String) diff --git a/csharp/src/Fory/TimeSerializers.cs b/csharp/src/Fory/TimeSerializers.cs index 61b2163155..ffc675a5b8 100644 --- a/csharp/src/Fory/TimeSerializers.cs +++ b/csharp/src/Fory/TimeSerializers.cs @@ -23,13 +23,13 @@ internal static class TimeCodec public static void WriteDate(WriteContext context, in DateOnly value) { - context.Writer.WriteInt32(value.DayNumber - Epoch.DayNumber); + context.Writer.WriteVarInt64(value.DayNumber - Epoch.DayNumber); } public static DateOnly ReadDate(ReadContext context) { - int days = context.Reader.ReadInt32(); - return DateOnly.FromDayNumber(Epoch.DayNumber + days); + long days = context.Reader.ReadVarInt64(); + return DateOnly.FromDayNumber(checked(Epoch.DayNumber + (int)days)); } public static DateTimeOffset ToDateTimeOffset(in DateTime value) diff --git a/csharp/src/Fory/TypeInfo.cs b/csharp/src/Fory/TypeInfo.cs index 71533672ba..31bc821c8d 100644 --- a/csharp/src/Fory/TypeInfo.cs +++ b/csharp/src/Fory/TypeInfo.cs @@ -317,6 +317,12 @@ private static bool TryResolveBuiltInTypeId(Type type, out TypeId typeId) return true; } + if (type == typeof(ForyDecimal)) + { + typeId = TypeId.Decimal; + return true; + } + if (type == typeof(byte[])) { typeId = TypeId.Binary; diff --git a/csharp/src/Fory/TypeResolver.cs b/csharp/src/Fory/TypeResolver.cs index 64f310fead..2c70e01286 100644 --- a/csharp/src/Fory/TypeResolver.cs +++ b/csharp/src/Fory/TypeResolver.cs @@ -989,6 +989,7 @@ private TypeInfo ResolveAnyBuiltInTypeInfo(TypeId wireTypeId) TypeId.Float32 => GetTypeInfo().WithWireTypeInfo(wireTypeId), TypeId.Float64 => GetTypeInfo().WithWireTypeInfo(wireTypeId), TypeId.String => GetTypeInfo().WithWireTypeInfo(wireTypeId), + TypeId.Decimal => GetTypeInfo().WithWireTypeInfo(wireTypeId), TypeId.Date => GetTypeInfo().WithWireTypeInfo(wireTypeId), TypeId.Timestamp => GetTypeInfo().WithWireTypeInfo(wireTypeId), TypeId.Duration => GetTypeInfo().WithWireTypeInfo(wireTypeId), @@ -1452,6 +1453,11 @@ private TypeInfo CreateBindingCore(Type type) return TypeInfo.Create(type, new StringSerializer()); } + if (type == typeof(ForyDecimal)) + { + return TypeInfo.Create(type, new ForyDecimalSerializer()); + } + if (type == typeof(byte[])) { return TypeInfo.Create(type, new BinarySerializer()); diff --git a/csharp/tests/Fory.Tests/RuntimeEdgeCaseTests.cs b/csharp/tests/Fory.Tests/RuntimeEdgeCaseTests.cs index f2fc81e3a4..a74b21fd3a 100644 --- a/csharp/tests/Fory.Tests/RuntimeEdgeCaseTests.cs +++ b/csharp/tests/Fory.Tests/RuntimeEdgeCaseTests.cs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +using System.Numerics; using Apache.Fory; using ForyRuntime = Apache.Fory.Fory; @@ -49,6 +50,13 @@ public sealed class CustomPayload public string Marker { get; set; } = string.Empty; } +[ForyObject] +public sealed class DecimalEnvelope +{ + public ForyDecimal Exact { get; set; } + public List History { get; set; } = []; +} + public sealed class CustomPayloadSerializer : Serializer { public override CustomPayload DefaultValue => null!; @@ -160,6 +168,133 @@ public void TimeSpanUsesVarIntSeconds() Assert.Equal(0, reader.Remaining); } + [Fact] + public void DateOnlyUsesVarInt64Days() + { + ForyRuntime fory = ForyRuntime.Builder().Build(); + byte[] payload = fory.Serialize(new DateOnly(2021, 11, 23)); + + ByteReader reader = new(payload); + Assert.False(fory.ReadHead(reader)); + Assert.Equal((sbyte)RefFlag.NotNullValue, reader.ReadInt8()); + Assert.Equal((uint)TypeId.Date, reader.ReadUInt8()); + Assert.Equal(18_954L, reader.ReadVarInt64()); + Assert.Equal(0, reader.Remaining); + } + + [Fact] + public void DecimalRoundTripEdgeCases() + { + ForyRuntime fory = ForyRuntime.Builder().Build(); + ForyDecimal[] values = + [ + new(BigInteger.Zero, 0), + new(BigInteger.Zero, 3), + new(BigInteger.One, 0), + new(BigInteger.MinusOne, 0), + new(new BigInteger(12_345), 2), + new(new BigInteger(long.MaxValue), 0), + new(new BigInteger(long.MinValue), 0), + new(new BigInteger(long.MaxValue) + BigInteger.One, 0), + new(new BigInteger(long.MinValue) - BigInteger.One, 0), + new(BigInteger.Parse("123456789012345678901234567890123456789"), 37), + new(BigInteger.Parse("-123456789012345678901234567890123456789"), -17), + ]; + + foreach (ForyDecimal value in values) + { + Assert.Equal(value, fory.Deserialize(fory.Serialize(value))); + } + } + + [Fact] + public void DecimalFieldsAndDynamicAnyRoundTrip() + { + ForyRuntime fory = ForyRuntime.Builder().Build(); + fory.Register(706); + fory.Register(707); + + DecimalEnvelope envelope = new() + { + Exact = new(BigInteger.Parse("987654321098765432109876543210"), 9), + History = + [ + new ForyDecimal(BigInteger.Zero, 2), + new ForyDecimal(BigInteger.Parse("-12345678901234567890"), 4), + new ForyDecimal(BigInteger.Parse("9223372036854775808"), 0), + ], + }; + DecimalEnvelope decodedEnvelope = fory.Deserialize(fory.Serialize(envelope)); + Assert.Equal(envelope.Exact, decodedEnvelope.Exact); + Assert.Equal(envelope.History, decodedEnvelope.History); + + DynamicAnyHolder anyHolder = new() + { + AnyValue = envelope.Exact, + AnySet = [envelope.History[1], "marker"], + AnyMap = new Dictionary + { + ["decimal"] = envelope.History[2], + [envelope.History[0]] = "scaled-zero", + }, + }; + DynamicAnyHolder decodedAny = fory.Deserialize(fory.Serialize(anyHolder)); + Assert.Equal(anyHolder.AnyValue, decodedAny.AnyValue); + Assert.Contains(envelope.History[1], decodedAny.AnySet); + Assert.Equal(envelope.History[2], decodedAny.AnyMap["decimal"]); + Assert.Equal("scaled-zero", decodedAny.AnyMap[envelope.History[0]]); + } + + [Fact] + public void DecimalUsesCanonicalWireFormat() + { + ForyRuntime fory = ForyRuntime.Builder().Build(); + byte[] payload = fory.Serialize(new ForyDecimal(BigInteger.Zero, 2)); + + ByteReader reader = new(payload); + Assert.False(fory.ReadHead(reader)); + Assert.Equal((sbyte)RefFlag.NotNullValue, reader.ReadInt8()); + Assert.Equal((uint)TypeId.Decimal, reader.ReadUInt8()); + Assert.Equal(2, reader.ReadVarInt32()); + Assert.Equal(0UL, reader.ReadVarUInt64()); + Assert.Equal(0, reader.Remaining); + + payload = fory.Serialize(new ForyDecimal(BigInteger.Parse("9223372036854775808"), 0)); + reader.Reset(payload); + Assert.False(fory.ReadHead(reader)); + Assert.Equal((sbyte)RefFlag.NotNullValue, reader.ReadInt8()); + Assert.Equal((uint)TypeId.Decimal, reader.ReadUInt8()); + Assert.Equal(0, reader.ReadVarInt32()); + ulong header = reader.ReadVarUInt64(); + Assert.Equal(1UL, header & 1UL); + Assert.True(reader.Remaining > 0); + } + + [Fact] + public void DecimalRejectsNonCanonicalBigPayload() + { + ForyRuntime fory = ForyRuntime.Builder().Build(); + ByteWriter writer = new(); + fory.WriteHead(writer, isNone: false); + writer.WriteInt8((sbyte)RefFlag.NotNullValue); + writer.WriteUInt8((byte)TypeId.Decimal); + writer.WriteVarInt32(0); + writer.WriteVarUInt64(1UL); + _ = Assert.Throws(() => fory.Deserialize(writer.ToArray())); + + writer.Reset(); + fory.WriteHead(writer, isNone: false); + writer.WriteInt8((sbyte)RefFlag.NotNullValue); + writer.WriteUInt8((byte)TypeId.Decimal); + writer.WriteVarInt32(0); + writer.WriteVarUInt64(((((ulong)2 << 1) | 0UL) << 1) | 1UL); + writer.WriteBytes([0x01, 0x00]); + + InvalidDataException trailingZeroException = + Assert.Throws(() => fory.Deserialize(writer.ToArray())); + Assert.Contains("trailing zero byte", trailingZeroException.Message); + } + [Fact] public void TimestampNormalizesNegativeFractionalSecond() { diff --git a/csharp/tests/Fory.XlangPeer/Program.cs b/csharp/tests/Fory.XlangPeer/Program.cs index c3d939a9fd..11809c4448 100644 --- a/csharp/tests/Fory.XlangPeer/Program.cs +++ b/csharp/tests/Fory.XlangPeer/Program.cs @@ -16,6 +16,7 @@ // under the License. using System.Buffers; +using System.Numerics; using System.Text; using Apache.Fory; using ForyRuntime = Apache.Fory.Fory; @@ -117,6 +118,23 @@ internal static class Program long.MaxValue, ]; + private static readonly ForyDecimal[] DecimalValues = + [ + new(BigInteger.Zero, 0), + new(BigInteger.Zero, 3), + new(BigInteger.One, 0), + new(BigInteger.MinusOne, 0), + new(new BigInteger(12345), 2), + new(BigInteger.Parse("9223372036854775807"), 0), + new(BigInteger.Parse("-9223372036854775808"), 0), + new(BigInteger.Parse("4611686018427387903"), 0), + new(BigInteger.Parse("-4611686018427387904"), 0), + new(BigInteger.Parse("9223372036854775808"), 0), + new(BigInteger.Parse("-9223372036854775809"), 0), + new(BigInteger.Parse("123456789012345678901234567890123456789"), 37), + new(BigInteger.Parse("-123456789012345678901234567890123456789"), -17), + ]; + private static int Main(string[] args) { try @@ -180,6 +198,7 @@ private static byte[] ExecuteCase(string caseName, byte[] input) "test_list" => CaseList(input), "test_map" => CaseMap(input), "test_integer" => CaseInteger(input), + "test_decimal" => CaseDecimal(input), "test_item" => CaseItem(input), "test_color" => CaseColor(input), "test_union_xlang" => CaseUnionXlang(input), @@ -515,6 +534,23 @@ private static byte[] CaseInteger(byte[] input) return output.ToArray(); } + private static byte[] CaseDecimal(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true); + ReadOnlySequence sequence = new(input); + List output = []; + + for (int i = 0; i < DecimalValues.Length; i++) + { + ForyDecimal value = fory.Deserialize(ref sequence); + Ensure(value == DecimalValues[i], $"decimal {i} mismatch"); + Append(output, fory.Serialize(value)); + } + + EnsureConsumed(sequence, nameof(CaseDecimal)); + return output.ToArray(); + } + private static byte[] CaseItem(byte[] input) { ForyRuntime fory = BuildFory(compatible: true); diff --git a/dart/packages/fory-test/test/cross_lang_test/xlang_test_main.dart b/dart/packages/fory-test/test/cross_lang_test/xlang_test_main.dart index 184f1f2da6..8d45d65380 100644 --- a/dart/packages/fory-test/test/cross_lang_test/xlang_test_main.dart +++ b/dart/packages/fory-test/test/cross_lang_test/xlang_test_main.dart @@ -39,11 +39,30 @@ void _writeFile(Uint8List bytes) { File(_dataFilePath()).writeAsBytesSync(bytes, flush: true); } +Decimal _decimal(String unscaled, int scale) { + return Decimal(BigInt.parse(unscaled), scale); +} + +List _decimalValues() { + return [ + Decimal.zero(), + Decimal.zero(3), + Decimal.fromInt(1), + Decimal.fromInt(-1), + Decimal.fromInt(12345, scale: 2), + _decimal('9223372036854775807', 0), + _decimal('-9223372036854775808', 0), + _decimal('4611686018427387903', 0), + _decimal('-4611686018427387904', 0), + _decimal('9223372036854775808', 0), + _decimal('-9223372036854775809', 0), + _decimal('123456789012345678901234567890123456789', 37), + _decimal('-123456789012345678901234567890123456789', -17), + ]; +} + Fory _newFory({bool compatible = false}) { - return Fory( - compatible: compatible, - checkStructVersion: !compatible, - ); + return Fory(compatible: compatible, checkStructVersion: !compatible); } void _roundTripFory(Fory fory, {bool trackRef = false}) { @@ -183,9 +202,7 @@ void _verifyVarBufferCase() { void _verifyMurmurCase() { final data = _readFile(); final shortHash = murmurHash3X64_128(const [1, 2, 8]); - final textHash = murmurHash3X64_128( - utf8.encode('01234567890123456789'), - ); + final textHash = murmurHash3X64_128(utf8.encode('01234567890123456789')); if (data.length == 32) { final expected = BytesBuilder(copy: false) ..add(_hashBytes(shortHash.$1, shortHash.$2)) @@ -207,6 +224,24 @@ void _verifyMurmurCase() { throw StateError('Unexpected MurmurHash3 payload size ${data.length}.'); } +void _verifyDecimalCase() { + final fory = _newFory(compatible: true); + final input = Buffer.wrap(_readFile()); + final output = BytesBuilder(copy: false); + final expectedValues = _decimalValues(); + for (var index = 0; index < expectedValues.length; index += 1) { + final actual = fory.deserializeFrom(input); + if (actual != expectedValues[index]) { + throw StateError('Unexpected decimal at index $index: $actual'); + } + output.add(fory.serialize(actual)); + } + if (input.readableBytes != 0) { + throw StateError('Unexpected trailing bytes after decimal payload.'); + } + _writeFile(output.takeBytes()); +} + Uint8List _hashBytes(int low, int high) { final buffer = Buffer(); buffer.writeInt64(low); @@ -274,10 +309,7 @@ void _runCollectionElementRefOverride() { final shared = container.listField.first; final output = RefOverrideContainer() ..listField = [shared, shared] - ..mapField = { - 'k1': shared, - 'k2': shared, - }; + ..mapField = {'k1': shared, 'k2': shared}; _writeFile(fory.serialize(output, trackRef: true)); } @@ -361,6 +393,9 @@ void _runCase(String caseName) { registerXlangType(fory, Item1, id: 101); _roundTripFory(fory); return; + case 'test_decimal': + _verifyDecimalCase(); + return; case 'test_color': final fory = _newFory(compatible: true); registerXlangType(fory, Color, id: 101); diff --git a/dart/packages/fory/lib/fory.dart b/dart/packages/fory/lib/fory.dart index 67cc27ab08..86232c609a 100644 --- a/dart/packages/fory/lib/fory.dart +++ b/dart/packages/fory/lib/fory.dart @@ -50,6 +50,7 @@ export 'src/serializer/enum_serializer.dart'; export 'src/serializer/serializer.dart'; export 'src/serializer/union_serializer.dart'; export 'src/types/bfloat16.dart'; +export 'src/types/decimal.dart'; export 'src/types/float16.dart'; export 'src/types/float32.dart'; export 'src/types/int16.dart'; diff --git a/dart/packages/fory/lib/src/codegen/fory_generator.dart b/dart/packages/fory/lib/src/codegen/fory_generator.dart index 86a22e9057..49ef309b8c 100644 --- a/dart/packages/fory/lib/src/codegen/fory_generator.dart +++ b/dart/packages/fory/lib/src/codegen/fory_generator.dart @@ -1182,8 +1182,9 @@ GeneratedFieldType( case TypeIds.int32: case TypeIds.uint32: case TypeIds.float32: - case TypeIds.date: return 4; + case TypeIds.date: + return 10; case TypeIds.int64: case TypeIds.uint64: case TypeIds.float64: @@ -1253,6 +1254,8 @@ GeneratedFieldType( return 'context.writeString($valueExpression)'; case TypeIds.binary: return 'writeGeneratedBinaryValue(context, $valueExpression)'; + case TypeIds.decimal: + return 'writeGeneratedDecimalValue(context, $valueExpression)'; case TypeIds.date: return 'writeGeneratedLocalDateValue(context, $valueExpression)'; case TypeIds.duration: @@ -1330,7 +1333,7 @@ GeneratedFieldType( case TypeIds.float64: return '$cursorExpression.writeFloat64(${_directGeneratedScalarExpression(field, valueExpression)})'; case TypeIds.date: - return '$cursorExpression.writeInt32($valueExpression.toEpochDay())'; + return '$cursorExpression.writeVarInt64($valueExpression.toEpochDay())'; case TypeIds.duration: return '$cursorExpression.writeVarInt64(generatedDurationWireSeconds($valueExpression)); $cursorExpression.writeInt32(generatedDurationWireNanoseconds($valueExpression))'; case TypeIds.timestamp: @@ -1414,6 +1417,8 @@ GeneratedFieldType( return 'context.readString()'; case TypeIds.binary: return 'readGeneratedBinaryValue(context)'; + case TypeIds.decimal: + return 'readGeneratedDecimalValue(context)'; case TypeIds.date: return 'readGeneratedLocalDateValue(context)'; case TypeIds.duration: @@ -1521,7 +1526,7 @@ GeneratedFieldType( case TypeIds.float64: return '$cursorExpression.readFloat64()'; case TypeIds.date: - return 'LocalDate.fromEpochDay($cursorExpression.readInt32())'; + return 'LocalDate.fromEpochDay($cursorExpression.readVarInt64())'; case TypeIds.duration: return 'readGeneratedDurationFromWire($cursorExpression.readVarInt64(), $cursorExpression.readInt32())'; case TypeIds.timestamp: @@ -1864,6 +1869,7 @@ GeneratedFieldType( switch (typeId) { case TypeIds.string: case TypeIds.binary: + case TypeIds.decimal: case TypeIds.date: case TypeIds.duration: case TypeIds.timestamp: @@ -2101,6 +2107,8 @@ GeneratedFieldType( return TypeIds.bfloat16; case 'Float32': return TypeIds.float32; + case 'Decimal': + return TypeIds.decimal; case 'Timestamp': case 'DateTime': return TypeIds.timestamp; diff --git a/dart/packages/fory/lib/src/codegen/generated_support.dart b/dart/packages/fory/lib/src/codegen/generated_support.dart index d22d4f135b..7fbc1de1bb 100644 --- a/dart/packages/fory/lib/src/codegen/generated_support.dart +++ b/dart/packages/fory/lib/src/codegen/generated_support.dart @@ -622,6 +622,16 @@ LocalDate readGeneratedLocalDateValue(ReadContext context) { return const LocalDateSerializer().read(context); } +@internal +void writeGeneratedDecimalValue(WriteContext context, Decimal value) { + const DecimalSerializer().write(context, value); +} + +@internal +Decimal readGeneratedDecimalValue(ReadContext context) { + return const DecimalSerializer().read(context); +} + @internal int generatedDurationWireSeconds(Duration value) { return durationWireSeconds(value); diff --git a/dart/packages/fory/lib/src/context/read_context.dart b/dart/packages/fory/lib/src/context/read_context.dart index 0d9ef02ecb..d2c5a6d549 100644 --- a/dart/packages/fory/lib/src/context/read_context.dart +++ b/dart/packages/fory/lib/src/context/read_context.dart @@ -304,6 +304,8 @@ final class ReadContext { return null; case TypeIds.string: return StringSerializer.readPayload(this); + case TypeIds.decimal: + return DecimalSerializer.readPayload(this); case TypeIds.binary: return BinarySerializer.readPayload(this); case TypeIds.boolArray: diff --git a/dart/packages/fory/lib/src/meta/type_ids.dart b/dart/packages/fory/lib/src/meta/type_ids.dart index b8c0e5c914..e841c8b2e4 100644 --- a/dart/packages/fory/lib/src/meta/type_ids.dart +++ b/dart/packages/fory/lib/src/meta/type_ids.dart @@ -58,6 +58,7 @@ abstract final class TypeIds { static const int duration = 37; static const int timestamp = 38; static const int date = 39; + static const int decimal = 40; static const int binary = 41; static const int boolArray = 43; static const int int8Array = 44; @@ -117,6 +118,7 @@ abstract final class TypeIds { typeId == binary || typeId == duration || typeId == timestamp || + typeId == decimal || typeId == date || typeId == boolArray || typeId == int8Array || @@ -141,6 +143,7 @@ abstract final class TypeIds { typeId == binary || typeId == duration || typeId == timestamp || + typeId == decimal || typeId == date) { return false; } diff --git a/dart/packages/fory/lib/src/meta/type_meta.dart b/dart/packages/fory/lib/src/meta/type_meta.dart index d9829b2cc0..5a40c85060 100644 --- a/dart/packages/fory/lib/src/meta/type_meta.dart +++ b/dart/packages/fory/lib/src/meta/type_meta.dart @@ -269,6 +269,7 @@ final class WireTypeMetaDecoder { wireTypeId == TypeIds.none || wireTypeId == TypeIds.binary || wireTypeId == TypeIds.duration || + wireTypeId == TypeIds.decimal || wireTypeId == TypeIds.date || wireTypeId == TypeIds.timestamp || wireTypeId == TypeIds.boolArray || diff --git a/dart/packages/fory/lib/src/resolver/type_resolver.dart b/dart/packages/fory/lib/src/resolver/type_resolver.dart index f090cf658d..3c39a06540 100644 --- a/dart/packages/fory/lib/src/resolver/type_resolver.dart +++ b/dart/packages/fory/lib/src/resolver/type_resolver.dart @@ -43,6 +43,7 @@ import 'package:fory/src/serializer/time_serializers.dart'; import 'package:fory/src/serializer/typed_array_serializers.dart'; import 'package:fory/src/serializer/union_serializer.dart'; import 'package:fory/src/types/bfloat16.dart'; +import 'package:fory/src/types/decimal.dart'; import 'package:fory/src/types/float16.dart'; import 'package:fory/src/types/float32.dart'; import 'package:fory/src/types/int16.dart'; @@ -352,6 +353,9 @@ final class TypeResolver { if (value is String) { return _builtin(String, TypeIds.string); } + if (value is Decimal) { + return _builtin(Decimal, TypeIds.decimal); + } if (value is Uint8List) { return _builtin(Uint8List, TypeIds.binary); } @@ -449,6 +453,7 @@ final class TypeResolver { case TypeIds.duration: case TypeIds.timestamp: case TypeIds.date: + case TypeIds.decimal: case TypeIds.boolArray: case TypeIds.int8Array: case TypeIds.int16Array: @@ -835,7 +840,8 @@ final class TypeResolver { fieldType.typeId == TypeIds.binary || fieldType.typeId == TypeIds.date || fieldType.typeId == TypeIds.duration || - fieldType.typeId == TypeIds.timestamp) { + fieldType.typeId == TypeIds.timestamp || + fieldType.typeId == TypeIds.decimal) { return fieldType.typeId; } return fieldType.ref ? TypeIds.unknown : fieldType.typeId; @@ -1050,6 +1056,8 @@ final class TypeResolver { return _builtin(Map, TypeIds.map); case TypeIds.none: return _builtin(Null, TypeIds.none); + case TypeIds.decimal: + return _builtin(Decimal, TypeIds.decimal); case TypeIds.binary: return _builtin(Uint8List, TypeIds.binary); case TypeIds.date: @@ -1158,6 +1166,8 @@ final class TypeResolver { return stringSerializer as Serializer; case TypeIds.none: return noneSerializer as Serializer; + case TypeIds.decimal: + return decimalSerializer as Serializer; case TypeIds.binary: case TypeIds.uint8Array: return binarySerializer as Serializer; @@ -1255,6 +1265,9 @@ final class TypeResolver { if (type == Float32) { return TypeIds.float32; } + if (type == Decimal) { + return TypeIds.decimal; + } if (type == Float16List) { return TypeIds.float16Array; } diff --git a/dart/packages/fory/lib/src/serializer/scalar_serializers.dart b/dart/packages/fory/lib/src/serializer/scalar_serializers.dart index 3fe0127f58..9d715cc2ca 100644 --- a/dart/packages/fory/lib/src/serializer/scalar_serializers.dart +++ b/dart/packages/fory/lib/src/serializer/scalar_serializers.dart @@ -23,6 +23,38 @@ import 'package:fory/src/context/read_context.dart'; import 'package:fory/src/context/write_context.dart'; import 'package:fory/src/serializer/serializer.dart'; import 'package:fory/src/string_encoding.dart'; +import 'package:fory/src/types/decimal.dart'; + +// The small form reserves the low header bit to distinguish small/big +// encodings, so the zigzag value itself must still fit in 63 bits before the +// final << 1. +final BigInt _decimalSmallMin = -(BigInt.one << 62); +final BigInt _decimalSmallMax = (BigInt.one << 62) - BigInt.one; + +bool _canUseSmallDecimalEncoding(BigInt value) { + return value >= _decimalSmallMin && value <= _decimalSmallMax; +} + +Uint8List _decimalMagnitudeToCanonicalLittleEndian(BigInt magnitude) { + if (magnitude == BigInt.zero) { + throw StateError('Zero must use the small decimal encoding.'); + } + final bytes = []; + var remaining = magnitude; + while (remaining > BigInt.zero) { + bytes.add((remaining & BigInt.from(0xff)).toInt()); + remaining >>= 8; + } + return Uint8List.fromList(bytes); +} + +BigInt _decimalMagnitudeFromCanonicalLittleEndian(Uint8List payload) { + var magnitude = BigInt.zero; + for (var index = payload.length - 1; index >= 0; index -= 1) { + magnitude = (magnitude << 8) | BigInt.from(payload[index]); + } + return magnitude; +} final class NoneSerializer extends Serializer { const NoneSerializer(); @@ -104,6 +136,70 @@ final class BinarySerializer extends Serializer { } } +final class DecimalSerializer extends Serializer { + const DecimalSerializer(); + + @override + bool get supportsRef => false; + + @override + void write(WriteContext context, Decimal value) { + writePayload(context, value); + } + + @override + Decimal read(ReadContext context) { + return readPayload(context); + } + + static void writePayload(WriteContext context, Decimal value) { + final buffer = context.buffer; + final unscaled = value.unscaledValue; + buffer.writeVarInt32(value.scale); + if (_canUseSmallDecimalEncoding(unscaled)) { + final smallValue = unscaled.toInt(); + final zigZag = (smallValue << 1) ^ (smallValue >> 63); + buffer.writeVarUint64(zigZag << 1); + return; + } + + final payload = _decimalMagnitudeToCanonicalLittleEndian(unscaled.abs()); + final sign = unscaled.isNegative ? 1 : 0; + final meta = (payload.length << 1) | sign; + buffer.writeVarUint64((meta << 1) | 1); + buffer.writeBytes(payload); + } + + static Decimal readPayload(ReadContext context) { + final scale = context.buffer.readVarInt32(); + final header = context.buffer.readVarUint64(); + if ((header & 1) == 0) { + final zigZag = header >>> 1; + final unscaled = (zigZag >>> 1) ^ -(zigZag & 1); + return Decimal(BigInt.from(unscaled), scale); + } + + final meta = header >>> 1; + final length = meta >>> 1; + if (length <= 0) { + throw StateError('Invalid decimal magnitude length $length.'); + } + final payload = context.buffer.copyBytes(length); + if (payload[length - 1] == 0) { + throw StateError( + 'Non-canonical decimal payload: trailing zero byte.', + ); + } + final magnitude = _decimalMagnitudeFromCanonicalLittleEndian(payload); + if (magnitude == BigInt.zero) { + throw StateError('Big decimal encoding must not represent zero.'); + } + final sign = meta & 1; + return Decimal(sign == 0 ? magnitude : -magnitude, scale); + } +} + const NoneSerializer noneSerializer = NoneSerializer(); const StringSerializer stringSerializer = StringSerializer(); const BinarySerializer binarySerializer = BinarySerializer(); +const DecimalSerializer decimalSerializer = DecimalSerializer(); diff --git a/dart/packages/fory/lib/src/serializer/time_serializers.dart b/dart/packages/fory/lib/src/serializer/time_serializers.dart index 6ce2c8a9bf..985f1e57cd 100644 --- a/dart/packages/fory/lib/src/serializer/time_serializers.dart +++ b/dart/packages/fory/lib/src/serializer/time_serializers.dart @@ -144,12 +144,12 @@ final class LocalDateSerializer extends Serializer { @override void write(WriteContext context, LocalDate value) { - context.buffer.writeInt32(value.toEpochDay()); + context.buffer.writeVarInt64(value.toEpochDay()); } @override LocalDate read(ReadContext context) { - return LocalDate.fromEpochDay(context.buffer.readInt32()); + return LocalDate.fromEpochDay(context.buffer.readVarInt64()); } } diff --git a/dart/packages/fory/lib/src/types/decimal.dart b/dart/packages/fory/lib/src/types/decimal.dart new file mode 100644 index 0000000000..cd2d8e7ecc --- /dev/null +++ b/dart/packages/fory/lib/src/types/decimal.dart @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/// Exact decimal value represented as `unscaledValue * 10^-scale`. +final class Decimal { + /// The exact integer significand. + final BigInt unscaledValue; + + /// The decimal scale. + final int scale; + + /// Creates a decimal from its unscaled integer value and decimal [scale]. + const Decimal(this.unscaledValue, this.scale); + + /// Creates a zero decimal with the given [scale]. + factory Decimal.zero([int scale = 0]) { + return Decimal(BigInt.zero, scale); + } + + /// Creates a decimal from a small integer value and optional [scale]. + factory Decimal.fromInt(int value, {int scale = 0}) { + return Decimal(BigInt.from(value), scale); + } + + @override + bool operator ==(Object other) => + identical(this, other) || + other is Decimal && + other.scale == scale && + other.unscaledValue == unscaledValue; + + @override + int get hashCode => Object.hash(unscaledValue, scale); + + @override + String toString() => '${unscaledValue}e${-scale}'; +} diff --git a/dart/packages/fory/lib/src/types/local_date.dart b/dart/packages/fory/lib/src/types/local_date.dart index 4f82868803..acd72fc3c4 100644 --- a/dart/packages/fory/lib/src/types/local_date.dart +++ b/dart/packages/fory/lib/src/types/local_date.dart @@ -40,11 +40,20 @@ final class LocalDate implements Comparable { return LocalDate(instant.year, instant.month, instant.day); } + /// Creates a date from a [DateTime] by taking its UTC calendar date. + factory LocalDate.fromDateTime(DateTime value) { + final utcValue = value.toUtc(); + return LocalDate(utcValue.year, utcValue.month, utcValue.day); + } + /// Converts this date to xlang epoch-day form. int toEpochDay() => DateTime.utc(year, month, day).millisecondsSinceEpoch ~/ Duration.millisecondsPerDay; + /// Converts this date to a UTC [DateTime] at midnight. + DateTime toDateTime() => DateTime.utc(year, month, day); + @override int compareTo(LocalDate other) { final yearCompare = year.compareTo(other.year); diff --git a/dart/packages/fory/test/decimal_serializer_test.dart b/dart/packages/fory/test/decimal_serializer_test.dart new file mode 100644 index 0000000000..ed009b5ad4 --- /dev/null +++ b/dart/packages/fory/test/decimal_serializer_test.dart @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import 'dart:typed_data'; + +import 'package:fory/fory.dart'; +import 'package:test/test.dart'; + +part 'decimal_serializer_test.fory.dart'; + +@ForyStruct() +class DecimalEnvelope { + DecimalEnvelope(); + + Decimal amount = Decimal.zero(); + String note = ''; +} + +Decimal _decimal(String unscaled, int scale) { + return Decimal(BigInt.parse(unscaled), scale); +} + +void _registerDecimalEnvelope(Fory fory) { + DecimalSerializerTestFory.register( + fory, + DecimalEnvelope, + namespace: 'test', + typeName: 'DecimalEnvelope', + ); +} + +void main() { + group('decimal serializer', () { + test('round-trips root decimal edge cases', () { + final fory = Fory(); + final values = [ + Decimal.zero(), + Decimal.zero(3), + Decimal.fromInt(1), + Decimal.fromInt(-1), + Decimal.fromInt(12345, scale: 2), + _decimal('9223372036854775807', 0), + _decimal('-9223372036854775808', 0), + _decimal('4611686018427387903', 0), + _decimal('-4611686018427387904', 0), + _decimal('9223372036854775808', 0), + _decimal('-9223372036854775809', 0), + _decimal('123456789012345678901234567890123456789', 37), + _decimal('-123456789012345678901234567890123456789', -17), + ]; + + for (final value in values) { + expect(fory.deserialize(fory.serialize(value)), equals(value)); + } + }); + + test('round-trips generated decimal fields', () { + final fory = Fory(); + _registerDecimalEnvelope(fory); + + final value = DecimalEnvelope() + ..amount = _decimal('123456789012345678901234567890123456789', 37) + ..note = 'principal'; + + final roundTrip = fory.deserialize( + fory.serialize(value), + ); + expect(roundTrip.amount, equals(value.amount)); + expect(roundTrip.note, equals('principal')); + }); + + test('rejects non-canonical big decimal payloads', () { + final fory = Fory(); + final zeroBigEncoding = Uint8List.fromList([ + 0x02, + 0xff, + TypeIds.decimal, + 0x00, + 0x01, + ]); + final trailingZeroPayload = Uint8List.fromList([ + 0x02, + 0xff, + TypeIds.decimal, + 0x00, + 0x09, + 0x01, + 0x00, + ]); + + expect( + () => fory.deserialize(zeroBigEncoding), + throwsA( + isA().having( + (error) => error.toString(), + 'message', + contains('Invalid decimal magnitude length'), + ), + ), + ); + expect( + () => fory.deserialize(trailingZeroPayload), + throwsA( + isA().having( + (error) => error.toString(), + 'message', + contains('trailing zero byte'), + ), + ), + ); + }); + }); +} diff --git a/dart/packages/fory/test/time_serializer_test.dart b/dart/packages/fory/test/time_serializer_test.dart index 585e670011..9ca5745e39 100644 --- a/dart/packages/fory/test/time_serializer_test.dart +++ b/dart/packages/fory/test/time_serializer_test.dart @@ -20,6 +20,7 @@ import 'dart:typed_data'; import 'package:fory/fory.dart'; +import 'package:fory/src/meta/type_ids.dart'; import 'package:test/test.dart'; part 'time_serializer_test.fory.dart'; @@ -89,6 +90,25 @@ void main() { } }); + test('encodes LocalDate as signed varint64 in xlang payloads', () { + final fory = Fory(); + final value = LocalDate.fromEpochDay(-1); + + final bytes = fory.serialize(value); + expect(bytes, equals(Uint8List.fromList([0x02, 0xff, TypeIds.date, 0x01]))); + expect(fory.deserialize(bytes), equals(value)); + }); + + test('LocalDate convenience methods bridge DateTime and epoch-day forms', + () { + final value = LocalDate.fromDateTime(DateTime.utc(2024, 1, 2, 3, 4, 5)); + + expect(value, equals(const LocalDate(2024, 1, 2))); + expect( + value.toEpochDay(), equals(const LocalDate(2024, 1, 2).toEpochDay())); + expect(value.toDateTime(), equals(DateTime.utc(2024, 1, 2))); + }); + test('decodes root Timestamp payloads to DateTime by default', () { final fory = Fory(); final cases = >[ diff --git a/docs/guide/dart/supported-types.md b/docs/guide/dart/supported-types.md index ad2b5c2733..3d309ee5b5 100644 --- a/docs/guide/dart/supported-types.md +++ b/docs/guide/dart/supported-types.md @@ -79,6 +79,12 @@ final birthday = LocalDate(1990, 12, 1); final timeout = const Duration(seconds: 30); ``` +The temporal wrappers expose conversion helpers: + +- `Timestamp.fromDateTime(...)` and `timestamp.toDateTime()` +- `LocalDate.fromEpochDay(...)`, `date.toEpochDay()` +- `LocalDate.fromDateTime(...)` and `date.toDateTime()` + `Duration` support in Dart is exact to microseconds. Incoming xlang duration payloads that use sub-microsecond nanoseconds are rejected instead of being silently truncated. diff --git a/docs/guide/swift/basic-serialization.md b/docs/guide/swift/basic-serialization.md index 7931831d93..4d7283be56 100644 --- a/docs/guide/swift/basic-serialization.md +++ b/docs/guide/swift/basic-serialization.md @@ -92,8 +92,12 @@ assert(fromBuffer == person) ### Date and time - `Date` -- `ForyDate` -- `ForyTimestamp` +- `LocalDate` +- `Duration` + +Use `Date` for timestamp values and `LocalDate` for day-only dates. `LocalDate` +supports epoch-day and `Date` conversions through `fromEpochDay(_:)`, +`toEpochDay()`, `init(date:)`, and `toDate()`. ### Collections diff --git a/docs/specification/xlang_serialization_spec.md b/docs/specification/xlang_serialization_spec.md index da344a902e..5019c70d83 100644 --- a/docs/specification/xlang_serialization_spec.md +++ b/docs/specification/xlang_serialization_spec.md @@ -73,8 +73,8 @@ This specification defines the Fory xlang binary format. The format is dynamic r - duration: an absolute length of time, independent of any calendar/timezone, as a count of nanoseconds. - timestamp: a point in time, independent of any calendar/timezone, encoded as seconds (int64) and nanoseconds (uint32) since the epoch at UTC midnight on January 1, 1970. -- date: a naive date without timezone. The count is days relative to an epoch at UTC midnight on Jan 1, 1970. -- decimal: exact decimal value represented as an integer value in two's complement. +- date: a naive date without timezone, encoded as a signed varint64 count of days since the Unix epoch. +- decimal: an exact decimal value encoded as a signed `scale` and an exact `unscaled` integer. - binary: an variable-length array of bytes. - array: only allow 1d numeric components. Other arrays will be taken as List. The implementation should support the interoperability between array and list. @@ -203,8 +203,8 @@ Named types (`NAMED_*`) do not embed a user ID; their names are carried in metad | 36 | NONE | Empty/unit type (no data) | | 37 | DURATION | Time duration (seconds + nanoseconds) | | 38 | TIMESTAMP | Point in time (seconds + nanoseconds since epoch) | -| 39 | DATE | Date without timezone (days since epoch) | -| 40 | DECIMAL | Arbitrary precision decimal | +| 39 | DATE | Date without timezone (signed varint64 days) | +| 40 | DECIMAL | Arbitrary precision decimal (scale + unscaled) | | 41 | BINARY | Raw binary data | | 42 | ARRAY | Generic array type | | 43 | BOOL_ARRAY | 1D boolean array | @@ -1248,12 +1248,86 @@ This is a fixed-size 12-byte payload (8 bytes seconds + 4 bytes nanos). ### date -Date represents a date without timezone. It is encoded as an `int32` count of days since the Unix epoch -(1970-01-01). This is a fixed-size 4-byte payload. +Date represents a date without timezone. It is encoded as: + +- `days` (varint64): signed count of days since the Unix epoch (`1970-01-01`) + +The value is reconstructed as `LocalDate.ofEpochDay(days)` or the equivalent calendar-date constructor in +the target runtime. + +This `varint64` encoding applies to xlang serialization only. Native, language-specific local-date +encodings are unchanged. ### decimal -Not supported for now. +A decimal value is encoded as: + +1. `scale`: signed varint32 +2. `unscaledHeader`: unsigned varint64 +3. optional `payload`: present only for large unscaled values + +The mathematical value is: + +`value = unscaled × 10^-scale` + +#### Scale + +- `scale` is encoded as signed varint32. +- `scale` carries no extra flags or mode bits. + +#### Unscaled Header + +`unscaledHeader` selects the encoding of `unscaled`: + +- If `(unscaledHeader & 1) == 0`, the value uses the small encoding. +- If `(unscaledHeader & 1) == 1`, the value uses the big encoding. + +#### Small Encoding + +For small values, `unscaled` must fit in signed 64-bit range and the zigzag-encoded value must fit in 63 bits. + +Encoding: + +- `unscaledHeader = zigzag(unscaled) << 1` +- no payload is written + +Decoding: + +- `unscaled = zigzagDecode(unscaledHeader >>> 1)` + +#### Big Encoding + +For big values, `unscaled` is encoded as sign plus magnitude bytes. + +Encoding: + +- `sign = 0` if `unscaled >= 0`, otherwise `1` +- `magnitude = abs(unscaled)` +- `len = byte length of magnitude in canonical minimal little-endian form` +- `meta = (len << 1) | sign` +- `unscaledHeader = (meta << 1) | 1` +- `payload = magnitude as canonical minimal little-endian bytes` + +Decoding: + +- `meta = unscaledHeader >>> 1` +- `sign = meta & 1` +- `len = meta >>> 1` +- read `len` bytes as little-endian unsigned magnitude +- `unscaled = magnitude` if `sign == 0`, otherwise `-magnitude` + +#### Canonical Rules + +- Zero must use the small encoding. +- Big encoding must not be used for zero. +- In big encoding, `payload` must be the minimal little-endian representation. +- Therefore, for big encoding, `len > 0` and `payload[len - 1] != 0`. + +#### Final Value + +After decoding `scale` and `unscaled`, the decimal value is reconstructed as: + +`value = unscaled × 10^-scale` ### struct diff --git a/docs/specification/xlang_type_mapping.md b/docs/specification/xlang_type_mapping.md index 748cb4db9a..ee62b1423d 100644 --- a/docs/specification/xlang_type_mapping.md +++ b/docs/specification/xlang_type_mapping.md @@ -86,8 +86,8 @@ When reading type IDs: | none | 36 | null | None | null | `std::monostate` | nil | `()` | | duration | 37 | Duration | timedelta | Number | duration | Duration | Duration | | timestamp | 38 | Instant | datetime | Number | std::chrono::nanoseconds | Time | DateTime | -| date | 39 | Date | datetime | Number | fory::serialization::Date | Time | DateTime | -| decimal | 40 | BigDecimal | Decimal | bigint | / | / | / | +| date | 39 | LocalDate | datetime.date | Date | fory::serialization::Date | fory.Date | chrono::NaiveDate | +| decimal | 40 | BigDecimal | Decimal | Decimal | / | fory.Decimal | fory::Decimal | | binary | 41 | byte[] | bytes | / | `uint8_t[n]/vector` | `[n]uint8/[]T` | `Vec` | | array | 42 | array | np.ndarray | / | / | array/slice | Vec | | bool_array | 43 | bool[] | ndarray(np.bool\_) | / | `bool[n]` | `[n]bool/[]T` | `Vec` | diff --git a/go/fory/codegen/decoder.go b/go/fory/codegen/decoder.go index 8e580b053d..560c0cd133 100644 --- a/go/fory/codegen/decoder.go +++ b/go/fory/codegen/decoder.go @@ -467,14 +467,26 @@ func generateSliceElementRead(buf *bytes.Buffer, elemType types.Type, elemAccess fmt.Fprintf(buf, "\t\t\t\t%s = fory.CreateTimeFromUnixSecondsAndNanos(seconds, nanos)\n", elemAccess) return nil case "github.com/apache/fory/go/fory.Date": - fmt.Fprintf(buf, "\t\t\t\tdays := buf.ReadInt32()\n") - fmt.Fprintf(buf, "\t\t\t\t// Handle zero date marker\n") - fmt.Fprintf(buf, "\t\t\t\tif days == int32(-2147483648) {\n") - fmt.Fprintf(buf, "\t\t\t\t\t%s = fory.Date{Year: 0, Month: 0, Day: 0}\n", elemAccess) + fmt.Fprintf(buf, "\t\t\t\tif ctx.TypeResolver().IsXlang() {\n") + fmt.Fprintf(buf, "\t\t\t\t\tdays := buf.ReadVarint64(err)\n") + fmt.Fprintf(buf, "\t\t\t\t\tif ctx.HasError() {\n") + fmt.Fprintf(buf, "\t\t\t\t\t\treturn ctx.TakeError()\n") + fmt.Fprintf(buf, "\t\t\t\t\t}\n") + fmt.Fprintf(buf, "\t\t\t\t\tdateValue, convErr := fory.DateFromEpochDay(days)\n") + fmt.Fprintf(buf, "\t\t\t\t\tif convErr != nil {\n") + fmt.Fprintf(buf, "\t\t\t\t\t\treturn convErr\n") + fmt.Fprintf(buf, "\t\t\t\t\t}\n") + fmt.Fprintf(buf, "\t\t\t\t\t%s = dateValue\n", elemAccess) fmt.Fprintf(buf, "\t\t\t\t} else {\n") - fmt.Fprintf(buf, "\t\t\t\t\tdiff := time.Duration(days) * 24 * time.Hour\n") - fmt.Fprintf(buf, "\t\t\t\t\tt := time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local).Add(diff)\n") - fmt.Fprintf(buf, "\t\t\t\t\t%s = fory.Date{Year: t.Year(), Month: t.Month(), Day: t.Day()}\n", elemAccess) + fmt.Fprintf(buf, "\t\t\t\t\tdays := buf.ReadInt32()\n") + fmt.Fprintf(buf, "\t\t\t\t\t// Handle zero date marker\n") + fmt.Fprintf(buf, "\t\t\t\t\tif days == int32(-2147483648) {\n") + fmt.Fprintf(buf, "\t\t\t\t\t\t%s = fory.Date{Year: 0, Month: 0, Day: 0}\n", elemAccess) + fmt.Fprintf(buf, "\t\t\t\t\t} else {\n") + fmt.Fprintf(buf, "\t\t\t\t\t\tdiff := time.Duration(days) * 24 * time.Hour\n") + fmt.Fprintf(buf, "\t\t\t\t\t\tt := time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local).Add(diff)\n") + fmt.Fprintf(buf, "\t\t\t\t\t\t%s = fory.Date{Year: t.Year(), Month: t.Month(), Day: t.Day()}\n", elemAccess) + fmt.Fprintf(buf, "\t\t\t\t\t}\n") fmt.Fprintf(buf, "\t\t\t\t}\n") return nil } diff --git a/go/fory/decimal.go b/go/fory/decimal.go new file mode 100644 index 0000000000..d7acd27c9f --- /dev/null +++ b/go/fory/decimal.go @@ -0,0 +1,191 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package fory + +import ( + "fmt" + "math/big" + "reflect" +) + +// Decimal is the Go carrier for Fory DECIMAL values. +// It preserves the exact value unscaled * 10^-scale. +type Decimal struct { + Unscaled big.Int + Scale int32 +} + +func NewDecimal(unscaled *big.Int, scale int32) Decimal { + var copied big.Int + if unscaled != nil { + copied.Set(unscaled) + } + return Decimal{Unscaled: copied, Scale: scale} +} + +func (d Decimal) Equal(other Decimal) bool { + return d.Scale == other.Scale && d.Unscaled.Cmp(&other.Unscaled) == 0 +} + +func (d Decimal) String() string { + return fmt.Sprintf("%se%d", d.Unscaled.String(), -d.Scale) +} + +var ( + decimalReflectType = reflect.TypeFor[Decimal]() + decimalLongMin = big.NewInt(MinInt64) + decimalLongMax = big.NewInt(MaxInt64) +) + +type decimalSerializer struct{} + +func (s decimalSerializer) Write(ctx *WriteContext, refMode RefMode, writeType bool, hasGenerics bool, value reflect.Value) { + if refMode != RefModeNone { + ctx.buffer.WriteInt8(NotNullValueFlag) + } + if writeType { + ctx.buffer.WriteUint8(uint8(DECIMAL)) + } + s.WriteData(ctx, value) +} + +func (s decimalSerializer) WriteData(ctx *WriteContext, value reflect.Value) { + decimal := value.Interface().(Decimal) + writeDecimalParts(ctx.buffer, decimal.Scale, &decimal.Unscaled) +} + +func (s decimalSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { + err := ctx.Err() + if refMode != RefModeNone { + if ctx.buffer.ReadInt8(err) == NullFlag { + value.Set(reflect.Zero(value.Type())) + return + } + } + if readType { + _ = ctx.buffer.ReadUint8(err) + } + if ctx.HasError() { + return + } + s.ReadData(ctx, value) +} + +func (s decimalSerializer) ReadData(ctx *ReadContext, value reflect.Value) { + scale, unscaled := readDecimalParts(ctx) + if ctx.HasError() { + return + } + value.Set(reflect.ValueOf(NewDecimal(unscaled, scale))) +} + +func (s decimalSerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) { + s.Read(ctx, refMode, false, false, value) +} + +func writeDecimalParts(buffer *ByteBuffer, scale int32, unscaled *big.Int) { + if unscaled == nil { + unscaled = new(big.Int) + } + buffer.WriteVarint32(scale) + if canUseSmallDecimalEncoding(unscaled) { + smallValue := unscaled.Int64() + header := encodeDecimalZigZag64(smallValue) << 1 + buffer.WriteVarUint64(header) + return + } + + abs := new(big.Int).Abs(unscaled) + payload := abs.Bytes() + if len(payload) == 0 { + panic("decimal zero must use the small encoding") + } + reverseBytes(payload) + meta := (uint64(len(payload)) << 1) | uint64(signBit(unscaled.Sign())) + buffer.WriteVarUint64((meta << 1) | 1) + buffer.WriteBinary(payload) +} + +func readDecimalParts(ctx *ReadContext) (int32, *big.Int) { + err := ctx.Err() + scale := ctx.buffer.ReadVarint32(err) + header := ctx.buffer.ReadVarUint64(err) + if ctx.HasError() { + return 0, nil + } + if (header & 1) == 0 { + return scale, big.NewInt(decodeDecimalZigZag64(header >> 1)) + } + + meta := header >> 1 + length := meta >> 1 + if length == 0 || length > uint64(MaxInt32) { + ctx.SetError(DeserializationErrorf("invalid decimal magnitude length %d", length)) + return 0, nil + } + payload := ctx.buffer.ReadBytes(int(length), err) + if ctx.HasError() { + return 0, nil + } + if payload[len(payload)-1] == 0 { + ctx.SetError(DeserializationError("non-canonical decimal payload: trailing zero byte")) + return 0, nil + } + bigEndian := append([]byte(nil), payload...) + reverseBytes(bigEndian) + magnitude := new(big.Int).SetBytes(bigEndian) + if magnitude.Sign() == 0 { + ctx.SetError(DeserializationError("big decimal encoding must not represent zero")) + return 0, nil + } + if (meta & 1) != 0 { + magnitude.Neg(magnitude) + } + return scale, magnitude +} + +func canUseSmallDecimalEncoding(value *big.Int) bool { + if value == nil { + return true + } + if value.Cmp(decimalLongMin) < 0 || value.Cmp(decimalLongMax) > 0 { + return false + } + return (encodeDecimalZigZag64(value.Int64()) & (1 << 63)) == 0 +} + +func encodeDecimalZigZag64(value int64) uint64 { + return uint64((value << 1) ^ (value >> 63)) +} + +func decodeDecimalZigZag64(value uint64) int64 { + return int64((value >> 1) ^ uint64(-(int64(value & 1)))) +} + +func signBit(sign int) int { + if sign < 0 { + return 1 + } + return 0 +} + +func reverseBytes(values []byte) { + for i, j := 0, len(values)-1; i < j; i, j = i+1, j-1 { + values[i], values[j] = values[j], values[i] + } +} diff --git a/go/fory/decimal_test.go b/go/fory/decimal_test.go new file mode 100644 index 0000000000..0acd9fbcc0 --- /dev/null +++ b/go/fory/decimal_test.go @@ -0,0 +1,133 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package fory + +import ( + "math/big" + "testing" + + "github.com/stretchr/testify/require" +) + +func mustDecimal(value string, scale int32) Decimal { + unscaled, ok := new(big.Int).SetString(value, 10) + if !ok { + panic("invalid decimal test value: " + value) + } + return NewDecimal(unscaled, scale) +} + +func TestDecimalRoundTrip(t *testing.T) { + values := []Decimal{ + NewDecimal(big.NewInt(0), 0), + NewDecimal(big.NewInt(0), 3), + NewDecimal(big.NewInt(1), 0), + NewDecimal(big.NewInt(-1), 0), + NewDecimal(big.NewInt(12345), 2), + NewDecimal(big.NewInt(MaxInt64), 0), + NewDecimal(big.NewInt(MinInt64), 0), + NewDecimal(new(big.Int).Add(big.NewInt(MaxInt64), big.NewInt(1)), 0), + NewDecimal(new(big.Int).Sub(big.NewInt(MinInt64), big.NewInt(1)), 0), + mustDecimal("123456789012345678901234567890123456789", 37), + mustDecimal("-123456789012345678901234567890123456789", -17), + } + for _, referenceTracking := range []bool{false, true} { + f := New(WithXlang(true), WithRefTracking(referenceTracking)) + for _, value := range values { + data, err := Serialize(f, value) + require.NoError(t, err) + var decoded Decimal + err = Deserialize(f, data, &decoded) + require.NoError(t, err) + require.True(t, value.Equal(decoded), "expected %v, got %v", value, decoded) + } + } +} + +func TestDecimalDynamicAnyRoundTrip(t *testing.T) { + f := New(WithXlang(true), WithRefTracking(true)) + value := mustDecimal("9223372036854775808", 4) + payload := []any{"marker", value, []any{value, mustDecimal("-12345678901234567890", 2)}} + data, err := Serialize(f, payload) + require.NoError(t, err) + + var decoded []any + err = Deserialize(f, data, &decoded) + require.NoError(t, err) + require.Len(t, decoded, 3) + require.Equal(t, "marker", decoded[0]) + gotDecimal, ok := decoded[1].(Decimal) + require.True(t, ok) + require.True(t, value.Equal(gotDecimal)) + nested, ok := decoded[2].([]any) + require.True(t, ok) + require.Len(t, nested, 2) + gotNested, ok := nested[0].(Decimal) + require.True(t, ok) + require.True(t, value.Equal(gotNested)) +} + +func TestDecimalWireEncoding(t *testing.T) { + f := New(WithXlang(true)) + data, err := Serialize(f, NewDecimal(big.NewInt(0), 2)) + require.NoError(t, err) + + buf := NewByteBuffer(data) + require.Equal(t, byte(XLangFlag), buf.ReadByte(nil)) + require.Equal(t, int8(NotNullValueFlag), buf.ReadInt8(nil)) + require.Equal(t, uint8(DECIMAL), buf.ReadUint8(nil)) + require.Equal(t, int32(2), buf.ReadVarint32(nil)) + require.Equal(t, uint64(0), buf.ReadVarUint64(nil)) + + data, err = Serialize(f, mustDecimal("9223372036854775808", 0)) + require.NoError(t, err) + buf = NewByteBuffer(data) + require.Equal(t, byte(XLangFlag), buf.ReadByte(nil)) + require.Equal(t, int8(NotNullValueFlag), buf.ReadInt8(nil)) + require.Equal(t, uint8(DECIMAL), buf.ReadUint8(nil)) + require.Equal(t, int32(0), buf.ReadVarint32(nil)) + require.Equal(t, uint64(1), buf.ReadVarUint64(nil)&1) +} + +func TestDecimalRejectsNonCanonicalBigPayload(t *testing.T) { + f := New(WithXlang(true)) + + buffer := NewByteBuffer(nil) + buffer.WriteByte_(XLangFlag) + buffer.WriteInt8(NotNullValueFlag) + buffer.WriteUint8(uint8(DECIMAL)) + buffer.WriteVarint32(0) + buffer.WriteVarUint64(1) + data := buffer.GetByteSlice(0, buffer.writerIndex) + var decoded Decimal + err := Deserialize(f, data, &decoded) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid decimal magnitude length") + + buffer = NewByteBuffer(nil) + buffer.WriteByte_(XLangFlag) + buffer.WriteInt8(NotNullValueFlag) + buffer.WriteUint8(uint8(DECIMAL)) + buffer.WriteVarint32(0) + buffer.WriteVarUint64((((uint64(2) << 1) | 0) << 1) | 1) + buffer.WriteBinary([]byte{0x01, 0x00}) + data = buffer.GetByteSlice(0, buffer.writerIndex) + err = Deserialize(f, data, &decoded) + require.Error(t, err) + require.Contains(t, err.Error(), "trailing zero byte") +} diff --git a/go/fory/field_info.go b/go/fory/field_info.go index fb92d0b1ee..246ba0e7f8 100644 --- a/go/fory/field_info.go +++ b/go/fory/field_info.go @@ -487,7 +487,7 @@ func isStructField(t reflect.Type) bool { return false } // Date/Timestamp are built-in types with dedicated encodings, not user structs. - if t == dateType || t == timestampType { + if t == dateType || t == timestampType || t == decimalType { return false } if t.Kind() == reflect.Struct { @@ -881,6 +881,9 @@ func typeIdFromKind(type_ reflect.Type) TypeId { if type_ == timestampType { return TIMESTAMP } + if type_ == decimalType { + return DECIMAL + } switch type_.Kind() { case reflect.Bool: return BOOL diff --git a/go/fory/fory.go b/go/fory/fory.go index 57da20c576..565dd1a5b3 100644 --- a/go/fory/fory.go +++ b/go/fory/fory.go @@ -903,6 +903,10 @@ func Serialize[T any](f *Fory, value T) ([]byte, error) { f.writeCtx.buffer.WriteInt8(NotNullValueFlag) f.writeCtx.WriteTypeId(FLOAT64) f.writeCtx.buffer.WriteFloat64(val) + case Decimal: + f.writeCtx.buffer.WriteInt8(NotNullValueFlag) + f.writeCtx.WriteTypeId(DECIMAL) + writeDecimalParts(f.writeCtx.buffer, val.Scale, &val.Unscaled) case string: f.writeCtx.buffer.WriteInt8(NotNullValueFlag) f.writeCtx.WriteTypeId(STRING) diff --git a/go/fory/skip.go b/go/fory/skip.go index 64660dfc36..3fde813532 100644 --- a/go/fory/skip.go +++ b/go/fory/skip.go @@ -633,7 +633,11 @@ func skipValue(ctx *ReadContext, fieldDef FieldDef, readRefFlag bool, isField bo // Date/Time types case DATE: - _ = ctx.buffer.ReadUint8(err) + if ctx.TypeResolver().IsXlang() { + _ = ctx.buffer.ReadVarint64(err) + } else { + _ = ctx.buffer.ReadInt32(err) + } case TIMESTAMP: _ = ctx.buffer.ReadInt64(err) _ = ctx.buffer.ReadUint32(err) diff --git a/go/fory/tests/xlang/xlang_test_main.go b/go/fory/tests/xlang/xlang_test_main.go index 451e619fdf..b2600fce98 100644 --- a/go/fory/tests/xlang/xlang_test_main.go +++ b/go/fory/tests/xlang/xlang_test_main.go @@ -20,6 +20,7 @@ package main import ( "flag" "fmt" + "math/big" "os" "reflect" "runtime" @@ -60,6 +61,14 @@ func assertEqual(expected, actual any, name string) { } } +func mustDecimal(unscaled string, scale int32) fory.Decimal { + value, ok := new(big.Int).SetString(unscaled, 10) + if !ok { + panic(fmt.Sprintf("invalid decimal unscaled value %s", unscaled)) + } + return fory.NewDecimal(value, scale) +} + // getStructValue extracts the struct value from either a struct or a pointer to struct. // This handles the case where deserialization may return either type depending on // reference tracking settings. @@ -890,6 +899,51 @@ func testInteger() { writeFile(dataFile, outData) } +func testDecimal() { + dataFile := getDataFile() + data := readFile(dataFile) + + f := fory.New(fory.WithXlang(true), fory.WithCompatible(true)) + buf := fory.NewByteBuffer(data) + values := []fory.Decimal{ + mustDecimal("0", 0), + mustDecimal("0", 3), + mustDecimal("1", 0), + mustDecimal("-1", 0), + mustDecimal("12345", 2), + mustDecimal("9223372036854775807", 0), + mustDecimal("-9223372036854775808", 0), + mustDecimal("4611686018427387903", 0), + mustDecimal("-4611686018427387904", 0), + mustDecimal("9223372036854775808", 0), + mustDecimal("-9223372036854775809", 0), + mustDecimal("123456789012345678901234567890123456789", 37), + mustDecimal("-123456789012345678901234567890123456789", -17), + } + + for i, expected := range values { + var actual fory.Decimal + err := f.DeserializeWithCallbackBuffers(buf, &actual, nil) + if err != nil { + panic(fmt.Sprintf("Failed to deserialize decimal %d: %v", i, err)) + } + if !actual.Equal(expected) { + panic(fmt.Sprintf("decimal %d mismatch: expected %v, got %v", i, expected, actual)) + } + } + + var outData []byte + for _, value := range values { + serialized, err := fory.Serialize(f, value) + if err != nil { + panic(fmt.Sprintf("Failed to serialize decimal: %v", err)) + } + outData = append(outData, serialized...) + } + + writeFile(dataFile, outData) +} + func testItem() { dataFile := getDataFile() data := readFile(dataFile) @@ -2511,6 +2565,8 @@ func main() { testMap() case "test_integer": testInteger() + case "test_decimal": + testDecimal() case "test_item": testItem() case "test_color": diff --git a/go/fory/time.go b/go/fory/time.go index 3eb2086464..630b79f5b3 100644 --- a/go/fory/time.go +++ b/go/fory/time.go @@ -31,6 +31,98 @@ type Date struct { var dateReflectType = reflect.TypeFor[Date]() +func isLeapYear(year int64) bool { + return year%4 == 0 && (year%100 != 0 || year%400 == 0) +} + +func daysInMonth(year int64, month time.Month) int { + switch month { + case time.January, time.March, time.May, time.July, time.August, time.October, time.December: + return 31 + case time.April, time.June, time.September, time.November: + return 30 + case time.February: + if isLeapYear(year) { + return 29 + } + return 28 + default: + return 0 + } +} + +func floorDiv(value, divisor int64) int64 { + quotient := value / divisor + remainder := value % divisor + if remainder != 0 && ((remainder > 0) != (divisor > 0)) { + quotient-- + } + return quotient +} + +func daysFromCivil(year int64, month time.Month, day int64) int64 { + y := year + m := int64(month) + if m <= 2 { + y-- + } + era := floorDiv(y, 400) + yoe := y - era*400 + monthPrime := m - 3 + if m <= 2 { + monthPrime = m + 9 + } + doy := (153*monthPrime+2)/5 + day - 1 + doe := yoe*365 + yoe/4 - yoe/100 + doy + return era*146097 + doe - 719468 +} + +func civilFromDays(days int64) (int64, time.Month, int) { + z := days + 719468 + era := floorDiv(z, 146097) + doe := z - era*146097 + yoe := (doe - doe/1460 + doe/36524 - doe/146096) / 365 + year := yoe + era*400 + doy := doe - (365*yoe + yoe/4 - yoe/100) + monthPrime := (5*doy + 2) / 153 + day := int(doy - (153*monthPrime+2)/5 + 1) + month := monthPrime + 3 + if month > 12 { + month -= 12 + } + if month <= 2 { + year++ + } + return year, time.Month(month), day +} + +// DateToEpochDay converts a Date to its day offset from the Unix epoch. +func DateToEpochDay(date Date) (int64, error) { + year := int64(date.Year) + if date.Month < time.January || date.Month > time.December { + return 0, SerializationErrorf("invalid date month %d", date.Month) + } + maxDay := daysInMonth(year, date.Month) + if date.Day < 1 || date.Day > maxDay { + return 0, SerializationErrorf( + "invalid date day %d for %d-%02d", + date.Day, + date.Year, + int(date.Month), + ) + } + return daysFromCivil(year, date.Month, int64(date.Day)), nil +} + +// DateFromEpochDay converts a Unix-epoch day offset to a Date. +func DateFromEpochDay(days int64) (Date, error) { + year, month, day := civilFromDays(days) + if year < int64(MinInt) || year > int64(MaxInt) { + return Date{}, DeserializationErrorf("date year %d out of int range", year) + } + return Date{Year: int(year), Month: month, Day: day}, nil +} + type dateSerializer struct{} func (s dateSerializer) Write(ctx *WriteContext, refMode RefMode, writeType bool, hasGenerics bool, value reflect.Value) { @@ -45,9 +137,19 @@ func (s dateSerializer) Write(ctx *WriteContext, refMode RefMode, writeType bool func (s dateSerializer) WriteData(ctx *WriteContext, value reflect.Value) { date := value.Interface().(Date) + if ctx.TypeResolver().IsXlang() { + days, err := DateToEpochDay(date) + if err != nil { + ctx.SetError(FromError(err)) + return + } + ctx.buffer.WriteVarint64(days) + return + } diff := time.Date(date.Year, date.Month, date.Day, 0, 0, 0, 0, time.Local).Sub( time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local)) - ctx.buffer.WriteInt32(int32(diff.Hours() / 24)) + days := int64(diff.Hours() / 24) + ctx.buffer.WriteInt32(int32(days)) } func (s dateSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, hasGenerics bool, value reflect.Value) { @@ -68,6 +170,19 @@ func (s dateSerializer) Read(ctx *ReadContext, refMode RefMode, readType bool, h func (s dateSerializer) ReadData(ctx *ReadContext, value reflect.Value) { err := ctx.Err() + if ctx.TypeResolver().IsXlang() { + days := ctx.buffer.ReadVarint64(err) + if ctx.HasError() { + return + } + date, convErr := DateFromEpochDay(days) + if convErr != nil { + ctx.SetError(FromError(convErr)) + return + } + value.Set(reflect.ValueOf(date)) + return + } diff := time.Duration(ctx.buffer.ReadInt32(err)) * 24 * time.Hour date := time.Date(1970, 1, 1, 0, 0, 0, 0, time.Local).Add(diff) value.Set(reflect.ValueOf(Date{date.Year(), date.Month(), date.Day()})) diff --git a/go/fory/time_test.go b/go/fory/time_test.go new file mode 100644 index 0000000000..c49a4a68c5 --- /dev/null +++ b/go/fory/time_test.go @@ -0,0 +1,105 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package fory + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestDateUsesVarint64InXlangAndInt32InNative(t *testing.T) { + date := Date{Year: 1969, Month: time.December, Day: 31} + expectedDays, err := DateToEpochDay(date) + require.NoError(t, err) + + for _, tc := range []struct { + name string + fory *Fory + check func(*testing.T, *ByteBuffer) + }{ + { + name: "xlang", + fory: NewFory(WithTrackRef(false), WithXlang(true)), + check: func(t *testing.T, buf *ByteBuffer) { + var err Error + require.Equal(t, byte(XLangFlag), buf.ReadByte(&err)) + require.False(t, err.HasError(), err.Error()) + require.Equal(t, int8(NotNullValueFlag), buf.ReadInt8(&err)) + require.False(t, err.HasError(), err.Error()) + require.Equal(t, uint8(DATE), buf.ReadUint8(&err)) + require.False(t, err.HasError(), err.Error()) + require.Equal(t, expectedDays, buf.ReadVarint64(&err)) + require.False(t, err.HasError(), err.Error()) + }, + }, + { + name: "native", + fory: NewFory(WithTrackRef(false), WithXlang(false)), + check: func(t *testing.T, buf *ByteBuffer) { + var err Error + require.Equal(t, byte(0), buf.ReadByte(&err)) + require.False(t, err.HasError(), err.Error()) + require.Equal(t, int8(NotNullValueFlag), buf.ReadInt8(&err)) + require.False(t, err.HasError(), err.Error()) + require.Equal(t, uint8(DATE), buf.ReadUint8(&err)) + require.False(t, err.HasError(), err.Error()) + require.Equal(t, int32(expectedDays), buf.ReadInt32(&err)) + require.False(t, err.HasError(), err.Error()) + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + data, err := tc.fory.Serialize(date) + require.NoError(t, err) + + buf := NewByteBuffer(data) + tc.check(t, buf) + require.Equal(t, len(data), buf.ReaderIndex()) + }) + } +} + +func TestXlangDateSupportsWideRange(t *testing.T) { + fory := NewFory(WithTrackRef(false), WithXlang(true)) + date := Date{Year: 200000, Month: time.January, Day: 1} + + expectedDays, err := DateToEpochDay(date) + require.NoError(t, err) + + data, err := Serialize(fory, &date) + require.NoError(t, err) + + buf := NewByteBuffer(data) + var bufErr Error + require.Equal(t, byte(XLangFlag), buf.ReadByte(&bufErr)) + require.False(t, bufErr.HasError(), bufErr.Error()) + require.Equal(t, int8(NotNullValueFlag), buf.ReadInt8(&bufErr)) + require.False(t, bufErr.HasError(), bufErr.Error()) + require.Equal(t, uint8(DATE), buf.ReadUint8(&bufErr)) + require.False(t, bufErr.HasError(), bufErr.Error()) + require.Equal(t, expectedDays, buf.ReadVarint64(&bufErr)) + require.False(t, bufErr.HasError(), bufErr.Error()) + require.Equal(t, len(data), buf.ReaderIndex()) + + var decoded Date + err = Deserialize(fory, data, &decoded) + require.NoError(t, err) + require.Equal(t, date, decoded) +} diff --git a/go/fory/type_resolver.go b/go/fory/type_resolver.go index 19cd839978..540114735f 100644 --- a/go/fory/type_resolver.go +++ b/go/fory/type_resolver.go @@ -108,6 +108,7 @@ var ( bfloat16Type = reflect.TypeOf((*bfloat16.BFloat16)(nil)).Elem() dateType = reflect.TypeOf((*Date)(nil)).Elem() timestampType = reflect.TypeOf((*time.Time)(nil)).Elem() + decimalType = reflect.TypeOf((*Decimal)(nil)).Elem() genericSetType = reflect.TypeOf((*Set[any])(nil)).Elem() ) @@ -266,6 +267,7 @@ func newTypeResolver(fory *Fory) *TypeResolver { stringType, dateType, timestampType, + decimalType, interfaceType, genericSetType, } { @@ -448,6 +450,7 @@ func (r *TypeResolver) initialize() { {bfloat16Type, BFLOAT16, bfloat16Serializer{}}, {dateType, DATE, dateSerializer{}}, {timestampType, TIMESTAMP, timeSerializer{}}, + {decimalType, DECIMAL, decimalSerializer{}}, {genericSetType, SET, setSerializer{}}, } for _, elem := range serializers { diff --git a/go/fory/types.go b/go/fory/types.go index a0ee730b45..38b1b71b4b 100644 --- a/go/fory/types.go +++ b/go/fory/types.go @@ -192,7 +192,8 @@ func isPrimitiveType(typeID TypeId) bool { FLOAT16, BFLOAT16, FLOAT32, - FLOAT64: + FLOAT64, + DECIMAL: return true default: return false @@ -206,7 +207,7 @@ func NeedWriteRef(typeID TypeId) bool { switch typeID { case BOOL, INT8, INT16, INT32, INT64, VARINT32, VARINT64, TAGGED_INT64, FLOAT32, FLOAT64, FLOAT16, FLOAT8, BFLOAT16, - STRING, TIMESTAMP, DATE, DURATION, NONE: + STRING, TIMESTAMP, DATE, DURATION, DECIMAL, NONE: return false default: return true diff --git a/integration_tests/idl_tests/generate_idl.py b/integration_tests/idl_tests/generate_idl.py index 4cc32c4e52..2190963eed 100755 --- a/integration_tests/idl_tests/generate_idl.py +++ b/integration_tests/idl_tests/generate_idl.py @@ -28,6 +28,7 @@ IDL_DIR / "idl" / "addressbook.fdl", IDL_DIR / "idl" / "collection.fdl", IDL_DIR / "idl" / "optional_types.fdl", + IDL_DIR / "idl" / "basic.fdl", IDL_DIR / "idl" / "tree.fdl", IDL_DIR / "idl" / "graph.fdl", IDL_DIR / "idl" / "root.idl", @@ -55,6 +56,7 @@ GO_OUTPUT_OVERRIDES = { "addressbook.fdl": IDL_DIR / "go" / "addressbook" / "generated", + "basic.fdl": IDL_DIR / "go" / "basic" / "generated", "collection.fdl": IDL_DIR / "go" / "collection" / "generated", "monster.fbs": IDL_DIR / "go" / "monster" / "generated", "complex_fbs.fbs": IDL_DIR / "go" / "complex_fbs" / "generated", diff --git a/integration_tests/idl_tests/idl/basic.fdl b/integration_tests/idl_tests/idl/basic.fdl new file mode 100644 index 0000000000..89efc30415 --- /dev/null +++ b/integration_tests/idl_tests/idl/basic.fdl @@ -0,0 +1,35 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package basic; + +message Money [id=130] { + decimal amount = 1; + string currency = 2; +} + +union BasicValue [id=131] { + decimal amount = 1; + string note = 2; + Money money = 3; +} + +message BasicEnvelope [id=132] { + Money money = 1; + BasicValue value = 2; + optional decimal nullable_amount = 3; +} diff --git a/integration_tests/idl_tests/swift/idl_package/Tests/IdlRoundTripTests/IdlRoundTripTests.swift b/integration_tests/idl_tests/swift/idl_package/Tests/IdlRoundTripTests/IdlRoundTripTests.swift index 158e0060fe..56544ba960 100644 --- a/integration_tests/idl_tests/swift/idl_package/Tests/IdlRoundTripTests/IdlRoundTripTests.swift +++ b/integration_tests/idl_tests/swift/idl_package/Tests/IdlRoundTripTests/IdlRoundTripTests.swift @@ -421,8 +421,8 @@ final class IdlRoundTripTests: XCTestCase { float64Value: 3.5, stringValue: "optional", bytesValue: Data([1, 2, 3]), - dateValue: ForyDate(daysSinceEpoch: 19724), - timestampValue: ForyTimestamp(seconds: 1704164645, nanos: 0), + dateValue: LocalDate(daysSinceEpoch: 19724), + timestampValue: Date(timeIntervalSince1970: 1704164645), int32List: [1, 2, 3], stringList: ["alpha", "beta"], int64Map: ["alpha": 10, "beta": 20] @@ -437,7 +437,7 @@ final class IdlRoundTripTests: XCTestCase { AnyExample.AnyHolder( boolValue: true, stringValue: "hello", - dateValue: ForyDate(daysSinceEpoch: 19724), + dateValue: LocalDate(daysSinceEpoch: 19724), timestampValue: Date(timeIntervalSince1970: 1704164645), messageValue: AnyExample.AnyInner(name: "inner"), unionValue: AnyExample.AnyUnion.text("union"), @@ -450,7 +450,7 @@ final class IdlRoundTripTests: XCTestCase { AnyExamplePb.AnyHolder( boolValue: true, stringValue: "hello", - dateValue: ForyDate(daysSinceEpoch: 19724), + dateValue: LocalDate(daysSinceEpoch: 19724), timestampValue: Date(timeIntervalSince1970: 1704164645), messageValue: AnyExamplePb.AnyInner(name: "inner"), unionValue: AnyExamplePb.AnyUnion(kind: .text("proto-union")), @@ -462,7 +462,7 @@ final class IdlRoundTripTests: XCTestCase { private func assertAnyHolder(expected: AnyExample.AnyHolder, actual: AnyExample.AnyHolder) { XCTAssertEqual(actual.boolValue as? Bool, expected.boolValue as? Bool) XCTAssertEqual(actual.stringValue as? String, expected.stringValue as? String) - XCTAssertEqual(actual.dateValue as? ForyDate, expected.dateValue as? ForyDate) + XCTAssertEqual(actual.dateValue as? LocalDate, expected.dateValue as? LocalDate) XCTAssertEqual((actual.timestampValue as? Date)?.timeIntervalSince1970, (expected.timestampValue as? Date)?.timeIntervalSince1970) XCTAssertEqual(actual.messageValue as? AnyExample.AnyInner, expected.messageValue as? AnyExample.AnyInner) XCTAssertEqual(actual.unionValue as? AnyExample.AnyUnion, expected.unionValue as? AnyExample.AnyUnion) @@ -473,7 +473,7 @@ final class IdlRoundTripTests: XCTestCase { private func assertAnyProtoHolder(expected: AnyExamplePb.AnyHolder, actual: AnyExamplePb.AnyHolder) { XCTAssertEqual(actual.boolValue as? Bool, expected.boolValue as? Bool) XCTAssertEqual(actual.stringValue as? String, expected.stringValue as? String) - XCTAssertEqual(actual.dateValue as? ForyDate, expected.dateValue as? ForyDate) + XCTAssertEqual(actual.dateValue as? LocalDate, expected.dateValue as? LocalDate) XCTAssertEqual((actual.timestampValue as? Date)?.timeIntervalSince1970, (expected.timestampValue as? Date)?.timeIntervalSince1970) XCTAssertEqual(actual.messageValue as? AnyExamplePb.AnyInner, expected.messageValue as? AnyExamplePb.AnyInner) XCTAssertEqual(actual.unionValue as? AnyExamplePb.AnyUnion, expected.unionValue as? AnyExamplePb.AnyUnion) diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/XtypeResolver.java b/java/fory-core/src/main/java/org/apache/fory/resolver/XtypeResolver.java index d02da96a82..06737b2c97 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/XtypeResolver.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/XtypeResolver.java @@ -77,6 +77,8 @@ import org.apache.fory.meta.TypeDef; import org.apache.fory.reflect.ReflectionUtils; import org.apache.fory.serializer.ArraySerializers; +import org.apache.fory.serializer.BigIntegerSerializer; +import org.apache.fory.serializer.DecimalSerializer; import org.apache.fory.serializer.DeferedLazySerializer; import org.apache.fory.serializer.DeferedLazySerializer.DeferredLazyObjectSerializer; import org.apache.fory.serializer.EnumSerializer; @@ -958,8 +960,8 @@ private void registerDefaultTypes() { registerType(Types.DATE, LocalDate.class, new TimeSerializers.LocalDateSerializer(config)); // Decimal types - registerType(Types.DECIMAL, BigDecimal.class, new Serializers.BigDecimalSerializer(config)); - registerType(Types.DECIMAL, BigInteger.class, new Serializers.BigIntegerSerializer(config)); + registerType(Types.DECIMAL, BigDecimal.class, new DecimalSerializer(config)); + registerType(Types.DECIMAL, BigInteger.class, new BigIntegerSerializer(config)); // Binary types registerType(Types.BINARY, byte[].class, new ArraySerializers.ByteArraySerializer(this)); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/BigIntegerSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/BigIntegerSerializer.java new file mode 100644 index 0000000000..9465011855 --- /dev/null +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/BigIntegerSerializer.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.fory.serializer; + +import java.math.BigInteger; +import org.apache.fory.config.Config; +import org.apache.fory.context.ReadContext; +import org.apache.fory.context.WriteContext; +import org.apache.fory.memory.MemoryBuffer; + +/** Serializer for {@link BigInteger} in native and xlang modes. */ +public final class BigIntegerSerializer extends ImmutableSerializer + implements Shareable { + private final boolean xlang; + + public BigIntegerSerializer(Config config) { + super(config, BigInteger.class); + xlang = config.isXlang(); + } + + @Override + public void write(WriteContext writeContext, BigInteger value) { + if (xlang) { + writeXlang(writeContext, value); + } else { + writeNative(writeContext, value); + } + } + + @Override + public BigInteger read(ReadContext readContext) { + if (xlang) { + return readXlang(readContext); + } + return readNative(readContext); + } + + private void writeNative(WriteContext writeContext, BigInteger value) { + MemoryBuffer buffer = writeContext.getBuffer(); + byte[] bytes = value.toByteArray(); + buffer.writeVarUint32Small7(bytes.length); + buffer.writeBytes(bytes); + } + + private BigInteger readNative(ReadContext readContext) { + MemoryBuffer buffer = readContext.getBuffer(); + int len = buffer.readVarUint32Small7(); + byte[] bytes = buffer.readBytes(len); + return new BigInteger(bytes); + } + + private void writeXlang(WriteContext writeContext, BigInteger value) { + DecimalSerializer.writeXlangDecimal(writeContext.getBuffer(), 0, value); + } + + private BigInteger readXlang(ReadContext readContext) { + return DecimalSerializer.readXlangBigInteger(readContext.getBuffer()); + } +} diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/DecimalSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/DecimalSerializer.java new file mode 100644 index 0000000000..4c19748398 --- /dev/null +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/DecimalSerializer.java @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.fory.serializer; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.math.MathContext; +import org.apache.fory.config.Config; +import org.apache.fory.context.ReadContext; +import org.apache.fory.context.WriteContext; +import org.apache.fory.memory.MemoryBuffer; + +/** Serializer for {@link BigDecimal} in native and xlang modes. */ +public final class DecimalSerializer extends ImmutableSerializer implements Shareable { + private static final BigInteger LONG_MIN = BigInteger.valueOf(Long.MIN_VALUE); + private static final BigInteger LONG_MAX = BigInteger.valueOf(Long.MAX_VALUE); + private final boolean xlang; + + public DecimalSerializer(Config config) { + super(config, BigDecimal.class); + xlang = config.isXlang(); + } + + @Override + public void write(WriteContext writeContext, BigDecimal value) { + if (xlang) { + writeXlang(writeContext, value); + } else { + writeNative(writeContext, value); + } + } + + @Override + public BigDecimal read(ReadContext readContext) { + if (xlang) { + return readXlang(readContext); + } + return readNative(readContext); + } + + private void writeNative(WriteContext writeContext, BigDecimal value) { + MemoryBuffer buffer = writeContext.getBuffer(); + byte[] bytes = value.unscaledValue().toByteArray(); + buffer.writeVarUint32Small7(value.scale()); + buffer.writeVarUint32Small7(value.precision()); + buffer.writeVarUint32Small7(bytes.length); + buffer.writeBytes(bytes); + } + + private BigDecimal readNative(ReadContext readContext) { + MemoryBuffer buffer = readContext.getBuffer(); + int scale = buffer.readVarUint32Small7(); + int precision = buffer.readVarUint32Small7(); + int len = buffer.readVarUint32Small7(); + byte[] bytes = buffer.readBytes(len); + BigInteger bigInteger = new BigInteger(bytes); + return new BigDecimal(bigInteger, scale, new MathContext(precision)); + } + + private void writeXlang(WriteContext writeContext, BigDecimal value) { + writeXlangDecimal(writeContext.getBuffer(), value.scale(), value.unscaledValue()); + } + + private BigDecimal readXlang(ReadContext readContext) { + return readXlangDecimal(readContext.getBuffer()); + } + + static void writeXlangDecimal(MemoryBuffer buffer, int scale, BigInteger unscaled) { + buffer.writeVarInt32(scale); + if (canUseSmallEncoding(unscaled)) { + long smallValue = unscaled.longValue(); + long header = encodeZigZag64(smallValue) << 1; + buffer.writeVarUint64(header); + return; + } + + int sign = unscaled.signum() < 0 ? 1 : 0; + byte[] payload = toCanonicalLittleEndianMagnitude(unscaled.abs()); + long meta = (((long) payload.length) << 1) | sign; + long header = (meta << 1) | 1L; + buffer.writeVarUint64(header); + buffer.writeBytes(payload); + } + + static BigDecimal readXlangDecimal(MemoryBuffer buffer) { + int scale = buffer.readVarInt32(); + return new BigDecimal(readXlangUnscaled(buffer), scale); + } + + static BigInteger readXlangBigInteger(MemoryBuffer buffer) { + int scale = buffer.readVarInt32(); + BigInteger unscaled = readXlangUnscaled(buffer); + if (scale != 0) { + throw new IllegalArgumentException( + "Cannot deserialize xlang decimal with scale " + scale + " into BigInteger"); + } + return unscaled; + } + + private static BigInteger readXlangUnscaled(MemoryBuffer buffer) { + long header = buffer.readVarUint64(); + if ((header & 1L) == 0L) { + return BigInteger.valueOf(decodeZigZag64(header >>> 1)); + } + long meta = header >>> 1; + int sign = (int) (meta & 1L); + long lenLong = meta >>> 1; + if (lenLong <= 0 || lenLong > Integer.MAX_VALUE) { + throw new IllegalArgumentException( + "Invalid decimal magnitude length " + lenLong + " in xlang payload"); + } + int len = (int) lenLong; + byte[] payload = buffer.readBytes(len); + if (payload[len - 1] == 0) { + throw new IllegalArgumentException("Non-canonical decimal payload: trailing zero byte"); + } + byte[] magnitude = toBigEndian(payload); + BigInteger abs = new BigInteger(1, magnitude); + if (abs.signum() == 0) { + throw new IllegalArgumentException("Big decimal encoding must not represent zero"); + } + return sign == 0 ? abs : abs.negate(); + } + + private static boolean canUseSmallEncoding(BigInteger value) { + if (value.compareTo(LONG_MIN) < 0 || value.compareTo(LONG_MAX) > 0) { + return false; + } + // The small form reserves the low header bit to distinguish small/big encodings, + // so the zigzag value itself must still fit in 63 bits before the final << 1. + long zigZag = encodeZigZag64(value.longValue()); + return (zigZag & Long.MIN_VALUE) == 0; + } + + private static long encodeZigZag64(long value) { + return (value << 1) ^ (value >> 63); + } + + private static long decodeZigZag64(long value) { + return (value >>> 1) ^ -(value & 1L); + } + + private static byte[] toCanonicalLittleEndianMagnitude(BigInteger abs) { + byte[] bigEndian = abs.toByteArray(); + int start = 0; + while (start < bigEndian.length - 1 && bigEndian[start] == 0) { + start++; + } + int len = bigEndian.length - start; + if (len <= 0) { + throw new IllegalArgumentException("Zero must use the small decimal encoding"); + } + byte[] littleEndian = new byte[len]; + for (int i = 0; i < len; i++) { + littleEndian[i] = bigEndian[bigEndian.length - 1 - i]; + } + return littleEndian; + } + + private static byte[] toBigEndian(byte[] littleEndian) { + byte[] bigEndian = new byte[littleEndian.length]; + for (int i = 0; i < littleEndian.length; i++) { + bigEndian[i] = littleEndian[littleEndian.length - 1 - i]; + } + return bigEndian; + } +} diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/Serializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/Serializers.java index 3e1437a9cd..580932c612 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/Serializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/Serializers.java @@ -27,7 +27,6 @@ import java.lang.reflect.Method; import java.math.BigDecimal; import java.math.BigInteger; -import java.math.MathContext; import java.net.URI; import java.nio.charset.Charset; import java.util.Currency; @@ -416,57 +415,6 @@ public StringBuffer read(ReadContext readContext) { } } - public static final class BigDecimalSerializer extends ImmutableSerializer - implements Shareable { - public BigDecimalSerializer(Config config) { - super(config, BigDecimal.class); - } - - @Override - public void write(WriteContext writeContext, BigDecimal value) { - MemoryBuffer buffer = writeContext.getBuffer(); - final byte[] bytes = value.unscaledValue().toByteArray(); - buffer.writeVarUint32Small7(value.scale()); - buffer.writeVarUint32Small7(value.precision()); - buffer.writeVarUint32Small7(bytes.length); - buffer.writeBytes(bytes); - } - - @Override - public BigDecimal read(ReadContext readContext) { - MemoryBuffer buffer = readContext.getBuffer(); - int scale = buffer.readVarUint32Small7(); - int precision = buffer.readVarUint32Small7(); - int len = buffer.readVarUint32Small7(); - byte[] bytes = buffer.readBytes(len); - final BigInteger bigInteger = new BigInteger(bytes); - return new BigDecimal(bigInteger, scale, new MathContext(precision)); - } - } - - public static final class BigIntegerSerializer extends ImmutableSerializer - implements Shareable { - public BigIntegerSerializer(Config config) { - super(config, BigInteger.class); - } - - @Override - public void write(WriteContext writeContext, BigInteger value) { - MemoryBuffer buffer = writeContext.getBuffer(); - final byte[] bytes = value.toByteArray(); - buffer.writeVarUint32Small7(bytes.length); - buffer.writeBytes(bytes); - } - - @Override - public BigInteger read(ReadContext readContext) { - MemoryBuffer buffer = readContext.getBuffer(); - int len = buffer.readVarUint32Small7(); - byte[] bytes = buffer.readBytes(len); - return new BigInteger(bytes); - } - } - public static final class AtomicBooleanSerializer extends Serializer implements Shareable { @@ -698,7 +646,7 @@ public static void registerDefaultSerializers(TypeResolver resolver) { resolver.registerInternalSerializer(StringBuilder.class, new StringBuilderSerializer(config)); resolver.registerInternalSerializer(StringBuffer.class, new StringBufferSerializer(config)); resolver.registerInternalSerializer(BigInteger.class, new BigIntegerSerializer(config)); - resolver.registerInternalSerializer(BigDecimal.class, new BigDecimalSerializer(config)); + resolver.registerInternalSerializer(BigDecimal.class, new DecimalSerializer(config)); resolver.registerInternalSerializer(AtomicBoolean.class, new AtomicBooleanSerializer(config)); resolver.registerInternalSerializer(AtomicInteger.class, new AtomicIntegerSerializer(config)); resolver.registerInternalSerializer(AtomicLong.class, new AtomicLongSerializer(config)); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/TimeSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/TimeSerializers.java index a49da25d5e..d7e79b8b48 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/TimeSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/TimeSerializers.java @@ -43,7 +43,6 @@ import org.apache.fory.context.WriteContext; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.resolver.ClassResolver; -import org.apache.fory.util.DateTimeUtils; /** Serializers for all time related types. */ public class TimeSerializers { @@ -131,13 +130,16 @@ public LocalDateSerializer(Config config, boolean needToWriteRef) { public void write(WriteContext writeContext, LocalDate value) { MemoryBuffer buffer = writeContext.getBuffer(); if (config.isXlang()) { - // TODO use java encoding to support larger range. - buffer.writeInt32(DateTimeUtils.localDateToDays(value)); + writeXlangLocalDate(buffer, value); } else { writeLocalDate(buffer, value); } } + public static void writeXlangLocalDate(MemoryBuffer buffer, LocalDate value) { + buffer.writeVarInt64(value.toEpochDay()); + } + public static void writeLocalDate(MemoryBuffer buffer, LocalDate value) { buffer.writeInt32(value.getYear()); buffer.writeByte(value.getMonthValue()); @@ -148,11 +150,15 @@ public static void writeLocalDate(MemoryBuffer buffer, LocalDate value) { public LocalDate read(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); if (config.isXlang()) { - return DateTimeUtils.daysToLocalDate(buffer.readInt32()); + return readXlangLocalDate(buffer); } return readLocalDate(buffer); } + public static LocalDate readXlangLocalDate(MemoryBuffer buffer) { + return LocalDate.ofEpochDay(buffer.readVarInt64()); + } + public static LocalDate readLocalDate(MemoryBuffer buffer) { int year = buffer.readInt32(); int month = buffer.readByte(); diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/SerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/SerializersTest.java index 2b3141415d..63222590bd 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/SerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/SerializersTest.java @@ -21,6 +21,7 @@ import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertSame; +import static org.testng.Assert.assertThrows; import static org.testng.Assert.assertTrue; import java.math.BigDecimal; @@ -42,6 +43,8 @@ import org.apache.fory.ForyTestBase; import org.apache.fory.config.ForyBuilder; import org.apache.fory.config.Language; +import org.apache.fory.memory.MemoryBuffer; +import org.apache.fory.memory.MemoryUtils; import org.testng.Assert; import org.testng.annotations.Test; @@ -76,6 +79,84 @@ public void testBigInt(boolean referenceTracking) { fory1, new BigInteger("11111111110101010000283895380202208220050200000000111111111")); } + @Test(dataProvider = "referenceTrackingConfig") + public void testXlangDecimalRoundTrip(boolean referenceTracking) { + ForyBuilder builder = + Fory.builder() + .withLanguage(Language.XLANG) + .withRefTracking(referenceTracking) + .requireClassRegistration(false); + Fory fory1 = builder.build(); + Fory fory2 = builder.build(); + List decimalValues = + Arrays.asList( + BigDecimal.ZERO, + BigDecimal.ONE, + BigDecimal.ONE.negate(), + BigDecimal.valueOf(12345, 2), + new BigDecimal(BigInteger.valueOf(Long.MAX_VALUE), 0), + new BigDecimal(BigInteger.valueOf(Long.MIN_VALUE), 0), + new BigDecimal(BigInteger.valueOf(Long.MAX_VALUE).add(BigInteger.ONE), 0), + new BigDecimal(BigInteger.valueOf(Long.MIN_VALUE).subtract(BigInteger.ONE), 0), + new BigDecimal(new BigInteger("123456789012345678901234567890123456789"), 37), + new BigDecimal(new BigInteger("-123456789012345678901234567890123456789"), -17)); + for (BigDecimal value : decimalValues) { + assertEquals(serDe(fory1, fory2, value), value); + } + } + + @Test + public void testXlangDecimalCodecCanonicalRoundTrip() { + List values = + Arrays.asList( + BigDecimal.ZERO, + BigDecimal.ONE, + BigDecimal.ONE.negate(), + BigDecimal.valueOf(100, 2), + new BigDecimal(BigInteger.valueOf(Long.MAX_VALUE), 0), + new BigDecimal(BigInteger.valueOf(Long.MIN_VALUE), 0), + new BigDecimal(BigInteger.valueOf(Long.MAX_VALUE).add(BigInteger.ONE), 0), + new BigDecimal(BigInteger.valueOf(Long.MIN_VALUE).subtract(BigInteger.ONE), 0), + new BigDecimal(new BigInteger("999999999999999999999999999999999999999999"), 200)); + for (BigDecimal value : values) { + MemoryBuffer buffer = MemoryUtils.buffer(64); + DecimalSerializer.writeXlangDecimal(buffer, value.scale(), value.unscaledValue()); + buffer.readerIndex(0); + assertEquals(DecimalSerializer.readXlangDecimal(buffer), value); + } + } + + @Test + public void testXlangDecimalCodecRejectsNonCanonicalBigPayloads() { + MemoryBuffer zeroBigEncoding = MemoryUtils.buffer(16); + zeroBigEncoding.writeVarInt32(0); + zeroBigEncoding.writeVarUint64(1L); + zeroBigEncoding.readerIndex(0); + assertThrows( + IllegalArgumentException.class, () -> DecimalSerializer.readXlangDecimal(zeroBigEncoding)); + + MemoryBuffer trailingZeroPayload = MemoryUtils.buffer(16); + trailingZeroPayload.writeVarInt32(0); + trailingZeroPayload.writeVarUint64(9L); + trailingZeroPayload.writeBytes(new byte[] {1, 0}); + trailingZeroPayload.readerIndex(0); + assertThrows( + IllegalArgumentException.class, + () -> DecimalSerializer.readXlangDecimal(trailingZeroPayload)); + } + + @Test + public void testDecimalSerializerSelectionByLanguage() { + Fory nativeFory = + Fory.builder().withLanguage(Language.JAVA).requireClassRegistration(false).build(); + Fory xlangFory = + Fory.builder().withLanguage(Language.XLANG).requireClassRegistration(false).build(); + assertEquals(nativeFory.getSerializer(BigDecimal.class).getClass(), DecimalSerializer.class); + assertEquals(xlangFory.getSerializer(BigDecimal.class).getClass(), DecimalSerializer.class); + assertEquals(nativeFory.getSerializer(BigInteger.class).getClass(), BigIntegerSerializer.class); + assertEquals(xlangFory.getSerializer(BigInteger.class).getClass(), BigIntegerSerializer.class); + } + @Test(dataProvider = "javaFory") public void testAtomic(Fory fory) { assertTrue( diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/TimeSerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/TimeSerializersTest.java index af5309b02f..25972107c5 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/TimeSerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/TimeSerializersTest.java @@ -42,6 +42,8 @@ import org.apache.fory.Fory; import org.apache.fory.ForyTestBase; import org.apache.fory.config.Language; +import org.apache.fory.memory.MemoryBuffer; +import org.apache.fory.memory.MemoryUtils; import org.apache.fory.util.DateTimeUtils; import org.testng.Assert; import org.testng.annotations.Test; @@ -68,6 +70,47 @@ public void testBasicTime() { serDeCheckSerializerAndEqual(fory, Period.of(100, 11, 20), "Time"); } + @Test + public void testXlangLocalDateUsesVarInt64Encoding() { + MemoryBuffer epochBuffer = MemoryUtils.buffer(16); + TimeSerializers.LocalDateSerializer.writeXlangLocalDate(epochBuffer, LocalDate.of(1970, 1, 1)); + Assert.assertEquals(epochBuffer.writerIndex(), 1); + epochBuffer.readerIndex(0); + Assert.assertEquals( + TimeSerializers.LocalDateSerializer.readXlangLocalDate(epochBuffer), + LocalDate.of(1970, 1, 1)); + + MemoryBuffer beforeEpochBuffer = MemoryUtils.buffer(16); + TimeSerializers.LocalDateSerializer.writeXlangLocalDate( + beforeEpochBuffer, LocalDate.of(1969, 12, 31)); + Assert.assertEquals(beforeEpochBuffer.writerIndex(), 1); + beforeEpochBuffer.readerIndex(0); + Assert.assertEquals( + TimeSerializers.LocalDateSerializer.readXlangLocalDate(beforeEpochBuffer), + LocalDate.of(1969, 12, 31)); + } + + @Test + public void testXlangLocalDateSupportsWideJavaRange() { + LocalDate[] values = {LocalDate.MIN, LocalDate.MAX}; + for (LocalDate value : values) { + MemoryBuffer buffer = MemoryUtils.buffer(16); + TimeSerializers.LocalDateSerializer.writeXlangLocalDate(buffer, value); + buffer.readerIndex(0); + Assert.assertEquals(TimeSerializers.LocalDateSerializer.readXlangLocalDate(buffer), value); + } + } + + @Test + public void testNativeLocalDateEncodingRemainsYearMonthDay() { + MemoryBuffer buffer = MemoryUtils.buffer(16); + LocalDate value = LocalDate.of(2024, 2, 29); + TimeSerializers.LocalDateSerializer.writeLocalDate(buffer, value); + Assert.assertEquals(buffer.writerIndex(), Integer.BYTES + 2); + buffer.readerIndex(0); + Assert.assertEquals(TimeSerializers.LocalDateSerializer.readLocalDate(buffer), value); + } + @Test(dataProvider = "foryCopyConfig") public void testBasicTime(Fory fory) { copyCheckWithoutSame(fory, new Date()); diff --git a/java/fory-core/src/test/java/org/apache/fory/xlang/CSharpXlangTest.java b/java/fory-core/src/test/java/org/apache/fory/xlang/CSharpXlangTest.java index 64095b133a..0618680a0f 100644 --- a/java/fory-core/src/test/java/org/apache/fory/xlang/CSharpXlangTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/xlang/CSharpXlangTest.java @@ -181,6 +181,11 @@ public void testInteger(boolean enableCodegen) throws java.io.IOException { super.testInteger(enableCodegen); } + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testDecimal(boolean enableCodegen) throws java.io.IOException { + super.testDecimal(enableCodegen); + } + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") public void testItem(boolean enableCodegen) throws java.io.IOException { super.testItem(enableCodegen); diff --git a/java/fory-core/src/test/java/org/apache/fory/xlang/DartXlangTest.java b/java/fory-core/src/test/java/org/apache/fory/xlang/DartXlangTest.java index f9d8336379..2799a58e1a 100644 --- a/java/fory-core/src/test/java/org/apache/fory/xlang/DartXlangTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/xlang/DartXlangTest.java @@ -167,6 +167,11 @@ public void testInteger(boolean enableCodegen) throws java.io.IOException { super.testInteger(enableCodegen); } + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testDecimal(boolean enableCodegen) throws java.io.IOException { + super.testDecimal(enableCodegen); + } + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") public void testItem(boolean enableCodegen) throws java.io.IOException { super.testItem(enableCodegen); diff --git a/java/fory-core/src/test/java/org/apache/fory/xlang/GoXlangTest.java b/java/fory-core/src/test/java/org/apache/fory/xlang/GoXlangTest.java index d80dca0f84..96ed9aeaf2 100644 --- a/java/fory-core/src/test/java/org/apache/fory/xlang/GoXlangTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/xlang/GoXlangTest.java @@ -156,6 +156,11 @@ public void testInteger(boolean enableCodegen) throws java.io.IOException { super.testInteger(enableCodegen); } + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testDecimal(boolean enableCodegen) throws java.io.IOException { + super.testDecimal(enableCodegen); + } + // this test failed more frequently when refactor, create two separate tests // to make debug more easy @Test(groups = "xlang") diff --git a/java/fory-core/src/test/java/org/apache/fory/xlang/JavaScriptXlangTest.java b/java/fory-core/src/test/java/org/apache/fory/xlang/JavaScriptXlangTest.java index 0c883974a3..c7e9e05a74 100644 --- a/java/fory-core/src/test/java/org/apache/fory/xlang/JavaScriptXlangTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/xlang/JavaScriptXlangTest.java @@ -501,6 +501,11 @@ public void testUnionXlang(boolean enableCodegen) throws java.io.IOException { super.testUnionXlang(enableCodegen); } + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testDecimal(boolean enableCodegen) throws java.io.IOException { + super.testDecimal(enableCodegen); + } + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") public void testRefSchemaConsistent(boolean enableCodegen) throws java.io.IOException { super.testRefSchemaConsistent(enableCodegen); diff --git a/java/fory-core/src/test/java/org/apache/fory/xlang/PythonXlangTest.java b/java/fory-core/src/test/java/org/apache/fory/xlang/PythonXlangTest.java index 5c100b46b8..ebfedf37a4 100644 --- a/java/fory-core/src/test/java/org/apache/fory/xlang/PythonXlangTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/xlang/PythonXlangTest.java @@ -135,6 +135,12 @@ public void testInteger(boolean enableCodegen) throws IOException { throw new SkipException("Skipping: similar test already covered in CrossLanguageTest"); } + @Override + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testDecimal(boolean enableCodegen) throws IOException { + super.testDecimal(enableCodegen); + } + // ============================================================================ // Explicitly re-declare inherited test methods to enable running individual // tests via Maven: mvn test -Dtest=org.apache.fory.xlang.PythonXlangTest#testXxx diff --git a/java/fory-core/src/test/java/org/apache/fory/xlang/RustXlangTest.java b/java/fory-core/src/test/java/org/apache/fory/xlang/RustXlangTest.java index 1764013648..9e2e4f9e2b 100644 --- a/java/fory-core/src/test/java/org/apache/fory/xlang/RustXlangTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/xlang/RustXlangTest.java @@ -149,6 +149,11 @@ public void testInteger(boolean enableCodegen) throws java.io.IOException { super.testInteger(enableCodegen); } + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testDecimal(boolean enableCodegen) throws java.io.IOException { + super.testDecimal(enableCodegen); + } + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") public void testItem(boolean enableCodegen) throws java.io.IOException { super.testItem(enableCodegen); diff --git a/java/fory-core/src/test/java/org/apache/fory/xlang/SwiftXlangTest.java b/java/fory-core/src/test/java/org/apache/fory/xlang/SwiftXlangTest.java index 229d55372b..79bc291ff7 100644 --- a/java/fory-core/src/test/java/org/apache/fory/xlang/SwiftXlangTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/xlang/SwiftXlangTest.java @@ -228,6 +228,11 @@ public void testInteger(boolean enableCodegen) throws java.io.IOException { super.testInteger(enableCodegen); } + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testDecimal(boolean enableCodegen) throws java.io.IOException { + super.testDecimal(enableCodegen); + } + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") public void testItem(boolean enableCodegen) throws java.io.IOException { super.testItem(enableCodegen); diff --git a/java/fory-core/src/test/java/org/apache/fory/xlang/XlangTestBase.java b/java/fory-core/src/test/java/org/apache/fory/xlang/XlangTestBase.java index 047f4c22ef..459b7885e6 100644 --- a/java/fory-core/src/test/java/org/apache/fory/xlang/XlangTestBase.java +++ b/java/fory-core/src/test/java/org/apache/fory/xlang/XlangTestBase.java @@ -23,6 +23,8 @@ import com.google.common.hash.Hashing; import java.io.File; import java.io.IOException; +import java.math.BigDecimal; +import java.math.BigInteger; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; @@ -811,6 +813,23 @@ static class Item1 { Integer f6; } + protected static List decimalValues() { + return Arrays.asList( + BigDecimal.ZERO, + new BigDecimal(BigInteger.ZERO, 3), + BigDecimal.ONE, + BigDecimal.ONE.negate(), + BigDecimal.valueOf(12345, 2), + new BigDecimal(BigInteger.valueOf(Long.MAX_VALUE), 0), + new BigDecimal(BigInteger.valueOf(Long.MIN_VALUE), 0), + new BigDecimal(BigInteger.valueOf(4611686018427387903L), 0), + new BigDecimal(BigInteger.valueOf(-4611686018427387904L), 0), + new BigDecimal(BigInteger.valueOf(Long.MAX_VALUE).add(BigInteger.ONE), 0), + new BigDecimal(BigInteger.valueOf(Long.MIN_VALUE).subtract(BigInteger.ONE), 0), + new BigDecimal(new BigInteger("123456789012345678901234567890123456789"), 37), + new BigDecimal(new BigInteger("-123456789012345678901234567890123456789"), -17)); + } + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") public void testInteger(boolean enableCodegen) throws java.io.IOException { String caseName = "test_integer"; @@ -865,6 +884,28 @@ public void testInteger(boolean enableCodegen) throws java.io.IOException { Assert.assertEquals(fory.deserialize(buffer2), 0); } + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testDecimal(boolean enableCodegen) throws java.io.IOException { + String caseName = "test_decimal"; + Fory fory = + Fory.builder() + .withLanguage(Language.XLANG) + .withCodegen(enableCodegen) + .withCompatibleMode(CompatibleMode.COMPATIBLE) + .build(); + List values = decimalValues(); + MemoryBuffer buffer = MemoryUtils.buffer(64); + for (BigDecimal value : values) { + fory.serialize(buffer, value); + } + ExecutionContext ctx = prepareExecution(caseName, buffer.getBytes(0, buffer.writerIndex())); + runPeer(ctx); + MemoryBuffer result = readBuffer(ctx.dataFile()); + for (BigDecimal value : values) { + Assert.assertEquals(fory.deserialize(result), value); + } + } + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") public void testItem(boolean enableCodegen) throws java.io.IOException { String caseName = "test_item"; diff --git a/javascript/packages/core/index.ts b/javascript/packages/core/index.ts index 8dfd570bf7..7a70db13f4 100644 --- a/javascript/packages/core/index.ts +++ b/javascript/packages/core/index.ts @@ -26,8 +26,9 @@ import { Serializer, Mode } from "./lib/type"; import Fory from "./lib/fory"; import { BinaryReader } from "./lib/reader"; import { BinaryWriter } from "./lib/writer"; -import { BFloat16, BFloat16Array } from "./lib/bfloat16"; +import { BFloat16, BFloat16Array } from "./lib/types/bfloat16"; import { ReadContext, WriteContext } from "./lib/context"; +import { Decimal } from "./lib/types/decimal"; export { Serializer, @@ -39,6 +40,7 @@ export { BinaryReader, BFloat16, BFloat16Array, + Decimal, ReadContext, WriteContext, }; diff --git a/javascript/packages/core/lib/gen/datetime.ts b/javascript/packages/core/lib/gen/datetime.ts index b785fec63d..c34fdadfd5 100644 --- a/javascript/packages/core/lib/gen/datetime.ts +++ b/javascript/packages/core/lib/gen/datetime.ts @@ -100,20 +100,20 @@ class DateSerializerGenerator extends BaseSerializerGenerator { const epoch = this.scope.declareByName("epoch", `new Date("1970/01/01 00:00").getTime()`); return ` if (${accessor} instanceof Date) { - ${this.builder.writer.writeInt32(`Math.floor((${accessor}.getTime() - ${epoch}) / 1000 / (24 * 60 * 60))`)} + ${this.builder.writer.writeVarInt64(`Math.floor((${accessor}.getTime() - ${epoch}) / 1000 / (24 * 60 * 60))`)} } else { - ${this.builder.writer.writeInt32(`Math.floor((${accessor} - ${epoch}) / 1000 / (24 * 60 * 60))`)} + ${this.builder.writer.writeVarInt64(`Math.floor((${accessor} - ${epoch}) / 1000 / (24 * 60 * 60))`)} } `; } read(accessor: (expr: string) => string): string { const epoch = this.scope.declareByName("epoch", `new Date("1970/01/01 00:00").getTime()`); - return accessor(`new Date(${epoch} + (${this.builder.reader.readInt32()} * (24 * 60 * 60) * 1000))`); + return accessor(`new Date(${epoch} + (Number(${this.builder.reader.readVarInt64()}) * (24 * 60 * 60) * 1000))`); } getFixedSize(): number { - return 7; + return 11; } } diff --git a/javascript/packages/core/lib/gen/decimal.ts b/javascript/packages/core/lib/gen/decimal.ts new file mode 100644 index 0000000000..6c3deddefb --- /dev/null +++ b/javascript/packages/core/lib/gen/decimal.ts @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { TypeInfo } from "../typeInfo"; +import { CodecBuilder } from "./builder"; +import { BaseSerializerGenerator } from "./serializer"; +import { CodegenRegistry } from "./router"; +import { TypeId } from "../type"; +import { Scope } from "./scope"; +import { Decimal, DecimalCodec } from "../types/decimal"; + +class DecimalSerializerGenerator extends BaseSerializerGenerator { + typeInfo: TypeInfo; + + constructor(typeInfo: TypeInfo, builder: CodecBuilder, scope: Scope) { + super(typeInfo, builder, scope); + this.typeInfo = typeInfo; + } + + write(accessor: string): string { + const codec = this.builder.getExternal(DecimalCodec.name); + const scale = this.scope.uniqueName("decimal_scale"); + const unscaled = this.scope.uniqueName("decimal_unscaled"); + const payload = this.scope.uniqueName("decimal_payload"); + const meta = this.scope.uniqueName("decimal_meta"); + return ` + const ${scale} = ${accessor}.scale; + const ${unscaled} = ${accessor}.unscaledValue; + ${this.builder.writer.writeVarInt32(scale)} + if (${codec}.canUseSmallEncoding(${unscaled})) { + ${this.builder.writer.writeVarUInt64(`(${codec}.encodeZigZag64(${unscaled}) << 1n)`)} + return; + } + const ${payload} = ${codec}.toCanonicalLittleEndianMagnitude(${unscaled}); + const ${meta} = (BigInt(${payload}.length) << 1n) | (${unscaled} < 0n ? 1n : 0n); + ${this.builder.writer.writeVarUInt64(`((${meta} << 1n) | 1n)`)} + ${this.builder.writer.buffer(payload)} + `; + } + + read(accessor: (expr: string) => string): string { + const codec = this.builder.getExternal(DecimalCodec.name); + const decimal = this.builder.getExternal(Decimal.name); + const scale = this.scope.uniqueName("decimal_scale"); + const header = this.scope.uniqueName("decimal_header"); + const meta = this.scope.uniqueName("decimal_meta"); + const length = this.scope.uniqueName("decimal_length"); + const payload = this.scope.uniqueName("decimal_payload"); + const magnitude = this.scope.uniqueName("decimal_magnitude"); + const unscaled = this.scope.uniqueName("decimal_unscaled"); + return ` + const ${scale} = ${this.builder.reader.readVarInt32()}; + const ${header} = ${this.builder.reader.readVarUInt64()}; + if ((${header} & 1n) === 0n) { + ${accessor(`new ${decimal}(${codec}.decodeZigZag64(${header} >> 1n), ${scale})`)} + return; + } + const ${meta} = ${header} >> 1n; + const ${length} = Number(${meta} >> 1n); + if (${length} <= 0 || ${length} > 0x7fffffff) { + throw new Error(\`Invalid decimal magnitude length \${${length}}.\`); + } + const ${payload} = ${this.builder.reader.buffer(length)}; + if (${payload}[${length} - 1] === 0) { + throw new Error("Non-canonical decimal payload: trailing zero byte."); + } + const ${magnitude} = ${codec}.fromCanonicalLittleEndianMagnitude(${payload}); + if (${magnitude} === 0n) { + throw new Error("Big decimal encoding must not represent zero."); + } + const ${unscaled} = ((${meta} & 1n) === 0n) ? ${magnitude} : -${magnitude}; + ${accessor(`new ${decimal}(${unscaled}, ${scale})`)} + `; + } + + getFixedSize(): number { + return 20; + } +} + +CodegenRegistry.register(TypeId.DECIMAL, DecimalSerializerGenerator); +CodegenRegistry.registerExternal(Decimal); +CodegenRegistry.registerExternal(DecimalCodec); diff --git a/javascript/packages/core/lib/gen/index.ts b/javascript/packages/core/lib/gen/index.ts index ee75d8138d..5e49878a04 100644 --- a/javascript/packages/core/lib/gen/index.ts +++ b/javascript/packages/core/lib/gen/index.ts @@ -27,6 +27,7 @@ import "./struct"; import "./string"; import "./bool"; import "./datetime"; +import "./decimal"; import "./map"; import "./number"; import "./set"; diff --git a/javascript/packages/core/lib/reader/index.ts b/javascript/packages/core/lib/reader/index.ts index 60218895e4..28be874d6a 100644 --- a/javascript/packages/core/lib/reader/index.ts +++ b/javascript/packages/core/lib/reader/index.ts @@ -21,7 +21,7 @@ import { LATIN1, UTF16, UTF8 } from "../type"; import { isNodeEnv } from "../util"; import { PlatformBuffer, alloc, fromUint8Array } from "../platformBuffer"; import { readLatin1String } from "./string"; -import { BFloat16 } from "../bfloat16"; +import { BFloat16 } from "../types/bfloat16"; export class BinaryReader { private sliceStringEnable; diff --git a/javascript/packages/core/lib/typeInfo.ts b/javascript/packages/core/lib/typeInfo.ts index 90cff00983..92dd0ef199 100644 --- a/javascript/packages/core/lib/typeInfo.ts +++ b/javascript/packages/core/lib/typeInfo.ts @@ -18,7 +18,8 @@ */ import { ForyTypeInfoSymbol, TypeId } from "./type"; -import { BFloat16 } from "./bfloat16"; +import { BFloat16 } from "./types/bfloat16"; +import { Decimal } from "./types/decimal"; const targetFields = new WeakMap any, { [key: string]: TypeInfo }>(); @@ -467,6 +468,10 @@ export type HintInput = T extends { type: typeof TypeId.DURATION; } ? Date + : T extends { + type: typeof TypeId.DECIMAL; + } + ? Decimal : T extends { type: typeof TypeId.TIMESTAMP; } @@ -537,6 +542,10 @@ export type HintResult = T extends never ? any : T extends { type: typeof TypeId.DURATION; } ? number + : T extends { + type: typeof TypeId.DECIMAL; + } + ? Decimal : T extends { type: typeof TypeId.DATE; } @@ -785,6 +794,11 @@ export const Type = { (TypeId.TIMESTAMP), ); }, + decimal() { + return TypeInfo.fromNonParam( + (TypeId.DECIMAL), + ); + }, boolArray() { return TypeInfo.fromNonParam( (TypeId.BOOL_ARRAY), diff --git a/javascript/packages/core/lib/typeResolver.ts b/javascript/packages/core/lib/typeResolver.ts index 69595c92e0..2c18fac6d8 100644 --- a/javascript/packages/core/lib/typeResolver.ts +++ b/javascript/packages/core/lib/typeResolver.ts @@ -21,6 +21,7 @@ import { ForyTypeInfoSymbol, WithForyClsInfo, Serializer, TypeId, MaxInt32, MinI import { Gen } from "./gen"; import { Dynamic, Type, TypeInfo } from "./typeInfo"; import { ReadContext, WriteContext } from "./context"; +import { Decimal } from "./types/decimal"; const uninitSerialize = { _initialized: false, @@ -94,6 +95,7 @@ export default class TypeResolver { private int64Serializer: null | Serializer = null; private boolSerializer: null | Serializer = null; private datetimeSerializer: null | Serializer = null; + private decimalSerializer: null | Serializer = null; private stringSerializer: null | Serializer = null; private setSerializer: null | Serializer = null; private arraySerializer: null | Serializer = null; @@ -195,6 +197,7 @@ export default class TypeResolver { registerSerializer(Type.timestamp()); registerSerializer(Type.duration()); registerSerializer(Type.date()); + registerSerializer(Type.decimal()); registerSerializer(Type.set(Type.any())); registerSerializer(Type.binary()); registerSerializer(Type.boolArray()); @@ -218,6 +221,7 @@ export default class TypeResolver { this.int64Serializer = this.getSerializerById(TypeId.INT64); this.boolSerializer = this.getSerializerById(TypeId.BOOL); this.datetimeSerializer = this.getSerializerById(TypeId.TIMESTAMP); + this.decimalSerializer = this.getSerializerById(TypeId.DECIMAL); this.stringSerializer = this.getSerializerById(TypeId.STRING); this.setSerializer = this.getSerializerById(TypeId.SET); this.arraySerializer = this.getSerializerById(TypeId.LIST); @@ -334,6 +338,10 @@ export default class TypeResolver { return this.stringSerializer; } + if (v instanceof Decimal) { + return this.decimalSerializer; + } + if (v instanceof Uint8Array) { return this.uint8ArraySerializer; } diff --git a/javascript/packages/core/lib/bfloat16.ts b/javascript/packages/core/lib/types/bfloat16.ts similarity index 100% rename from javascript/packages/core/lib/bfloat16.ts rename to javascript/packages/core/lib/types/bfloat16.ts diff --git a/javascript/packages/core/lib/types/decimal.ts b/javascript/packages/core/lib/types/decimal.ts new file mode 100644 index 0000000000..4baf3eec81 --- /dev/null +++ b/javascript/packages/core/lib/types/decimal.ts @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +const DECIMAL_SMALL_MIN = -(1n << 62n); +const DECIMAL_SMALL_MAX = (1n << 62n) - 1n; + +export class Decimal { + readonly unscaledValue: bigint; + readonly scale: number; + + constructor(unscaledValue: bigint | number | string, scale: number) { + if (!Number.isInteger(scale)) { + throw new Error(`Decimal scale must be an integer, got ${scale}`); + } + this.unscaledValue = BigInt(unscaledValue); + this.scale = scale; + } + + static from(unscaledValue: bigint | number | string, scale = 0): Decimal { + return new Decimal(unscaledValue, scale); + } + + equals(other: unknown): boolean { + return other instanceof Decimal + && other.scale === this.scale + && other.unscaledValue === this.unscaledValue; + } + + toString(): string { + return `${this.unscaledValue.toString()}e${-this.scale}`; + } +} + +export class DecimalCodec { + static canUseSmallEncoding(value: bigint): boolean { + return value >= DECIMAL_SMALL_MIN && value <= DECIMAL_SMALL_MAX; + } + + static encodeZigZag64(value: bigint): bigint { + return (value << 1n) ^ (value >> 63n); + } + + static decodeZigZag64(value: bigint): bigint { + return (value >> 1n) ^ -(value & 1n); + } + + static toCanonicalLittleEndianMagnitude(value: bigint): Uint8Array { + let magnitude = value < 0n ? -value : value; + if (magnitude === 0n) { + throw new Error("Zero must use the small decimal encoding."); + } + const bytes: number[] = []; + while (magnitude !== 0n) { + bytes.push(Number(magnitude & 0xffn)); + magnitude >>= 8n; + } + return Uint8Array.from(bytes); + } + + static fromCanonicalLittleEndianMagnitude(payload: Uint8Array): bigint { + let magnitude = 0n; + for (let i = payload.length - 1; i >= 0; i--) { + magnitude = (magnitude << 8n) | BigInt(payload[i]); + } + return magnitude; + } +} diff --git a/javascript/packages/core/lib/writer/index.ts b/javascript/packages/core/lib/writer/index.ts index fd2b96a544..f51c35b8a8 100644 --- a/javascript/packages/core/lib/writer/index.ts +++ b/javascript/packages/core/lib/writer/index.ts @@ -21,7 +21,7 @@ import { HalfMaxInt32, HalfMinInt32, Hps, LATIN1, UTF16, UTF8 } from "../type"; import { PlatformBuffer, alloc, strByteLength } from "../platformBuffer"; import { OwnershipError } from "../error"; import { toFloat16, toBFloat16 } from "./number"; -import { BFloat16 } from "../bfloat16"; +import { BFloat16 } from "../types/bfloat16"; const MAX_POOL_SIZE = 1024 * 1024 * 3; // 3MB diff --git a/javascript/test/crossLanguage.test.ts b/javascript/test/crossLanguage.test.ts index b58534ab42..8a4da7a536 100644 --- a/javascript/test/crossLanguage.test.ts +++ b/javascript/test/crossLanguage.test.ts @@ -20,6 +20,7 @@ import Fory, { BinaryReader, BinaryWriter, + Decimal, ReadContext, Type, Dynamic, @@ -50,6 +51,28 @@ const Long = { MIN_VALUE: BigInt("-9223372036854775808"), } +function decimal(unscaledValue: string | bigint | number, scale: number): Decimal { + return new Decimal(unscaledValue, scale); +} + +function decimalValues(): Decimal[] { + return [ + decimal(0n, 0), + decimal(0n, 3), + decimal(1n, 0), + decimal(-1n, 0), + decimal(12345n, 2), + decimal("9223372036854775807", 0), + decimal("-9223372036854775808", 0), + decimal("4611686018427387903", 0), + decimal("-4611686018427387904", 0), + decimal("9223372036854775808", 0), + decimal("-9223372036854775809", 0), + decimal("123456789012345678901234567890123456789", 37), + decimal("-123456789012345678901234567890123456789", -17), + ]; +} + describe("bool", () => { const dataFile = process.env["DATA_FILE"]; if (!dataFile) { @@ -531,6 +554,29 @@ describe("bool", () => { writeToFile(Buffer.concat(bfs)); }); + test("test_decimal", () => { + const fory = new Fory({ + compatible: true, + }); + + const expectedValues = decimalValues(); + const actualValues: Decimal[] = []; + let cursor = 0; + for (let i = 0; i < expectedValues.length; i++) { + const value = fory.deserialize(content.subarray(cursor)); + cursor += fory.readContext.reader.readGetCursor(); + expect(value).toBeInstanceOf(Decimal); + expect((value as Decimal).equals(expectedValues[i])).toBe(true); + actualValues.push(value as Decimal); + } + + const bfs = []; + for (const value of actualValues) { + bfs.push(fory.serialize(value)); + } + writeToFile(Buffer.concat(bfs)); + }); + test("test_item", () => { const fory = new Fory({ compatible: true diff --git a/javascript/test/datetime.test.ts b/javascript/test/datetime.test.ts index 8bbbfdc587..f68e1a8240 100644 --- a/javascript/test/datetime.test.ts +++ b/javascript/test/datetime.test.ts @@ -19,6 +19,7 @@ import Fory, { Type } from '../packages/core/index'; import {describe, expect, test} from '@jest/globals'; +import { TypeId } from '../packages/core/lib/type'; describe('datetime', () => { test('should date work', () => { @@ -46,5 +47,13 @@ describe('datetime', () => { ); expect(result).toEqual({ a: d, b: d.getTime() }) }); -}); + test('should use signed varint64 for date payloads', () => { + const fory = new Fory({ ref: true }); + const serializer = fory.register(Type.date()).serializer; + const value = new Date(1969, 11, 31); + const encoded = fory.serialize(value, serializer); + expect(Array.from(encoded)).toEqual([0x02, 0xff, TypeId.DATE, 0x01]); + expect(fory.deserialize(encoded, serializer)).toEqual(value); + }); +}); diff --git a/javascript/test/decimal.test.ts b/javascript/test/decimal.test.ts new file mode 100644 index 0000000000..2be7e9d06e --- /dev/null +++ b/javascript/test/decimal.test.ts @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import Fory, { Decimal, Type } from "../packages/core/index"; +import { describe, expect, test } from "@jest/globals"; + +function decimal(unscaledValue: string | bigint | number, scale: number): Decimal { + return new Decimal(unscaledValue, scale); +} + +describe("decimal", () => { + test("round-trips root decimal edge cases", () => { + const fory = new Fory(); + const values = [ + decimal(0n, 0), + decimal(0n, 3), + decimal(1n, 0), + decimal(-1n, 0), + decimal(12345n, 2), + decimal("9223372036854775807", 0), + decimal("-9223372036854775808", 0), + decimal("4611686018427387903", 0), + decimal("-4611686018427387904", 0), + decimal("9223372036854775808", 0), + decimal("-9223372036854775809", 0), + decimal("123456789012345678901234567890123456789", 37), + decimal("-123456789012345678901234567890123456789", -17), + ]; + + for (const value of values) { + const roundTrip = fory.deserialize(fory.serialize(value)) as Decimal; + expect(roundTrip).toBeInstanceOf(Decimal); + expect(roundTrip.equals(value)).toBe(true); + } + }); + + test("round-trips struct decimal fields", () => { + const fory = new Fory(); + const serializer = fory.register(Type.struct({ + typeName: "example.DecimalEnvelope", + }, { + amount: Type.decimal(), + note: Type.string(), + })).serializer; + const value = { + amount: decimal("123456789012345678901234567890123456789", 37), + note: "principal", + }; + + const roundTrip = fory.deserialize(fory.serialize(value, serializer), serializer) as { + amount: Decimal; + note: string; + }; + + expect(roundTrip.amount).toBeInstanceOf(Decimal); + expect(roundTrip.amount.equals(value.amount)).toBe(true); + expect(roundTrip.note).toBe("principal"); + }); + + test("rejects non-canonical big decimal payloads", () => { + const fory = new Fory(); + const zeroBigEncoding = Buffer.from([ + 0x02, + 0xff, + 0x28, + 0x00, + 0x01, + ]); + const trailingZeroPayload = Buffer.from([ + 0x02, + 0xff, + 0x28, + 0x00, + 0x09, + 0x01, + 0x00, + ]); + + expect(() => fory.deserialize(zeroBigEncoding)).toThrow(/Invalid decimal magnitude length/); + expect(() => fory.deserialize(trailingZeroPayload)).toThrow(/trailing zero byte/); + }); +}); diff --git a/kotlin/src/test/kotlin/org/apache/fory/serializer/kotlin/BuiltinClassSerializerTests.kt b/kotlin/src/test/kotlin/org/apache/fory/serializer/kotlin/BuiltinClassSerializerTests.kt index 789ee129e3..fd12050ebf 100644 --- a/kotlin/src/test/kotlin/org/apache/fory/serializer/kotlin/BuiltinClassSerializerTests.kt +++ b/kotlin/src/test/kotlin/org/apache/fory/serializer/kotlin/BuiltinClassSerializerTests.kt @@ -19,6 +19,8 @@ package org.apache.fory.serializer.kotlin +import java.math.BigDecimal +import java.math.BigInteger import kotlin.random.Random import kotlin.test.Test import kotlin.time.Duration @@ -228,4 +230,32 @@ class BuiltinClassSerializerTests { Assert.assertEquals(value.pattern, fory.deserialize(fory.serialize(value.pattern))) Assert.assertEquals(value.options, fory.deserialize(fory.serialize(value.options))) } + + @Test + fun testSerializeBigDecimal() { + val values = + listOf( + BigDecimal.ZERO, + BigDecimal(BigInteger.ZERO, 3), + BigDecimal.ONE, + BigDecimal.ONE.negate(), + BigDecimal.valueOf(12345, 2), + BigDecimal(BigInteger.valueOf(Long.MAX_VALUE), 0), + BigDecimal(BigInteger.valueOf(Long.MIN_VALUE), 0), + BigDecimal(BigInteger.valueOf(Long.MAX_VALUE).add(BigInteger.ONE), 0), + BigDecimal(BigInteger.valueOf(Long.MIN_VALUE).subtract(BigInteger.ONE), 0), + BigDecimal(BigInteger("123456789012345678901234567890123456789"), 37), + ) + for (language in listOf(Language.JAVA, Language.XLANG)) { + val fory = + Fory.builder() + .withLanguage(language) + .requireClassRegistration(true) + .withRefTracking(true) + .build() + for (value in values) { + Assert.assertEquals(value, fory.deserialize(fory.serialize(value))) + } + } + } } diff --git a/python/pyfory/_serializer.py b/python/pyfory/_serializer.py index 6b547cecec..4d0beadf56 100644 --- a/python/pyfory/_serializer.py +++ b/python/pyfory/_serializer.py @@ -259,10 +259,16 @@ def write(self, write_context, value: datetime.date): if not isinstance(value, datetime.date): raise TypeError("{} should be {} instead of {}".format(value, datetime.date, type(value))) days = (value - _base_date).days - write_context.write_int32(days) + if self.type_resolver.xlang: + write_context.write_varint64(days) + else: + write_context.write_int32(days) def read(self, read_context): - days = read_context.read_int32() + if self.type_resolver.xlang: + days = read_context.read_varint64() + else: + days = read_context.read_int32() return _base_date + datetime.timedelta(days=days) diff --git a/python/pyfory/primitive.pxi b/python/pyfory/primitive.pxi index bda3c12aed..2383d933a8 100644 --- a/python/pyfory/primitive.pxi +++ b/python/pyfory/primitive.pxi @@ -237,10 +237,16 @@ cdef class DateSerializer(Serializer): ) ) days = (value - _base_date).days - write_context.write_int32(days) + if self.type_resolver.xlang: + write_context.write_varint64(days) + else: + write_context.write_int32(days) cpdef inline read(self, ReadContext read_context): - days = read_context.read_int32() + if self.type_resolver.xlang: + days = read_context.read_varint64() + else: + days = read_context.read_int32() return datetime.date.fromordinal(_base_date_ordinal + days) diff --git a/python/pyfory/registry.py b/python/pyfory/registry.py index 0de40d1795..da1fab51ce 100644 --- a/python/pyfory/registry.py +++ b/python/pyfory/registry.py @@ -18,6 +18,7 @@ import array import dataclasses import datetime +import decimal import enum import functools import inspect @@ -58,6 +59,7 @@ Float32Serializer, Float64Serializer, StringSerializer, + DecimalSerializer, DateSerializer, TimestampSerializer, BytesSerializer, @@ -442,7 +444,7 @@ def _initialize_common(self): ) register(float, type_id=TypeId.FLOAT64, serializer=Float64Serializer) register(str, type_id=TypeId.STRING, serializer=StringSerializer) - # TODO(chaokunyang) DURATION DECIMAL + register(decimal.Decimal, type_id=TypeId.DECIMAL, serializer=DecimalSerializer) register(datetime.datetime, type_id=TypeId.TIMESTAMP, serializer=TimestampSerializer) register(datetime.date, type_id=TypeId.DATE, serializer=DateSerializer) register(bytes, type_id=TypeId.BINARY, serializer=BytesSerializer) diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py index f2adb5fd26..30101757f6 100644 --- a/python/pyfory/serializer.py +++ b/python/pyfory/serializer.py @@ -18,12 +18,14 @@ import array import builtins import dataclasses +import decimal import importlib import inspect import marshal import os import pickle import types +from typing import Tuple from pyfory.serialization import Buffer from pyfory.resolver import NULL_FLAG, NOT_NULL_VALUE_FLAG @@ -159,6 +161,110 @@ def read(self, buffer): return None +_MIN_INT64 = -(1 << 63) +_MAX_INT64 = (1 << 63) - 1 +_MAX_SMALL_ZIGZAG = (1 << 63) - 1 +_MIN_INT32 = -(1 << 31) +_MAX_INT32 = (1 << 31) - 1 +_UINT64_MOD = 1 << 64 + + +def _encode_zigzag64(value: int) -> int: + return (value << 1) ^ (value >> 63) + + +def _decode_zigzag64(value: int) -> int: + return (value >> 1) ^ -(value & 1) + + +def _can_use_small_decimal_encoding(unscaled: int) -> bool: + if unscaled < _MIN_INT64 or unscaled > _MAX_INT64: + return False + return _encode_zigzag64(unscaled) <= _MAX_SMALL_ZIGZAG + + +def _decimal_parts(value: decimal.Decimal) -> Tuple[int, int]: + if not value.is_finite(): + raise ValueError(f"Decimal value must be finite, got {value!r}") + sign, digits, exponent = value.as_tuple() + scale = -exponent + if scale < _MIN_INT32 or scale > _MAX_INT32: + raise ValueError(f"Decimal scale {scale} is outside signed int32 range") + unscaled = 0 + for digit in digits: + unscaled = unscaled * 10 + digit + if sign: + unscaled = -unscaled + return scale, unscaled + + +def _decimal_from_parts(scale: int, unscaled: int) -> decimal.Decimal: + if unscaled == 0: + digits = (0,) + sign = 0 + else: + sign = 1 if unscaled < 0 else 0 + digits = tuple(int(ch) for ch in str(abs(unscaled))) + return decimal.Decimal((sign, digits, -scale)) + + +def _write_decimal_parts(write_context, scale: int, unscaled: int): + write_context.write_varint32(scale) + if _can_use_small_decimal_encoding(unscaled): + header = _encode_zigzag64(unscaled) << 1 + _write_var_uint64(write_context, header) + return + magnitude = abs(unscaled) + if magnitude == 0: + raise ValueError("Zero must use the small decimal encoding") + payload = magnitude.to_bytes((magnitude.bit_length() + 7) // 8, "little", signed=False) + meta = (len(payload) << 1) | (1 if unscaled < 0 else 0) + _write_var_uint64(write_context, (meta << 1) | 1) + write_context.write_bytes(payload) + + +def _write_var_uint64(write_context, value: int): + try: + write_context.write_var_uint64(value) + except OverflowError: + write_context.write_var_uint64(value - _UINT64_MOD) + + +def _read_decimal_parts(read_context) -> Tuple[int, int]: + scale = read_context.read_varint32() + header = read_context.read_var_uint64() + if header < 0: + header += _UINT64_MOD + if (header & 1) == 0: + return scale, _decode_zigzag64(header >> 1) + meta = header >> 1 + sign = meta & 1 + length = meta >> 1 + if length <= 0: + raise ValueError(f"Invalid decimal magnitude length {length}") + payload = read_context.read_bytes(length) + if payload[-1] == 0: + raise ValueError("Non-canonical decimal payload: trailing zero byte") + magnitude = int.from_bytes(payload, "little", signed=False) + if magnitude == 0: + raise ValueError("Big decimal encoding must not represent zero") + return scale, -magnitude if sign else magnitude + + +class DecimalSerializer(Serializer): + def __init__(self, type_resolver, type_): + super().__init__(type_resolver, type_) + self.need_to_write_ref = False + + def write(self, write_context, value: decimal.Decimal): + scale, unscaled = _decimal_parts(value) + _write_decimal_parts(write_context, scale, unscaled) + + def read(self, read_context): + scale, unscaled = _read_decimal_parts(read_context) + return _decimal_from_parts(scale, unscaled) + + class PandasRangeIndexSerializer(Serializer): __slots__ = "_cached" diff --git a/python/pyfory/tests/test_serializer.py b/python/pyfory/tests/test_serializer.py index 8fc963b8a8..04da47ca27 100644 --- a/python/pyfory/tests/test_serializer.py +++ b/python/pyfory/tests/test_serializer.py @@ -16,6 +16,7 @@ # under the License. import array +import decimal import datetime import gc import io @@ -36,6 +37,7 @@ from pyfory.serialization import Buffer from pyfory import Fory, EnumSerializer from pyfory.serializer import ( + DecimalSerializer, TimestampSerializer, DateSerializer, PyArraySerializer, @@ -137,6 +139,10 @@ def test_basic_serializer(xlang): assert isinstance(typeinfo.serializer, DateSerializer) if xlang: assert typeinfo.type_id == TypeId.DATE + typeinfo = fory.type_resolver.get_type_info(decimal.Decimal) + assert isinstance(typeinfo.serializer, DecimalSerializer) + if xlang: + assert typeinfo.type_id == TypeId.DECIMAL assert ser_de(fory, True) is True assert ser_de(fory, False) is False assert ser_de(fory, -1) == -1 @@ -149,6 +155,8 @@ def test_basic_serializer(xlang): assert ser_de(fory, 1.0) == 1.0 assert ser_de(fory, -1.0) == -1.0 assert ser_de(fory, "str") == "str" + assert ser_de(fory, decimal.Decimal("1234567890.0123456789")) == decimal.Decimal("1234567890.0123456789") + assert ser_de(fory, decimal.Decimal("0.000")) == decimal.Decimal("0.000") assert ser_de(fory, b"") == b"" now = datetime.datetime.now(datetime.timezone.utc) assert ser_de(fory, now) == now @@ -164,6 +172,99 @@ def test_basic_serializer(xlang): assert ser_de(fory, set_) == set_ +@pytest.mark.parametrize("xlang", [True, False]) +def test_date_serializer_uses_xlang_varint64_and_native_int32(xlang): + fory = Fory(xlang=xlang, ref=False) + day = datetime.date(1969, 12, 31) + payload = fory.serialize(day) + buffer = Buffer(payload) + assert buffer.read_uint8() == 2 + assert buffer.read_int8() == -1 + assert buffer.read_uint8() == TypeId.DATE + if xlang: + assert buffer.read_varint64() == -1 + else: + assert buffer.read_int32() == -1 + assert buffer.get_reader_index() == len(payload) + + +@pytest.mark.parametrize("xlang", [False, True]) +def test_decimal_round_trip(xlang): + fory = Fory(xlang=xlang, ref=False) + values = [ + decimal.Decimal("0"), + decimal.Decimal("0.000"), + decimal.Decimal("1"), + decimal.Decimal("-1"), + decimal.Decimal("123.45"), + decimal.Decimal("-123.45"), + decimal.Decimal("9223372036854775807"), + decimal.Decimal("-9223372036854775808"), + decimal.Decimal("4611686018427387903"), + decimal.Decimal("-4611686018427387904"), + decimal.Decimal("9223372036854775808"), + decimal.Decimal("-9223372036854775809"), + decimal.Decimal("123456789012345678901234567890123456789e-37"), + decimal.Decimal("-123456789012345678901234567890123456789e17"), + ] + for value in values: + assert ser_de(fory, value) == value + + +def test_decimal_codec_canonical_round_trip(): + fory = Fory(xlang=True, ref=False) + buffer = Buffer.allocate(256) + values = [ + decimal.Decimal("0"), + decimal.Decimal("0.00"), + decimal.Decimal("1"), + decimal.Decimal("-1"), + decimal.Decimal("100e-2"), + decimal.Decimal("4611686018427387903"), + decimal.Decimal("-4611686018427387904"), + decimal.Decimal("9223372036854775808"), + decimal.Decimal("-9223372036854775809"), + decimal.Decimal("999999999999999999999999999999999999999999e-200"), + ] + serializer = DecimalSerializer(fory.type_resolver, decimal.Decimal) + for value in values: + buffer.set_reader_index(0) + buffer.set_writer_index(0) + serializer.write(buffer, value) + decoded = serializer.read(buffer) + assert decoded == value + + +def test_decimal_codec_rejects_non_canonical_big_payloads(): + fory = Fory(xlang=True, ref=False) + serializer = DecimalSerializer(fory.type_resolver, decimal.Decimal) + + zero_big_encoding = Buffer.allocate(32) + zero_big_encoding.write_varint32(0) + zero_big_encoding.write_var_uint64(1) + zero_big_encoding.set_reader_index(0) + with pytest.raises(ValueError): + serializer.read(zero_big_encoding) + + trailing_zero_payload = Buffer.allocate(32) + trailing_zero_payload.write_varint32(0) + trailing_zero_payload.write_var_uint64((((2 << 1) | 0) << 1) | 1) + trailing_zero_payload.write_bytes(b"\x01\x00") + trailing_zero_payload.set_reader_index(0) + with pytest.raises(ValueError, match="trailing zero byte"): + serializer.read(trailing_zero_payload) + + +def test_decimal_rejects_non_finite_values(): + fory = Fory(xlang=True, ref=False) + serializer = DecimalSerializer(fory.type_resolver, decimal.Decimal) + buffer = Buffer.allocate(32) + with pytest.raises(ValueError, match="must be finite"): + serializer.write(buffer, decimal.Decimal("NaN")) + with pytest.raises(ValueError, match="must be finite"): + serializer.write(buffer, decimal.Decimal("Infinity")) + + @pytest.mark.parametrize("xlang", [True, False]) def test_timestamp_serializer(xlang): """Test timestamp serialization. TimestampSerializer always returns UTC-aware datetimes.""" diff --git a/python/pyfory/tests/xlang_test_main.py b/python/pyfory/tests/xlang_test_main.py index acd71e0046..a8e5a526e3 100644 --- a/python/pyfory/tests/xlang_test_main.py +++ b/python/pyfory/tests/xlang_test_main.py @@ -26,6 +26,7 @@ import enum import logging import os +import decimal from dataclasses import dataclass from typing import Any, Dict, List, Optional, Set @@ -45,6 +46,30 @@ def get_data_file() -> str: return os.environ["DATA_FILE"] +def decimal_from_parts(unscaled: int, scale: int) -> decimal.Decimal: + sign = 1 if unscaled < 0 else 0 + digits = tuple(int(ch) for ch in str(abs(unscaled))) if unscaled else (0,) + return decimal.Decimal((sign, digits, -scale)) + + +def decimal_values() -> List[decimal.Decimal]: + return [ + decimal_from_parts(0, 0), + decimal_from_parts(0, 3), + decimal_from_parts(1, 0), + decimal_from_parts(-1, 0), + decimal_from_parts(12345, 2), + decimal_from_parts(9223372036854775807, 0), + decimal_from_parts(-9223372036854775808, 0), + decimal_from_parts(4611686018427387903, 0), + decimal_from_parts(-4611686018427387904, 0), + decimal_from_parts(9223372036854775808, 0), + decimal_from_parts(-9223372036854775809, 0), + decimal_from_parts(123456789012345678901234567890123456789, 37), + decimal_from_parts(-123456789012345678901234567890123456789, -17), + ] + + # ============================================================================ # Test Data Classes - Must match XlangTestBase.java definitions # ============================================================================ @@ -717,6 +742,29 @@ def test_two_string_field_compatible(): f.write(new_bytes) +def test_decimal(): + data_file = get_data_file() + with open(data_file, "rb") as f: + data_bytes = f.read() + + buffer = pyfory.Buffer(data_bytes) + fory = pyfory.Fory(xlang=True, compatible=True) + expected_values = decimal_values() + actual_values = [] + for expected in expected_values: + value = fory.deserialize(buffer) + debug_print(f"Deserialized decimal: {value!r}") + assert isinstance(value, decimal.Decimal) + assert value.as_tuple() == expected.as_tuple(), f"Mismatch: {value!r} != {expected!r}" + actual_values.append(value) + + new_buffer = pyfory.Buffer.allocate(max(256, len(data_bytes) * 2)) + for value in actual_values: + fory.serialize(value, buffer=new_buffer) + with open(data_file, "wb") as f: + f.write(new_buffer.get_bytes(0, new_buffer.get_writer_index())) + + def test_schema_evolution_compatible(): """Test schema evolution: deserialize TwoStringFieldStruct as EmptyStruct.""" data_file = get_data_file() diff --git a/rust/fory-core/Cargo.toml b/rust/fory-core/Cargo.toml index 912782a74f..6a96e6f8dc 100644 --- a/rust/fory-core/Cargo.toml +++ b/rust/fory-core/Cargo.toml @@ -36,6 +36,7 @@ chrono = "0.4" thiserror = { default-features = false, version = "1.0" } num_enum = "0.5.1" paste = "1.0" +num-bigint = "0.4" [[bench]] diff --git a/rust/fory-core/benches/simd_bench.rs b/rust/fory-core/benches/simd_bench.rs index 2dacffbbe8..83d171a94a 100644 --- a/rust/fory-core/benches/simd_bench.rs +++ b/rust/fory-core/benches/simd_bench.rs @@ -20,7 +20,7 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion}; use std::arch::x86_64::*; use fory_core::buffer::{Reader, Writer}; -use fory_core::meta::buffer_rw_string::{ +use fory_core::util::buffer_rw_string::{ read_latin1_simd, read_latin1_standard, write_latin1_simd, write_latin1_standard, write_latin1_string, }; diff --git a/rust/fory-core/src/buffer.rs b/rust/fory-core/src/buffer.rs index 7caa636975..9bd7310476 100644 --- a/rust/fory-core/src/buffer.rs +++ b/rust/fory-core/src/buffer.rs @@ -16,8 +16,8 @@ // under the License. use crate::error::Error; -use crate::float16::float16; -use crate::meta::buffer_rw_string::read_latin1_simd; +use crate::types::float16::float16; +use crate::util::buffer_rw_string::read_latin1_simd; use byteorder::{ByteOrder, LittleEndian}; use std::cmp::max; diff --git a/rust/fory-core/src/resolver/context.rs b/rust/fory-core/src/context.rs similarity index 99% rename from rust/fory-core/src/resolver/context.rs rename to rust/fory-core/src/context.rs index 16c14bf1df..059f8935b5 100644 --- a/rust/fory-core/src/resolver/context.rs +++ b/rust/fory-core/src/context.rs @@ -24,9 +24,9 @@ use crate::error::Error; use crate::meta::MetaString; use crate::resolver::meta_resolver::{MetaReaderResolver, MetaWriterResolver}; use crate::resolver::meta_string_resolver::{MetaStringReaderResolver, MetaStringWriterResolver}; -use crate::resolver::ref_resolver::{RefReader, RefWriter}; -use crate::resolver::type_resolver::{TypeInfo, TypeResolver}; -use crate::types; +use crate::resolver::{RefReader, RefWriter}; +use crate::resolver::{TypeInfo, TypeResolver}; +use crate::type_id as types; use crate::TypeId; use std::rc::Rc; diff --git a/rust/fory-core/src/error.rs b/rust/fory-core/src/error.rs index c198a98da9..07eb44e688 100644 --- a/rust/fory-core/src/error.rs +++ b/rust/fory-core/src/error.rs @@ -29,7 +29,7 @@ use std::borrow::Cow; -use crate::types::format_type_id; +use crate::type_id::format_type_id; use thiserror::Error; /// Global flag to check if FORY_PANIC_ON_ERROR environment variable is set at compile time. diff --git a/rust/fory-core/src/fory.rs b/rust/fory-core/src/fory.rs index 7ad755546e..bc7953441a 100644 --- a/rust/fory-core/src/fory.rs +++ b/rust/fory-core/src/fory.rs @@ -17,14 +17,15 @@ use crate::buffer::{Reader, Writer}; use crate::config::Config; +use crate::context::{ContextCache, ReadContext, WriteContext}; use crate::ensure; use crate::error::Error; -use crate::resolver::context::{ContextCache, ReadContext, WriteContext}; -use crate::resolver::type_resolver::TypeResolver; +use crate::resolver::RefMode; +use crate::resolver::TypeResolver; use crate::serializer::ForyDefault; use crate::serializer::{Serializer, StructSerializer}; -use crate::types::config_flags::{IS_CROSS_LANGUAGE_FLAG, IS_NULL_FLAG}; -use crate::types::{RefMode, SIZE_OF_REF_AND_TYPE}; +use crate::type_id::config_flags::{IS_CROSS_LANGUAGE_FLAG, IS_NULL_FLAG}; +use crate::type_id::SIZE_OF_REF_AND_TYPE; use std::cell::UnsafeCell; use std::mem; use std::sync::atomic::{AtomicU64, Ordering}; diff --git a/rust/fory-core/src/lib.rs b/rust/fory-core/src/lib.rs index b9be2e7645..eb8083097d 100644 --- a/rust/fory-core/src/lib.rs +++ b/rust/fory-core/src/lib.rs @@ -27,11 +27,13 @@ //! //! - **`fory`**: Main serialization engine and public API //! - **`buffer`**: Efficient binary buffer management with Reader/Writer +//! - **`context`**: Per-operation read/write state and context pooling types //! - **`row`**: Row-based serialization for zero-copy operations //! - **`serializer`**: Type-specific serialization implementations //! - **`resolver`**: Type resolution and metadata management //! - **`meta`**: Metadata handling for schema evolution -//! - **`types`**: Core type definitions and constants +//! - **`types`**: Runtime value carriers such as decimal, float16, and weak refs +//! - **`type_id`**: Type IDs and protocol header helpers //! - **`error`**: Error handling and result types //! - **`util`**: Utility functions and helpers //! @@ -178,15 +180,15 @@ pub mod buffer; pub mod config; +pub mod context; pub mod error; -pub mod float16; pub mod fory; pub mod meta; pub mod resolver; pub mod row; pub mod serializer; +pub mod type_id; pub mod types; -pub use float16::float16 as Float16; pub mod util; // Re-export paste for use in macros @@ -194,10 +196,12 @@ pub use paste; pub use crate::buffer::{Reader, Writer}; pub use crate::config::Config; +pub use crate::context::{ReadContext, WriteContext}; pub use crate::error::Error; pub use crate::fory::{Fory, ForyBuilder}; -pub use crate::resolver::context::{ReadContext, WriteContext}; -pub use crate::resolver::type_resolver::{TypeInfo, TypeResolver}; -pub use crate::serializer::weak::{ArcWeak, RcWeak}; +pub use crate::meta::{compute_field_hash, compute_struct_hash}; +pub use crate::resolver::{RefFlag, RefMode, TypeInfo, TypeResolver}; pub use crate::serializer::{read_data, write_data, ForyDefault, Serializer, StructSerializer}; -pub use crate::types::{RefFlag, RefMode, TypeId}; +pub use crate::type_id::TypeId; +pub use crate::types::float16::float16 as Float16; +pub use crate::types::{ArcWeak, Decimal, RcWeak}; diff --git a/rust/fory-core/src/meta/meta_string.rs b/rust/fory-core/src/meta/meta_string.rs index 6ee818f950..962ac4b437 100644 --- a/rust/fory-core/src/meta/meta_string.rs +++ b/rust/fory-core/src/meta/meta_string.rs @@ -17,7 +17,7 @@ use crate::ensure; use crate::error::Error; -use crate::meta::string_util; +use crate::util::is_latin; use std::sync::OnceLock; // equal to "std::i16::MAX" @@ -140,7 +140,7 @@ impl MetaStringEncoder { } fn is_latin(&self, s: &str) -> bool { - string_util::is_latin(s) + is_latin(s) } fn _encode(&self, input: &str) -> Result, Error> { diff --git a/rust/fory-core/src/meta/mod.rs b/rust/fory-core/src/meta/mod.rs index 59ef65d8be..90f07c556b 100644 --- a/rust/fory-core/src/meta/mod.rs +++ b/rust/fory-core/src/meta/mod.rs @@ -16,14 +16,13 @@ // under the License. mod meta_string; -mod string_util; mod type_meta; pub use meta_string::{ Encoding, MetaString, MetaStringDecoder, MetaStringEncoder, FIELD_NAME_DECODER, FIELD_NAME_ENCODER, NAMESPACE_DECODER, NAMESPACE_ENCODER, TYPE_NAME_DECODER, TYPE_NAME_ENCODER, }; -pub use string_util::{buffer_rw_string, get_latin1_length, is_latin, murmurhash3_x64_128}; pub use type_meta::{ - sort_fields, FieldInfo, FieldType, TypeMeta, NAMESPACE_ENCODINGS, TYPE_NAME_ENCODINGS, + compute_field_hash, compute_struct_hash, sort_fields, FieldInfo, FieldType, TypeMeta, + NAMESPACE_ENCODINGS, TYPE_NAME_ENCODINGS, }; diff --git a/rust/fory-core/src/meta/string_util.rs b/rust/fory-core/src/meta/string_util.rs deleted file mode 100644 index 9685567afc..0000000000 --- a/rust/fory-core/src/meta/string_util.rs +++ /dev/null @@ -1,915 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::mem; - -#[cfg(target_feature = "neon")] -use std::arch::aarch64::*; - -#[cfg(target_feature = "avx2")] -use std::arch::x86_64::*; - -#[cfg(target_feature = "sse2")] -use std::arch::x86_64::*; - -#[cfg(target_arch = "x86_64")] -pub const MIN_DIM_SIZE_AVX: usize = 32; - -#[cfg(any( - target_arch = "x86", - target_arch = "x86_64", - all(target_arch = "aarch64", target_feature = "neon") -))] -pub const MIN_DIM_SIZE_SIMD: usize = 16; - -#[cfg(target_arch = "x86_64")] -unsafe fn is_latin_avx(s: &str) -> bool { - let bytes = s.as_bytes(); - let len = bytes.len(); - let mut i = 0; - // SIMD skip ASCII - while i + MIN_DIM_SIZE_AVX <= len { - let chunk = _mm256_loadu_si256(bytes.as_ptr().add(i) as *const __m256i); - let hi_mask = _mm256_set1_epi8(0x80u8 as i8); - let masked = _mm256_and_si256(chunk, hi_mask); - let cmp = _mm256_cmpeq_epi8(masked, _mm256_setzero_si256()); - if _mm256_movemask_epi8(cmp) != -1 { - break; - } - i += MIN_DIM_SIZE_AVX; - } - // check latin in remaining chars - let s_tail = &s[i..]; - for c in s_tail.chars() { - if c as u32 > 0xFF { - return false; - } - } - true -} - -#[cfg(target_feature = "sse2")] -unsafe fn is_latin_sse(s: &str) -> bool { - let bytes = s.as_bytes(); - let len = bytes.len(); - let mut i = 0; - // SIMD skip ASCII - while i + MIN_DIM_SIZE_SIMD <= len { - let chunk = _mm_loadu_si128(bytes.as_ptr().add(i) as *const __m128i); - let hi_mask = _mm_set1_epi8(0x80u8 as i8); - let masked = _mm_and_si128(chunk, hi_mask); - let cmp = _mm_cmpeq_epi8(masked, _mm_setzero_si128()); - if _mm_movemask_epi8(cmp) != 0xFFFF { - break; - } - i += MIN_DIM_SIZE_SIMD; - } - // check latin in remaining chars - let s_tail = &s[i..]; - for c in s_tail.chars() { - if c as u32 > 0xFF { - return false; - } - } - true -} - -#[cfg(target_feature = "neon")] -unsafe fn is_latin_neon(s: &str) -> bool { - let bytes = s.as_bytes(); - let len = bytes.len(); - let mut i = 0; - // SIMD skip ASCII - while i + MIN_DIM_SIZE_SIMD <= len { - let chunk = vld1q_u8(bytes.as_ptr().add(i)); - let hi_mask = vdupq_n_u8(0x80); - let masked = vandq_u8(chunk, hi_mask); - if vmaxvq_u8(masked) != 0 { - break; - } - i += MIN_DIM_SIZE_SIMD; - } - // check latin in remaining chars - let s_tail = &s[i..]; - for c in s_tail.chars() { - if c as u32 > 0xFF { - return false; - } - } - true -} - -fn is_latin_standard(s: &str) -> bool { - s.chars().all(|c| c as u32 <= 0xFF) -} - -pub fn is_latin(s: &str) -> bool { - #[cfg(target_arch = "x86_64")] - { - if is_x86_feature_detected!("avx") - && is_x86_feature_detected!("fma") - && s.len() >= MIN_DIM_SIZE_AVX - { - return unsafe { is_latin_avx(s) }; - } - } - - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - { - if is_x86_feature_detected!("sse") && s.len() >= MIN_DIM_SIZE_SIMD { - return unsafe { is_latin_sse(s) }; - } - } - - #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] - { - if std::arch::is_aarch64_feature_detected!("neon") && s.len() >= MIN_DIM_SIZE_SIMD { - return unsafe { is_latin_neon(s) }; - } - } - is_latin_standard(s) -} - -#[cfg(target_arch = "x86_64")] -unsafe fn get_latin1_length_avx(s: &str) -> i32 { - let bytes = s.as_bytes(); - let len = bytes.len(); - let mut count = 0; - // SIMD skip ASCII - while count + MIN_DIM_SIZE_AVX <= len { - let chunk = _mm256_loadu_si256(bytes.as_ptr().add(count) as *const __m256i); - let hi_mask = _mm256_set1_epi8(0x80u8 as i8); - let masked = _mm256_and_si256(chunk, hi_mask); - let cmp = _mm256_cmpeq_epi8(masked, _mm256_setzero_si256()); - if _mm256_movemask_epi8(cmp) != -1 { - break; - } - count += MIN_DIM_SIZE_AVX; - } - // check latin in remaining chars - let s_tail = &s[count..]; - for c in s_tail.chars() { - if c as u32 > 0xFF { - return -1; - } - count += 1; - } - count as i32 -} - -#[cfg(target_feature = "sse2")] -unsafe fn get_latin1_length_sse(s: &str) -> i32 { - let bytes = s.as_bytes(); - let len = bytes.len(); - let mut count = 0; - // SIMD skip ASCII - while count + MIN_DIM_SIZE_SIMD <= len { - let chunk = _mm_loadu_si128(bytes.as_ptr().add(count) as *const __m128i); - let hi_mask = _mm_set1_epi8(0x80u8 as i8); - let masked = _mm_and_si128(chunk, hi_mask); - let cmp = _mm_cmpeq_epi8(masked, _mm_setzero_si128()); - if _mm_movemask_epi8(cmp) != 0xFFFF { - break; - } - count += MIN_DIM_SIZE_SIMD; - } - // check latin in remaining chars - let s_tail = &s[count..]; - for c in s_tail.chars() { - if c as u32 > 0xFF { - return -1; - } - count += 1; - } - count as i32 -} - -#[cfg(target_feature = "neon")] -unsafe fn get_latin1_length_neon(s: &str) -> i32 { - let bytes = s.as_bytes(); - let len = bytes.len(); - let mut count = 0; - // SIMD skip ASCII - while count + MIN_DIM_SIZE_SIMD <= len { - let chunk = vld1q_u8(bytes.as_ptr().add(count)); - let hi_mask = vdupq_n_u8(0x80); - let masked = vandq_u8(chunk, hi_mask); - if vmaxvq_u8(masked) != 0 { - break; - } - count += MIN_DIM_SIZE_SIMD; - } - // check latin in remaining chars - let s_tail = &s[count..]; - for c in s_tail.chars() { - if c as u32 > 0xFF { - return -1; - } - count += 1; - } - count as i32 -} - -fn get_latin1_length_standard(s: &str) -> i32 { - let mut count = 0; - for c in s.chars() { - if c as u32 > 0xFF { - return -1; - } - count += 1; - } - count -} - -pub fn get_latin1_length(s: &str) -> i32 { - #[cfg(target_arch = "x86_64")] - { - if is_x86_feature_detected!("avx") - && is_x86_feature_detected!("fma") - && s.len() >= MIN_DIM_SIZE_AVX - { - return unsafe { get_latin1_length_avx(s) }; - } - } - - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - { - if is_x86_feature_detected!("sse") && s.len() >= MIN_DIM_SIZE_SIMD { - return unsafe { get_latin1_length_sse(s) }; - } - } - - #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] - { - if std::arch::is_aarch64_feature_detected!("neon") && s.len() >= MIN_DIM_SIZE_SIMD { - return unsafe { get_latin1_length_neon(s) }; - } - } - get_latin1_length_standard(s) -} - -#[cfg(test)] -mod tests { - // Import content from external modules - use super::*; - use rand::Rng; - - fn generate_random_string(length: usize) -> String { - const CHARSET: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; - let mut rng = rand::thread_rng(); - - let result: String = (0..length) - .map(|_| { - let idx = rng.gen_range(0..CHARSET.len()); - CHARSET[idx] as char - }) - .collect(); - - result - } - - #[test] - fn test_is_latin() { - let s = generate_random_string(1000); - let not_latin_str = generate_random_string(1000) + "abc\u{1234}"; - - #[cfg(target_arch = "x86_64")] - { - if is_x86_feature_detected!("avx") && is_x86_feature_detected!("fma") { - assert!(unsafe { is_latin_avx(&s) }); - assert!(!unsafe { is_latin_avx(¬_latin_str) }); - } - } - - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - { - if is_x86_feature_detected!("sse") && s.len() >= MIN_DIM_SIZE_SIMD { - assert!(unsafe { is_latin_sse(&s) }); - assert!(!unsafe { is_latin_sse(¬_latin_str) }); - } - } - - #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] - { - if std::arch::is_aarch64_feature_detected!("neon") && s.len() >= MIN_DIM_SIZE_SIMD { - assert!(unsafe { is_latin_neon(&s) }); - assert!(!unsafe { is_latin_neon(¬_latin_str) }); - } - } - assert!(is_latin_standard(&s)); - assert!(!is_latin_standard(¬_latin_str)); - } -} - -fn fmix64(mut k: u64) -> u64 { - k ^= k >> 33; - k = k.wrapping_mul(0xff51afd7ed558ccdu64); - k ^= k >> 33; - k = k.wrapping_mul(0xc4ceb9fe1a85ec53u64); - k ^= k >> 33; - - k -} - -pub fn murmurhash3_x64_128(bytes: &[u8], seed: u64) -> (u64, u64) { - let c1 = 0x87c37b91114253d5u64; - let c2 = 0x4cf5ad432745937fu64; - let read_size = 16; - let len = bytes.len() as u64; - let block_count = len / read_size; - - let (mut h1, mut h2) = (seed, seed); - - for i in 0..block_count as usize { - let b64: &[u64] = unsafe { mem::transmute(bytes) }; - let (mut k1, mut k2) = (b64[i * 2], b64[i * 2 + 1]); - - k1 = k1.wrapping_mul(c1); - k1 = k1.rotate_left(31); - k1 = k1.wrapping_mul(c2); - h1 ^= k1; - - h1 = h1.rotate_left(27); - h1 = h1.wrapping_add(h2); - h1 = h1.wrapping_mul(5); - h1 = h1.wrapping_add(0x52dce729); - - k2 = k2.wrapping_mul(c2); - k2 = k2.rotate_left(33); - k2 = k2.wrapping_mul(c1); - h2 ^= k2; - - h2 = h2.rotate_left(31); - h2 = h2.wrapping_add(h1); - h2 = h2.wrapping_mul(5); - h2 = h2.wrapping_add(0x38495ab5); - } - let (mut k1, mut k2) = (0u64, 0u64); - - if len & 15 == 15 { - k2 ^= (bytes[(block_count * read_size) as usize + 14] as u64) << 48; - } - if len & 15 >= 14 { - k2 ^= (bytes[(block_count * read_size) as usize + 13] as u64) << 40; - } - if len & 15 >= 13 { - k2 ^= (bytes[(block_count * read_size) as usize + 12] as u64) << 32; - } - if len & 15 >= 12 { - k2 ^= (bytes[(block_count * read_size) as usize + 11] as u64) << 24; - } - if len & 15 >= 11 { - k2 ^= (bytes[(block_count * read_size) as usize + 10] as u64) << 16; - } - if len & 15 >= 10 { - k2 ^= (bytes[(block_count * read_size) as usize + 9] as u64) << 8; - } - if len & 15 >= 9 { - k2 ^= bytes[(block_count * read_size) as usize + 8] as u64; - k2 = k2.wrapping_mul(c2); - k2 = k2.rotate_left(33); - k2 = k2.wrapping_mul(c1); - h2 ^= k2; - } - - if len & 15 >= 8 { - k1 ^= (bytes[(block_count * read_size) as usize + 7] as u64) << 56; - } - if len & 15 >= 7 { - k1 ^= (bytes[(block_count * read_size) as usize + 6] as u64) << 48; - } - if len & 15 >= 6 { - k1 ^= (bytes[(block_count * read_size) as usize + 5] as u64) << 40; - } - if len & 15 >= 5 { - k1 ^= (bytes[(block_count * read_size) as usize + 4] as u64) << 32; - } - if len & 15 >= 4 { - k1 ^= (bytes[(block_count * read_size) as usize + 3] as u64) << 24; - } - if len & 15 >= 3 { - k1 ^= (bytes[(block_count * read_size) as usize + 2] as u64) << 16; - } - if len & 15 >= 2 { - k1 ^= (bytes[(block_count * read_size) as usize + 1] as u64) << 8; - } - if len & 15 >= 1 { - k1 ^= bytes[(block_count * read_size) as usize] as u64; - k1 = k1.wrapping_mul(c1); - k1 = k1.rotate_left(31); - k1 = k1.wrapping_mul(c2); - h1 ^= k1; - } - - h1 ^= bytes.len() as u64; - h2 ^= bytes.len() as u64; - - h1 = h1.wrapping_add(h2); - h2 = h2.wrapping_add(h1); - - h1 = fmix64(h1); - h2 = fmix64(h2); - - h1 = h1.wrapping_add(h2); - h2 = h2.wrapping_add(h1); - - (h1, h2) -} - -#[cfg(test)] -mod test_hash { - use super::murmurhash3_x64_128; - - #[test] - fn test_empty_string() { - assert!(murmurhash3_x64_128("".as_bytes(), 0) == (0, 0)); - } - - #[test] - fn test_tail_lengths() { - assert!( - murmurhash3_x64_128("1".as_bytes(), 0) == (8213365047359667313, 10676604921780958775) - ); - assert!( - murmurhash3_x64_128("12".as_bytes(), 0) == (5355690773644049813, 9855895140584599837) - ); - assert!( - murmurhash3_x64_128("123".as_bytes(), 0) == (10978418110857903978, 4791445053355511657) - ); - assert!( - murmurhash3_x64_128("1234".as_bytes(), 0) == (619023178690193332, 3755592904005385637) - ); - assert!( - murmurhash3_x64_128("12345".as_bytes(), 0) - == (2375712675693977547, 17382870096830835188) - ); - assert!( - murmurhash3_x64_128("123456".as_bytes(), 0) - == (16435832985690558678, 5882968373513761278) - ); - assert!( - murmurhash3_x64_128("1234567".as_bytes(), 0) - == (3232113351312417698, 4025181827808483669) - ); - assert!( - murmurhash3_x64_128("12345678".as_bytes(), 0) - == (4272337174398058908, 10464973996478965079) - ); - assert!( - murmurhash3_x64_128("123456789".as_bytes(), 0) - == (4360720697772133540, 11094893415607738629) - ); - assert!( - murmurhash3_x64_128("123456789a".as_bytes(), 0) - == (12594836289594257748, 2662019112679848245) - ); - assert!( - murmurhash3_x64_128("123456789ab".as_bytes(), 0) - == (6978636991469537545, 12243090730442643750) - ); - assert!( - murmurhash3_x64_128("123456789abc".as_bytes(), 0) - == (211890993682310078, 16480638721813329343) - ); - assert!( - murmurhash3_x64_128("123456789abcd".as_bytes(), 0) - == (12459781455342427559, 3193214493011213179) - ); - assert!( - murmurhash3_x64_128("123456789abcde".as_bytes(), 0) - == (12538342858731408721, 9820739847336455216) - ); - assert!( - murmurhash3_x64_128("123456789abcdef".as_bytes(), 0) - == (9165946068217512774, 2451472574052603025) - ); - assert!( - murmurhash3_x64_128("123456789abcdef1".as_bytes(), 0) - == (9259082041050667785, 12459473952842597282) - ); - } - - #[test] - fn test_large_data() { - assert!(murmurhash3_x64_128("Lorem ipsum dolor sit amet, consectetur adipiscing elit. Etiam at consequat massa. Cras eleifend pellentesque ex, at dignissim libero maximus ut. Sed eget nulla felis".as_bytes(), 0) - == (9455322759164802692, 17863277201603478371)); - } -} - -pub mod buffer_rw_string { - #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] - use std::arch::aarch64::*; - #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] - use std::arch::x86_64::*; - #[cfg(all( - any(target_arch = "x86", target_arch = "x86_64"), - target_feature = "sse2", - not(target_feature = "avx2") - ))] - use std::arch::x86_64::*; - - use crate::buffer::{Reader, Writer}; - use crate::error::Error; - - #[inline] - pub fn write_latin1_standard(writer: &mut Writer, s: &str) { - for c in s.chars() { - let b = c as u32; - assert!(b <= 0xFF, "Non-Latin1 character found"); - writer.write_u8(b as u8); - } - } - - #[inline(always)] - pub fn write_latin1_string(writer: &mut Writer, s: &str) { - if s.len() < 128 { - // Fast path for small buffers - let bytes = s.as_bytes(); - // CRITICAL: Only safe if ASCII (UTF-8 == Latin1 for ASCII) - let is_ascii = bytes.iter().all(|&b| b < 0x80); - if is_ascii { - writer.bf.reserve(s.len()); - writer.bf.extend_from_slice(bytes); - } else { - // Non-ASCII: must iterate chars to extract Latin1 byte values - writer.bf.reserve(s.len()); - for c in s.chars() { - let v = c as u32; - assert!(v <= 0xFF, "Non-Latin1 character found"); - writer.bf.push(v as u8); - } - } - return; - } - write_latin1_simd(writer, s); - } - - #[inline] - pub fn write_utf8_standard(writer: &mut Writer, s: &str) { - let bytes = s.as_bytes(); - writer.bf.extend_from_slice(bytes); - } - - #[inline] - pub fn write_utf16_standard(writer: &mut Writer, utf16: &[u16]) { - #[cfg(target_endian = "little")] - { - let total_bytes = utf16.len() * 2; - let old_len = writer.bf.len(); - writer.bf.reserve(total_bytes); - unsafe { - let dest = writer.bf.as_mut_ptr().add(old_len); - let src = utf16.as_ptr() as *const u8; - std::ptr::copy_nonoverlapping(src, dest, total_bytes); - writer.bf.set_len(old_len + total_bytes); - } - } - #[cfg(target_endian = "big")] - { - let total_bytes = utf16.len() * 2; - let old_len = writer.bf.len(); - writer.bf.reserve(total_bytes); - unsafe { - let dest = writer.bf.as_mut_ptr().add(old_len); - // Need to swap bytes for each u16 to little-endian - for (i, &unit) in utf16.iter().enumerate() { - let swapped = unit.swap_bytes(); - let ptr = dest.add(i * 2) as *mut u16; - std::ptr::write_unaligned(ptr, swapped); - } - writer.bf.set_len(old_len + total_bytes); - } - } - } - - #[inline] - pub fn read_latin1_standard(reader: &mut Reader, len: usize) -> Result { - let slice = reader.sub_slice(reader.get_cursor(), reader.get_cursor() + len)?; - let result: String = slice.iter().map(|&b| b as char).collect(); - reader.move_next(len); - Ok(result) - } - - #[inline] - pub fn read_utf8_standard(reader: &mut Reader, len: usize) -> Result { - unsafe { - let mut vec = Vec::with_capacity(len); - let src = reader.bf.as_ptr().add(reader.cursor); - let dst = vec.as_mut_ptr(); - // Use fastest possible copy - copy_nonoverlapping compiles to memcpy - std::ptr::copy_nonoverlapping(src, dst, len); - vec.set_len(len); - reader.move_next(len); - // Use from_utf8_lossy for safety - handles invalid UTF-8 gracefully - // If you're certain the data is valid UTF-8, use from_utf8_unchecked for more performance - Ok(String::from_utf8_lossy(&vec).into_owned()) - } - } - - #[inline] - pub fn read_utf16_standard(reader: &mut Reader, len: usize) -> Result { - if len % 2 != 0 { - return Err(Error::encoding_error("UTF-16 length must be even")); - } - unsafe { - let slice = std::slice::from_raw_parts(reader.bf.as_ptr().add(reader.cursor), len); - let units: Vec = slice - .chunks_exact(2) - .map(|c| u16::from_le_bytes([c[0], c[1]])) - .collect(); - reader.move_next(len); - Ok(String::from_utf16_lossy(&units)) - } - } - - #[inline] - fn is_ascii_bytes(bytes: &[u8]) -> bool { - let len = bytes.len(); - let mut i = 0; - - #[cfg(target_arch = "x86_64")] - unsafe { - if is_x86_feature_detected!("avx2") && len >= 32 { - while i + 32 <= len { - let chunk = _mm256_loadu_si256(bytes.as_ptr().add(i) as *const __m256i); - let mask = _mm256_movemask_epi8(chunk); - if mask != 0 { - return false; - } - i += 32; - } - } - } - - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - unsafe { - if is_x86_feature_detected!("sse2") && len >= 16 { - while i + 16 <= len { - let chunk = _mm_loadu_si128(bytes.as_ptr().add(i) as *const __m128i); - let mask = _mm_movemask_epi8(chunk); - if mask != 0 { - return false; - } - i += 16; - } - } - } - - #[cfg(target_arch = "aarch64")] - unsafe { - if std::arch::is_aarch64_feature_detected!("neon") && len >= 16 { - while i + 16 <= len { - let chunk = vld1q_u8(bytes.as_ptr().add(i)); - if vmaxvq_u8(chunk) >= 0x80 { - return false; - } - i += 16; - } - } - } - - // Scalar fallback - bytes[i..].iter().all(|&b| b < 0x80) - } - - #[inline] - pub fn write_latin1_simd(writer: &mut Writer, s: &str) { - if s.is_empty() { - return; - } - - let bytes = s.as_bytes(); - - // CRITICAL OPTIMIZATION: For ASCII strings, UTF-8 bytes == Latin1 bytes - // Check if all ASCII using SIMD - if is_ascii_bytes(bytes) { - // Zero-copy fast path: direct write - let len = bytes.len(); - writer.bf.reserve(len); - writer.bf.extend_from_slice(bytes); - } else { - // Non-ASCII: Must iterate chars to extract Latin1 byte values - // Example: 'À' in Rust String is UTF-8 [0xC3, 0x80] but Latin1 is [0xC0] - let mut buf: Vec = Vec::with_capacity(s.len()); - for c in s.chars() { - let v = c as u32; - assert!(v <= 0xFF, "Non-Latin1 character found"); - buf.push(v as u8); - } - let len = buf.len(); - writer.bf.reserve(len); - writer.bf.extend_from_slice(&buf); - } - } - - #[inline] - pub fn read_latin1_simd(reader: &mut Reader, len: usize) -> Result { - if len == 0 { - return Ok(String::new()); - } - let src = reader.sub_slice(reader.get_cursor(), reader.get_cursor() + len)?; - - // Pessimistic allocation: Latin1 0x80-0xFF expands to 2 bytes in UTF-8 - let mut out: Vec = Vec::with_capacity(len * 2); - - unsafe { - let out_ptr = out.as_mut_ptr(); - let mut out_len = 0usize; - let mut i = 0usize; - - // ---- AVX2 fast-path: process 32 ASCII bytes at once ---- - #[cfg(target_arch = "x86_64")] - { - if std::arch::is_x86_feature_detected!("avx2") { - use std::arch::x86_64::*; - while i + 32 <= len { - let ptr = src.as_ptr().add(i) as *const __m256i; - let chunk = _mm256_loadu_si256(ptr); - let mask = _mm256_movemask_epi8(chunk); - if mask == 0 { - // All ASCII: direct copy (no conversion needed) - _mm256_storeu_si256(out_ptr.add(out_len) as *mut __m256i, chunk); - out_len += 32; - i += 32; - continue; - } else { - // Contains Latin1 bytes, break to scalar - break; - } - } - } - } - - // ---- SSE2 fast-path: process 16 ASCII bytes at once ---- - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - { - if std::arch::is_x86_feature_detected!("sse2") { - use std::arch::x86_64::*; - while i + 16 <= len { - let ptr = src.as_ptr().add(i) as *const __m128i; - let chunk = _mm_loadu_si128(ptr); - let mask = _mm_movemask_epi8(chunk); - if mask == 0 { - // All ASCII: direct copy - _mm_storeu_si128(out_ptr.add(out_len) as *mut __m128i, chunk); - out_len += 16; - i += 16; - continue; - } else { - break; - } - } - } - } - - // ---- NEON fast-path: process 16 ASCII bytes at once ---- - #[cfg(target_arch = "aarch64")] - { - if std::arch::is_aarch64_feature_detected!("neon") { - use std::arch::aarch64::*; - while i + 16 <= len { - let ptr = src.as_ptr().add(i); - let v = vld1q_u8(ptr); - // Check if any byte >= 0x80 - if vmaxvq_u8(v) < 0x80 { - // All ASCII: direct copy - vst1q_u8(out_ptr.add(out_len), v); - out_len += 16; - i += 16; - continue; - } else { - break; - } - } - } - } - - // ---- Scalar fallback: convert Latin1 -> UTF-8 ---- - // ASCII (0x00-0x7F): copy as-is - // Latin1 (0x80-0xFF): encode as 2-byte UTF-8 - while i < len { - let b = *src.get_unchecked(i); - if b < 0x80 { - *out_ptr.add(out_len) = b; - out_len += 1; - } else { - // Latin1 byte 0x80-0xFF -> UTF-8 encoding - // Example: 0xC0 (À) -> [0xC3, 0x80] - *out_ptr.add(out_len) = 0xC0 | (b >> 6); - *out_ptr.add(out_len + 1) = 0x80 | (b & 0x3F); - out_len += 2; - } - i += 1; - } - - out.set_len(out_len); - } - reader.move_next(len); - Ok(unsafe { String::from_utf8_unchecked(out) }) - } - - #[cfg(test)] - mod tests { - use super::*; - use crate::buffer::{Reader, Writer}; - - #[test] - fn test_latin1() { - let samples = [ - "Hello World!", - "Rusty Café", - "1234567890", - "ÀÁÂÃÄÅÆÇÈÉÊËÌÍÎÏÐÑÒÓÔÕÖרÙÚÛÜÝ", - ]; - - for s in samples { - let mut buffer = vec![]; - let mut writer = Writer::from_buffer(&mut buffer); - write_latin1_simd(&mut writer, s); - write_latin1_simd(&mut writer, s); - let bytes = &*writer.dump(); - let bytes_len = bytes.len() / 2; - let mut reader = Reader::new(bytes); - assert_eq!(read_latin1_standard(&mut reader, bytes_len).unwrap(), s); - assert_eq!(read_latin1_standard(&mut reader, bytes_len).unwrap(), s); - - let mut buffer = vec![]; - let mut writer = Writer::from_buffer(&mut buffer); - write_latin1_standard(&mut writer, s); - write_latin1_standard(&mut writer, s); - let bytes = &*writer.dump(); - let bytes_len = bytes.len() / 2; - let mut reader = Reader::new(bytes); - assert_eq!(read_latin1_simd(&mut reader, bytes_len).unwrap(), s); - assert_eq!(read_latin1_simd(&mut reader, bytes_len).unwrap(), s); - } - } - - #[test] - fn test_utf8() { - let samples = [ - "hello", - "rust语言", - "你好,世界", - "emoji 😀😃😄😁", - "mixed ASCII + 中文 + emoji 😁", - ]; - - for s in samples { - let bytes_len = s.len(); - - let mut buffer = vec![]; - let mut writer = Writer::from_buffer(&mut buffer); - write_utf8_standard(&mut writer, s); - write_utf8_standard(&mut writer, s); - let bytes = &*writer.dump(); - let mut reader = Reader::new(bytes); - assert_eq!(read_utf8_standard(&mut reader, bytes_len).unwrap(), s); - assert_eq!(read_utf8_standard(&mut reader, bytes_len).unwrap(), s); - } - } - - #[test] - fn test_utf16() { - let samples = [ - "hello", - "rust语言", - "你好,世界", - "emoji 😀😃😄😁", - "混合文字 + emoji 🐍💻🦀", - ]; - for s in samples { - let utf16: Vec = s.encode_utf16().collect(); - let bytes_len = utf16.len() * 2; - - let mut buffer = vec![]; - let mut writer = Writer::from_buffer(&mut buffer); - write_utf16_standard(&mut writer, &utf16); - write_utf16_standard(&mut writer, &utf16); - - let mut buffer = vec![]; - let mut writer = Writer::from_buffer(&mut buffer); - write_utf16_standard(&mut writer, &utf16); - write_utf16_standard(&mut writer, &utf16); - let bytes = &*writer.dump(); - let mut reader = Reader::new(bytes); - assert_eq!(read_utf16_standard(&mut reader, bytes_len).unwrap(), s); - assert_eq!(read_utf16_standard(&mut reader, bytes_len).unwrap(), s); - } - } - } -} diff --git a/rust/fory-core/src/meta/type_meta.rs b/rust/fory-core/src/meta/type_meta.rs index 81154997d9..00f896020d 100644 --- a/rust/fory-core/src/meta/type_meta.rs +++ b/rust/fory-core/src/meta/type_meta.rs @@ -18,15 +18,15 @@ use crate::buffer::{Reader, Writer}; use crate::error::Error; use crate::meta::{ - murmurhash3_x64_128, Encoding, MetaString, MetaStringDecoder, FIELD_NAME_DECODER, - FIELD_NAME_ENCODER, NAMESPACE_DECODER, TYPE_NAME_DECODER, + Encoding, MetaString, MetaStringDecoder, FIELD_NAME_DECODER, FIELD_NAME_ENCODER, + NAMESPACE_DECODER, TYPE_NAME_DECODER, }; -use crate::resolver::type_resolver::{TypeInfo, TypeResolver}; -use crate::types::{ +use crate::resolver::{TypeInfo, TypeResolver}; +use crate::type_id::{ TypeId, BINARY, COMPATIBLE_STRUCT, ENUM, EXT, INT8_ARRAY, NAMED_COMPATIBLE_STRUCT, NAMED_ENUM, NAMED_EXT, NAMED_STRUCT, PRIMITIVE_TYPES, STRUCT, UINT8_ARRAY, UNKNOWN, }; -use crate::util::to_snake_case; +use crate::util::{murmurhash3_x64_128, to_snake_case}; /// Normalizes a type ID for comparison purposes in cross-language schema evolution. /// This treats all struct variants (STRUCT, COMPATIBLE_STRUCT, NAMED_STRUCT, @@ -75,6 +75,7 @@ const COMPRESS_META_FLAG: i64 = 0b1 << 9; const HAS_FIELDS_META_FLAG: i64 = 0b1 << 8; const NUM_HASH_BITS: i8 = 50; const NO_USER_TYPE_ID: u32 = u32::MAX; +const MAX_HASH32: u64 = (1 << 31) - 1; pub static NAMESPACE_ENCODINGS: &[Encoding] = &[ Encoding::Utf8, @@ -420,6 +421,20 @@ fn compute_schema_hash(field_infos: &[FieldInfo]) -> i64 { hash as i64 } +#[inline(always)] +pub fn compute_field_hash(hash: u32, id: i16) -> u32 { + let mut next_hash = (hash as u64) * 31 + (id as u64); + while next_hash >= MAX_HASH32 { + next_hash /= 7; + } + next_hash as u32 +} + +#[inline(always)] +pub fn compute_struct_hash(field_ids: impl IntoIterator) -> u32 { + field_ids.into_iter().fold(17u32, compute_field_hash) +} + /// Sorts field infos according to the provided sorted field names and assigns field IDs. /// /// This function takes a vector of field infos and a slice of sorted field names, @@ -717,7 +732,7 @@ impl TypeMeta { set_fields.push(field_info); } else if TypeId::MAP as u32 == type_id { map_fields.push(field_info); - } else if crate::types::is_internal_type(type_id) { + } else if crate::type_id::is_internal_type(type_id) { internal_type_fields.push(field_info); } else { other_fields.push(field_info); diff --git a/rust/fory-core/src/resolver/meta_resolver.rs b/rust/fory-core/src/resolver/meta_resolver.rs index b680461cad..38e07e4572 100644 --- a/rust/fory-core/src/resolver/meta_resolver.rs +++ b/rust/fory-core/src/resolver/meta_resolver.rs @@ -18,8 +18,8 @@ use crate::buffer::{Reader, Writer}; use crate::error::Error; use crate::meta::TypeMeta; -use crate::resolver::type_resolver::{TypeInfo, NO_USER_TYPE_ID}; -use crate::TypeResolver; +use crate::resolver::type_resolver::NO_USER_TYPE_ID; +use crate::resolver::{TypeInfo, TypeResolver}; use std::collections::HashMap; use std::rc::Rc; diff --git a/rust/fory-core/src/resolver/meta_string_resolver.rs b/rust/fory-core/src/resolver/meta_string_resolver.rs index 705308ddf0..f8593e4e4f 100644 --- a/rust/fory-core/src/resolver/meta_string_resolver.rs +++ b/rust/fory-core/src/resolver/meta_string_resolver.rs @@ -23,8 +23,9 @@ use std::sync::OnceLock; use crate::buffer::Writer; use crate::error::Error; -use crate::meta::{murmurhash3_x64_128, NAMESPACE_DECODER}; +use crate::meta::NAMESPACE_DECODER; use crate::meta::{Encoding, MetaString}; +use crate::util::murmurhash3_x64_128; use crate::Reader; #[derive(Debug, Clone, PartialEq, Eq, Hash)] diff --git a/rust/fory-core/src/resolver/mod.rs b/rust/fory-core/src/resolver/mod.rs index 35a3538fa6..08b657be42 100644 --- a/rust/fory-core/src/resolver/mod.rs +++ b/rust/fory-core/src/resolver/mod.rs @@ -15,8 +15,10 @@ // specific language governing permissions and limitations // under the License. -pub mod context; pub mod meta_resolver; pub mod meta_string_resolver; pub mod ref_resolver; pub mod type_resolver; + +pub use ref_resolver::{RefFlag, RefMode, RefReader, RefWriter}; +pub use type_resolver::{TypeInfo, TypeResolver}; diff --git a/rust/fory-core/src/resolver/ref_resolver.rs b/rust/fory-core/src/resolver/ref_resolver.rs index 1a26ed7496..14ca87c581 100644 --- a/rust/fory-core/src/resolver/ref_resolver.rs +++ b/rust/fory-core/src/resolver/ref_resolver.rs @@ -17,12 +17,81 @@ use crate::buffer::{Reader, Writer}; use crate::error::Error; -use crate::types::RefFlag; +use num_enum::TryFromPrimitive; use std::any::Any; use std::collections::HashMap; use std::rc::Rc; use std::sync::Arc; +#[derive(Debug, TryFromPrimitive)] +#[repr(i8)] +pub enum RefFlag { + Null = -3, + // Ref indicates that object is a not-null value. + // We don't use another byte to indicate REF, so that we can save one byte. + Ref = -2, + // NotNullValueFlag indicates that the object is a non-null value. + NotNullValue = -1, + // RefValueFlag indicates that the object is a referencable and first read. + RefValue = 0, +} + +/// Controls how reference and null flags are handled during serialization. +/// +/// This enum combines nullable semantics and reference tracking into one parameter, +/// enabling fine-grained control per type and per field: +/// - `None` = non-nullable, no ref tracking (primitives) +/// - `NullOnly` = nullable, no circular ref tracking +/// - `Tracking` = nullable, with circular ref tracking (Rc/Arc/Weak) +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +#[repr(u8)] +pub enum RefMode { + /// Skip ref handling entirely. No ref/null flags are written/read. + /// Used for non-nullable primitives or when caller handles ref externally. + #[default] + None = 0, + + /// Only null check without reference tracking. + /// Write: NullFlag (-3) for None, NotNullValueFlag (-1) for Some. + /// Read: Read flag and return ForyDefault on null. + NullOnly = 1, + + /// Full reference tracking with circular reference support. + /// Write: Uses RefWriter which writes NullFlag, RefFlag+refId, or RefValueFlag. + /// Read: Uses RefReader with full reference resolution. + Tracking = 2, +} + +impl RefMode { + /// Create RefMode from nullable and track_ref flags. + #[inline] + pub const fn from_flags(nullable: bool, track_ref: bool) -> Self { + match (nullable, track_ref) { + (false, false) => RefMode::None, + (true, false) => RefMode::NullOnly, + (_, true) => RefMode::Tracking, + } + } + + /// Check if this mode reads/writes ref flags. + #[inline] + pub const fn has_ref_flag(self) -> bool { + !matches!(self, RefMode::None) + } + + /// Check if this mode tracks circular references. + #[inline] + pub const fn tracks_refs(self) -> bool { + matches!(self, RefMode::Tracking) + } + + /// Check if this mode handles nullable values. + #[inline] + pub const fn is_nullable(self) -> bool { + !matches!(self, RefMode::None) + } +} + /// Reference writer for tracking shared references during serialization. /// /// RefWriter maintains a mapping from object pointer addresses to reference IDs, @@ -34,7 +103,7 @@ use std::sync::Arc; /// /// ```rust /// use fory_core::buffer::Writer; -/// use fory_core::resolver::ref_resolver::RefWriter; +/// use fory_core::resolver::RefWriter; /// use std::rc::Rc; /// /// let mut ref_writer = RefWriter::new(); @@ -166,7 +235,7 @@ impl RefWriter { /// # Examples /// /// ```rust -/// use fory_core::resolver::ref_resolver::RefReader; +/// use fory_core::resolver::RefReader; /// use std::rc::Rc; /// /// let mut ref_reader = RefReader::new(); diff --git a/rust/fory-core/src/resolver/type_resolver.rs b/rust/fory-core/src/resolver/type_resolver.rs index fe9d813064..bbd6b3a547 100644 --- a/rust/fory-core/src/resolver/type_resolver.rs +++ b/rust/fory-core/src/resolver/type_resolver.rs @@ -15,14 +15,15 @@ // specific language governing permissions and limitations // under the License. -use super::context::{ReadContext, WriteContext}; +use crate::context::{ReadContext, WriteContext}; use crate::error::Error; use crate::meta::{ MetaString, TypeMeta, NAMESPACE_ENCODER, NAMESPACE_ENCODINGS, TYPE_NAME_ENCODER, TYPE_NAME_ENCODINGS, }; +use crate::resolver::RefMode; use crate::serializer::{ForyDefault, Serializer, StructSerializer}; -use crate::types::{get_ext_actual_type_id, is_enum_type_id, RefMode}; +use crate::type_id::{get_ext_actual_type_id, is_enum_type_id}; use crate::TypeId; use chrono::{NaiveDate, NaiveDateTime}; use std::collections::{HashSet, LinkedList}; @@ -554,7 +555,7 @@ impl TypeResolver { #[inline(always)] pub fn get_type_info_by_id(&self, type_id: u32) -> Option> { - if crate::types::is_internal_type(type_id) { + if crate::type_id::is_internal_type(type_id) { let index = type_id as usize; if index < self.internal_type_info_by_id.len() { return self.internal_type_info_by_id[index].clone(); @@ -731,7 +732,7 @@ impl TypeResolver { self.register_internal_serializer::(TypeId::INT128)?; self.register_internal_serializer::(TypeId::FLOAT32)?; self.register_internal_serializer::(TypeId::FLOAT64)?; - self.register_internal_serializer::(TypeId::FLOAT16)?; + self.register_internal_serializer::(TypeId::FLOAT16)?; self.register_internal_serializer::(TypeId::UINT8)?; self.register_internal_serializer::(TypeId::UINT16)?; self.register_internal_serializer::(TypeId::VAR_UINT32)?; @@ -741,6 +742,7 @@ impl TypeResolver { self.register_internal_serializer::(TypeId::STRING)?; self.register_internal_serializer::(TypeId::TIMESTAMP)?; self.register_internal_serializer::(TypeId::DATE)?; + self.register_internal_serializer::(TypeId::DECIMAL)?; self.register_internal_serializer::>(TypeId::BOOL_ARRAY)?; self.register_internal_serializer::>(TypeId::INT8_ARRAY)?; @@ -749,7 +751,9 @@ impl TypeResolver { self.register_internal_serializer::>(TypeId::INT64_ARRAY)?; self.register_internal_serializer::>(TypeId::FLOAT32_ARRAY)?; self.register_internal_serializer::>(TypeId::FLOAT64_ARRAY)?; - self.register_internal_serializer::>(TypeId::FLOAT16_ARRAY)?; + self.register_internal_serializer::>( + TypeId::FLOAT16_ARRAY, + )?; self.register_internal_serializer::>(TypeId::BINARY)?; self.register_internal_serializer::>(TypeId::UINT16_ARRAY)?; self.register_internal_serializer::>(TypeId::UINT32_ARRAY)?; @@ -827,7 +831,7 @@ impl TypeResolver { } let actual_type_id = T::fory_actual_type_id(id, register_by_name, self.compatible, self.xlang); - let user_type_id = if register_by_name || crate::types::is_internal_type(actual_type_id) { + let user_type_id = if register_by_name || crate::type_id::is_internal_type(actual_type_id) { NO_USER_TYPE_ID } else { id @@ -937,7 +941,7 @@ impl TypeResolver { // 1. Internal types (type_id < TypeId::BOUND) as they can be shared // 2. Types registered by name (they use shared type IDs like NAMED_STRUCT) if !register_by_name - && !crate::types::is_internal_type(actual_type_id) + && !crate::type_id::is_internal_type(actual_type_id) && self.user_type_info_by_id.contains_key(&user_type_id) { return Err(Error::type_error(format!( @@ -963,7 +967,7 @@ impl TypeResolver { self.rust_type_id_by_index[index] = Some(rs_type_id); // Insert partial type info into id maps - if crate::types::is_internal_type(actual_type_id) { + if crate::type_id::is_internal_type(actual_type_id) { let index = actual_type_id as usize; if index >= self.internal_type_info_by_id.len() { return Err(Error::not_allowed(format!( @@ -1151,7 +1155,7 @@ impl TypeResolver { // Check if type_id conflicts with any already registered type // Skip check for internal types as they can be shared - if !crate::types::is_internal_type(actual_type_id) + if !crate::type_id::is_internal_type(actual_type_id) && user_type_id != NO_USER_TYPE_ID && self.user_type_info_by_id.contains_key(&user_type_id) { @@ -1162,7 +1166,7 @@ impl TypeResolver { } // Insert partial type info into id maps - if crate::types::is_internal_type(actual_type_id) { + if crate::type_id::is_internal_type(actual_type_id) { let index = actual_type_id as usize; if index >= self.internal_type_info_by_id.len() { return Err(Error::not_allowed(format!( @@ -1253,7 +1257,7 @@ impl TypeResolver { // Iterate through all type infos uniformly for (type_rust_id, type_info) in type_infos.iter() { // Insert into id maps - if crate::types::is_internal_type(type_info.type_id as u32) { + if crate::type_id::is_internal_type(type_info.type_id as u32) { let index = type_info.type_id as usize; if index < internal_type_info_by_id.len() { internal_type_info_by_id[index] = Some(Rc::new(type_info.clone())); diff --git a/rust/fory-core/src/serializer/any.rs b/rust/fory-core/src/serializer/any.rs index 8a7e9f8e44..565032b234 100644 --- a/rust/fory-core/src/serializer/any.rs +++ b/rust/fory-core/src/serializer/any.rs @@ -15,13 +15,14 @@ // specific language governing permissions and limitations // under the License. +use crate::context::{ReadContext, WriteContext}; use crate::ensure; use crate::error::Error; -use crate::resolver::context::{ReadContext, WriteContext}; -use crate::resolver::type_resolver::{TypeInfo, TypeResolver}; +use crate::resolver::{RefFlag, RefMode}; +use crate::resolver::{TypeInfo, TypeResolver}; use crate::serializer::util::write_dyn_data_generic; use crate::serializer::{ForyDefault, Serializer}; -use crate::types::{RefFlag, RefMode, TypeId}; +use crate::type_id::TypeId; use std::any::Any; use std::rc::Rc; use std::sync::Arc; diff --git a/rust/fory-core/src/serializer/arc.rs b/rust/fory-core/src/serializer/arc.rs index 4ef4e50d55..2277078a36 100644 --- a/rust/fory-core/src/serializer/arc.rs +++ b/rust/fory-core/src/serializer/arc.rs @@ -15,11 +15,12 @@ // specific language governing permissions and limitations // under the License. +use crate::context::{ReadContext, WriteContext}; use crate::error::Error; -use crate::resolver::context::{ReadContext, WriteContext}; -use crate::resolver::type_resolver::{TypeInfo, TypeResolver}; +use crate::resolver::{RefFlag, RefMode}; +use crate::resolver::{TypeInfo, TypeResolver}; use crate::serializer::{ForyDefault, Serializer}; -use crate::types::{RefFlag, RefMode, TypeId}; +use crate::type_id::TypeId; use std::rc::Rc; use std::sync::Arc; diff --git a/rust/fory-core/src/serializer/array.rs b/rust/fory-core/src/serializer/array.rs index 57d56ddf8a..35e6a31069 100644 --- a/rust/fory-core/src/serializer/array.rs +++ b/rust/fory-core/src/serializer/array.rs @@ -15,13 +15,13 @@ // specific language governing permissions and limitations // under the License. +use crate::context::ReadContext; +use crate::context::WriteContext; use crate::error::Error; -use crate::resolver::context::ReadContext; -use crate::resolver::context::WriteContext; -use crate::resolver::type_resolver::TypeResolver; +use crate::resolver::TypeResolver; use crate::serializer::primitive_list; use crate::serializer::{ForyDefault, Serializer}; -use crate::types::TypeId; +use crate::type_id::TypeId; use std::mem; use std::mem::MaybeUninit; @@ -31,7 +31,7 @@ use super::collection::{ }; use super::list::{get_primitive_type_id, is_primitive_type}; use crate::ensure; -use crate::types::{RefFlag, RefMode}; +use crate::resolver::{RefFlag, RefMode}; // Collection header flags (matching collection.rs private constants) const TRACKING_REF: u8 = 0b1; diff --git a/rust/fory-core/src/serializer/bool.rs b/rust/fory-core/src/serializer/bool.rs index edffac5216..09307e168c 100644 --- a/rust/fory-core/src/serializer/bool.rs +++ b/rust/fory-core/src/serializer/bool.rs @@ -15,13 +15,13 @@ // specific language governing permissions and limitations // under the License. +use crate::context::ReadContext; +use crate::context::WriteContext; use crate::error::Error; -use crate::resolver::context::ReadContext; -use crate::resolver::context::WriteContext; -use crate::resolver::type_resolver::TypeResolver; +use crate::resolver::TypeResolver; use crate::serializer::util::read_basic_type_info; use crate::serializer::{ForyDefault, Serializer}; -use crate::types::TypeId; +use crate::type_id::TypeId; use std::mem; impl Serializer for bool { diff --git a/rust/fory-core/src/serializer/box_.rs b/rust/fory-core/src/serializer/box_.rs index 9f992e0ec9..2e5c021971 100644 --- a/rust/fory-core/src/serializer/box_.rs +++ b/rust/fory-core/src/serializer/box_.rs @@ -15,12 +15,12 @@ // specific language governing permissions and limitations // under the License. +use crate::context::ReadContext; +use crate::context::WriteContext; use crate::error::Error; -use crate::resolver::context::ReadContext; -use crate::resolver::context::WriteContext; -use crate::resolver::type_resolver::{TypeInfo, TypeResolver}; +use crate::resolver::{TypeInfo, TypeResolver}; use crate::serializer::{ForyDefault, Serializer}; -use crate::types::TypeId; +use crate::type_id::TypeId; use std::rc::Rc; impl Serializer for Box { diff --git a/rust/fory-core/src/serializer/collection.rs b/rust/fory-core/src/serializer/collection.rs index 2b590daee8..0fbe38c527 100644 --- a/rust/fory-core/src/serializer/collection.rs +++ b/rust/fory-core/src/serializer/collection.rs @@ -15,12 +15,13 @@ // specific language governing permissions and limitations // under the License. +use crate::context::ReadContext; +use crate::context::WriteContext; use crate::ensure; use crate::error::Error; -use crate::resolver::context::ReadContext; -use crate::resolver::context::WriteContext; +use crate::resolver::{RefFlag, RefMode}; use crate::serializer::{ForyDefault, Serializer}; -use crate::types::{need_to_write_type_for_field, RefFlag, RefMode, PRIMITIVE_ARRAY_TYPES}; +use crate::type_id::{need_to_write_type_for_field, PRIMITIVE_ARRAY_TYPES}; const TRACKING_REF: u8 = 0b1; diff --git a/rust/fory-core/src/serializer/core.rs b/rust/fory-core/src/serializer/core.rs index 840807e1ff..62bb628453 100644 --- a/rust/fory-core/src/serializer/core.rs +++ b/rust/fory-core/src/serializer/core.rs @@ -15,13 +15,12 @@ // specific language governing permissions and limitations // under the License. +use crate::context::{ReadContext, WriteContext}; use crate::error::Error; use crate::meta::FieldInfo; -use crate::resolver::context::{ReadContext, WriteContext}; -use crate::resolver::type_resolver::TypeInfo; +use crate::resolver::{RefFlag, RefMode, TypeInfo, TypeResolver}; use crate::serializer::{bool, struct_}; -use crate::types::{RefFlag, RefMode, TypeId}; -use crate::TypeResolver; +use crate::type_id::TypeId; use std::any::Any; use std::rc::Rc; @@ -346,7 +345,7 @@ pub trait Serializer: 'static { /// /// ```rust /// use fory_core::{Serializer, ForyDefault}; - /// use fory_core::resolver::context::WriteContext; + /// use fory_core::WriteContext; /// use fory_core::error::Error; /// use std::any::Any; /// @@ -369,7 +368,7 @@ pub trait Serializer: 'static { /// Ok(()) /// } /// - /// fn fory_read_data(context: &mut fory_core::resolver::context::ReadContext) -> Result + /// fn fory_read_data(context: &mut fory_core::ReadContext) -> Result /// where /// Self: Sized + fory_core::ForyDefault, /// { @@ -378,7 +377,7 @@ pub trait Serializer: 'static { /// Ok(Point { x, y }) /// } /// - /// fn fory_type_id_dyn(&self, type_resolver: &fory_core::TypeResolver) -> Result { + /// fn fory_type_id_dyn(&self, type_resolver: &fory_core::resolver::TypeResolver) -> Result { /// Self::fory_get_type_id(type_resolver) /// } /// @@ -392,7 +391,7 @@ pub trait Serializer: 'static { /// /// ```rust /// use fory_core::{Serializer, ForyDefault, RefMode}; - /// use fory_core::resolver::context::WriteContext; + /// use fory_core::WriteContext; /// use fory_core::error::Error; /// use std::any::Any; /// @@ -421,7 +420,7 @@ pub trait Serializer: 'static { /// Ok(()) /// } /// - /// fn fory_read_data(context: &mut fory_core::resolver::context::ReadContext) -> Result + /// fn fory_read_data(context: &mut fory_core::ReadContext) -> Result /// where /// Self: Sized + fory_core::ForyDefault, /// { @@ -431,7 +430,7 @@ pub trait Serializer: 'static { /// Ok(Person { name, age, scores }) /// } /// - /// fn fory_type_id_dyn(&self, type_resolver: &fory_core::TypeResolver) -> Result { + /// fn fory_type_id_dyn(&self, type_resolver: &fory_core::resolver::TypeResolver) -> Result { /// Self::fory_get_type_id(type_resolver) /// } /// @@ -692,7 +691,7 @@ pub trait Serializer: 'static { /// /// ```rust /// use fory_core::{Serializer, ForyDefault}; - /// use fory_core::resolver::context::{ReadContext, WriteContext}; + /// use fory_core::{ReadContext, WriteContext}; /// use fory_core::error::Error; /// use std::any::Any; /// @@ -725,7 +724,7 @@ pub trait Serializer: 'static { /// Ok(Point { x, y }) /// } /// - /// fn fory_type_id_dyn(&self, type_resolver: &fory_core::TypeResolver) -> Result { + /// fn fory_type_id_dyn(&self, type_resolver: &fory_core::resolver::TypeResolver) -> Result { /// Self::fory_get_type_id(type_resolver) /// } /// @@ -739,7 +738,7 @@ pub trait Serializer: 'static { /// /// ```rust /// use fory_core::{Serializer, ForyDefault, RefMode}; - /// use fory_core::resolver::context::{ReadContext, WriteContext}; + /// use fory_core::{ReadContext, WriteContext}; /// use fory_core::error::Error; /// use std::any::Any; /// @@ -779,7 +778,7 @@ pub trait Serializer: 'static { /// Ok(Person { name, age, scores }) /// } /// - /// fn fory_type_id_dyn(&self, type_resolver: &fory_core::TypeResolver) -> Result { + /// fn fory_type_id_dyn(&self, type_resolver: &fory_core::resolver::TypeResolver) -> Result { /// Self::fory_get_type_id(type_resolver) /// } /// @@ -793,7 +792,7 @@ pub trait Serializer: 'static { /// /// ```rust,ignore /// use fory_core::{Serializer, ForyDefault}; - /// use fory_core::resolver::context::{ReadContext, WriteContext}; + /// use fory_core::{ReadContext, WriteContext}; /// use fory_core::error::Error; /// use std::any::Any; /// @@ -830,7 +829,7 @@ pub trait Serializer: 'static { /// Ok(Config { name, timeout }) /// } /// - /// fn fory_type_id_dyn(&self, type_resolver: &fory_core::TypeResolver) -> Result { + /// fn fory_type_id_dyn(&self, type_resolver: &fory_core::resolver::TypeResolver) -> Result { /// Self::fory_get_type_id(type_resolver) /// } /// diff --git a/rust/fory-core/src/serializer/datetime.rs b/rust/fory-core/src/serializer/datetime.rs index 23f8c49685..5823d0c7c0 100644 --- a/rust/fory-core/src/serializer/datetime.rs +++ b/rust/fory-core/src/serializer/datetime.rs @@ -15,16 +15,16 @@ // specific language governing permissions and limitations // under the License. +use crate::context::ReadContext; +use crate::context::WriteContext; use crate::error::Error; -use crate::resolver::context::ReadContext; -use crate::resolver::context::WriteContext; -use crate::resolver::type_resolver::TypeResolver; +use crate::resolver::TypeResolver; use crate::serializer::util::read_basic_type_info; use crate::serializer::ForyDefault; use crate::serializer::Serializer; -use crate::types::TypeId; +use crate::type_id::TypeId; use crate::util::EPOCH; -use chrono::{Duration as ChronoDuration, NaiveDate, NaiveDateTime}; +use chrono::{Duration as ChronoDuration, NaiveDate, NaiveDateTime, TimeDelta}; use std::mem; use std::time::Duration; @@ -89,22 +89,44 @@ impl Serializer for NaiveDate { #[inline(always)] fn fory_write_data(&self, context: &mut WriteContext) -> Result<(), Error> { let days_since_epoch = self.signed_duration_since(EPOCH).num_days(); - context.writer.write_i32(days_since_epoch as i32); + if context.is_xlang() { + context.writer.write_var_i64(days_since_epoch); + } else { + let native_days = i32::try_from(days_since_epoch).map_err(|_| { + Error::invalid_data(format!( + "date day count {} exceeds native i32 range", + days_since_epoch + )) + })?; + context.writer.write_i32(native_days); + } Ok(()) } #[inline(always)] fn fory_read_data(context: &mut ReadContext) -> Result { - let days = context.reader.read_i32()?; - use chrono::TimeDelta; - let duration = TimeDelta::days(days as i64); - let result = EPOCH + duration; - Ok(result) + let days = if context.is_xlang() { + context.reader.read_var_i64()? + } else { + i64::from(context.reader.read_i32()?) + }; + let duration = TimeDelta::try_days(days).ok_or_else(|| { + Error::invalid_data(format!( + "date day count {} is out of chrono::TimeDelta range", + days + )) + })?; + EPOCH.checked_add_signed(duration).ok_or_else(|| { + Error::invalid_data(format!( + "date day count {} is out of chrono::NaiveDate range", + days + )) + }) } #[inline(always)] fn fory_reserved_space() -> usize { - mem::size_of::() + 9 } #[inline(always)] diff --git a/rust/fory-core/src/serializer/decimal.rs b/rust/fory-core/src/serializer/decimal.rs new file mode 100644 index 0000000000..4775df108f --- /dev/null +++ b/rust/fory-core/src/serializer/decimal.rs @@ -0,0 +1,149 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::buffer::{Reader, Writer}; +use crate::context::{ReadContext, WriteContext}; +use crate::error::Error; +use crate::resolver::TypeResolver; +use crate::serializer::util::read_basic_type_info; +use crate::serializer::{ForyDefault, Serializer}; +use crate::type_id::TypeId; +use crate::types::Decimal; +use num_bigint::{BigInt, Sign}; +use std::convert::TryFrom; + +impl Serializer for Decimal { + #[inline(always)] + fn fory_write_data(&self, context: &mut WriteContext) -> Result<(), Error> { + context.writer.write_var_i32(self.scale); + write_decimal_unscaled(&self.unscaled, &mut context.writer) + } + + #[inline(always)] + fn fory_read_data(context: &mut ReadContext) -> Result { + let scale = context.reader.read_var_i32()?; + let unscaled = read_decimal_unscaled(&mut context.reader)?; + Ok(Self { unscaled, scale }) + } + + #[inline(always)] + fn fory_get_type_id(_: &TypeResolver) -> Result { + Ok(TypeId::DECIMAL) + } + + #[inline(always)] + fn fory_type_id_dyn(&self, _: &TypeResolver) -> Result { + Ok(TypeId::DECIMAL) + } + + #[inline(always)] + fn fory_static_type_id() -> TypeId { + TypeId::DECIMAL + } + + #[inline(always)] + fn as_any(&self) -> &dyn std::any::Any { + self + } + + #[inline(always)] + fn fory_write_type_info(context: &mut WriteContext) -> Result<(), Error> { + context.writer.write_var_u32(TypeId::DECIMAL as u32); + Ok(()) + } + + #[inline(always)] + fn fory_read_type_info(context: &mut ReadContext) -> Result<(), Error> { + read_basic_type_info::(context) + } +} + +impl ForyDefault for Decimal { + #[inline(always)] + fn fory_default() -> Self { + Self { + unscaled: BigInt::from(0), + scale: 0, + } + } +} + +fn write_decimal_unscaled(value: &BigInt, writer: &mut Writer) -> Result<(), Error> { + if let Some(small_value) = can_use_small_encoding(value) { + writer.write_var_u64(encode_zigzag64(small_value) << 1); + return Ok(()); + } + + let (sign, payload) = value.to_bytes_le(); + if payload.is_empty() { + return Err(Error::invalid_data( + "zero must use the small decimal encoding".to_string(), + )); + } + let meta = ((payload.len() as u64) << 1) | u64::from(matches!(sign, Sign::Minus)); + writer.write_var_u64((meta << 1) | 1); + writer.write_bytes(&payload); + Ok(()) +} + +fn read_decimal_unscaled(reader: &mut Reader) -> Result { + let header = reader.read_var_u64()?; + if (header & 1) == 0 { + return Ok(BigInt::from(decode_zigzag64(header >> 1))); + } + + let meta = header >> 1; + let sign = (meta & 1) != 0; + let len = (meta >> 1) as usize; + if len == 0 { + return Err(Error::invalid_data( + "invalid decimal magnitude length 0".to_string(), + )); + } + let payload = reader.read_bytes(len)?; + if payload[len - 1] == 0 { + return Err(Error::invalid_data( + "non-canonical decimal payload: trailing zero byte".to_string(), + )); + } + let magnitude = BigInt::from_bytes_le(Sign::Plus, payload); + if magnitude == BigInt::from(0) { + return Err(Error::invalid_data( + "big decimal encoding must not represent zero".to_string(), + )); + } + Ok(if sign { -magnitude } else { magnitude }) +} + +fn can_use_small_encoding(value: &BigInt) -> Option { + let small_value = i64::try_from(value).ok()?; + if (encode_zigzag64(small_value) & (1u64 << 63)) == 0 { + Some(small_value) + } else { + None + } +} + +#[inline(always)] +fn encode_zigzag64(value: i64) -> u64 { + ((value << 1) ^ (value >> 63)) as u64 +} + +#[inline(always)] +fn decode_zigzag64(value: u64) -> i64 { + ((value >> 1) as i64) ^ -((value & 1) as i64) +} diff --git a/rust/fory-core/src/serializer/enum_.rs b/rust/fory-core/src/serializer/enum_.rs index 00ae71229f..1ab3d0f8ee 100644 --- a/rust/fory-core/src/serializer/enum_.rs +++ b/rust/fory-core/src/serializer/enum_.rs @@ -15,13 +15,13 @@ // specific language governing permissions and limitations // under the License. +use crate::context::{ReadContext, WriteContext}; use crate::ensure; use crate::error::Error; use crate::meta::FieldInfo; -use crate::resolver::context::{ReadContext, WriteContext}; +use crate::resolver::{RefFlag, RefMode, TypeResolver}; use crate::serializer::{ForyDefault, Serializer}; -use crate::types::{RefFlag, RefMode, TypeId}; -use crate::TypeResolver; +use crate::type_id::TypeId; #[inline(always)] pub fn actual_type_id(_type_id: u32, register_by_name: bool, _compatible: bool) -> u32 { diff --git a/rust/fory-core/src/serializer/heap.rs b/rust/fory-core/src/serializer/heap.rs index 623edba302..cfbd7be30e 100644 --- a/rust/fory-core/src/serializer/heap.rs +++ b/rust/fory-core/src/serializer/heap.rs @@ -15,17 +15,17 @@ // specific language governing permissions and limitations // under the License. +use crate::context::ReadContext; +use crate::context::WriteContext; use crate::error::Error; -use crate::resolver::context::ReadContext; -use crate::resolver::context::WriteContext; -use crate::resolver::type_resolver::TypeResolver; +use crate::resolver::TypeResolver; use crate::serializer::collection::{ read_collection_data, read_collection_type_info, write_collection_data, write_collection_type_info, }; use crate::serializer::{ForyDefault, Serializer}; -use crate::types::TypeId; +use crate::type_id::TypeId; use std::collections::BinaryHeap; use std::mem; diff --git a/rust/fory-core/src/serializer/list.rs b/rust/fory-core/src/serializer/list.rs index d93d583699..36ec760a4f 100644 --- a/rust/fory-core/src/serializer/list.rs +++ b/rust/fory-core/src/serializer/list.rs @@ -15,13 +15,13 @@ // specific language governing permissions and limitations // under the License. +use crate::context::ReadContext; +use crate::context::WriteContext; use crate::error::Error; -use crate::resolver::context::ReadContext; -use crate::resolver::context::WriteContext; -use crate::resolver::type_resolver::TypeResolver; +use crate::resolver::TypeResolver; use crate::serializer::primitive_list; use crate::serializer::{ForyDefault, Serializer}; -use crate::types::TypeId; +use crate::type_id::TypeId; use std::collections::{LinkedList, VecDeque}; use std::mem; diff --git a/rust/fory-core/src/serializer/map.rs b/rust/fory-core/src/serializer/map.rs index 9c043a0486..acaf9ec911 100644 --- a/rust/fory-core/src/serializer/map.rs +++ b/rust/fory-core/src/serializer/map.rs @@ -15,13 +15,14 @@ // specific language governing permissions and limitations // under the License. +use crate::context::{ReadContext, WriteContext}; use crate::ensure; use crate::error::Error; -use crate::resolver::context::{ReadContext, WriteContext}; -use crate::resolver::type_resolver::{TypeInfo, TypeResolver}; +use crate::resolver::RefMode; +use crate::resolver::{TypeInfo, TypeResolver}; use crate::serializer::util::read_basic_type_info; use crate::serializer::{ForyDefault, Serializer}; -use crate::types::{need_to_write_type_for_field, RefMode, TypeId, SIZE_OF_REF_AND_TYPE}; +use crate::type_id::{need_to_write_type_for_field, TypeId, SIZE_OF_REF_AND_TYPE}; use std::collections::{BTreeMap, HashMap}; use std::rc::Rc; diff --git a/rust/fory-core/src/serializer/marker.rs b/rust/fory-core/src/serializer/marker.rs index 5045080cf7..24e919d021 100644 --- a/rust/fory-core/src/serializer/marker.rs +++ b/rust/fory-core/src/serializer/marker.rs @@ -17,11 +17,11 @@ //! Serializer implementations for marker types like `PhantomData`. +use crate::context::{ReadContext, WriteContext}; use crate::error::Error; -use crate::resolver::context::{ReadContext, WriteContext}; -use crate::resolver::type_resolver::TypeResolver; +use crate::resolver::TypeResolver; use crate::serializer::{ForyDefault, Serializer}; -use crate::types::TypeId; +use crate::type_id::TypeId; use std::marker::PhantomData; impl Serializer for PhantomData { diff --git a/rust/fory-core/src/serializer/mod.rs b/rust/fory-core/src/serializer/mod.rs index baf70905c5..1b2c971032 100644 --- a/rust/fory-core/src/serializer/mod.rs +++ b/rust/fory-core/src/serializer/mod.rs @@ -44,5 +44,6 @@ pub mod util; pub mod weak; mod core; +mod decimal; pub use any::{read_box_any, write_box_any}; pub use core::{read_data, write_data, ForyDefault, Serializer, StructSerializer}; diff --git a/rust/fory-core/src/serializer/mutex.rs b/rust/fory-core/src/serializer/mutex.rs index ae8eddf416..9af0b06a2b 100644 --- a/rust/fory-core/src/serializer/mutex.rs +++ b/rust/fory-core/src/serializer/mutex.rs @@ -41,11 +41,12 @@ //! You should serialize in a quiescent state with no concurrent mutation. //! - A poisoned mutex (from a panicked holder) will cause `.lock().unwrap()` to panic //! during serialization — it is assumed this is a programmer error. +use crate::context::{ReadContext, WriteContext}; use crate::error::Error; -use crate::resolver::context::{ReadContext, WriteContext}; -use crate::resolver::type_resolver::{TypeInfo, TypeResolver}; +use crate::resolver::RefMode; +use crate::resolver::{TypeInfo, TypeResolver}; use crate::serializer::{ForyDefault, Serializer}; -use crate::types::{RefMode, TypeId}; +use crate::type_id::TypeId; use std::rc::Rc; use std::sync::Mutex; diff --git a/rust/fory-core/src/serializer/number.rs b/rust/fory-core/src/serializer/number.rs index f62f7e9d07..50f83d0958 100644 --- a/rust/fory-core/src/serializer/number.rs +++ b/rust/fory-core/src/serializer/number.rs @@ -15,16 +15,16 @@ // specific language governing permissions and limitations // under the License. -use crate::float16::float16; +use crate::types::float16::float16; use crate::buffer::{Reader, Writer}; +use crate::context::ReadContext; +use crate::context::WriteContext; use crate::error::Error; -use crate::resolver::context::ReadContext; -use crate::resolver::context::WriteContext; -use crate::resolver::type_resolver::TypeResolver; +use crate::resolver::TypeResolver; use crate::serializer::util::read_basic_type_info; use crate::serializer::{ForyDefault, Serializer}; -use crate::types::TypeId; +use crate::type_id::TypeId; macro_rules! impl_num_serializer { ($ty:ty, $writer:expr, $reader:expr, $field_type:expr) => { diff --git a/rust/fory-core/src/serializer/option.rs b/rust/fory-core/src/serializer/option.rs index 6b770f9002..8ae4e8ffc1 100644 --- a/rust/fory-core/src/serializer/option.rs +++ b/rust/fory-core/src/serializer/option.rs @@ -15,12 +15,13 @@ // specific language governing permissions and limitations // under the License. +use crate::context::ReadContext; +use crate::context::WriteContext; use crate::error::Error; -use crate::resolver::context::ReadContext; -use crate::resolver::context::WriteContext; -use crate::resolver::type_resolver::TypeResolver; +use crate::resolver::TypeResolver; +use crate::resolver::{RefFlag, RefMode}; use crate::serializer::{ForyDefault, Serializer}; -use crate::types::{RefFlag, RefMode, TypeId}; +use crate::type_id::TypeId; use std::rc::Rc; impl Serializer for Option { diff --git a/rust/fory-core/src/serializer/primitive_list.rs b/rust/fory-core/src/serializer/primitive_list.rs index df17663d07..c1b927bd3a 100644 --- a/rust/fory-core/src/serializer/primitive_list.rs +++ b/rust/fory-core/src/serializer/primitive_list.rs @@ -15,12 +15,12 @@ // specific language governing permissions and limitations // under the License. +use crate::context::ReadContext; +use crate::context::WriteContext; use crate::ensure; use crate::error::Error; -use crate::resolver::context::ReadContext; -use crate::resolver::context::WriteContext; use crate::serializer::Serializer; -use crate::types::TypeId; +use crate::type_id::TypeId; #[cold] fn binary_size_limit_exceeded(size_bytes: usize, max: usize) -> Error { diff --git a/rust/fory-core/src/serializer/rc.rs b/rust/fory-core/src/serializer/rc.rs index 991746d558..a42e6dbeb0 100644 --- a/rust/fory-core/src/serializer/rc.rs +++ b/rust/fory-core/src/serializer/rc.rs @@ -15,11 +15,12 @@ // specific language governing permissions and limitations // under the License. +use crate::context::{ReadContext, WriteContext}; use crate::error::Error; -use crate::resolver::context::{ReadContext, WriteContext}; -use crate::resolver::type_resolver::{TypeInfo, TypeResolver}; +use crate::resolver::{RefFlag, RefMode}; +use crate::resolver::{TypeInfo, TypeResolver}; use crate::serializer::{ForyDefault, Serializer}; -use crate::types::{RefFlag, RefMode, TypeId}; +use crate::type_id::TypeId; use std::rc::Rc; impl Serializer for Rc { diff --git a/rust/fory-core/src/serializer/refcell.rs b/rust/fory-core/src/serializer/refcell.rs index f8b7013f54..da47ad6783 100644 --- a/rust/fory-core/src/serializer/refcell.rs +++ b/rust/fory-core/src/serializer/refcell.rs @@ -32,11 +32,12 @@ //! let cell = RefCell::new(42); //! // Can be serialized by the Fory framework //! ``` +use crate::context::{ReadContext, WriteContext}; use crate::error::Error; -use crate::resolver::context::{ReadContext, WriteContext}; -use crate::resolver::type_resolver::{TypeInfo, TypeResolver}; +use crate::resolver::RefMode; +use crate::resolver::{TypeInfo, TypeResolver}; use crate::serializer::{ForyDefault, Serializer}; -use crate::types::{RefMode, TypeId}; +use crate::type_id::TypeId; use std::cell::RefCell; use std::rc::Rc; diff --git a/rust/fory-core/src/serializer/set.rs b/rust/fory-core/src/serializer/set.rs index df2ff454bc..d49bc5559f 100644 --- a/rust/fory-core/src/serializer/set.rs +++ b/rust/fory-core/src/serializer/set.rs @@ -15,17 +15,17 @@ // specific language governing permissions and limitations // under the License. +use crate::context::ReadContext; +use crate::context::WriteContext; use crate::error::Error; -use crate::resolver::context::ReadContext; -use crate::resolver::context::WriteContext; -use crate::resolver::type_resolver::TypeResolver; +use crate::resolver::TypeResolver; use crate::serializer::collection::{ read_collection_data, read_collection_type_info, write_collection_data, write_collection_type_info, }; use crate::serializer::{ForyDefault, Serializer}; -use crate::types::TypeId; +use crate::type_id::TypeId; use std::collections::{BTreeSet, HashSet}; use std::mem; diff --git a/rust/fory-core/src/serializer/skip.rs b/rust/fory-core/src/serializer/skip.rs index 177bdfc73d..9c51dd60f3 100644 --- a/rust/fory-core/src/serializer/skip.rs +++ b/rust/fory-core/src/serializer/skip.rs @@ -15,16 +15,16 @@ // specific language governing permissions and limitations // under the License. +use crate::context::ReadContext; use crate::ensure; use crate::error::Error; use crate::meta::FieldType; -use crate::resolver::context::ReadContext; use crate::serializer::collection::{DECL_ELEMENT_TYPE, HAS_NULL, IS_SAME_TYPE}; use crate::serializer::util; use crate::serializer::Serializer; -use crate::types; -use crate::types::RefFlag; +use crate::type_id as types; use crate::util::ENABLE_FORY_DEBUG_OUTPUT; +use crate::RefFlag; use chrono::{Duration, NaiveDate, NaiveDateTime}; use std::rc::Rc; @@ -541,7 +541,7 @@ fn skip_value( // ============ FLOAT16 (TypeId = 17) ============ types::FLOAT16 => { - ::fory_read_data(context)?; + ::fory_read_data(context)?; } // ============ FLOAT32 (TypeId = 17) ============ @@ -648,6 +648,11 @@ fn skip_value( ::fory_read_data(context)?; } + // ============ DECIMAL (TypeId = 38) ============ + types::DECIMAL => { + ::fory_read_data(context)?; + } + // ============ BINARY (TypeId = 39) ============ types::BINARY => { as Serializer>::fory_read_data(context)?; @@ -700,7 +705,7 @@ fn skip_value( // ============ FLOAT16_ARRAY (TypeId = 53) ============ types::FLOAT16_ARRAY => { - as Serializer>::fory_read_data(context)?; + as Serializer>::fory_read_data(context)?; } // ============ FLOAT32_ARRAY (TypeId = 51) ============ diff --git a/rust/fory-core/src/serializer/string.rs b/rust/fory-core/src/serializer/string.rs index 8edef1ee47..093f873c76 100644 --- a/rust/fory-core/src/serializer/string.rs +++ b/rust/fory-core/src/serializer/string.rs @@ -15,13 +15,13 @@ // specific language governing permissions and limitations // under the License. +use crate::context::ReadContext; +use crate::context::WriteContext; use crate::error::Error; -use crate::resolver::context::ReadContext; -use crate::resolver::context::WriteContext; -use crate::resolver::type_resolver::TypeResolver; +use crate::resolver::TypeResolver; use crate::serializer::util::read_basic_type_info; use crate::serializer::{ForyDefault, Serializer}; -use crate::types::TypeId; +use crate::type_id::TypeId; use std::mem; #[allow(dead_code)] diff --git a/rust/fory-core/src/serializer/struct_.rs b/rust/fory-core/src/serializer/struct_.rs index 9c6871e5b3..d22fe1626a 100644 --- a/rust/fory-core/src/serializer/struct_.rs +++ b/rust/fory-core/src/serializer/struct_.rs @@ -15,11 +15,12 @@ // specific language governing permissions and limitations // under the License. +use crate::context::{ReadContext, WriteContext}; use crate::ensure; use crate::error::Error; -use crate::resolver::context::{ReadContext, WriteContext}; +use crate::resolver::{RefFlag, RefMode}; use crate::serializer::{Serializer, StructSerializer}; -use crate::types::{RefFlag, RefMode, TypeId}; +use crate::type_id::TypeId; use crate::util::ENABLE_FORY_DEBUG_OUTPUT; use std::any::Any; @@ -61,7 +62,7 @@ pub fn read_type_info_fast(context: &mut ReadContext) -> Re .get_type_resolver() .get_type_id_by_index(T::fory_type_index())?; let local_type_id_u32 = local_type_id as u32; - if !crate::types::needs_user_type_id(local_type_id_u32) { + if !crate::type_id::needs_user_type_id(local_type_id_u32) { return read_type_info::(context); } let remote_type_id = context.reader.read_u8()? as u32; diff --git a/rust/fory-core/src/serializer/trait_object.rs b/rust/fory-core/src/serializer/trait_object.rs index 6367e533cc..bfbdfd9d0b 100644 --- a/rust/fory-core/src/serializer/trait_object.rs +++ b/rust/fory-core/src/serializer/trait_object.rs @@ -18,13 +18,13 @@ // Re-exports for use in macros - these are needed for macro expansion in user crates // Even though they appear unused in this file, they are used by the macro-generated code +use crate::context::{ReadContext, WriteContext}; use crate::ensure; use crate::error::Error; -use crate::resolver::context::{ReadContext, WriteContext}; -use crate::resolver::type_resolver::{TypeInfo, TypeResolver}; +use crate::resolver::{TypeInfo, TypeResolver}; use crate::serializer::{ForyDefault, Serializer}; -use crate::types::RefMode; use crate::RefFlag; +use crate::RefMode; use crate::TypeId; use std::rc::Rc; @@ -156,7 +156,7 @@ macro_rules! register_trait_type { } #[inline(always)] - fn fory_type_id_dyn(&self, type_resolver: &fory_core::TypeResolver) -> Result { + fn fory_type_id_dyn(&self, type_resolver: &fory_core::resolver::TypeResolver) -> Result { let any_ref = ::as_any(&**self); let concrete_type_id = any_ref.type_id(); type_resolver @@ -220,7 +220,7 @@ macro_rules! register_trait_type { $crate::not_allowed!("fory_read_data should not be called directly on polymorphic Box trait object", stringify!($trait_name)) } - fn fory_get_type_id(_type_resolver: &fory_core::TypeResolver) -> Result { + fn fory_get_type_id(_type_resolver: &fory_core::resolver::TypeResolver) -> Result { $crate::not_allowed!("fory_get_type_id should not be called directly on polymorphic Box trait object", stringify!($trait_name)) } @@ -231,7 +231,7 @@ macro_rules! register_trait_type { #[inline(always)] fn fory_reserved_space() -> usize { - $crate::types::SIZE_OF_REF_AND_TYPE + $crate::type_id::SIZE_OF_REF_AND_TYPE } #[inline(always)] @@ -502,7 +502,7 @@ macro_rules! impl_smart_pointer_serializer { } #[inline(always)] - fn fory_get_type_id(_type_resolver: &fory_core::TypeResolver) -> Result { + fn fory_get_type_id(_type_resolver: &fory_core::resolver::TypeResolver) -> Result { Ok(fory_core::TypeId::STRUCT) } @@ -532,7 +532,7 @@ macro_rules! impl_smart_pointer_serializer { } #[inline(always)] - fn fory_type_id_dyn(&self, type_resolver: &fory_core::TypeResolver) -> Result { + fn fory_type_id_dyn(&self, type_resolver: &fory_core::resolver::TypeResolver) -> Result { let any_obj = ::as_any(&*self.0); let concrete_type_id = any_obj.type_id(); type_resolver diff --git a/rust/fory-core/src/serializer/tuple.rs b/rust/fory-core/src/serializer/tuple.rs index 59570407c5..a0789a9069 100644 --- a/rust/fory-core/src/serializer/tuple.rs +++ b/rust/fory-core/src/serializer/tuple.rs @@ -15,13 +15,14 @@ // specific language governing permissions and limitations // under the License. +use crate::context::{ReadContext, WriteContext}; use crate::error::Error; -use crate::resolver::context::{ReadContext, WriteContext}; -use crate::resolver::type_resolver::TypeResolver; +use crate::resolver::RefMode; +use crate::resolver::TypeResolver; use crate::serializer::collection::{read_collection_type_info, write_collection_type_info}; use crate::serializer::skip::skip_any_value; use crate::serializer::{ForyDefault, Serializer}; -use crate::types::{RefMode, TypeId}; +use crate::type_id::TypeId; use std::mem; // Unit type () implementation - represents an empty/unit value with no data diff --git a/rust/fory-core/src/serializer/unsigned_number.rs b/rust/fory-core/src/serializer/unsigned_number.rs index 26c9a58c60..e5f49d0084 100644 --- a/rust/fory-core/src/serializer/unsigned_number.rs +++ b/rust/fory-core/src/serializer/unsigned_number.rs @@ -16,13 +16,13 @@ // under the License. use crate::buffer::{Reader, Writer}; +use crate::context::ReadContext; +use crate::context::WriteContext; use crate::error::Error; -use crate::resolver::context::ReadContext; -use crate::resolver::context::WriteContext; -use crate::resolver::type_resolver::TypeResolver; +use crate::resolver::TypeResolver; use crate::serializer::util::read_basic_type_info; use crate::serializer::{ForyDefault, Serializer}; -use crate::types::TypeId; +use crate::type_id::TypeId; // Macro for xlang-compatible unsigned types (u8, u16, u32, u64) macro_rules! impl_xlang_unsigned_num_serializer { diff --git a/rust/fory-core/src/serializer/util.rs b/rust/fory-core/src/serializer/util.rs index a712af5e9c..db256969ff 100644 --- a/rust/fory-core/src/serializer/util.rs +++ b/rust/fory-core/src/serializer/util.rs @@ -15,12 +15,12 @@ // specific language governing permissions and limitations // under the License. +use crate::context::{ReadContext, WriteContext}; use crate::ensure; use crate::error::Error; -use crate::resolver::context::{ReadContext, WriteContext}; use crate::serializer::Serializer; -use crate::types::TypeId; -use crate::types::{is_user_type, ENUM, NAMED_ENUM, UNION}; +use crate::type_id::TypeId; +use crate::type_id::{is_user_type, ENUM, NAMED_ENUM, UNION}; #[inline(always)] pub(crate) fn read_basic_type_info(context: &mut ReadContext) -> Result<(), Error> { diff --git a/rust/fory-core/src/serializer/weak.rs b/rust/fory-core/src/serializer/weak.rs index 26b3e6f2fd..f9f4e0849e 100644 --- a/rust/fory-core/src/serializer/weak.rs +++ b/rust/fory-core/src/serializer/weak.rs @@ -15,314 +15,18 @@ // specific language governing permissions and limitations // under the License. -//! Weak pointer serialization support for `Rc` and `Arc`. -//! -//! This module provides [`RcWeak`] and [`ArcWeak`] wrapper types that integrate -//! Rust's `std::rc::Weak` / `std::sync::Weak` into the Fory serialization framework, -//! with full support for: -// -//! - **Reference identity tracking** — weak pointers serialize as references to their -//! corresponding strong pointers, ensuring shared and circular references in the graph -//! are preserved without duplication. -//! - **Null weak pointers** — if the strong pointer has been dropped or was never set, -//! the weak will serialize as a `Null` flag. -//! - **Forward references during deserialization** — if the strong pointer appears later -//! in the serialized data, the weak will be resolved after deserialization using -//! `RefReader` callbacks. -//! -//! ## When to use -//! -//! Use these wrappers when your graph structure contains parent/child relationships -//! or other shared edges where a strong pointer would cause a reference cycle. -//! Storing a weak pointer avoids owning the target strongly, but serialization -//! will preserve the link by reference ID. -//! -//! ## Example — Parent/Child Graph -//! -//! ```rust,ignore -//! use fory_core::RcWeak; -//! use fory_core::Fory; -//! use std::cell::RefCell; -//! use std::rc::Rc; -//! use fory_derive::ForyObject; -//! -//! #[derive(ForyObject)] -//! struct Node { -//! value: i32, -//! parent: RcWeak>, -//! children: Vec>>, -//! } -//! -//! let mut fory = Fory::default(); -//! fory.register::(2000); -//! -//! let parent = Rc::new(RefCell::new(Node { -//! value: 1, -//! parent: RcWeak::new(), -//! children: vec![], -//! })); -//! -//! let child1 = Rc::new(RefCell::new(Node { -//! value: 2, -//! parent: RcWeak::from(&parent), -//! children: vec![], -//! })); -//! -//! let child2 = Rc::new(RefCell::new(Node { -//! value: 3, -//! parent: RcWeak::from(&parent), -//! children: vec![], -//! })); -//! -//! parent.borrow_mut().children.push(child1); -//! parent.borrow_mut().children.push(child2); -//! -//! let serialized = fory.serialize(&parent); -//! let deserialized: Rc> = fory.deserialize(&serialized).unwrap(); -//! -//! assert_eq!(deserialized.borrow().children.len(), 2); -//! for child in &deserialized.borrow().children { -//! let upgraded_parent = child.borrow().parent.upgrade().unwrap(); -//! assert!(Rc::ptr_eq(&deserialized, &upgraded_parent)); -//! } -//! ``` -//! -//! ## Example — Arc for Multi-Threaded Graphs -//! -//! ```rust,ignore -//! use fory_core::ArcWeak; -//! use fory_core::Fory; -//! use std::sync::{Arc, Mutex}; -//! use fory_derive::ForyObject; -//! -//! #[derive(ForyObject)] -//! struct Node { -//! value: i32, -//! parent: ArcWeak>, -//! } -//! -//! let mut fory = Fory::default(); -//! fory.register::(2001); -//! -//! let parent = Arc::new(Mutex::new(Node { value: 1, parent: ArcWeak::new() })); -//! let child = Arc::new(Mutex::new(Node { value: 2, parent: ArcWeak::from(&parent) })); -//! -//! let serialized = fory.serialize(&child); -//! let deserialized: Arc> = fory.deserialize(&serialized).unwrap(); -//! assert_eq!(deserialized.lock().unwrap().value, 2); -//! ``` -//! -//! ## Notes -//! -//! - These types share the same `UnsafeCell` across clones, so updating a weak in one clone -//! will update all of them. -//! - During serialization, weak pointers **never serialize the target object's data directly** -//! — they only emit a reference to the already-serialized strong pointer, or `Null`. -//! - During deserialization, unresolved references will be patched up by `RefReader::add_callback` -//! once the strong pointer becomes available. +//! Serialization support for [`crate::types::weak::RcWeak`] and [`crate::types::weak::ArcWeak`]. +use crate::context::{ReadContext, WriteContext}; use crate::error::Error; -use crate::resolver::context::{ReadContext, WriteContext}; -use crate::resolver::type_resolver::{TypeInfo, TypeResolver}; +use crate::resolver::{RefFlag, RefMode}; +use crate::resolver::{TypeInfo, TypeResolver}; use crate::serializer::{ForyDefault, Serializer}; -use crate::types::{RefFlag, RefMode, TypeId}; -use std::cell::UnsafeCell; +use crate::type_id::TypeId; +use crate::types::{ArcWeak, RcWeak}; use std::rc::Rc; use std::sync::Arc; -/// A serializable wrapper around `std::rc::Weak`. -/// -/// `RcWeak` is designed for use in graph-like structures where nodes may need to hold -/// non-owning references to other nodes (e.g., parent pointers), and you still want them -/// to round-trip through serialization while preserving reference identity. -/// -/// Unlike a raw `Weak`, cloning `RcWeak` keeps all clones pointing to the same -/// internal `UnsafeCell`, so updates via deserialization callbacks affect all copies. -/// -/// # Example -/// See module-level docs for a complete graph example. -/// -/// # Null handling -/// If the target `Rc` has been dropped or never assigned, `upgrade()` returns `None` -/// and serialization will write a `RefFlag::Null` instead of a reference ID. -pub struct RcWeak { - // Use Rc so that clones share the same cell - inner: Rc>>, -} - -impl std::fmt::Debug for RcWeak { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("RcWeak") - .field("strong_count", &self.strong_count()) - .field("weak_count", &self.weak_count()) - .finish() - } -} - -impl RcWeak { - pub fn new() -> Self { - RcWeak { - inner: Rc::new(UnsafeCell::new(std::rc::Weak::new())), - } - } -} - -impl RcWeak { - pub fn upgrade(&self) -> Option> { - unsafe { (*self.inner.get()).upgrade() } - } - - pub fn strong_count(&self) -> usize { - unsafe { (*self.inner.get()).strong_count() } - } - - pub fn weak_count(&self) -> usize { - unsafe { (*self.inner.get()).weak_count() } - } - - pub fn ptr_eq(&self, other: &Self) -> bool { - unsafe { std::rc::Weak::ptr_eq(&*self.inner.get(), &*other.inner.get()) } - } - - pub fn update(&self, weak: std::rc::Weak) { - unsafe { - *self.inner.get() = weak; - } - } - - pub fn from_std(weak: std::rc::Weak) -> Self { - RcWeak { - inner: Rc::new(UnsafeCell::new(weak)), - } - } -} - -impl PartialEq for RcWeak { - fn eq(&self, other: &Self) -> bool { - self.ptr_eq(other) - } -} - -impl Eq for RcWeak {} - -impl Default for RcWeak { - fn default() -> Self { - Self::new() - } -} - -impl Clone for RcWeak { - fn clone(&self) -> Self { - // Clone the Rc, not the inner Weak - this way clones share the same cell! - RcWeak { - inner: self.inner.clone(), - } - } -} - -impl From<&Rc> for RcWeak { - fn from(rc: &Rc) -> Self { - RcWeak::from_std(Rc::downgrade(rc)) - } -} - -unsafe impl Send for RcWeak where std::rc::Weak: Send {} -unsafe impl Sync for RcWeak where std::rc::Weak: Sync {} - -/// A serializable wrapper around `std::sync::Weak` (thread-safe). -/// -/// `ArcWeak` works exactly like [`RcWeak`] but is intended for use with -/// multi-threaded shared graphs where strong pointers are `Arc`. -/// -/// All clones of an `ArcWeak` share the same `UnsafeCell` so deserialization -/// updates propagate to all copies. -/// -/// # Example -/// See module-level docs for an `Arc>` usage example. -pub struct ArcWeak { - // Use Arc so that clones share the same cell - inner: Arc>>, -} - -impl std::fmt::Debug for ArcWeak { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("ArcWeak") - .field("strong_count", &self.strong_count()) - .field("weak_count", &self.weak_count()) - .finish() - } -} - -impl ArcWeak { - pub fn new() -> Self { - ArcWeak { - inner: Arc::new(UnsafeCell::new(std::sync::Weak::new())), - } - } -} - -impl ArcWeak { - pub fn upgrade(&self) -> Option> { - unsafe { (*self.inner.get()).upgrade() } - } - - pub fn strong_count(&self) -> usize { - unsafe { (*self.inner.get()).strong_count() } - } - - pub fn weak_count(&self) -> usize { - unsafe { (*self.inner.get()).weak_count() } - } - - pub fn ptr_eq(&self, other: &Self) -> bool { - unsafe { std::sync::Weak::ptr_eq(&*self.inner.get(), &*other.inner.get()) } - } - - pub fn update(&self, weak: std::sync::Weak) { - unsafe { - *self.inner.get() = weak; - } - } - - pub fn from_std(weak: std::sync::Weak) -> Self { - ArcWeak { - inner: Arc::new(UnsafeCell::new(weak)), - } - } -} - -impl PartialEq for ArcWeak { - fn eq(&self, other: &Self) -> bool { - self.ptr_eq(other) - } -} - -impl Eq for ArcWeak {} - -impl Default for ArcWeak { - fn default() -> Self { - Self::new() - } -} - -impl Clone for ArcWeak { - fn clone(&self) -> Self { - // Clone the Arc, not the inner Weak - this way clones share the same cell! - ArcWeak { - inner: self.inner.clone(), - } - } -} - -impl From<&Arc> for ArcWeak { - fn from(arc: &Arc) -> Self { - ArcWeak::from_std(Arc::downgrade(arc)) - } -} - -unsafe impl Send for ArcWeak {} -unsafe impl Sync for ArcWeak {} - impl Serializer for RcWeak { fn fory_is_shared_ref() -> bool { true @@ -651,13 +355,10 @@ fn read_arc_weak( if let Some(arc) = context.ref_reader.get_arc_ref::(ref_id) { weak.update(Arc::downgrade(&arc)); } else { - // Capture the raw pointer to the UnsafeCell so we can update it in the callback - let weak_ptr = weak.inner.get(); + let callback_weak = weak.clone(); context.ref_reader.add_callback(Box::new(move |ref_reader| { if let Some(arc) = ref_reader.get_arc_ref::(ref_id) { - unsafe { - *weak_ptr = Arc::downgrade(&arc); - } + callback_weak.update(Arc::downgrade(&arc)); } })); } diff --git a/rust/fory-core/src/types.rs b/rust/fory-core/src/type_id.rs similarity index 82% rename from rust/fory-core/src/types.rs rename to rust/fory-core/src/type_id.rs index 6d68b3a3c8..ecb51628c3 100644 --- a/rust/fory-core/src/types.rs +++ b/rust/fory-core/src/type_id.rs @@ -18,81 +18,6 @@ use num_enum::{IntoPrimitive, TryFromPrimitive}; use std::mem; -#[allow(dead_code)] -pub enum StringFlag { - LATIN1 = 0, - UTF8 = 1, -} - -#[derive(Debug, TryFromPrimitive)] -#[repr(i8)] -pub enum RefFlag { - Null = -3, - // Ref indicates that object is a not-null value. - // We don't use another byte to indicate REF, so that we can save one byte. - Ref = -2, - // NotNullValueFlag indicates that the object is a non-null value. - NotNullValue = -1, - // RefValueFlag indicates that the object is a referencable and first read. - RefValue = 0, -} - -/// Controls how reference and null flags are handled during serialization. -/// -/// This enum combines nullable semantics and reference tracking into one parameter, -/// enabling fine-grained control per type and per field: -/// - `None` = non-nullable, no ref tracking (primitives) -/// - `NullOnly` = nullable, no circular ref tracking -/// - `Tracking` = nullable, with circular ref tracking (Rc/Arc/Weak) -#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] -#[repr(u8)] -pub enum RefMode { - /// Skip ref handling entirely. No ref/null flags are written/read. - /// Used for non-nullable primitives or when caller handles ref externally. - #[default] - None = 0, - - /// Only null check without reference tracking. - /// Write: NullFlag (-3) for None, NotNullValueFlag (-1) for Some. - /// Read: Read flag and return ForyDefault on null. - NullOnly = 1, - - /// Full reference tracking with circular reference support. - /// Write: Uses RefWriter which writes NullFlag, RefFlag+refId, or RefValueFlag. - /// Read: Uses RefReader with full reference resolution. - Tracking = 2, -} - -impl RefMode { - /// Create RefMode from nullable and track_ref flags. - #[inline] - pub const fn from_flags(nullable: bool, track_ref: bool) -> Self { - match (nullable, track_ref) { - (false, false) => RefMode::None, - (true, false) => RefMode::NullOnly, - (_, true) => RefMode::Tracking, - } - } - - /// Check if this mode reads/writes ref flags. - #[inline] - pub const fn has_ref_flag(self) -> bool { - !matches!(self, RefMode::None) - } - - /// Check if this mode tracks circular references. - #[inline] - pub const fn tracks_refs(self) -> bool { - matches!(self, RefMode::Tracking) - } - - /// Check if this mode handles nullable values. - #[inline] - pub const fn is_nullable(self) -> bool { - !matches!(self, RefMode::None) - } -} - #[derive(Clone, Copy, Debug, PartialEq, TryFromPrimitive, IntoPrimitive)] #[allow(non_camel_case_types)] #[repr(u8)] @@ -259,21 +184,6 @@ pub const fn is_enum_type_id(type_id: TypeId) -> bool { matches!(type_id, TypeId::ENUM | TypeId::NAMED_ENUM | TypeId::UNION) } -const MAX_UNT32: u64 = (1 << 31) - 1; - -// todo: struct hash -#[allow(dead_code)] -pub fn compute_string_hash(s: &str) -> u32 { - let mut hash: u64 = 17; - s.as_bytes().iter().for_each(|b| { - hash = (hash * 31) + (*b as u64); - while hash >= MAX_UNT32 { - hash /= 7; - } - }); - hash as u32 -} - pub static BASIC_TYPES: [TypeId; 34] = [ TypeId::BOOL, TypeId::INT8, @@ -500,29 +410,12 @@ pub const fn needs_user_type_id(type_id: u32) -> bool { ) } -pub fn compute_field_hash(hash: u32, id: i16) -> u32 { - let mut new_hash: u64 = (hash as u64) * 31 + (id as u64); - while new_hash >= MAX_UNT32 { - new_hash /= 7; - } - new_hash as u32 -} - pub mod config_flags { pub const IS_NULL_FLAG: u8 = 1 << 0; pub const IS_CROSS_LANGUAGE_FLAG: u8 = 1 << 1; pub const IS_OUT_OF_BAND_FLAG: u8 = 1 << 2; } -#[derive(Debug, PartialEq)] -pub enum Mode { - // Type declaration must be consistent between serialization peer and deserialization peer. - SchemaConsistent, - // Type declaration can be different between serialization peer and deserialization peer. - // They can add/delete fields independently. - Compatible, -} - // every object start with i8 i16 reference flag and type flag pub const SIZE_OF_REF_AND_TYPE: usize = mem::size_of::() + mem::size_of::(); diff --git a/rust/fory-core/src/types/decimal.rs b/rust/fory-core/src/types/decimal.rs new file mode 100644 index 0000000000..ba1a708b5f --- /dev/null +++ b/rust/fory-core/src/types/decimal.rs @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use num_bigint::BigInt; + +#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] +pub struct Decimal { + pub unscaled: BigInt, + pub scale: i32, +} + +impl Decimal { + pub fn new(unscaled: BigInt, scale: i32) -> Self { + Self { unscaled, scale } + } +} diff --git a/rust/fory-core/src/float16.rs b/rust/fory-core/src/types/float16.rs similarity index 100% rename from rust/fory-core/src/float16.rs rename to rust/fory-core/src/types/float16.rs diff --git a/rust/fory-core/src/types/mod.rs b/rust/fory-core/src/types/mod.rs new file mode 100644 index 0000000000..a73f92366b --- /dev/null +++ b/rust/fory-core/src/types/mod.rs @@ -0,0 +1,23 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod decimal; +pub mod float16; +pub mod weak; + +pub use decimal::Decimal; +pub use weak::{ArcWeak, RcWeak}; diff --git a/rust/fory-core/src/types/weak.rs b/rust/fory-core/src/types/weak.rs new file mode 100644 index 0000000000..91cfab6eae --- /dev/null +++ b/rust/fory-core/src/types/weak.rs @@ -0,0 +1,200 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Weak pointer runtime carriers for `Rc` and `Arc`. +//! +//! This module provides [`RcWeak`] and [`ArcWeak`] wrapper types that integrate +//! Rust's `std::rc::Weak` / `std::sync::Weak` into the Fory type system. + +use std::cell::UnsafeCell; +use std::rc::Rc; +use std::sync::Arc; + +/// A serializable runtime wrapper around `std::rc::Weak`. +/// +/// `RcWeak` is designed for graph-like structures where nodes may need to hold +/// non-owning references to other nodes, and you still want them to round-trip +/// through serialization while preserving reference identity. +pub struct RcWeak { + // Use Rc so that clones share the same cell. + inner: Rc>>, +} + +impl std::fmt::Debug for RcWeak { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RcWeak") + .field("strong_count", &self.strong_count()) + .field("weak_count", &self.weak_count()) + .finish() + } +} + +impl RcWeak { + pub fn new() -> Self { + RcWeak { + inner: Rc::new(UnsafeCell::new(std::rc::Weak::new())), + } + } +} + +impl RcWeak { + pub fn upgrade(&self) -> Option> { + unsafe { (*self.inner.get()).upgrade() } + } + + pub fn strong_count(&self) -> usize { + unsafe { (*self.inner.get()).strong_count() } + } + + pub fn weak_count(&self) -> usize { + unsafe { (*self.inner.get()).weak_count() } + } + + pub fn ptr_eq(&self, other: &Self) -> bool { + unsafe { std::rc::Weak::ptr_eq(&*self.inner.get(), &*other.inner.get()) } + } + + pub fn update(&self, weak: std::rc::Weak) { + unsafe { + *self.inner.get() = weak; + } + } + + pub fn from_std(weak: std::rc::Weak) -> Self { + RcWeak { + inner: Rc::new(UnsafeCell::new(weak)), + } + } +} + +impl PartialEq for RcWeak { + fn eq(&self, other: &Self) -> bool { + self.ptr_eq(other) + } +} + +impl Eq for RcWeak {} + +impl Default for RcWeak { + fn default() -> Self { + Self::new() + } +} + +impl Clone for RcWeak { + fn clone(&self) -> Self { + RcWeak { + inner: self.inner.clone(), + } + } +} + +impl From<&Rc> for RcWeak { + fn from(rc: &Rc) -> Self { + RcWeak::from_std(Rc::downgrade(rc)) + } +} + +unsafe impl Send for RcWeak where std::rc::Weak: Send {} +unsafe impl Sync for RcWeak where std::rc::Weak: Sync {} + +/// A serializable runtime wrapper around `std::sync::Weak`. +/// +/// `ArcWeak` works like [`RcWeak`] but is intended for multi-threaded shared +/// graphs where strong pointers are `Arc`. +pub struct ArcWeak { + // Use Arc so that clones share the same cell. + inner: Arc>>, +} + +impl std::fmt::Debug for ArcWeak { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ArcWeak") + .field("strong_count", &self.strong_count()) + .field("weak_count", &self.weak_count()) + .finish() + } +} + +impl ArcWeak { + pub fn new() -> Self { + ArcWeak { + inner: Arc::new(UnsafeCell::new(std::sync::Weak::new())), + } + } +} + +impl ArcWeak { + pub fn upgrade(&self) -> Option> { + unsafe { (*self.inner.get()).upgrade() } + } + + pub fn strong_count(&self) -> usize { + unsafe { (*self.inner.get()).strong_count() } + } + + pub fn weak_count(&self) -> usize { + unsafe { (*self.inner.get()).weak_count() } + } + + pub fn ptr_eq(&self, other: &Self) -> bool { + unsafe { std::sync::Weak::ptr_eq(&*self.inner.get(), &*other.inner.get()) } + } + + pub fn update(&self, weak: std::sync::Weak) { + unsafe { + *self.inner.get() = weak; + } + } + + pub fn from_std(weak: std::sync::Weak) -> Self { + ArcWeak { + inner: Arc::new(UnsafeCell::new(weak)), + } + } +} + +impl PartialEq for ArcWeak { + fn eq(&self, other: &Self) -> bool { + self.ptr_eq(other) + } +} + +impl Eq for ArcWeak {} + +impl Default for ArcWeak { + fn default() -> Self { + Self::new() + } +} + +impl Clone for ArcWeak { + fn clone(&self) -> Self { + ArcWeak { + inner: self.inner.clone(), + } + } +} + +impl From<&Arc> for ArcWeak { + fn from(arc: &Arc) -> Self { + ArcWeak::from_std(Arc::downgrade(arc)) + } +} + +unsafe impl Send for ArcWeak {} +unsafe impl Sync for ArcWeak {} diff --git a/rust/fory-core/src/util/mod.rs b/rust/fory-core/src/util/mod.rs index 7df7640909..071094d958 100644 --- a/rust/fory-core/src/util/mod.rs +++ b/rust/fory-core/src/util/mod.rs @@ -18,7 +18,10 @@ mod string_util; mod sync; -pub use string_util::{to_camel_case, to_snake_case, to_utf8}; +pub use string_util::{ + buffer_rw_string, compute_string_hash, get_latin1_length, is_latin, murmurhash3_x64_128, + to_camel_case, to_snake_case, to_utf8, StringFlag, +}; pub use sync::{Spinlock, SpinlockGuard}; use chrono::NaiveDate; diff --git a/rust/fory-core/src/util/string_util.rs b/rust/fory-core/src/util/string_util.rs index ca6a17bcae..26bb768ca3 100644 --- a/rust/fory-core/src/util/string_util.rs +++ b/rust/fory-core/src/util/string_util.rs @@ -15,8 +15,17 @@ // specific language governing permissions and limitations // under the License. +use std::mem; use std::ptr; +const MAX_HASH32: u64 = (1 << 31) - 1; + +#[allow(dead_code)] +pub enum StringFlag { + LATIN1 = 0, + UTF8 = 1, +} + /// Swaps the high 8 bits and the low 8 bits of a 16-bit value. fn swap_endian(value: u16) -> u16 { value.rotate_right(8) @@ -24,15 +33,13 @@ fn swap_endian(value: u16) -> u16 { /// Converts UTF-16 encoded data to UTF-8. pub fn to_utf8(utf16: &[u16], is_little_endian: bool) -> Result, String> { - // Pre-allocating capacity to avoid dynamic resizing - // Longest case: 1 u16 to 3 u8 + // Pre-allocating capacity to avoid dynamic resizing. + // Longest case: 1 u16 to 3 u8. let mut utf8_bytes: Vec = Vec::with_capacity(utf16.len() * 3); - // For unsafe write to Vec let ptr = utf8_bytes.as_mut_ptr(); let mut offset = 0; let mut iter = utf16.iter(); while let Some(&wc) = iter.next() { - // Using big endian in this conversion let wc = if is_little_endian { swap_endian(wc) } else { @@ -40,16 +47,12 @@ pub fn to_utf8(utf16: &[u16], is_little_endian: bool) -> Result, String> }; match wc { code_point if code_point < 0x80 => { - // 1-byte UTF-8 - // [0000|0000|0ccc|cccc] => [0ccc|cccc] unsafe { ptr.add(offset).write(code_point as u8); } offset += 1; } code_point if code_point < 0x800 => { - // 2-byte UTF-8 - // [0000|0bbb|bbcc|cccc] => [110|bbbbb], [10|cccccc] let bytes = [ ((code_point >> 6) & 0b1_1111) as u8 | 0b1100_0000, (code_point & 0b11_1111) as u8 | 0b1000_0000, @@ -60,8 +63,6 @@ pub fn to_utf8(utf16: &[u16], is_little_endian: bool) -> Result, String> offset += 2; } wc1 if (0xd800..=0xdbff).contains(&wc1) => { - // Surrogate pair (4-byte UTF-8) - // Need extra u16, 2 u16 -> 4 u8 if let Some(&wc2) = iter.next() { let wc2 = if is_little_endian { swap_endian(wc2) @@ -71,11 +72,8 @@ pub fn to_utf8(utf16: &[u16], is_little_endian: bool) -> Result, String> if !(0xdc00..=0xdfff).contains(&wc2) { return Err("Invalid UTF-16 string: wrong surrogate pair".to_string()); } - // utf16 to unicode let code_point = ((((wc1 as u32) - 0xd800) << 10) | ((wc2 as u32) - 0xdc00)) + 0x10000; - // 11110??? 10?????? 10?????? 10?????? - // Need 21 bit suffix of code_point let bytes = [ ((code_point >> 18) & 0b111) as u8 | 0b1111_0000, ((code_point >> 12) & 0b11_1111) as u8 | 0b1000_0000, @@ -91,9 +89,6 @@ pub fn to_utf8(utf16: &[u16], is_little_endian: bool) -> Result, String> } } _ => { - // 3-byte UTF-8, 1 u16 -> 3 u8 - // [aaaa|bbbb|bbcc|cccc] => [1110|aaaa], [10|bbbbbb], [10|cccccc] - // Need 16 bit suffix of wc, as same as wc itself let bytes = [ ((wc >> 12) | 0b1110_0000) as u8, ((wc >> 6) & 0b11_1111) as u8 | 0b1000_0000, @@ -107,25 +102,18 @@ pub fn to_utf8(utf16: &[u16], is_little_endian: bool) -> Result, String> } } unsafe { - // As ptr.write don't change the length utf8_bytes.set_len(offset); } Ok(utf8_bytes) } /// Converts a camelCase or PascalCase string to snake_case. -/// Used for cross-language field name matching since Java uses camelCase -/// and Rust uses snake_case. pub fn to_snake_case(name: &str) -> String { let mut result = String::with_capacity(name.len() + 4); let chars: Vec = name.chars().collect(); for (i, &c) in chars.iter().enumerate() { if c.is_ascii_uppercase() { - // Add underscore before uppercase unless: - // - It's the first character - // - Previous char was uppercase and next is uppercase or doesn't exist - // (e.g., "HTTPRequest" -> "http_request", not "h_t_t_p_request") if i > 0 { let prev_upper = chars.get(i - 1).is_some_and(|c| c.is_ascii_uppercase()); let next_upper_or_end = chars.get(i + 1).map_or(true, |c| c.is_ascii_uppercase()); @@ -142,8 +130,6 @@ pub fn to_snake_case(name: &str) -> String { } /// Converts a snake_case string to lowerCamelCase. -/// Used for cross-language field name serialization since Rust uses snake_case -/// but other languages (Java, etc.) expect camelCase in the wire format. pub fn to_camel_case(name: &str) -> String { let mut result = String::with_capacity(name.len()); let mut capitalize_next = false; @@ -161,8 +147,513 @@ pub fn to_camel_case(name: &str) -> String { result } +#[allow(dead_code)] +pub fn compute_string_hash(s: &str) -> u32 { + let mut hash: u64 = 17; + s.as_bytes().iter().for_each(|b| { + hash = (hash * 31) + (*b as u64); + while hash >= MAX_HASH32 { + hash /= 7; + } + }); + hash as u32 +} + +#[cfg(target_feature = "neon")] +use std::arch::aarch64::*; + +#[cfg(target_feature = "avx2")] +use std::arch::x86_64::*; + +#[cfg(target_feature = "sse2")] +use std::arch::x86_64::*; + +#[cfg(target_arch = "x86_64")] +pub const MIN_DIM_SIZE_AVX: usize = 32; + +#[cfg(any( + target_arch = "x86", + target_arch = "x86_64", + all(target_arch = "aarch64", target_feature = "neon") +))] +pub const MIN_DIM_SIZE_SIMD: usize = 16; + +#[cfg(target_arch = "x86_64")] +unsafe fn is_latin_avx(s: &str) -> bool { + let bytes = s.as_bytes(); + let len = bytes.len(); + let mut i = 0; + // SIMD skip ASCII + while i + MIN_DIM_SIZE_AVX <= len { + let chunk = _mm256_loadu_si256(bytes.as_ptr().add(i) as *const __m256i); + let hi_mask = _mm256_set1_epi8(0x80u8 as i8); + let masked = _mm256_and_si256(chunk, hi_mask); + let cmp = _mm256_cmpeq_epi8(masked, _mm256_setzero_si256()); + if _mm256_movemask_epi8(cmp) != -1 { + break; + } + i += MIN_DIM_SIZE_AVX; + } + // check latin in remaining chars + let s_tail = &s[i..]; + for c in s_tail.chars() { + if c as u32 > 0xFF { + return false; + } + } + true +} + +#[cfg(target_feature = "sse2")] +unsafe fn is_latin_sse(s: &str) -> bool { + let bytes = s.as_bytes(); + let len = bytes.len(); + let mut i = 0; + // SIMD skip ASCII + while i + MIN_DIM_SIZE_SIMD <= len { + let chunk = _mm_loadu_si128(bytes.as_ptr().add(i) as *const __m128i); + let hi_mask = _mm_set1_epi8(0x80u8 as i8); + let masked = _mm_and_si128(chunk, hi_mask); + let cmp = _mm_cmpeq_epi8(masked, _mm_setzero_si128()); + if _mm_movemask_epi8(cmp) != 0xFFFF { + break; + } + i += MIN_DIM_SIZE_SIMD; + } + // check latin in remaining chars + let s_tail = &s[i..]; + for c in s_tail.chars() { + if c as u32 > 0xFF { + return false; + } + } + true +} + +#[cfg(target_feature = "neon")] +unsafe fn is_latin_neon(s: &str) -> bool { + let bytes = s.as_bytes(); + let len = bytes.len(); + let mut i = 0; + // SIMD skip ASCII + while i + MIN_DIM_SIZE_SIMD <= len { + let chunk = vld1q_u8(bytes.as_ptr().add(i)); + let hi_mask = vdupq_n_u8(0x80); + let masked = vandq_u8(chunk, hi_mask); + if vmaxvq_u8(masked) != 0 { + break; + } + i += MIN_DIM_SIZE_SIMD; + } + // check latin in remaining chars + let s_tail = &s[i..]; + for c in s_tail.chars() { + if c as u32 > 0xFF { + return false; + } + } + true +} + +fn is_latin_standard(s: &str) -> bool { + s.chars().all(|c| c as u32 <= 0xFF) +} + +pub fn is_latin(s: &str) -> bool { + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx") + && is_x86_feature_detected!("fma") + && s.len() >= MIN_DIM_SIZE_AVX + { + return unsafe { is_latin_avx(s) }; + } + } + + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + { + if is_x86_feature_detected!("sse") && s.len() >= MIN_DIM_SIZE_SIMD { + return unsafe { is_latin_sse(s) }; + } + } + + #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] + { + if std::arch::is_aarch64_feature_detected!("neon") && s.len() >= MIN_DIM_SIZE_SIMD { + return unsafe { is_latin_neon(s) }; + } + } + is_latin_standard(s) +} + +#[cfg(target_arch = "x86_64")] +unsafe fn get_latin1_length_avx(s: &str) -> i32 { + let bytes = s.as_bytes(); + let len = bytes.len(); + let mut count = 0; + // SIMD skip ASCII + while count + MIN_DIM_SIZE_AVX <= len { + let chunk = _mm256_loadu_si256(bytes.as_ptr().add(count) as *const __m256i); + let hi_mask = _mm256_set1_epi8(0x80u8 as i8); + let masked = _mm256_and_si256(chunk, hi_mask); + let cmp = _mm256_cmpeq_epi8(masked, _mm256_setzero_si256()); + if _mm256_movemask_epi8(cmp) != -1 { + break; + } + count += MIN_DIM_SIZE_AVX; + } + // check latin in remaining chars + let s_tail = &s[count..]; + for c in s_tail.chars() { + if c as u32 > 0xFF { + return -1; + } + count += 1; + } + count as i32 +} + +#[cfg(target_feature = "sse2")] +unsafe fn get_latin1_length_sse(s: &str) -> i32 { + let bytes = s.as_bytes(); + let len = bytes.len(); + let mut count = 0; + // SIMD skip ASCII + while count + MIN_DIM_SIZE_SIMD <= len { + let chunk = _mm_loadu_si128(bytes.as_ptr().add(count) as *const __m128i); + let hi_mask = _mm_set1_epi8(0x80u8 as i8); + let masked = _mm_and_si128(chunk, hi_mask); + let cmp = _mm_cmpeq_epi8(masked, _mm_setzero_si128()); + if _mm_movemask_epi8(cmp) != 0xFFFF { + break; + } + count += MIN_DIM_SIZE_SIMD; + } + // check latin in remaining chars + let s_tail = &s[count..]; + for c in s_tail.chars() { + if c as u32 > 0xFF { + return -1; + } + count += 1; + } + count as i32 +} + +#[cfg(target_feature = "neon")] +unsafe fn get_latin1_length_neon(s: &str) -> i32 { + let bytes = s.as_bytes(); + let len = bytes.len(); + let mut count = 0; + // SIMD skip ASCII + while count + MIN_DIM_SIZE_SIMD <= len { + let chunk = vld1q_u8(bytes.as_ptr().add(count)); + let hi_mask = vdupq_n_u8(0x80); + let masked = vandq_u8(chunk, hi_mask); + if vmaxvq_u8(masked) != 0 { + break; + } + count += MIN_DIM_SIZE_SIMD; + } + // check latin in remaining chars + let s_tail = &s[count..]; + for c in s_tail.chars() { + if c as u32 > 0xFF { + return -1; + } + count += 1; + } + count as i32 +} + +fn get_latin1_length_standard(s: &str) -> i32 { + let mut count = 0; + for c in s.chars() { + if c as u32 > 0xFF { + return -1; + } + count += 1; + } + count +} + +pub fn get_latin1_length(s: &str) -> i32 { + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx") + && is_x86_feature_detected!("fma") + && s.len() >= MIN_DIM_SIZE_AVX + { + return unsafe { get_latin1_length_avx(s) }; + } + } + + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + { + if is_x86_feature_detected!("sse") && s.len() >= MIN_DIM_SIZE_SIMD { + return unsafe { get_latin1_length_sse(s) }; + } + } + + #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] + { + if std::arch::is_aarch64_feature_detected!("neon") && s.len() >= MIN_DIM_SIZE_SIMD { + return unsafe { get_latin1_length_neon(s) }; + } + } + get_latin1_length_standard(s) +} + +#[cfg(test)] +mod latin_tests { + // Import content from external modules + use super::*; + use rand::Rng; + + fn generate_random_string(length: usize) -> String { + const CHARSET: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; + let mut rng = rand::thread_rng(); + + let result: String = (0..length) + .map(|_| { + let idx = rng.gen_range(0..CHARSET.len()); + CHARSET[idx] as char + }) + .collect(); + + result + } + + #[test] + fn test_is_latin() { + let s = generate_random_string(1000); + let not_latin_str = generate_random_string(1000) + "abc\u{1234}"; + + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx") && is_x86_feature_detected!("fma") { + assert!(unsafe { is_latin_avx(&s) }); + assert!(!unsafe { is_latin_avx(¬_latin_str) }); + } + } + + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + { + if is_x86_feature_detected!("sse") && s.len() >= MIN_DIM_SIZE_SIMD { + assert!(unsafe { is_latin_sse(&s) }); + assert!(!unsafe { is_latin_sse(¬_latin_str) }); + } + } + + #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] + { + if std::arch::is_aarch64_feature_detected!("neon") && s.len() >= MIN_DIM_SIZE_SIMD { + assert!(unsafe { is_latin_neon(&s) }); + assert!(!unsafe { is_latin_neon(¬_latin_str) }); + } + } + assert!(is_latin_standard(&s)); + assert!(!is_latin_standard(¬_latin_str)); + } +} + +fn fmix64(mut k: u64) -> u64 { + k ^= k >> 33; + k = k.wrapping_mul(0xff51afd7ed558ccdu64); + k ^= k >> 33; + k = k.wrapping_mul(0xc4ceb9fe1a85ec53u64); + k ^= k >> 33; + + k +} + +pub fn murmurhash3_x64_128(bytes: &[u8], seed: u64) -> (u64, u64) { + let c1 = 0x87c37b91114253d5u64; + let c2 = 0x4cf5ad432745937fu64; + let read_size = 16; + let len = bytes.len() as u64; + let block_count = len / read_size; + + let (mut h1, mut h2) = (seed, seed); + + for i in 0..block_count as usize { + let b64: &[u64] = unsafe { mem::transmute(bytes) }; + let (mut k1, mut k2) = (b64[i * 2], b64[i * 2 + 1]); + + k1 = k1.wrapping_mul(c1); + k1 = k1.rotate_left(31); + k1 = k1.wrapping_mul(c2); + h1 ^= k1; + + h1 = h1.rotate_left(27); + h1 = h1.wrapping_add(h2); + h1 = h1.wrapping_mul(5); + h1 = h1.wrapping_add(0x52dce729); + + k2 = k2.wrapping_mul(c2); + k2 = k2.rotate_left(33); + k2 = k2.wrapping_mul(c1); + h2 ^= k2; + + h2 = h2.rotate_left(31); + h2 = h2.wrapping_add(h1); + h2 = h2.wrapping_mul(5); + h2 = h2.wrapping_add(0x38495ab5); + } + let (mut k1, mut k2) = (0u64, 0u64); + + if len & 15 == 15 { + k2 ^= (bytes[(block_count * read_size) as usize + 14] as u64) << 48; + } + if len & 15 >= 14 { + k2 ^= (bytes[(block_count * read_size) as usize + 13] as u64) << 40; + } + if len & 15 >= 13 { + k2 ^= (bytes[(block_count * read_size) as usize + 12] as u64) << 32; + } + if len & 15 >= 12 { + k2 ^= (bytes[(block_count * read_size) as usize + 11] as u64) << 24; + } + if len & 15 >= 11 { + k2 ^= (bytes[(block_count * read_size) as usize + 10] as u64) << 16; + } + if len & 15 >= 10 { + k2 ^= (bytes[(block_count * read_size) as usize + 9] as u64) << 8; + } + if len & 15 >= 9 { + k2 ^= bytes[(block_count * read_size) as usize + 8] as u64; + k2 = k2.wrapping_mul(c2); + k2 = k2.rotate_left(33); + k2 = k2.wrapping_mul(c1); + h2 ^= k2; + } + + if len & 15 >= 8 { + k1 ^= (bytes[(block_count * read_size) as usize + 7] as u64) << 56; + } + if len & 15 >= 7 { + k1 ^= (bytes[(block_count * read_size) as usize + 6] as u64) << 48; + } + if len & 15 >= 6 { + k1 ^= (bytes[(block_count * read_size) as usize + 5] as u64) << 40; + } + if len & 15 >= 5 { + k1 ^= (bytes[(block_count * read_size) as usize + 4] as u64) << 32; + } + if len & 15 >= 4 { + k1 ^= (bytes[(block_count * read_size) as usize + 3] as u64) << 24; + } + if len & 15 >= 3 { + k1 ^= (bytes[(block_count * read_size) as usize + 2] as u64) << 16; + } + if len & 15 >= 2 { + k1 ^= (bytes[(block_count * read_size) as usize + 1] as u64) << 8; + } + if len & 15 >= 1 { + k1 ^= bytes[(block_count * read_size) as usize] as u64; + k1 = k1.wrapping_mul(c1); + k1 = k1.rotate_left(31); + k1 = k1.wrapping_mul(c2); + h1 ^= k1; + } + + h1 ^= bytes.len() as u64; + h2 ^= bytes.len() as u64; + + h1 = h1.wrapping_add(h2); + h2 = h2.wrapping_add(h1); + + h1 = fmix64(h1); + h2 = fmix64(h2); + + h1 = h1.wrapping_add(h2); + h2 = h2.wrapping_add(h1); + + (h1, h2) +} + +#[cfg(test)] +mod test_hash { + use super::murmurhash3_x64_128; + + #[test] + fn test_empty_string() { + assert!(murmurhash3_x64_128("".as_bytes(), 0) == (0, 0)); + } + + #[test] + fn test_tail_lengths() { + assert!( + murmurhash3_x64_128("1".as_bytes(), 0) == (8213365047359667313, 10676604921780958775) + ); + assert!( + murmurhash3_x64_128("12".as_bytes(), 0) == (5355690773644049813, 9855895140584599837) + ); + assert!( + murmurhash3_x64_128("123".as_bytes(), 0) == (10978418110857903978, 4791445053355511657) + ); + assert!( + murmurhash3_x64_128("1234".as_bytes(), 0) == (619023178690193332, 3755592904005385637) + ); + assert!( + murmurhash3_x64_128("12345".as_bytes(), 0) + == (2375712675693977547, 17382870096830835188) + ); + assert!( + murmurhash3_x64_128("123456".as_bytes(), 0) + == (16435832985690558678, 5882968373513761278) + ); + assert!( + murmurhash3_x64_128("1234567".as_bytes(), 0) + == (3232113351312417698, 4025181827808483669) + ); + assert!( + murmurhash3_x64_128("12345678".as_bytes(), 0) + == (4272337174398058908, 10464973996478965079) + ); + assert!( + murmurhash3_x64_128("123456789".as_bytes(), 0) + == (4360720697772133540, 11094893415607738629) + ); + assert!( + murmurhash3_x64_128("123456789a".as_bytes(), 0) + == (12594836289594257748, 2662019112679848245) + ); + assert!( + murmurhash3_x64_128("123456789ab".as_bytes(), 0) + == (6978636991469537545, 12243090730442643750) + ); + assert!( + murmurhash3_x64_128("123456789abc".as_bytes(), 0) + == (211890993682310078, 16480638721813329343) + ); + assert!( + murmurhash3_x64_128("123456789abcd".as_bytes(), 0) + == (12459781455342427559, 3193214493011213179) + ); + assert!( + murmurhash3_x64_128("123456789abcde".as_bytes(), 0) + == (12538342858731408721, 9820739847336455216) + ); + assert!( + murmurhash3_x64_128("123456789abcdef".as_bytes(), 0) + == (9165946068217512774, 2451472574052603025) + ); + assert!( + murmurhash3_x64_128("123456789abcdef1".as_bytes(), 0) + == (9259082041050667785, 12459473952842597282) + ); + } + + #[test] + fn test_large_data() { + assert!(murmurhash3_x64_128("Lorem ipsum dolor sit amet, consectetur adipiscing elit. Etiam at consequat massa. Cras eleifend pellentesque ex, at dignissim libero maximus ut. Sed eget nulla felis".as_bytes(), 0) + == (9455322759164802692, 17863277201603478371)); + } +} + #[cfg(test)] -mod tests { +mod case_tests { use super::*; #[test] @@ -183,3 +674,407 @@ mod tests { assert_eq!(to_camel_case("a_b_c"), "aBC"); } } + +pub mod buffer_rw_string { + #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] + use std::arch::aarch64::*; + #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] + use std::arch::x86_64::*; + #[cfg(all( + any(target_arch = "x86", target_arch = "x86_64"), + target_feature = "sse2", + not(target_feature = "avx2") + ))] + use std::arch::x86_64::*; + + use crate::buffer::{Reader, Writer}; + use crate::error::Error; + + #[inline] + pub fn write_latin1_standard(writer: &mut Writer, s: &str) { + for c in s.chars() { + let b = c as u32; + assert!(b <= 0xFF, "Non-Latin1 character found"); + writer.write_u8(b as u8); + } + } + + #[inline(always)] + pub fn write_latin1_string(writer: &mut Writer, s: &str) { + if s.len() < 128 { + // Fast path for small buffers + let bytes = s.as_bytes(); + // CRITICAL: Only safe if ASCII (UTF-8 == Latin1 for ASCII) + let is_ascii = bytes.iter().all(|&b| b < 0x80); + if is_ascii { + writer.bf.reserve(s.len()); + writer.bf.extend_from_slice(bytes); + } else { + // Non-ASCII: must iterate chars to extract Latin1 byte values + writer.bf.reserve(s.len()); + for c in s.chars() { + let v = c as u32; + assert!(v <= 0xFF, "Non-Latin1 character found"); + writer.bf.push(v as u8); + } + } + return; + } + write_latin1_simd(writer, s); + } + + #[inline] + pub fn write_utf8_standard(writer: &mut Writer, s: &str) { + let bytes = s.as_bytes(); + writer.bf.extend_from_slice(bytes); + } + + #[inline] + pub fn write_utf16_standard(writer: &mut Writer, utf16: &[u16]) { + #[cfg(target_endian = "little")] + { + let total_bytes = utf16.len() * 2; + let old_len = writer.bf.len(); + writer.bf.reserve(total_bytes); + unsafe { + let dest = writer.bf.as_mut_ptr().add(old_len); + let src = utf16.as_ptr() as *const u8; + std::ptr::copy_nonoverlapping(src, dest, total_bytes); + writer.bf.set_len(old_len + total_bytes); + } + } + #[cfg(target_endian = "big")] + { + let total_bytes = utf16.len() * 2; + let old_len = writer.bf.len(); + writer.bf.reserve(total_bytes); + unsafe { + let dest = writer.bf.as_mut_ptr().add(old_len); + // Need to swap bytes for each u16 to little-endian + for (i, &unit) in utf16.iter().enumerate() { + let swapped = unit.swap_bytes(); + let ptr = dest.add(i * 2) as *mut u16; + std::ptr::write_unaligned(ptr, swapped); + } + writer.bf.set_len(old_len + total_bytes); + } + } + } + + #[inline] + pub fn read_latin1_standard(reader: &mut Reader, len: usize) -> Result { + let slice = reader.sub_slice(reader.get_cursor(), reader.get_cursor() + len)?; + let result: String = slice.iter().map(|&b| b as char).collect(); + reader.move_next(len); + Ok(result) + } + + #[inline] + pub fn read_utf8_standard(reader: &mut Reader, len: usize) -> Result { + unsafe { + let mut vec = Vec::with_capacity(len); + let src = reader.bf.as_ptr().add(reader.cursor); + let dst = vec.as_mut_ptr(); + // Use fastest possible copy - copy_nonoverlapping compiles to memcpy + std::ptr::copy_nonoverlapping(src, dst, len); + vec.set_len(len); + reader.move_next(len); + // Use from_utf8_lossy for safety - handles invalid UTF-8 gracefully + // If you're certain the data is valid UTF-8, use from_utf8_unchecked for more performance + Ok(String::from_utf8_lossy(&vec).into_owned()) + } + } + + #[inline] + pub fn read_utf16_standard(reader: &mut Reader, len: usize) -> Result { + if len % 2 != 0 { + return Err(Error::encoding_error("UTF-16 length must be even")); + } + unsafe { + let slice = std::slice::from_raw_parts(reader.bf.as_ptr().add(reader.cursor), len); + let units: Vec = slice + .chunks_exact(2) + .map(|c| u16::from_le_bytes([c[0], c[1]])) + .collect(); + reader.move_next(len); + Ok(String::from_utf16_lossy(&units)) + } + } + + #[inline] + fn is_ascii_bytes(bytes: &[u8]) -> bool { + let len = bytes.len(); + let mut i = 0; + + #[cfg(target_arch = "x86_64")] + unsafe { + if is_x86_feature_detected!("avx2") && len >= 32 { + while i + 32 <= len { + let chunk = _mm256_loadu_si256(bytes.as_ptr().add(i) as *const __m256i); + let mask = _mm256_movemask_epi8(chunk); + if mask != 0 { + return false; + } + i += 32; + } + } + } + + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + unsafe { + if is_x86_feature_detected!("sse2") && len >= 16 { + while i + 16 <= len { + let chunk = _mm_loadu_si128(bytes.as_ptr().add(i) as *const __m128i); + let mask = _mm_movemask_epi8(chunk); + if mask != 0 { + return false; + } + i += 16; + } + } + } + + #[cfg(target_arch = "aarch64")] + unsafe { + if std::arch::is_aarch64_feature_detected!("neon") && len >= 16 { + while i + 16 <= len { + let chunk = vld1q_u8(bytes.as_ptr().add(i)); + if vmaxvq_u8(chunk) >= 0x80 { + return false; + } + i += 16; + } + } + } + + // Scalar fallback + bytes[i..].iter().all(|&b| b < 0x80) + } + + #[inline] + pub fn write_latin1_simd(writer: &mut Writer, s: &str) { + if s.is_empty() { + return; + } + + let bytes = s.as_bytes(); + + // CRITICAL OPTIMIZATION: For ASCII strings, UTF-8 bytes == Latin1 bytes + // Check if all ASCII using SIMD + if is_ascii_bytes(bytes) { + // Zero-copy fast path: direct write + let len = bytes.len(); + writer.bf.reserve(len); + writer.bf.extend_from_slice(bytes); + } else { + // Non-ASCII: Must iterate chars to extract Latin1 byte values + // Example: 'À' in Rust String is UTF-8 [0xC3, 0x80] but Latin1 is [0xC0] + let mut buf: Vec = Vec::with_capacity(s.len()); + for c in s.chars() { + let v = c as u32; + assert!(v <= 0xFF, "Non-Latin1 character found"); + buf.push(v as u8); + } + let len = buf.len(); + writer.bf.reserve(len); + writer.bf.extend_from_slice(&buf); + } + } + + #[inline] + pub fn read_latin1_simd(reader: &mut Reader, len: usize) -> Result { + if len == 0 { + return Ok(String::new()); + } + let src = reader.sub_slice(reader.get_cursor(), reader.get_cursor() + len)?; + + // Pessimistic allocation: Latin1 0x80-0xFF expands to 2 bytes in UTF-8 + let mut out: Vec = Vec::with_capacity(len * 2); + + unsafe { + let out_ptr = out.as_mut_ptr(); + let mut out_len = 0usize; + let mut i = 0usize; + + // ---- AVX2 fast-path: process 32 ASCII bytes at once ---- + #[cfg(target_arch = "x86_64")] + { + if std::arch::is_x86_feature_detected!("avx2") { + use std::arch::x86_64::*; + while i + 32 <= len { + let ptr = src.as_ptr().add(i) as *const __m256i; + let chunk = _mm256_loadu_si256(ptr); + let mask = _mm256_movemask_epi8(chunk); + if mask == 0 { + // All ASCII: direct copy (no conversion needed) + _mm256_storeu_si256(out_ptr.add(out_len) as *mut __m256i, chunk); + out_len += 32; + i += 32; + continue; + } else { + // Contains Latin1 bytes, break to scalar + break; + } + } + } + } + + // ---- SSE2 fast-path: process 16 ASCII bytes at once ---- + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + { + if std::arch::is_x86_feature_detected!("sse2") { + use std::arch::x86_64::*; + while i + 16 <= len { + let ptr = src.as_ptr().add(i) as *const __m128i; + let chunk = _mm_loadu_si128(ptr); + let mask = _mm_movemask_epi8(chunk); + if mask == 0 { + // All ASCII: direct copy + _mm_storeu_si128(out_ptr.add(out_len) as *mut __m128i, chunk); + out_len += 16; + i += 16; + continue; + } else { + break; + } + } + } + } + + // ---- NEON fast-path: process 16 ASCII bytes at once ---- + #[cfg(target_arch = "aarch64")] + { + if std::arch::is_aarch64_feature_detected!("neon") { + use std::arch::aarch64::*; + while i + 16 <= len { + let ptr = src.as_ptr().add(i); + let v = vld1q_u8(ptr); + // Check if any byte >= 0x80 + if vmaxvq_u8(v) < 0x80 { + // All ASCII: direct copy + vst1q_u8(out_ptr.add(out_len), v); + out_len += 16; + i += 16; + continue; + } else { + break; + } + } + } + } + + // ---- Scalar fallback: convert Latin1 -> UTF-8 ---- + // ASCII (0x00-0x7F): copy as-is + // Latin1 (0x80-0xFF): encode as 2-byte UTF-8 + while i < len { + let b = *src.get_unchecked(i); + if b < 0x80 { + *out_ptr.add(out_len) = b; + out_len += 1; + } else { + // Latin1 byte 0x80-0xFF -> UTF-8 encoding + // Example: 0xC0 (À) -> [0xC3, 0x80] + *out_ptr.add(out_len) = 0xC0 | (b >> 6); + *out_ptr.add(out_len + 1) = 0x80 | (b & 0x3F); + out_len += 2; + } + i += 1; + } + + out.set_len(out_len); + } + reader.move_next(len); + Ok(unsafe { String::from_utf8_unchecked(out) }) + } + + #[cfg(test)] + mod tests { + use super::*; + use crate::buffer::{Reader, Writer}; + + #[test] + fn test_latin1() { + let samples = [ + "Hello World!", + "Rusty Café", + "1234567890", + "ÀÁÂÃÄÅÆÇÈÉÊËÌÍÎÏÐÑÒÓÔÕÖרÙÚÛÜÝ", + ]; + + for s in samples { + let mut buffer = vec![]; + let mut writer = Writer::from_buffer(&mut buffer); + write_latin1_simd(&mut writer, s); + write_latin1_simd(&mut writer, s); + let bytes = &*writer.dump(); + let bytes_len = bytes.len() / 2; + let mut reader = Reader::new(bytes); + assert_eq!(read_latin1_standard(&mut reader, bytes_len).unwrap(), s); + assert_eq!(read_latin1_standard(&mut reader, bytes_len).unwrap(), s); + + let mut buffer = vec![]; + let mut writer = Writer::from_buffer(&mut buffer); + write_latin1_standard(&mut writer, s); + write_latin1_standard(&mut writer, s); + let bytes = &*writer.dump(); + let bytes_len = bytes.len() / 2; + let mut reader = Reader::new(bytes); + assert_eq!(read_latin1_simd(&mut reader, bytes_len).unwrap(), s); + assert_eq!(read_latin1_simd(&mut reader, bytes_len).unwrap(), s); + } + } + + #[test] + fn test_utf8() { + let samples = [ + "hello", + "rust语言", + "你好,世界", + "emoji 😀😃😄😁", + "mixed ASCII + 中文 + emoji 😁", + ]; + + for s in samples { + let bytes_len = s.len(); + + let mut buffer = vec![]; + let mut writer = Writer::from_buffer(&mut buffer); + write_utf8_standard(&mut writer, s); + write_utf8_standard(&mut writer, s); + let bytes = &*writer.dump(); + let mut reader = Reader::new(bytes); + assert_eq!(read_utf8_standard(&mut reader, bytes_len).unwrap(), s); + assert_eq!(read_utf8_standard(&mut reader, bytes_len).unwrap(), s); + } + } + + #[test] + fn test_utf16() { + let samples = [ + "hello", + "rust语言", + "你好,世界", + "emoji 😀😃😄😁", + "混合文字 + emoji 🐍💻🦀", + ]; + for s in samples { + let utf16: Vec = s.encode_utf16().collect(); + let bytes_len = utf16.len() * 2; + + let mut buffer = vec![]; + let mut writer = Writer::from_buffer(&mut buffer); + write_utf16_standard(&mut writer, &utf16); + write_utf16_standard(&mut writer, &utf16); + + let mut buffer = vec![]; + let mut writer = Writer::from_buffer(&mut buffer); + write_utf16_standard(&mut writer, &utf16); + write_utf16_standard(&mut writer, &utf16); + let bytes = &*writer.dump(); + let mut reader = Reader::new(bytes); + assert_eq!(read_utf16_standard(&mut reader, bytes_len).unwrap(), s); + assert_eq!(read_utf16_standard(&mut reader, bytes_len).unwrap(), s); + } + } + } +} diff --git a/rust/fory-derive/src/object/derive_enum.rs b/rust/fory-derive/src/object/derive_enum.rs index 3949618b94..dbeaf57298 100644 --- a/rust/fory-derive/src/object/derive_enum.rs +++ b/rust/fory-derive/src/object/derive_enum.rs @@ -41,9 +41,9 @@ pub fn gen_actual_type_id(data_enum: &DataEnum) -> TokenStream { quote! { if xlang { if register_by_name { - fory_core::types::TypeId::NAMED_UNION as u32 + fory_core::type_id::TypeId::NAMED_UNION as u32 } else { - fory_core::types::TypeId::TYPED_UNION as u32 + fory_core::type_id::TypeId::TYPED_UNION as u32 } } else { fory_core::serializer::enum_::actual_type_id(type_id, register_by_name, compatible) @@ -171,7 +171,7 @@ pub(crate) fn gen_named_variant_meta_type_impl_with_enum_name( &[#(#field_name_literals),*] } - fn fory_fields_info(type_resolver: &fory_core::TypeResolver) -> Result, fory_core::error::Error> { + fn fory_fields_info(type_resolver: &fory_core::resolver::TypeResolver) -> Result, fory_core::error::Error> { #fields_info_ts } } @@ -210,7 +210,7 @@ fn xlang_variant_branches(data_enum: &DataEnum, default_variant_value: u32) -> V Self::#ident => { context.writer.write_var_u32(#tag_value); // Write null flag for unit variant (no value) - context.writer.write_i8(fory_core::types::RefFlag::Null as i8); + context.writer.write_i8(fory_core::RefFlag::Null as i8); } } } else { @@ -228,7 +228,7 @@ fn xlang_variant_branches(data_enum: &DataEnum, default_variant_value: u32) -> V Self::#ident(ref value) => { context.writer.write_var_u32(#tag_value); use fory_core::serializer::Serializer; - value.fory_write(context, fory_core::types::RefMode::Tracking, true, false)?; + value.fory_write(context, fory_core::RefMode::Tracking, true, false)?; } } } else { @@ -248,7 +248,7 @@ fn xlang_variant_branches(data_enum: &DataEnum, default_variant_value: u32) -> V Self::#ident { ref #field_ident } => { context.writer.write_var_u32(#tag_value); use fory_core::serializer::Serializer; - #field_ident.fory_write(context, fory_core::types::RefMode::Tracking, true, false)?; + #field_ident.fory_write(context, fory_core::RefMode::Tracking, true, false)?; } } } else { @@ -462,7 +462,7 @@ pub fn gen_write_type_info(data_enum: &DataEnum) -> TokenStream { quote! { if context.is_xlang() { let rs_type_id = std::any::TypeId::of::(); - context.write_any_type_info(fory_core::types::UNKNOWN, rs_type_id)?; + context.write_any_type_info(fory_core::type_id::UNKNOWN, rs_type_id)?; Ok(()) } else { fory_core::serializer::enum_::write_type_info::(context) @@ -565,7 +565,7 @@ fn xlang_variant_read_branches( quote! { #tag_value => { use fory_core::serializer::Serializer; - let value = <#field_ty as Serializer>::fory_read(context, fory_core::types::RefMode::Tracking, true)?; + let value = <#field_ty as Serializer>::fory_read(context, fory_core::RefMode::Tracking, true)?; Ok(Self::#ident(value)) } } @@ -592,7 +592,7 @@ fn xlang_variant_read_branches( quote! { #tag_value => { use fory_core::serializer::Serializer; - let value = <#field_ty as Serializer>::fory_read(context, fory_core::types::RefMode::Tracking, true)?; + let value = <#field_ty as Serializer>::fory_read(context, fory_core::RefMode::Tracking, true)?; Ok(Self::#ident { #field_ident: value }) } } diff --git a/rust/fory-derive/src/object/field_meta.rs b/rust/fory-derive/src/object/field_meta.rs index d95b09c6c3..e50112efad 100644 --- a/rust/fory-derive/src/object/field_meta.rs +++ b/rust/fory-derive/src/object/field_meta.rs @@ -29,7 +29,7 @@ //! Both `compress` and `encoding` are converted to a `type_id` internally. If both are //! specified, they must not conflict. -use fory_core::types::TypeId; +use fory_core::type_id::TypeId; use quote::ToTokens; use std::collections::HashMap; use syn::{Field, GenericArgument, PathArguments, Type}; diff --git a/rust/fory-derive/src/object/misc.rs b/rust/fory-derive/src/object/misc.rs index 9919c44101..c248b0ed42 100644 --- a/rust/fory-derive/src/object/misc.rs +++ b/rust/fory-derive/src/object/misc.rs @@ -54,7 +54,7 @@ fn hash(fields: &[&Field]) -> TokenStream { static name_hash_once: Once = Once::new(); unsafe { name_hash_once.call_once(|| { - name_hash = fory_core::types::compute_struct_hash(vec![#(#props),*]); + name_hash = fory_core::meta::compute_struct_hash(vec![#(#props),*]); }); name_hash } @@ -71,9 +71,9 @@ pub fn gen_actual_type_id() -> TokenStream { pub fn gen_actual_type_id_no_evolving() -> TokenStream { quote! { if register_by_name { - fory_core::types::TypeId::NAMED_STRUCT as u32 + fory_core::type_id::TypeId::NAMED_STRUCT as u32 } else { - fory_core::types::TypeId::STRUCT as u32 + fory_core::type_id::TypeId::STRUCT as u32 } } } @@ -123,8 +123,8 @@ pub fn gen_field_fields_info(source_fields: &[SourceField<'_>]) -> TokenStream { let has_array_override = matches!( meta.type_id, Some(tid) - if tid == fory_core::types::TypeId::INT8_ARRAY as i16 - || tid == fory_core::types::TypeId::UINT8_ARRAY as i16 + if tid == fory_core::type_id::TypeId::INT8_ARRAY as i16 + || tid == fory_core::type_id::TypeId::UINT8_ARRAY as i16 ) && (inner_ty_str == "Vec" || inner_ty_str == "Vec" || inner_ty_str.starts_with("[u8;") @@ -134,30 +134,30 @@ pub fn gen_field_fields_info(source_fields: &[SourceField<'_>]) -> TokenStream { // Generate FieldType directly with the correct type ID based on meta.type_id let type_id_ts = match (inner_ty_str.as_str(), meta.type_id) { // i32: VARINT32 (default) or INT32 (fixed) - ("i32", Some(tid)) if tid == fory_core::types::TypeId::INT32 as i16 => { - quote! { fory_core::types::TypeId::INT32 as u32 } + ("i32", Some(tid)) if tid == fory_core::type_id::TypeId::INT32 as i16 => { + quote! { fory_core::type_id::TypeId::INT32 as u32 } } ("i32", _) => { - quote! { fory_core::types::TypeId::VARINT32 as u32 } + quote! { fory_core::type_id::TypeId::VARINT32 as u32 } } // u32: VAR_UINT32 (default) or UINT32 (fixed) - ("u32", Some(tid)) if tid == fory_core::types::TypeId::INT32 as i16 => { - quote! { fory_core::types::TypeId::UINT32 as u32 } + ("u32", Some(tid)) if tid == fory_core::type_id::TypeId::INT32 as i16 => { + quote! { fory_core::type_id::TypeId::UINT32 as u32 } } ("u32", _) => { - quote! { fory_core::types::TypeId::VAR_UINT32 as u32 } + quote! { fory_core::type_id::TypeId::VAR_UINT32 as u32 } } // u64: VAR_UINT64 (default), UINT64 (fixed), or TAGGED_UINT64 (tagged) - ("u64", Some(tid)) if tid == fory_core::types::TypeId::INT32 as i16 => { - quote! { fory_core::types::TypeId::UINT64 as u32 } + ("u64", Some(tid)) if tid == fory_core::type_id::TypeId::INT32 as i16 => { + quote! { fory_core::type_id::TypeId::UINT64 as u32 } } ("u64", Some(tid)) - if tid == fory_core::types::TypeId::TAGGED_UINT64 as i16 => + if tid == fory_core::type_id::TypeId::TAGGED_UINT64 as i16 => { - quote! { fory_core::types::TypeId::TAGGED_UINT64 as u32 } + quote! { fory_core::type_id::TypeId::TAGGED_UINT64 as u32 } } ("u64", _) => { - quote! { fory_core::types::TypeId::VAR_UINT64 as u32 } + quote! { fory_core::type_id::TypeId::VAR_UINT64 as u32 } } _ => unreachable!(), }; @@ -177,11 +177,11 @@ pub fn gen_field_fields_info(source_fields: &[SourceField<'_>]) -> TokenStream { } } else if has_array_override { let type_id_ts = match meta.type_id { - Some(tid) if tid == fory_core::types::TypeId::INT8_ARRAY as i16 => { - quote! { fory_core::types::TypeId::INT8_ARRAY as u32 } + Some(tid) if tid == fory_core::type_id::TypeId::INT8_ARRAY as i16 => { + quote! { fory_core::type_id::TypeId::INT8_ARRAY as u32 } } - Some(tid) if tid == fory_core::types::TypeId::UINT8_ARRAY as i16 => { - quote! { fory_core::types::TypeId::UINT8_ARRAY as u32 } + Some(tid) if tid == fory_core::type_id::TypeId::UINT8_ARRAY as i16 => { + quote! { fory_core::type_id::TypeId::UINT8_ARRAY as u32 } } _ => unreachable!(), }; @@ -218,12 +218,12 @@ pub fn gen_field_fields_info(source_fields: &[SourceField<'_>]) -> TokenStream { StructField::VecBox(_) | StructField::VecRc(_) | StructField::VecArc(_) => { quote! { fory_core::meta::FieldInfo::new_with_id(#field_id, #name, fory_core::meta::FieldType { - type_id: fory_core::types::TypeId::LIST as u32, + type_id: fory_core::type_id::TypeId::LIST as u32, user_type_id: u32::MAX, nullable: #nullable, track_ref: #track_ref, generics: vec![fory_core::meta::FieldType { - type_id: fory_core::types::TypeId::UNKNOWN as u32, + type_id: fory_core::type_id::TypeId::UNKNOWN as u32, user_type_id: u32::MAX, nullable: false, track_ref: false, @@ -239,14 +239,14 @@ pub fn gen_field_fields_info(source_fields: &[SourceField<'_>]) -> TokenStream { let key_generic_token = generic_tree_to_tokens(&key_generic_tree); quote! { fory_core::meta::FieldInfo::new_with_id(#field_id, #name, fory_core::meta::FieldType { - type_id: fory_core::types::TypeId::MAP as u32, + type_id: fory_core::type_id::TypeId::MAP as u32, user_type_id: u32::MAX, nullable: #nullable, track_ref: #track_ref, generics: vec![ #key_generic_token, fory_core::meta::FieldType { - type_id: fory_core::types::TypeId::UNKNOWN as u32, + type_id: fory_core::type_id::TypeId::UNKNOWN as u32, user_type_id: u32::MAX, nullable: false, track_ref: false, @@ -259,7 +259,7 @@ pub fn gen_field_fields_info(source_fields: &[SourceField<'_>]) -> TokenStream { _ => { quote! { fory_core::meta::FieldInfo::new_with_id(#field_id, #name, fory_core::meta::FieldType { - type_id: fory_core::types::TypeId::UNKNOWN as u32, + type_id: fory_core::type_id::TypeId::UNKNOWN as u32, user_type_id: u32::MAX, nullable: #nullable, track_ref: #track_ref, diff --git a/rust/fory-derive/src/object/read.rs b/rust/fory-derive/src/object/read.rs index e7ae9dd9ec..dc16783ce5 100644 --- a/rust/fory-derive/src/object/read.rs +++ b/rust/fory-derive/src/object/read.rs @@ -68,8 +68,8 @@ fn gen_compatible_unsigned_read( quote! { // Read u32 based on remote type_id match _field.field_type.type_id { - fory_core::types::UINT32 => context.reader.read_u32()?, - fory_core::types::VAR_UINT32 => context.reader.read_var_u32()?, + fory_core::type_id::UINT32 => context.reader.read_u32()?, + fory_core::type_id::VAR_UINT32 => context.reader.read_var_u32()?, _ => context.reader.read_var_u32()?, // Default to varint } } @@ -78,9 +78,9 @@ fn gen_compatible_unsigned_read( quote! { // Read u64 based on remote type_id match _field.field_type.type_id { - fory_core::types::UINT64 => context.reader.read_u64()?, - fory_core::types::VAR_UINT64 => context.reader.read_var_u64()?, - fory_core::types::TAGGED_UINT64 => context.reader.read_tagged_u64()?, + fory_core::type_id::UINT64 => context.reader.read_u64()?, + fory_core::type_id::VAR_UINT64 => context.reader.read_var_u64()?, + fory_core::type_id::TAGGED_UINT64 => context.reader.read_tagged_u64()?, _ => context.reader.read_var_u64()?, // Default to varint } } @@ -188,7 +188,7 @@ pub(crate) fn declare_var(source_fields: &[SourceField<'_>]) -> Vec } } else if extract_type_name(&field.ty) == "float16" { quote! { - let mut #var_name: fory_core::float16::float16 = fory_core::float16::float16::ZERO; + let mut #var_name: fory_core::types::float16::float16 = fory_core::types::float16::float16::ZERO; } } else if extract_type_name(&field.ty) == "bool" { quote! { @@ -408,7 +408,7 @@ pub fn gen_read_field(field: &Field, private_ident: &Ident, field_name: &str) -> // for struct-type fields in compatible mode, even for non-nullable fields. quote! { let read_type_info = if context.is_compatible() { - fory_core::types::need_to_write_type_for_field( + fory_core::type_id::need_to_write_type_for_field( <#ty as fory_core::Serializer>::fory_static_type_id() ) } else { @@ -745,7 +745,7 @@ pub(crate) fn gen_read_compatible_match_arm_body( fory_core::RefMode::None }; // For ref-tracked struct types, Java writes type info after RefValue flag - let read_type_info = fory_core::types::need_to_write_type_for_field( + let read_type_info = fory_core::type_id::need_to_write_type_for_field( <#ty as fory_core::Serializer>::fory_static_type_id() ); #var_name = Some(<#ty as fory_core::Serializer>::fory_read(context, ref_mode, read_type_info)?); @@ -808,7 +808,7 @@ pub(crate) fn gen_read_compatible_match_arm_body( fory_core::RefMode::None }; // For ref-tracked struct types, Java writes type info after RefValue flag - let read_type_info = fory_core::types::need_to_write_type_for_field( + let read_type_info = fory_core::type_id::need_to_write_type_for_field( <#ty as fory_core::Serializer>::fory_static_type_id() ); #var_name = Some(<#ty as fory_core::Serializer>::fory_read(context, ref_mode, read_type_info)?); @@ -828,7 +828,7 @@ pub(crate) fn gen_read_compatible_match_arm_body( fory_core::RefMode::None }; // For ref-tracked struct types, Java writes type info after RefValue flag - let read_type_info = fory_core::types::need_to_write_type_for_field( + let read_type_info = fory_core::type_id::need_to_write_type_for_field( <#ty as fory_core::Serializer>::fory_static_type_id() ); #var_name = <#ty as fory_core::Serializer>::fory_read(context, ref_mode, read_type_info)?; diff --git a/rust/fory-derive/src/object/serializer.rs b/rust/fory-derive/src/object/serializer.rs index b9dec57c89..1ee2b0d221 100644 --- a/rust/fory-derive/src/object/serializer.rs +++ b/rust/fory-derive/src/object/serializer.rs @@ -178,23 +178,23 @@ pub fn derive_serializer(ast: &syn::DeriveInput, attrs: ForyAttrs) -> TokenStrea #get_sorted_field_names_ts } - fn fory_fields_info(type_resolver: &fory_core::resolver::type_resolver::TypeResolver) -> Result, fory_core::error::Error> { + fn fory_fields_info(type_resolver: &fory_core::resolver::TypeResolver) -> Result, fory_core::error::Error> { #fields_info_ts } - fn fory_variants_fields_info(type_resolver: &fory_core::resolver::type_resolver::TypeResolver) -> Result)>, fory_core::error::Error> { + fn fory_variants_fields_info(type_resolver: &fory_core::resolver::TypeResolver) -> Result)>, fory_core::error::Error> { #variants_fields_info_ts } #[inline] - fn fory_read_compatible(context: &mut fory_core::resolver::context::ReadContext, type_info: std::rc::Rc) -> Result { + fn fory_read_compatible(context: &mut fory_core::ReadContext, type_info: std::rc::Rc) -> Result { #read_compatible_ts } } impl #impl_generics fory_core::Serializer for #name #ty_generics #where_clause { #[inline(always)] - fn fory_get_type_id(type_resolver: &fory_core::resolver::type_resolver::TypeResolver) -> Result { + fn fory_get_type_id(type_resolver: &fory_core::resolver::TypeResolver) -> Result { let type_id = type_resolver .get_type_id(&std::any::TypeId::of::(), #type_idx) .map_err(fory_core::error::Error::enhance_type_error::)?; @@ -202,7 +202,7 @@ pub fn derive_serializer(ast: &syn::DeriveInput, attrs: ForyAttrs) -> TokenStrea } #[inline(always)] - fn fory_type_id_dyn(&self, type_resolver: &fory_core::resolver::type_resolver::TypeResolver) -> Result { + fn fory_type_id_dyn(&self, type_resolver: &fory_core::resolver::TypeResolver) -> Result { Self::fory_get_type_id(type_resolver) } @@ -225,37 +225,37 @@ pub fn derive_serializer(ast: &syn::DeriveInput, attrs: ForyAttrs) -> TokenStrea } #[inline(always)] - fn fory_write(&self, context: &mut fory_core::resolver::context::WriteContext, ref_mode: fory_core::RefMode, write_type_info: bool, _: bool) -> Result<(), fory_core::error::Error> { + fn fory_write(&self, context: &mut fory_core::WriteContext, ref_mode: fory_core::RefMode, write_type_info: bool, _: bool) -> Result<(), fory_core::error::Error> { #write_ts } #[inline] - fn fory_write_data(&self, context: &mut fory_core::resolver::context::WriteContext) -> Result<(), fory_core::error::Error> { + fn fory_write_data(&self, context: &mut fory_core::WriteContext) -> Result<(), fory_core::error::Error> { #write_data_ts } #[inline(always)] - fn fory_write_type_info(context: &mut fory_core::resolver::context::WriteContext) -> Result<(), fory_core::error::Error> { + fn fory_write_type_info(context: &mut fory_core::WriteContext) -> Result<(), fory_core::error::Error> { #write_type_info_ts } #[inline(always)] - fn fory_read(context: &mut fory_core::resolver::context::ReadContext, ref_mode: fory_core::RefMode, read_type_info: bool) -> Result { + fn fory_read(context: &mut fory_core::ReadContext, ref_mode: fory_core::RefMode, read_type_info: bool) -> Result { #read_ts } #[inline(always)] - fn fory_read_with_type_info(context: &mut fory_core::resolver::context::ReadContext, ref_mode: fory_core::RefMode, type_info: std::rc::Rc) -> Result { + fn fory_read_with_type_info(context: &mut fory_core::ReadContext, ref_mode: fory_core::RefMode, type_info: std::rc::Rc) -> Result { #read_with_type_info_ts } #[inline] - fn fory_read_data( context: &mut fory_core::resolver::context::ReadContext) -> Result { + fn fory_read_data( context: &mut fory_core::ReadContext) -> Result { #read_data_ts } #[inline(always)] - fn fory_read_type_info(context: &mut fory_core::resolver::context::ReadContext) -> Result<(), fory_core::error::Error> { + fn fory_read_type_info(context: &mut fory_core::ReadContext) -> Result<(), fory_core::error::Error> { #read_type_info_ts } } diff --git a/rust/fory-derive/src/object/util.rs b/rust/fory-derive/src/object/util.rs index cc938a1170..cb9adc4e9e 100644 --- a/rust/fory-derive/src/object/util.rs +++ b/rust/fory-derive/src/object/util.rs @@ -19,7 +19,7 @@ use crate::util::{ detect_collection_with_trait_object, is_arc_dyn_trait, is_box_dyn_trait, is_rc_dyn_trait, CollectionTraitInfo, }; -use fory_core::types::{TypeId, PRIMITIVE_ARRAY_TYPE_MAP}; +use fory_core::type_id::{TypeId, PRIMITIVE_ARRAY_TYPE_MAP}; use fory_core::util::to_snake_case; use proc_macro2::{Ident, TokenStream}; use quote::{format_ident, quote, ToTokens}; @@ -494,7 +494,7 @@ pub(super) fn generic_tree_to_tokens(node: &TypeNode) -> TokenStream { if is_type_parameter(&node.name) { return quote! { fory_core::meta::FieldType::new( - fory_core::types::TypeId::UNKNOWN as u32, + fory_core::type_id::TypeId::UNKNOWN as u32, true, vec![] ) @@ -505,10 +505,10 @@ pub(super) fn generic_tree_to_tokens(node: &TypeNode) -> TokenStream { if node.name == "Tuple" { return quote! { fory_core::meta::FieldType::new( - fory_core::types::TypeId::LIST as u32, + fory_core::type_id::TypeId::LIST as u32, true, vec![fory_core::meta::FieldType::new( - fory_core::types::TypeId::UNKNOWN as u32, + fory_core::type_id::TypeId::UNKNOWN as u32, true, vec![] )] @@ -525,21 +525,21 @@ pub(super) fn generic_tree_to_tokens(node: &TypeNode) -> TokenStream { if is_primitive_elem { // For primitive arrays, use primitive array type ID let type_id_token = match elem_node.name.as_str() { - "bool" => quote! { fory_core::types::TypeId::BOOL_ARRAY as u32 }, - "i8" => quote! { fory_core::types::TypeId::INT8_ARRAY as u32 }, - "i16" => quote! { fory_core::types::TypeId::INT16_ARRAY as u32 }, - "i32" => quote! { fory_core::types::TypeId::INT32_ARRAY as u32 }, - "i64" => quote! { fory_core::types::TypeId::INT64_ARRAY as u32 }, - "i128" => quote! { fory_core::types::TypeId::INT128_ARRAY as u32 }, - "float16" => quote! { fory_core::types::TypeId::FLOAT16_ARRAY as u32 }, - "f32" => quote! { fory_core::types::TypeId::FLOAT32_ARRAY as u32 }, - "f64" => quote! { fory_core::types::TypeId::FLOAT64_ARRAY as u32 }, - "u8" => quote! { fory_core::types::TypeId::BINARY as u32 }, - "u16" => quote! { fory_core::types::TypeId::UINT16_ARRAY as u32 }, - "u32" => quote! { fory_core::types::TypeId::UINT32_ARRAY as u32 }, - "u64" => quote! { fory_core::types::TypeId::UINT64_ARRAY as u32 }, - "u128" => quote! { fory_core::types::TypeId::U128_ARRAY as u32 }, - _ => quote! { fory_core::types::TypeId::LIST as u32 }, + "bool" => quote! { fory_core::type_id::TypeId::BOOL_ARRAY as u32 }, + "i8" => quote! { fory_core::type_id::TypeId::INT8_ARRAY as u32 }, + "i16" => quote! { fory_core::type_id::TypeId::INT16_ARRAY as u32 }, + "i32" => quote! { fory_core::type_id::TypeId::INT32_ARRAY as u32 }, + "i64" => quote! { fory_core::type_id::TypeId::INT64_ARRAY as u32 }, + "i128" => quote! { fory_core::type_id::TypeId::INT128_ARRAY as u32 }, + "float16" => quote! { fory_core::type_id::TypeId::FLOAT16_ARRAY as u32 }, + "f32" => quote! { fory_core::type_id::TypeId::FLOAT32_ARRAY as u32 }, + "f64" => quote! { fory_core::type_id::TypeId::FLOAT64_ARRAY as u32 }, + "u8" => quote! { fory_core::type_id::TypeId::BINARY as u32 }, + "u16" => quote! { fory_core::type_id::TypeId::UINT16_ARRAY as u32 }, + "u32" => quote! { fory_core::type_id::TypeId::UINT32_ARRAY as u32 }, + "u64" => quote! { fory_core::type_id::TypeId::UINT64_ARRAY as u32 }, + "u128" => quote! { fory_core::type_id::TypeId::U128_ARRAY as u32 }, + _ => quote! { fory_core::type_id::TypeId::LIST as u32 }, }; return quote! { fory_core::meta::FieldType::new( @@ -552,7 +552,7 @@ pub(super) fn generic_tree_to_tokens(node: &TypeNode) -> TokenStream { // For non-primitive arrays, use LIST type ID with element type as generic return quote! { fory_core::meta::FieldType::new( - fory_core::types::TypeId::LIST as u32, + fory_core::type_id::TypeId::LIST as u32, false, vec![#elem_token] ) @@ -574,10 +574,10 @@ pub(super) fn generic_tree_to_tokens(node: &TypeNode) -> TokenStream { if inner.name == "Tuple" { return quote! { fory_core::meta::FieldType::new( - fory_core::types::TypeId::LIST as u32, + fory_core::type_id::TypeId::LIST as u32, true, vec![fory_core::meta::FieldType::new( - fory_core::types::TypeId::UNKNOWN as u32, + fory_core::type_id::TypeId::UNKNOWN as u32, true, vec![] )] @@ -643,23 +643,23 @@ pub(super) fn generic_tree_to_tokens(node: &TypeNode) -> TokenStream { let mut generics = vec![#(#children_tokens),*] as Vec; // For tuples and sets, if no generic info is available, add UNKNOWN element // This handles type aliases to tuples where we can't detect the tuple at macro time - if (type_id == fory_core::types::TypeId::LIST as u32 - || type_id == fory_core::types::TypeId::SET as u32) + if (type_id == fory_core::type_id::TypeId::LIST as u32 + || type_id == fory_core::type_id::TypeId::SET as u32) && generics.is_empty() { generics.push(fory_core::meta::FieldType::new( - fory_core::types::TypeId::UNKNOWN as u32, + fory_core::type_id::TypeId::UNKNOWN as u32, true, vec![] )); } - let is_custom = !fory_core::types::is_internal_type(type_id); + let is_custom = !fory_core::type_id::is_internal_type(type_id); if is_custom { let type_info = <#ty as fory_core::serializer::Serializer>::fory_get_type_info(type_resolver)?; type_id = type_info.get_type_id() as u32; user_type_id = type_info.get_user_type_id(); - if type_id == fory_core::types::TypeId::TYPED_UNION as u32 - || type_id == fory_core::types::TypeId::NAMED_UNION as u32 { - type_id = fory_core::types::TypeId::UNION as u32; + if type_id == fory_core::type_id::TypeId::TYPED_UNION as u32 + || type_id == fory_core::type_id::TypeId::NAMED_UNION as u32 { + type_id = fory_core::type_id::TypeId::UNION as u32; user_type_id = u32::MAX; } if type_resolver.is_xlang() && generics.len() > 0 { @@ -667,9 +667,9 @@ pub(super) fn generic_tree_to_tokens(node: &TypeNode) -> TokenStream { } else { generics = vec![]; } - } else if type_id == fory_core::types::TypeId::TYPED_UNION as u32 - || type_id == fory_core::types::TypeId::NAMED_UNION as u32 { - type_id = fory_core::types::TypeId::UNION as u32; + } else if type_id == fory_core::type_id::TypeId::TYPED_UNION as u32 + || type_id == fory_core::type_id::TypeId::NAMED_UNION as u32 { + type_id = fory_core::type_id::TypeId::UNION as u32; } fory_core::meta::FieldType { type_id, @@ -797,7 +797,7 @@ pub(super) fn get_primitive_writer_method_with_encoding( type_name: &str, meta: &super::field_meta::ForyFieldMeta, ) -> &'static str { - use fory_core::types::TypeId; + use fory_core::type_id::TypeId; // Handle i32 with type_id if type_name == "i32" { @@ -863,7 +863,7 @@ pub(super) fn get_primitive_reader_method_with_encoding( type_name: &str, meta: &super::field_meta::ForyFieldMeta, ) -> &'static str { - use fory_core::types::TypeId; + use fory_core::type_id::TypeId; // Handle i32 with type_id if type_name == "i32" { @@ -1483,7 +1483,7 @@ fn compute_struct_fingerprint(fields: &[&Field]) -> String { /// Generates TokenStream for struct version hash (computed at compile time). pub(crate) fn gen_struct_version_hash_ts(fields: &[&Field]) -> TokenStream { let fingerprint = compute_struct_fingerprint(fields); - let (hash, _) = fory_core::meta::murmurhash3_x64_128(fingerprint.as_bytes(), 47); + let (hash, _) = fory_core::util::murmurhash3_x64_128(fingerprint.as_bytes(), 47); let version_hash = (hash & 0xFFFF_FFFF) as i32; quote! { diff --git a/rust/fory-derive/src/object/write.rs b/rust/fory-derive/src/object/write.rs index 8300e8784b..f6781b4585 100644 --- a/rust/fory-derive/src/object/write.rs +++ b/rust/fory-derive/src/object/write.rs @@ -25,7 +25,7 @@ use super::util::{ StructField, }; use crate::util::SourceField; -use fory_core::types::TypeId; +use fory_core::type_id::TypeId; use proc_macro2::{Ident, TokenStream}; use quote::quote; use syn::Field; @@ -36,71 +36,71 @@ pub fn gen_reserved_space(fields: &[&Field]) -> TokenStream { match classify_trait_object_field(ty) { StructField::BoxDyn => { quote! { - fory_core::types::SIZE_OF_REF_AND_TYPE + fory_core::type_id::SIZE_OF_REF_AND_TYPE } } StructField::RcDyn(trait_name) => { let types = create_wrapper_types_rc(&trait_name); let wrapper_ty = types.wrapper_ty; quote! { - <#wrapper_ty as fory_core::Serializer>::fory_reserved_space() + fory_core::types::SIZE_OF_REF_AND_TYPE + <#wrapper_ty as fory_core::Serializer>::fory_reserved_space() + fory_core::type_id::SIZE_OF_REF_AND_TYPE } } StructField::ArcDyn(trait_name) => { let types = create_wrapper_types_arc(&trait_name); let wrapper_ty = types.wrapper_ty; quote! { - <#wrapper_ty as fory_core::Serializer>::fory_reserved_space() + fory_core::types::SIZE_OF_REF_AND_TYPE + <#wrapper_ty as fory_core::Serializer>::fory_reserved_space() + fory_core::type_id::SIZE_OF_REF_AND_TYPE } } StructField::VecRc(trait_name) => { let types = create_wrapper_types_rc(&trait_name); let wrapper_ty = types.wrapper_ty; quote! { - as fory_core::Serializer>::fory_reserved_space() + fory_core::types::SIZE_OF_REF_AND_TYPE + as fory_core::Serializer>::fory_reserved_space() + fory_core::type_id::SIZE_OF_REF_AND_TYPE } } StructField::VecArc(trait_name) => { let types = create_wrapper_types_arc(&trait_name); let wrapper_ty = types.wrapper_ty; quote! { - as fory_core::Serializer>::fory_reserved_space() + fory_core::types::SIZE_OF_REF_AND_TYPE + as fory_core::Serializer>::fory_reserved_space() + fory_core::type_id::SIZE_OF_REF_AND_TYPE } } StructField::VecBox(_) => { // Vec> uses standard Vec serialization quote! { - <#ty as fory_core::Serializer>::fory_reserved_space() + fory_core::types::SIZE_OF_REF_AND_TYPE + <#ty as fory_core::Serializer>::fory_reserved_space() + fory_core::type_id::SIZE_OF_REF_AND_TYPE } } StructField::HashMapBox(_, _) => { // HashMap> uses standard HashMap serialization quote! { - <#ty as fory_core::Serializer>::fory_reserved_space() + fory_core::types::SIZE_OF_REF_AND_TYPE + <#ty as fory_core::Serializer>::fory_reserved_space() + fory_core::type_id::SIZE_OF_REF_AND_TYPE } } StructField::HashMapRc(key_ty, trait_name) => { let types = create_wrapper_types_rc(&trait_name); let wrapper_ty = types.wrapper_ty; quote! { - as fory_core::Serializer>::fory_reserved_space() + fory_core::types::SIZE_OF_REF_AND_TYPE + as fory_core::Serializer>::fory_reserved_space() + fory_core::type_id::SIZE_OF_REF_AND_TYPE } } StructField::HashMapArc(key_ty, trait_name) => { let types = create_wrapper_types_arc(&trait_name); let wrapper_ty = types.wrapper_ty; quote! { - as fory_core::Serializer>::fory_reserved_space() + fory_core::types::SIZE_OF_REF_AND_TYPE + as fory_core::Serializer>::fory_reserved_space() + fory_core::type_id::SIZE_OF_REF_AND_TYPE } } StructField::Forward => { quote! { - <#ty as fory_core::Serializer>::fory_reserved_space() + fory_core::types::SIZE_OF_REF_AND_TYPE + <#ty as fory_core::Serializer>::fory_reserved_space() + fory_core::type_id::SIZE_OF_REF_AND_TYPE } } _ => { quote! { - <#ty as fory_core::Serializer>::fory_reserved_space() + fory_core::types::SIZE_OF_REF_AND_TYPE + <#ty as fory_core::Serializer>::fory_reserved_space() + fory_core::type_id::SIZE_OF_REF_AND_TYPE } } } @@ -321,7 +321,7 @@ fn gen_write_field_impl( // for struct-type fields in compatible mode, even for non-nullable fields. quote! { let write_type_info = if context.is_compatible() { - fory_core::types::need_to_write_type_for_field( + fory_core::type_id::need_to_write_type_for_field( <#ty as fory_core::Serializer>::fory_static_type_id() ) } else { diff --git a/rust/fory/src/lib.rs b/rust/fory/src/lib.rs index d8199eb227..2a8912284b 100644 --- a/rust/fory/src/lib.rs +++ b/rust/fory/src/lib.rs @@ -1202,7 +1202,7 @@ pub use fory_core::{ error::Error, fory::Fory, fory::ForyBuilder, register_trait_type, row::from_row, row::to_row, - types::TypeId, ArcWeak, ForyDefault, RcWeak, ReadContext, Reader, Serializer, TypeResolver, - WriteContext, Writer, + ArcWeak, ForyDefault, RcWeak, ReadContext, Reader, RefFlag, RefMode, Serializer, TypeId, + TypeResolver, WriteContext, Writer, }; pub use fory_derive::{ForyObject, ForyRow}; diff --git a/rust/tests/Cargo.toml b/rust/tests/Cargo.toml index 3f9ee2343a..59e8546703 100644 --- a/rust/tests/Cargo.toml +++ b/rust/tests/Cargo.toml @@ -27,3 +27,4 @@ fory-core = { path = "../fory-core" } fory-derive = { path = "../fory-derive" } chrono = "0.4" +num-bigint = "0.4" diff --git a/rust/tests/tests/test_array.rs b/rust/tests/tests/test_array.rs index 784afcecd8..54076eeca6 100644 --- a/rust/tests/tests/test_array.rs +++ b/rust/tests/tests/test_array.rs @@ -372,7 +372,7 @@ fn test_array_rc_trait_objects() { #[test] fn test_array_float16() { - use fory_core::float16::float16; + use fory_core::types::float16::float16; let fory = fory_core::fory::Fory::default(); let arr = [ float16::from_f32(1.0), @@ -389,7 +389,7 @@ fn test_array_float16() { #[test] fn test_array_float16_special_values() { - use fory_core::float16::float16; + use fory_core::types::float16::float16; let fory = fory_core::fory::Fory::default(); let arr = [ float16::INFINITY, diff --git a/rust/tests/tests/test_cross_language.rs b/rust/tests/tests/test_cross_language.rs index 712d6309c0..f3da1c030a 100644 --- a/rust/tests/tests/test_cross_language.rs +++ b/rust/tests/tests/test_cross_language.rs @@ -18,12 +18,14 @@ use chrono::{NaiveDate, NaiveDateTime}; use fory_core::buffer::{Reader, Writer}; use fory_core::error::Error; -use fory_core::meta::murmurhash3_x64_128; -use fory_core::resolver::context::{ReadContext, WriteContext}; +use fory_core::resolver::TypeResolver; use fory_core::serializer::{ForyDefault, Serializer}; -use fory_core::TypeResolver; -use fory_core::{read_data, write_data, Fory}; +use fory_core::type_id::TypeId; +use fory_core::util::murmurhash3_x64_128; +use fory_core::{read_data, write_data, Decimal, Fory}; +use fory_core::{ReadContext, WriteContext}; use fory_derive::ForyObject; +use num_bigint::BigInt; use std::collections::{HashMap, HashSet}; use std::{fs, vec}; @@ -113,6 +115,38 @@ fn test_buffer() { fs::write(&data_file_path, writer.dump()).unwrap(); } +#[test] +#[allow(deprecated)] +fn test_naive_date_uses_var_i64_day_count() { + let fory = Fory::builder().xlang(true).track_ref(false).build(); + let day = NaiveDate::from_ymd_opt(1969, 12, 31).unwrap(); + let mut buf = Vec::new(); + fory.serialize_to(&mut buf, &day).unwrap(); + + let mut reader = Reader::new(buf.as_slice()); + assert_eq!(reader.read_u8().unwrap(), 2); + assert_eq!(reader.read_i8().unwrap(), -1); + assert_eq!(reader.read_u8().unwrap(), TypeId::DATE as u8); + assert_eq!(reader.read_var_i64().unwrap(), -1); + assert_eq!(reader.get_cursor(), buf.len()); +} + +#[test] +#[allow(deprecated)] +fn test_naive_date_uses_i32_day_count_in_native_mode() { + let fory = Fory::builder().xlang(false).track_ref(false).build(); + let day = NaiveDate::from_ymd_opt(1969, 12, 31).unwrap(); + let mut buf = Vec::new(); + fory.serialize_to(&mut buf, &day).unwrap(); + + let mut reader = Reader::new(buf.as_slice()); + assert_eq!(reader.read_u8().unwrap(), 0); + assert_eq!(reader.read_i8().unwrap(), -1); + assert_eq!(reader.read_u8().unwrap(), TypeId::DATE as u8); + assert_eq!(reader.read_i32().unwrap(), -1); + assert_eq!(reader.get_cursor(), buf.len()); +} + #[test] #[ignore] fn test_buffer_var() { @@ -632,6 +666,47 @@ fn test_integer() { fs::write(&data_file_path, buf).unwrap(); } +fn decimal_value(unscaled: &str, scale: i32) -> Decimal { + Decimal::new( + BigInt::parse_bytes(unscaled.as_bytes(), 10).expect("invalid decimal"), + scale, + ) +} + +#[test] +#[ignore] +fn test_decimal() { + let data_file_path = get_data_file(); + let bytes = fs::read(&data_file_path).unwrap(); + let mut reader = Reader::new(bytes.as_slice()); + let fory = Fory::builder().compatible(true).xlang(true).build(); + let values = vec![ + decimal_value("0", 0), + decimal_value("0", 3), + decimal_value("1", 0), + decimal_value("-1", 0), + decimal_value("12345", 2), + decimal_value("9223372036854775807", 0), + decimal_value("-9223372036854775808", 0), + decimal_value("4611686018427387903", 0), + decimal_value("-4611686018427387904", 0), + decimal_value("9223372036854775808", 0), + decimal_value("-9223372036854775809", 0), + decimal_value("123456789012345678901234567890123456789", 37), + decimal_value("-123456789012345678901234567890123456789", -17), + ]; + for expected in &values { + let actual: Decimal = fory.deserialize_from(&mut reader).unwrap(); + assert_eq!(&actual, expected); + } + + let mut buf = Vec::new(); + for value in &values { + fory.serialize_to(&mut buf, value).unwrap(); + } + fs::write(&data_file_path, buf).unwrap(); +} + #[derive(ForyObject, Debug, PartialEq)] struct MyStruct { id: i32, diff --git a/rust/tests/tests/test_debug.rs b/rust/tests/tests/test_debug.rs index a83e3a9f8d..3c5141e8da 100644 --- a/rust/tests/tests/test_debug.rs +++ b/rust/tests/tests/test_debug.rs @@ -19,11 +19,11 @@ use std::any::Any; use std::sync::{Mutex, MutexGuard, OnceLock}; use fory_core::fory::Fory; -use fory_core::resolver::context::{ReadContext, WriteContext}; use fory_core::serializer::struct_::{ reset_struct_debug_hooks, set_after_read_field_func, set_after_write_field_func, set_before_read_field_func, set_before_write_field_func, }; +use fory_core::{ReadContext, WriteContext}; #[derive(fory_derive::ForyObject)] #[fory(debug)] diff --git a/rust/tests/tests/test_decimal.rs b/rust/tests/tests/test_decimal.rs new file mode 100644 index 0000000000..52681a6044 --- /dev/null +++ b/rust/tests/tests/test_decimal.rs @@ -0,0 +1,98 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use fory_core::buffer::Reader; +use fory_core::type_id::config_flags::IS_CROSS_LANGUAGE_FLAG; +use fory_core::{Decimal, Fory, RefFlag, TypeId}; +use num_bigint::BigInt; + +fn decimal(unscaled: &str, scale: i32) -> Decimal { + Decimal::new( + BigInt::parse_bytes(unscaled.as_bytes(), 10).expect("invalid decimal test value"), + scale, + ) +} + +#[test] +fn test_decimal_round_trip() { + let fory = Fory::builder().xlang(true).build(); + let values = vec![ + Decimal::new(BigInt::from(0), 0), + Decimal::new(BigInt::from(0), 3), + Decimal::new(BigInt::from(1), 0), + Decimal::new(BigInt::from(-1), 0), + Decimal::new(BigInt::from(12_345), 2), + Decimal::new(BigInt::from(i64::MAX), 0), + Decimal::new(BigInt::from(i64::MIN), 0), + Decimal::new(BigInt::from(i64::MAX) + BigInt::from(1), 0), + Decimal::new(BigInt::from(i64::MIN) - BigInt::from(1), 0), + decimal("123456789012345678901234567890123456789", 37), + decimal("-123456789012345678901234567890123456789", -17), + ]; + + for value in values { + let bytes = fory.serialize(&value).unwrap(); + let decoded: Decimal = fory.deserialize(&bytes).unwrap(); + assert_eq!(value, decoded); + } +} + +#[test] +fn test_decimal_wire_format() { + let fory = Fory::builder().xlang(true).build(); + let bytes = fory.serialize(&Decimal::new(BigInt::from(0), 2)).unwrap(); + let mut reader = Reader::new(bytes.as_slice()); + assert_eq!(reader.read_u8().unwrap(), IS_CROSS_LANGUAGE_FLAG); + assert_eq!(reader.read_i8().unwrap(), RefFlag::NotNullValue as i8); + assert_eq!(reader.read_var_u32().unwrap(), TypeId::DECIMAL as u32); + assert_eq!(reader.read_var_i32().unwrap(), 2); + assert_eq!(reader.read_var_u64().unwrap(), 0); + + let bytes = fory.serialize(&decimal("9223372036854775808", 0)).unwrap(); + let mut reader = Reader::new(bytes.as_slice()); + assert_eq!(reader.read_u8().unwrap(), IS_CROSS_LANGUAGE_FLAG); + assert_eq!(reader.read_i8().unwrap(), RefFlag::NotNullValue as i8); + assert_eq!(reader.read_var_u32().unwrap(), TypeId::DECIMAL as u32); + assert_eq!(reader.read_var_i32().unwrap(), 0); + assert_eq!(reader.read_var_u64().unwrap() & 1, 1); +} + +#[test] +fn test_decimal_rejects_non_canonical_big_payload() { + let fory = Fory::builder().xlang(true).build(); + + let payload = vec![ + IS_CROSS_LANGUAGE_FLAG, + RefFlag::NotNullValue as i8 as u8, + TypeId::DECIMAL as u8, + 0x00, + 0x01, + ]; + assert!(fory.deserialize::(&payload).is_err()); + + let payload = vec![ + IS_CROSS_LANGUAGE_FLAG, + RefFlag::NotNullValue as i8 as u8, + TypeId::DECIMAL as u8, + 0x00, + 0x09, + 0x01, + 0x00, + ]; + let err = fory.deserialize::(&payload).unwrap_err(); + assert!(err.to_string().contains("trailing zero byte")); +} diff --git a/rust/tests/tests/test_enum.rs b/rust/tests/tests/test_enum.rs index a355ed11d9..6b2b6b831c 100644 --- a/rust/tests/tests/test_enum.rs +++ b/rust/tests/tests/test_enum.rs @@ -104,7 +104,7 @@ fn named_enum() { #[test] fn struct_with_enum_field() { use fory_core::serializer::Serializer; - use fory_core::types::TypeId; + use fory_core::type_id::TypeId; // Define a simple enum #[derive(ForyObject, Debug, PartialEq, Clone)] @@ -153,7 +153,7 @@ fn struct_with_enum_field() { #[test] fn union_compatible_enum_xlang_format() { use fory_core::serializer::Serializer; - use fory_core::types::TypeId; + use fory_core::type_id::TypeId; // Define a Union-compatible enum (each variant has exactly one field) #[derive(ForyObject, Debug, PartialEq, Clone)] @@ -201,7 +201,7 @@ fn union_compatible_enum_xlang_format() { #[test] fn struct_with_enum_field_explicit_nullable() { use fory_core::serializer::Serializer; - use fory_core::types::TypeId; + use fory_core::type_id::TypeId; #[derive(ForyObject, Debug, PartialEq, Clone)] enum Status { diff --git a/rust/tests/tests/test_ext.rs b/rust/tests/tests/test_ext.rs index 2947695696..ec2c2c69dc 100644 --- a/rust/tests/tests/test_ext.rs +++ b/rust/tests/tests/test_ext.rs @@ -17,9 +17,9 @@ use fory_core::error::Error; use fory_core::fory::Fory; -use fory_core::resolver::context::{ReadContext, WriteContext}; +use fory_core::resolver::TypeResolver; use fory_core::serializer::{ForyDefault, Serializer}; -use fory_core::TypeResolver; +use fory_core::{ReadContext, WriteContext}; use fory_derive::ForyObject; #[test] diff --git a/rust/tests/tests/test_fory.rs b/rust/tests/tests/test_fory.rs index ebd2bb29f4..671263a01a 100644 --- a/rust/tests/tests/test_fory.rs +++ b/rust/tests/tests/test_fory.rs @@ -290,7 +290,7 @@ fn test_unregistered_type_error_message() { #[test] fn test_type_mismatch_error_shows_type_name() { - use fory_core::types::{format_type_id, TypeId}; + use fory_core::type_id::{format_type_id, TypeId}; // Test internal type (BOOL = 1), no registered_id let formatted = format_type_id(TypeId::BOOL as u32); diff --git a/rust/tests/tests/test_list.rs b/rust/tests/tests/test_list.rs index b119697e4b..6edc8faff5 100644 --- a/rust/tests/tests/test_list.rs +++ b/rust/tests/tests/test_list.rs @@ -141,7 +141,7 @@ fn test_struct_with_collections() { #[test] fn test_vec_float16_basic() { - use fory_core::float16::float16; + use fory_core::types::float16::float16; let fory = fory_core::fory::Fory::default(); let vec: Vec = vec![ float16::from_f32(1.0), @@ -159,7 +159,7 @@ fn test_vec_float16_basic() { #[test] fn test_vec_float16_special_values() { - use fory_core::float16::float16; + use fory_core::types::float16::float16; let fory = fory_core::fory::Fory::default(); let vec: Vec = vec![ float16::INFINITY, @@ -181,7 +181,7 @@ fn test_vec_float16_special_values() { #[test] fn test_vec_float16_empty() { - use fory_core::float16::float16; + use fory_core::types::float16::float16; let fory = fory_core::fory::Fory::default(); let vec: Vec = vec![]; let bin = fory.serialize(&vec).unwrap(); diff --git a/rust/tests/tests/test_marker.rs b/rust/tests/tests/test_marker.rs index 258198d8bd..3beab27609 100644 --- a/rust/tests/tests/test_marker.rs +++ b/rust/tests/tests/test_marker.rs @@ -25,7 +25,7 @@ use fory_core::fory::Fory; use fory_core::serializer::Serializer; -use fory_core::types::TypeId; +use fory_core::type_id::TypeId; use fory_derive::ForyObject; use std::marker::PhantomData; @@ -118,14 +118,14 @@ fn test_nested_struct_with_phantom_data() { #[test] fn test_union_type_id() { assert_eq!(TypeId::UNION as i16, 33); - assert_eq!(fory_core::types::UNION, 33); + assert_eq!(fory_core::type_id::UNION, 33); } /// Test that NONE TypeId matches xlang spec (36) #[test] fn test_none_type_id() { assert_eq!(TypeId::NONE as i16, 36); - assert_eq!(fory_core::types::NONE, 36); + assert_eq!(fory_core::type_id::NONE, 36); } /// Test that PhantomData uses NONE TypeId (no runtime data) diff --git a/rust/tests/tests/test_meta.rs b/rust/tests/tests/test_meta.rs index deae54dd77..349749d2e8 100644 --- a/rust/tests/tests/test_meta.rs +++ b/rust/tests/tests/test_meta.rs @@ -16,7 +16,7 @@ // under the License. use fory_core::meta::{FieldInfo, FieldType, MetaString, TypeMeta}; -use fory_core::types::TypeId; +use fory_core::type_id::TypeId; #[test] fn test_meta_hash() { diff --git a/rust/tests/tests/test_ref_resolver.rs b/rust/tests/tests/test_ref_resolver.rs index 03b0d78658..643c077c92 100644 --- a/rust/tests/tests/test_ref_resolver.rs +++ b/rust/tests/tests/test_ref_resolver.rs @@ -18,8 +18,8 @@ //! Tests for RefWriter and RefReader functionality use fory_core::buffer::Writer; -use fory_core::resolver::ref_resolver::{RefReader, RefWriter}; -use fory_core::serializer::weak::{ArcWeak, RcWeak}; +use fory_core::resolver::{RefReader, RefWriter}; +use fory_core::{ArcWeak, RcWeak}; use std::rc::Rc; use std::sync::Arc; diff --git a/rust/tests/tests/test_simple_struct.rs b/rust/tests/tests/test_simple_struct.rs index b0cb285ff2..583ba1a92a 100644 --- a/rust/tests/tests/test_simple_struct.rs +++ b/rust/tests/tests/test_simple_struct.rs @@ -220,7 +220,7 @@ fn test_compatible_map_to_empty_struct() { #[test] fn test_struct_with_float16_fields() { - use fory_core::float16::float16; + use fory_core::types::float16::float16; #[derive(ForyObject, Debug)] struct Float16Data { diff --git a/rust/tests/tests/test_tuple.rs b/rust/tests/tests/test_tuple.rs index 864e3cf5f6..f735343bca 100644 --- a/rust/tests/tests/test_tuple.rs +++ b/rust/tests/tests/test_tuple.rs @@ -221,7 +221,7 @@ fn test_homogeneous_tuple_unsigned() { #[test] fn test_tuple_type_id() { use fory_core::serializer::Serializer; - use fory_core::types::TypeId; + use fory_core::type_id::TypeId; assert_eq!(<(i32, i32)>::fory_static_type_id(), TypeId::LIST); assert_eq!(<(i32, String)>::fory_static_type_id(), TypeId::LIST); assert_eq!(<(i32,)>::fory_static_type_id(), TypeId::LIST); diff --git a/rust/tests/tests/test_util.rs b/rust/tests/tests/test_util.rs index 35e2e7406e..c3db365ec0 100644 --- a/rust/tests/tests/test_util.rs +++ b/rust/tests/tests/test_util.rs @@ -15,8 +15,7 @@ // specific language governing permissions and limitations // under the License. -use fory_core::meta::{get_latin1_length, is_latin}; -use fory_core::util::to_utf8; +use fory_core::util::{get_latin1_length, is_latin, to_utf8}; #[test] fn test_to_utf8() { diff --git a/rust/tests/tests/test_weak.rs b/rust/tests/tests/test_weak.rs index 60174846fa..6f0dcc47ca 100644 --- a/rust/tests/tests/test_weak.rs +++ b/rust/tests/tests/test_weak.rs @@ -16,7 +16,7 @@ // under the License. use fory_core::fory::Fory; -use fory_core::serializer::weak::{ArcWeak, RcWeak}; +use fory_core::{ArcWeak, RcWeak}; use fory_derive::ForyObject; use std::cell::RefCell; use std::rc::Rc; diff --git a/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaTest.scala b/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaTest.scala index 220d366d34..99323590a6 100644 --- a/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaTest.scala +++ b/scala/src/test/scala/org/apache/fory/serializer/scala/ScalaTest.scala @@ -19,6 +19,7 @@ package org.apache.fory.serializer.scala +import java.math.{BigDecimal => JBigDecimal, BigInteger} import org.apache.fory.Fory import org.apache.fory.config.Language import org.scalatest.matchers.should.Matchers @@ -41,6 +42,32 @@ class ScalaTest extends AnyWordSpec with Matchers { val p = SomePackageObject.SomeClass(1) fory.deserialize(fory.serialize(p)) shouldEqual p } + "serialize/deserialize java.math.BigDecimal" in { + val values = Seq( + JBigDecimal.ZERO, + new JBigDecimal(BigInteger.ZERO, 3), + JBigDecimal.ONE, + JBigDecimal.ONE.negate(), + JBigDecimal.valueOf(12345, 2), + new JBigDecimal(BigInteger.valueOf(Long.MaxValue), 0), + new JBigDecimal(BigInteger.valueOf(Long.MinValue), 0), + new JBigDecimal(BigInteger.valueOf(Long.MaxValue).add(BigInteger.ONE), 0), + new JBigDecimal(BigInteger.valueOf(Long.MinValue).subtract(BigInteger.ONE), 0), + new JBigDecimal(new BigInteger("123456789012345678901234567890123456789"), 37) + ) + Seq(Language.JAVA, Language.XLANG).foreach { language => + val decimalFory = Fory.builder() + .withLanguage(language) + .withRefTracking(true) + .withScalaOptimizationEnabled(true) + .requireClassRegistration(false) + .suppressClassRegistrationWarnings(false) + .build() + values.foreach { value => + decimalFory.deserialize(decimalFory.serialize(value)) shouldEqual value + } + } + } } "serialize/deserialize package object in app" in { // If we move code in main here, we can't reproduce https://github.com/apache/fory/issues/1165. diff --git a/swift/Sources/Fory/DateTimeSerializers.swift b/swift/Sources/Fory/DateTimeSerializers.swift index 4248554a98..9f9521cd96 100644 --- a/swift/Sources/Fory/DateTimeSerializers.swift +++ b/swift/Sources/Fory/DateTimeSerializers.swift @@ -17,55 +17,203 @@ import Foundation -public struct ForyDate: Serializer, Equatable, Hashable { +private let nanosPerSecond: Int64 = 1_000_000_000 +private let secondsPerDay = 86_400.0 +private let localDateCalendar: Calendar = { + var calendar = Calendar(identifier: .gregorian) + calendar.timeZone = TimeZone(secondsFromGMT: 0)! + return calendar +}() + +@inline(__always) +private func normalizeTimestampComponents(for date: Date) -> (seconds: Int64, nanos: UInt32) { + let time = date.timeIntervalSince1970 + let seconds = Int64(floor(time)) + let nanos = Int64((time - Double(seconds)) * Double(nanosPerSecond)) + return normalizeTimestampComponents(seconds: seconds, nanos: nanos) +} + +@inline(__always) +private func normalizeTimestampComponents(seconds: Int64, nanos: Int64) -> (seconds: Int64, nanos: UInt32) { + var normalizedSeconds = seconds + nanos / nanosPerSecond + var normalizedNanos = nanos % nanosPerSecond + if normalizedNanos < 0 { + normalizedNanos += nanosPerSecond + normalizedSeconds -= 1 + } + return (normalizedSeconds, UInt32(normalizedNanos)) +} + +@inline(__always) +private func timestampDate(seconds: Int64, nanos: UInt32) -> Date { + Date(timeIntervalSince1970: Double(seconds) + Double(nanos) / Double(nanosPerSecond)) +} + +@inline(__always) +private func localDateDaysSinceEpoch(for date: Date) throws -> Int32 { + let days = floor(date.timeIntervalSince1970 / secondsPerDay) + guard days >= Double(Int32.min), days <= Double(Int32.max) else { + throw ForyError.encodingError("date daysSinceEpoch is out of Int32 range") + } + return Int32(days) +} + +@inline(__always) +private func localDateFromDaysSinceEpoch(_ daysSinceEpoch: Int32) -> Date { + Date(timeIntervalSince1970: Double(daysSinceEpoch) * secondsPerDay) +} + +@inline(__always) +private func localDateComponents(_ daysSinceEpoch: Int32) -> DateComponents { + localDateCalendar.dateComponents([.year, .month, .day], from: localDateFromDaysSinceEpoch(daysSinceEpoch)) +} + +@inline(__always) +private func writeScalarValue( + _ value: T?, + context: WriteContext, + refMode: RefMode, + writeTypeInfo: Bool, + typeID: TypeId, + writePayload: (T) throws -> Void +) throws { + switch refMode { + case .none: + guard let value else { + throw ForyError.encodingError("nil value requires nullable ref mode") + } + if writeTypeInfo { + context.writeStaticTypeInfo(typeID) + } + try writePayload(value) + case .nullOnly, .tracking: + guard let value else { + context.buffer.writeInt8(RefFlag.null.rawValue) + return + } + context.buffer.writeInt8(RefFlag.notNullValue.rawValue) + if writeTypeInfo { + context.writeStaticTypeInfo(typeID) + } + try writePayload(value) + } +} + +@inline(__always) +private func readScalarNullableValue( + context: ReadContext, + refMode: RefMode, + readPayload: () throws -> T +) throws -> T? { + switch refMode { + case .none: + return try readPayload() + case .nullOnly: + let rawFlag = try context.buffer.readInt8() + switch rawFlag { + case RefFlag.null.rawValue: + return nil + case RefFlag.notNullValue.rawValue: + return try readPayload() + case RefFlag.refValue.rawValue: + if context.trackRef { + let reservedRefID = context.refReader.reserveRefID() + let value = try readPayload() + context.refReader.storeRef(value, at: reservedRefID) + return value + } + return try readPayload() + case RefFlag.ref.rawValue: + let refID = try context.buffer.readVarUInt32() + return try context.refReader.readRef(refID, as: T.self) + default: + throw ForyError.refError("invalid ref flag \(rawFlag)") + } + case .tracking: + let rawFlag = try context.buffer.readInt8() + guard let flag = RefFlag(rawValue: rawFlag) else { + throw ForyError.refError("invalid ref flag \(rawFlag)") + } + switch flag { + case .null: + return nil + case .ref: + let refID = try context.buffer.readVarUInt32() + return try context.refReader.readRef(refID, as: T.self) + case .refValue: + let reservedRefID = context.trackRef ? context.refReader.reserveRefID() : nil + let value = try readPayload() + if let reservedRefID { + context.refReader.storeRef(value, at: reservedRefID) + } + return value + case .notNullValue: + return try readPayload() + } + } +} + +@inline(__always) +private func readTypeID(_ context: ReadContext, expectedTypeIDs: [TypeId]) throws -> TypeId { + let rawTypeID = UInt32(try context.buffer.readUInt8()) + guard let actualTypeID = TypeId(rawValue: rawTypeID) else { + throw ForyError.invalidData("unknown type id \(rawTypeID)") + } + if expectedTypeIDs.contains(actualTypeID) { + return actualTypeID + } + if let expectedTypeID = expectedTypeIDs.first, expectedTypeIDs.count == 1 { + throw ForyError.typeMismatch(expected: expectedTypeID.rawValue, actual: rawTypeID) + } + let expectedList = expectedTypeIDs.map(\.rawValue).map(String.init).joined(separator: ", ") + throw ForyError.invalidData("expected one of type ids [\(expectedList)], got \(rawTypeID)") +} + +public struct LocalDate: Serializer, Equatable, Hashable, Comparable { public var daysSinceEpoch: Int32 public init(daysSinceEpoch: Int32 = 0) { self.daysSinceEpoch = daysSinceEpoch } - public static func foryDefault() -> ForyDate { - .init() + public static func fromEpochDay(_ epochDay: Int32) -> LocalDate { + .init(daysSinceEpoch: epochDay) } - public static var staticTypeId: TypeId { - .date + public init(date: Date) throws { + self.daysSinceEpoch = try localDateDaysSinceEpoch(for: date) } - public static func foryWriteStaticTypeInfo(_ context: WriteContext) throws { - context.writeStaticTypeInfo(staticTypeId) + public func toEpochDay() -> Int32 { + daysSinceEpoch } - public static func foryReadTypeInfo(_ context: ReadContext) throws -> TypeInfo? { - try context.readStaticTypeInfo(staticTypeId) + public func toDate() -> Date { + localDateFromDaysSinceEpoch(daysSinceEpoch) } - public func foryWriteData(_ context: WriteContext, hasGenerics: Bool) throws { - _ = hasGenerics - context.buffer.writeInt32(daysSinceEpoch) + public var year: Int { + localDateComponents(daysSinceEpoch).year ?? 1970 } - public static func foryReadData(_ context: ReadContext) throws -> ForyDate { - .init(daysSinceEpoch: try context.buffer.readInt32()) + public var month: Int { + localDateComponents(daysSinceEpoch).month ?? 1 } -} -public struct ForyTimestamp: Serializer, Equatable, Hashable { - public var seconds: Int64 - public var nanos: UInt32 + public var day: Int { + localDateComponents(daysSinceEpoch).day ?? 1 + } - public init(seconds: Int64 = 0, nanos: UInt32 = 0) { - let normalized = Self.normalize(seconds: seconds, nanos: Int64(nanos)) - self.seconds = normalized.seconds - self.nanos = normalized.nanos + public static func < (lhs: LocalDate, rhs: LocalDate) -> Bool { + lhs.daysSinceEpoch < rhs.daysSinceEpoch } - public static func foryDefault() -> ForyTimestamp { + public static func foryDefault() -> LocalDate { .init() } public static var staticTypeId: TypeId { - .timestamp + .date } public static func foryWriteStaticTypeInfo(_ context: WriteContext) throws { @@ -78,35 +226,121 @@ public struct ForyTimestamp: Serializer, Equatable, Hashable { public func foryWriteData(_ context: WriteContext, hasGenerics: Bool) throws { _ = hasGenerics - context.buffer.writeInt64(seconds) - context.buffer.writeUInt32(nanos) + try context.writeLocalDate(self) + } + + public static func foryReadData(_ context: ReadContext) throws -> LocalDate { + try context.readLocalDate() } +} - public static func foryReadData(_ context: ReadContext) throws -> ForyTimestamp { - .init(seconds: try context.buffer.readInt64(), nanos: try context.buffer.readUInt32()) +public extension WriteContext { + @inline(__always) + func writeTimestamp(_ value: Date) throws { + let normalized = normalizeTimestampComponents(for: value) + buffer.writeInt64(normalized.seconds) + buffer.writeUInt32(normalized.nanos) } - public init(date: Date) { - let time = date.timeIntervalSince1970 - let seconds = Int64(floor(time)) - let nanos = Int64((time - Double(seconds)) * 1_000_000_000.0) - let normalized = Self.normalize(seconds: seconds, nanos: nanos) - self.seconds = normalized.seconds - self.nanos = normalized.nanos + @inline(__always) + func writeLocalDate(_ value: LocalDate) throws { + if xlang { + buffer.writeVarInt64(Int64(value.daysSinceEpoch)) + } else { + buffer.writeInt32(value.daysSinceEpoch) + } } - public func toDate() -> Date { - Date(timeIntervalSince1970: Double(seconds) + Double(nanos) / 1_000_000_000.0) + @inline(__always) + func writeTimestamp( + _ value: Date?, + refMode: RefMode, + writeTypeInfo: Bool + ) throws { + try writeScalarValue( + value, + context: self, + refMode: refMode, + writeTypeInfo: writeTypeInfo, + typeID: .timestamp, + writePayload: { try self.writeTimestamp($0) } + ) } - private static func normalize(seconds: Int64, nanos: Int64) -> (seconds: Int64, nanos: UInt32) { - var normalizedSeconds = seconds + nanos / 1_000_000_000 - var normalizedNanos = nanos % 1_000_000_000 - if normalizedNanos < 0 { - normalizedNanos += 1_000_000_000 - normalizedSeconds -= 1 + @inline(__always) + func writeLocalDate( + _ value: LocalDate?, + refMode: RefMode, + writeTypeInfo: Bool + ) throws { + try writeScalarValue( + value, + context: self, + refMode: refMode, + writeTypeInfo: writeTypeInfo, + typeID: .date, + writePayload: { try self.writeLocalDate($0) } + ) + } +} + +public extension ReadContext { + @inline(__always) + func readTimestamp() throws -> Date { + timestampDate(seconds: try buffer.readInt64(), nanos: try buffer.readUInt32()) + } + + @inline(__always) + func readLocalDate() throws -> LocalDate { + if xlang { + guard let daysSinceEpoch = Int32(exactly: try buffer.readVarInt64()) else { + throw ForyError.invalidData("date daysSinceEpoch is out of Int32 range") + } + return .init(daysSinceEpoch: daysSinceEpoch) } - return (normalizedSeconds, UInt32(normalizedNanos)) + return .init(daysSinceEpoch: try buffer.readInt32()) + } + + @inline(__always) + func readNullableTimestamp( + refMode: RefMode, + readTypeInfo: Bool + ) throws -> Date? { + try readScalarNullableValue(context: self, refMode: refMode) { + if readTypeInfo { + _ = try readTypeID(self, expectedTypeIDs: [.timestamp]) + } + return try self.readTimestamp() + } + } + + @inline(__always) + func readTimestamp( + refMode: RefMode, + readTypeInfo: Bool + ) throws -> Date { + try readNullableTimestamp(refMode: refMode, readTypeInfo: readTypeInfo) ?? Date.foryDefault() + } + + @inline(__always) + func readNullableLocalDate( + refMode: RefMode, + readTypeInfo: Bool + ) throws -> LocalDate? { + try readScalarNullableValue(context: self, refMode: refMode) { + if readTypeInfo { + _ = try readTypeID(self, expectedTypeIDs: [.date]) + } + return try self.readLocalDate() + } + } + + @inline(__always) + func readLocalDate( + refMode: RefMode, + readTypeInfo: Bool + ) throws -> LocalDate { + try readNullableLocalDate(refMode: refMode, readTypeInfo: readTypeInfo) ?? LocalDate.foryDefault() } } @@ -165,13 +399,18 @@ extension Date: Serializer { public func foryWriteData(_ context: WriteContext, hasGenerics: Bool) throws { _ = hasGenerics - let ts = ForyTimestamp(date: self) - context.buffer.writeInt64(ts.seconds) - context.buffer.writeUInt32(ts.nanos) + try context.writeTimestamp(self) } public static func foryReadData(_ context: ReadContext) throws -> Date { - let ts = ForyTimestamp(seconds: try context.buffer.readInt64(), nanos: try context.buffer.readUInt32()) - return ts.toDate() + try context.readTimestamp() + } + + public static func foryRead( + _ context: ReadContext, + refMode: RefMode, + readTypeInfo: Bool + ) throws -> Date { + try context.readTimestamp(refMode: refMode, readTypeInfo: readTypeInfo) } } diff --git a/swift/Sources/Fory/Decimal.swift b/swift/Sources/Fory/Decimal.swift new file mode 100644 index 0000000000..3c5cef18d0 --- /dev/null +++ b/swift/Sources/Fory/Decimal.swift @@ -0,0 +1,285 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import Foundation + +private let decimalSmallPositiveMax: UInt64 = 0x3fff_ffff_ffff_ffff +private let decimalSmallNegativeAbsMax: UInt64 = 0x4000_0000_0000_0000 +private let decimalLengthMask: UInt8 = 0x0f +private let decimalNegativeMask: UInt8 = 0x10 +private let decimalCompactMask: UInt8 = 0x20 +private let decimalHeaderSize = 4 +private let decimalMaxMantissaWords = 8 +private let decimalMaxMagnitudeBytes = decimalMaxMantissaWords * 2 + +private struct FoundationDecimalWireState { + let scale: Int32 + let signum: Int8 + let magnitude: [UInt8] +} + +@inline(__always) +private func normalizeDecimalMagnitude(_ magnitude: [UInt8]) -> [UInt8] { + var normalized = magnitude + while let last = normalized.last, last == 0 { + normalized.removeLast() + } + return normalized +} + +@inline(__always) +private func decimalMagnitudeBytes(from value: UInt64) -> [UInt8] { + guard value != 0 else { + return [] + } + var remaining = value + var bytes: [UInt8] = [] + bytes.reserveCapacity(8) + while remaining != 0 { + bytes.append(UInt8(truncatingIfNeeded: remaining)) + remaining >>= 8 + } + return bytes +} + +@inline(__always) +private func decimalUInt64(from magnitude: [UInt8]) -> UInt64? { + guard magnitude.count <= 8 else { + return nil + } + var result: UInt64 = 0 + for (index, byte) in magnitude.enumerated() { + result |= UInt64(byte) << (index * 8) + } + return result +} + +private func divideDecimalMagnitudeBy10(_ magnitude: inout [UInt8]) -> UInt8 { + guard !magnitude.isEmpty else { + return 0 + } + var remainder = 0 + for index in stride(from: magnitude.count - 1, through: 0, by: -1) { + let value = remainder * 256 + Int(magnitude[index]) + magnitude[index] = UInt8(value / 10) + remainder = value % 10 + } + magnitude = normalizeDecimalMagnitude(magnitude) + return UInt8(remainder) +} + +private func decimalMagnitudeString(signum: Int8, magnitude: [UInt8]) -> String { + let normalized = normalizeDecimalMagnitude(magnitude) + guard !normalized.isEmpty else { + return "0" + } + + var remainderMagnitude = normalized + var digits: [Character] = [] + while !remainderMagnitude.isEmpty { + let digit = divideDecimalMagnitudeBy10(&remainderMagnitude) + digits.append(Character(String(UnicodeScalar(48 + Int(digit))!))) + } + let decimalDigits = String(digits.reversed()) + return signum < 0 ? "-\(decimalDigits)" : decimalDigits +} + +@inline(__always) +private func encodeDecimalZigZag64(_ value: Int64) -> UInt64 { + UInt64(bitPattern: (value << 1) ^ (value >> 63)) +} + +@inline(__always) +private func decodeDecimalZigZag64(_ value: UInt64) -> Int64 { + let shifted = Int64(bitPattern: value >> 1) + let mask = -Int64(value & 1) + return shifted ^ mask +} + +private func decimalMantissaWords(from magnitude: [UInt8]) throws -> (words: [UInt16], length: Int) { + let normalized = normalizeDecimalMagnitude(magnitude) + guard normalized.count <= decimalMaxMagnitudeBytes else { + throw ForyError.invalidData( + "decimal magnitude with \(normalized.count) bytes exceeds Foundation.Decimal precision" + ) + } + var words = Array(repeating: UInt16(0), count: decimalMaxMantissaWords) + for (index, byte) in normalized.enumerated() { + let wordIndex = index / 2 + let shift = (index % 2) * 8 + words[wordIndex] |= UInt16(byte) << shift + } + var length = words.count + while length > 0, words[length - 1] == 0 { + length -= 1 + } + return (words, length) +} + +private func foundationDecimalExponent(forScale scale: Int32) throws -> Int8 { + let exponent = 0 - Int64(scale) + guard exponent >= Int64(Int8.min), exponent <= Int64(Int8.max) else { + throw ForyError.invalidData( + "decimal scale \(scale) is out of Foundation.Decimal exponent range" + ) + } + return Int8(exponent) +} + +private func foundationDecimalWireState(_ value: Decimal) -> FoundationDecimalWireState { + var compact = value + NSDecimalCompact(&compact) + return withUnsafeBytes(of: &compact) { raw in + let exponent = Int8(bitPattern: raw[0]) + let flags = raw[1] + let length = min(Int(flags & decimalLengthMask), decimalMaxMantissaWords) + let isNegative = (flags & decimalNegativeMask) != 0 + + var magnitude: [UInt8] = [] + magnitude.reserveCapacity(length * 2) + for index in 0.. Decimal { + let normalized = normalizeDecimalMagnitude(magnitude) + let exponent = try foundationDecimalExponent(forScale: scale) + let mantissa = try decimalMantissaWords(from: normalized) + + var value = Decimal.zero + withUnsafeMutableBytes(of: &value) { raw in + raw.initializeMemory(as: UInt8.self, repeating: 0) + raw[0] = UInt8(bitPattern: exponent) + var flags = UInt8(truncatingIfNeeded: mantissa.length) + if signum < 0 && !normalized.isEmpty { + flags |= decimalNegativeMask + } + if mantissa.length > 0 { + flags |= decimalCompactMask + } + raw[1] = flags + for index in 0..> 8) + } + } + NSDecimalCompact(&value) + return value +} + +private func smallUnscaledValueForWire(_ state: FoundationDecimalWireState) -> Int64? { + guard let magnitude = decimalUInt64(from: state.magnitude) else { + return nil + } + if state.signum >= 0 { + guard magnitude <= decimalSmallPositiveMax else { + return nil + } + return Int64(magnitude) + } + guard magnitude <= decimalSmallNegativeAbsMax else { + return nil + } + return -Int64(magnitude) +} + +extension Decimal { + internal var foryScale: Int32 { + foundationDecimalWireState(self).scale + } + + internal var foryUnscaledString: String { + let state = foundationDecimalWireState(self) + return decimalMagnitudeString(signum: state.signum, magnitude: state.magnitude) + } +} + +extension Decimal: Serializer { + public static func foryDefault() -> Decimal { + .zero + } + + public static var staticTypeId: TypeId { + .decimal + } + + public func foryWriteData(_ context: WriteContext, hasGenerics: Bool) throws { + _ = hasGenerics + let state = foundationDecimalWireState(self) + context.buffer.writeVarInt32(state.scale) + if let small = smallUnscaledValueForWire(state) { + let header = encodeDecimalZigZag64(small) << 1 + context.buffer.writeVarUInt64(header) + return + } + + guard !state.magnitude.isEmpty else { + throw ForyError.invalidData("zero must use the small decimal encoding") + } + let sign: UInt64 = state.signum < 0 ? 1 : 0 + let meta = (UInt64(state.magnitude.count) << 1) | sign + let header = (meta << 1) | 1 + context.buffer.writeVarUInt64(header) + context.buffer.writeBytes(state.magnitude) + } + + public static func foryReadData(_ context: ReadContext) throws -> Decimal { + let scale = try context.buffer.readVarInt32() + let header = try context.buffer.readVarUInt64() + if (header & 1) == 0 { + let unscaled = decodeDecimalZigZag64(header >> 1) + let signum: Int8 = unscaled == 0 ? 0 : (unscaled < 0 ? -1 : 1) + return try buildFoundationDecimal( + signum: signum, + magnitude: decimalMagnitudeBytes(from: unscaled.magnitude), + scale: scale + ) + } + + let meta = header >> 1 + let signum: Int8 = (meta & 1) == 0 ? 1 : -1 + let length = Int(meta >> 1) + guard length > 0 else { + throw ForyError.invalidData("invalid decimal magnitude length \(length)") + } + let payload = try context.buffer.readBytes(count: length) + guard payload[length - 1] != 0 else { + throw ForyError.invalidData("non-canonical decimal payload: trailing zero byte") + } + let normalized = normalizeDecimalMagnitude(payload) + guard !normalized.isEmpty else { + throw ForyError.invalidData("big decimal encoding must not represent zero") + } + return try buildFoundationDecimal(signum: signum, magnitude: normalized, scale: scale) + } +} diff --git a/swift/Sources/Fory/FieldSkipper.swift b/swift/Sources/Fory/FieldSkipper.swift index d61c6462d0..0273fe35e9 100644 --- a/swift/Sources/Fory/FieldSkipper.swift +++ b/swift/Sources/Fory/FieldSkipper.swift @@ -168,7 +168,9 @@ public extension ReadContext { case .timestamp: return try Date.foryRead(self, refMode: .none, readTypeInfo: false) case .date: - return try ForyDate.foryRead(self, refMode: .none, readTypeInfo: false) + return try LocalDate.foryRead(self, refMode: .none, readTypeInfo: false) + case .decimal: + return try Decimal.foryRead(self, refMode: .none, readTypeInfo: false) case .binary, .uint8Array: return try Data.foryRead(self, refMode: .none, readTypeInfo: false) case .boolArray: diff --git a/swift/Sources/Fory/Fory.swift b/swift/Sources/Fory/Fory.swift index e560c04e4d..0c5ed23366 100644 --- a/swift/Sources/Fory/Fory.swift +++ b/swift/Sources/Fory/Fory.swift @@ -83,6 +83,7 @@ public final class Fory { self.writeContext = WriteContext( buffer: ByteBuffer(), typeResolver: typeResolver, + xlang: self.config.xlang, trackRef: self.config.trackRef, compatible: self.config.compatible, checkClassVersion: self.config.checkClassVersion, @@ -92,6 +93,7 @@ public final class Fory { self.readContext = ReadContext( buffer: ByteBuffer(), typeResolver: typeResolver, + xlang: self.config.xlang, trackRef: self.config.trackRef, compatible: self.config.compatible, checkClassVersion: self.config.checkClassVersion, diff --git a/swift/Sources/Fory/ReadContext.swift b/swift/Sources/Fory/ReadContext.swift index d0e3e46e88..899ec2a296 100644 --- a/swift/Sources/Fory/ReadContext.swift +++ b/swift/Sources/Fory/ReadContext.swift @@ -22,6 +22,7 @@ private let typeMetaSizeMask = 0xFF public final class ReadContext { public let buffer: ByteBuffer let typeResolver: TypeResolver + public let xlang: Bool public let trackRef: Bool public let compatible: Bool public let checkClassVersion: Bool @@ -40,6 +41,7 @@ public final class ReadContext { init( buffer: ByteBuffer, typeResolver: TypeResolver, + xlang: Bool = false, trackRef: Bool, compatible: Bool = false, checkClassVersion: Bool = true, @@ -49,6 +51,7 @@ public final class ReadContext { ) { self.buffer = buffer self.typeResolver = typeResolver + self.xlang = xlang self.trackRef = trackRef self.compatible = compatible self.checkClassVersion = checkClassVersion @@ -445,7 +448,9 @@ public final class ReadContext { case .timestamp: value = try Date.foryRead(self, refMode: .none, readTypeInfo: false) case .date: - value = try ForyDate.foryRead(self, refMode: .none, readTypeInfo: false) + value = try LocalDate.foryRead(self, refMode: .none, readTypeInfo: false) + case .decimal: + value = try Decimal.foryRead(self, refMode: .none, readTypeInfo: false) case .binary, .uint8Array: value = try Data.foryRead(self, refMode: .none, readTypeInfo: false) case .boolArray: diff --git a/swift/Sources/Fory/WriteContext.swift b/swift/Sources/Fory/WriteContext.swift index 1fcf8ae7f4..466f710b12 100644 --- a/swift/Sources/Fory/WriteContext.swift +++ b/swift/Sources/Fory/WriteContext.swift @@ -56,6 +56,7 @@ final class MetaStringWriteState { public final class WriteContext { public let buffer: ByteBuffer let typeResolver: TypeResolver + public let xlang: Bool public let trackRef: Bool public let compatible: Bool public let checkClassVersion: Bool @@ -71,6 +72,7 @@ public final class WriteContext { convenience init( buffer: ByteBuffer, typeResolver: TypeResolver, + xlang: Bool = false, trackRef: Bool, compatible: Bool = false, checkClassVersion: Bool = true, @@ -79,6 +81,7 @@ public final class WriteContext { self.init( buffer: buffer, typeResolver: typeResolver, + xlang: xlang, trackRef: trackRef, compatible: compatible, checkClassVersion: checkClassVersion, @@ -90,6 +93,7 @@ public final class WriteContext { init( buffer: ByteBuffer, typeResolver: TypeResolver, + xlang: Bool, trackRef: Bool, compatible: Bool, checkClassVersion: Bool, @@ -98,6 +102,7 @@ public final class WriteContext { ) { self.buffer = buffer self.typeResolver = typeResolver + self.xlang = xlang self.trackRef = trackRef self.compatible = compatible self.checkClassVersion = checkClassVersion diff --git a/swift/Sources/ForyMacro/ForyObjectMacro.swift b/swift/Sources/ForyMacro/ForyObjectMacro.swift index 03c57b3e5d..f1b022a8ec 100644 --- a/swift/Sources/ForyMacro/ForyObjectMacro.swift +++ b/swift/Sources/ForyMacro/ForyObjectMacro.swift @@ -189,6 +189,7 @@ private enum ParsedEnumKind { private struct ParsedEnumPayloadField { let label: String? let typeText: String + let isOptional: Bool let hasGenerics: Bool } @@ -268,6 +269,7 @@ private func parseEnumDecl(_ enumDecl: EnumDeclSyntax) throws -> ParsedEnumDecl .init( label: label, typeText: payloadType, + isOptional: optional.isOptional, hasGenerics: hasGenerics ) ) @@ -413,9 +415,9 @@ private func buildTaggedUnionEnumDecls(_ cases: [ParsedEnumCase], accessPrefix: var lines: [String] = [] lines.append("case \(enumCasePattern(enumCase)):") lines.append(" context.buffer.writeVarUInt32(\(caseID))") - for payloadIndex in enumCase.payload.indices { + for (payloadIndex, payloadField) in enumCase.payload.enumerated() { let variableName = "__value\(payloadIndex)" - let hasGenerics = enumCase.payload[payloadIndex].hasGenerics ? "true" : "false" + let hasGenerics = payloadField.hasGenerics ? "true" : "false" lines.append( " try \(variableName).foryWrite(context, refMode: .tracking, writeTypeInfo: true, hasGenerics: \(hasGenerics))" ) @@ -665,7 +667,9 @@ private func parseForyFieldConfiguration( continue } - throw MacroExpansionErrorMessage("@ForyField supports only 'id' and 'encoding' arguments") + throw MacroExpansionErrorMessage( + "@ForyField supports only 'id' and 'encoding' arguments" + ) } } @@ -673,7 +677,10 @@ private func parseForyFieldConfiguration( return nil } - return ParsedForyFieldConfiguration(encoding: parsedEncoding, id: parsedID) + return ParsedForyFieldConfiguration( + encoding: parsedEncoding, + id: parsedID + ) } private func parseForyObjectConfiguration(_ attribute: AttributeSyntax) throws -> ParsedForyObjectConfiguration { @@ -1369,7 +1376,7 @@ private func compatibleTypeMetaFieldExpression( typeText: field.typeText, nullableExpression: field.isOptional ? "true" : "false", trackRefExpression: fieldTrackRefExpression, - explicitTypeID: field.customCodecType == nil ? nil : field.typeID + explicitTypeID: field.customCodecType != nil ? field.typeID : nil ) } @@ -1557,7 +1564,9 @@ private struct TypeClassification { let primitiveSize: Int } -private func classifyType(_ typeText: String) -> TypeClassification { +private func classifyType( + _ typeText: String +) -> TypeClassification { let normalized = trimKnownModulePrefix(trimType(typeText)) if isDynamicAnyConcreteType(normalized) { return .init(typeID: 0, isPrimitive: false, isBuiltIn: true, isCollection: false, isMap: false, isCompressedNumeric: false, primitiveSize: 0) @@ -1590,10 +1599,20 @@ private func classifyType(_ typeText: String) -> TypeClassification { return .init(typeID: 21, isPrimitive: false, isBuiltIn: true, isCollection: false, isMap: false, isCompressedNumeric: false, primitiveSize: 0) case "Data": return .init(typeID: 41, isPrimitive: false, isBuiltIn: true, isCollection: false, isMap: false, isCompressedNumeric: false, primitiveSize: 0) - case "Date", "ForyTimestamp": - return .init(typeID: 38, isPrimitive: false, isBuiltIn: true, isCollection: false, isMap: false, isCompressedNumeric: false, primitiveSize: 0) - case "ForyDate": + case "Date": + return .init( + typeID: 38, + isPrimitive: false, + isBuiltIn: true, + isCollection: false, + isMap: false, + isCompressedNumeric: false, + primitiveSize: 0 + ) + case "LocalDate": return .init(typeID: 39, isPrimitive: false, isBuiltIn: true, isCollection: false, isMap: false, isCompressedNumeric: false, primitiveSize: 0) + case "Decimal": + return .init(typeID: 40, isPrimitive: false, isBuiltIn: true, isCollection: false, isMap: false, isCompressedNumeric: false, primitiveSize: 0) default: break } diff --git a/swift/Sources/ForyMacro/ForyObjectMacroReadGeneration.swift b/swift/Sources/ForyMacro/ForyObjectMacroReadGeneration.swift index b2c973089b..71a9ed3d5c 100644 --- a/swift/Sources/ForyMacro/ForyObjectMacroReadGeneration.swift +++ b/swift/Sources/ForyMacro/ForyObjectMacroReadGeneration.swift @@ -492,7 +492,10 @@ private func compatibleDefaultDecl(_ field: ParsedField) -> String { } private func fieldNeedsGeneralSchemaRead(_ field: ParsedField) -> Bool { - field.dynamicAnyCodec != nil || field.customCodecType != nil || field.isOptional || field.typeID == 27 + field.dynamicAnyCodec != nil || + field.customCodecType != nil || + field.isOptional || + field.typeID == 27 } private func fieldNeedsGeneralCompatibleRead(_ field: ParsedField) -> Bool { diff --git a/swift/Tests/ForyTests/DateTimeTests.swift b/swift/Tests/ForyTests/DateTimeTests.swift index 599897a34d..1c8e693fa9 100644 --- a/swift/Tests/ForyTests/DateTimeTests.swift +++ b/swift/Tests/ForyTests/DateTimeTests.swift @@ -19,35 +19,40 @@ import Foundation import Testing @testable import Fory +private let secondsPerDay = 86_400.0 + @ForyObject private struct DateMacroHolder { - var day: ForyDate = .init() + var day: LocalDate = .foryDefault() + var instant: Date = .foryDefault() - var timestamp: ForyTimestamp = .init() + var timestamp: Date = .foryDefault() +} + +private func midnightUTC(daysSinceEpoch: Int32) -> Date { + Date(timeIntervalSince1970: Double(daysSinceEpoch) * secondsPerDay) +} + +private func localDate(_ daysSinceEpoch: Int32) -> LocalDate { + .init(daysSinceEpoch: daysSinceEpoch) } @Test func dateAndTimestampTypeIds() { - #expect(ForyDate.staticTypeId == .date) - #expect(ForyTimestamp.staticTypeId == .timestamp) #expect(Duration.staticTypeId == .duration) + #expect(LocalDate.staticTypeId == .date) #expect(Date.staticTypeId == .timestamp) } @Test func dateAndTimestampRoundTrip() throws { - let fory = Fory() + let fory = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) - let day = ForyDate(daysSinceEpoch: 18_745) + let day = localDate(18_745) let dayData = try fory.serialize(day) - let dayDecoded: ForyDate = try fory.deserialize(dayData) + let dayDecoded: LocalDate = try fory.deserialize(dayData) #expect(dayDecoded == day) - let ts = ForyTimestamp(seconds: -123, nanos: 987_654_321) - let tsData = try fory.serialize(ts) - let tsDecoded: ForyTimestamp = try fory.deserialize(tsData) - #expect(tsDecoded == ts) - let duration = Duration.seconds(-7) + Duration.nanoseconds(12_000_000) let durationData = try fory.serialize(duration) let durationDecoded: Duration = try fory.deserialize(durationData) @@ -60,22 +65,147 @@ func dateAndTimestampRoundTrip() throws { #expect(diff < 0.000_001) } +@Test +func localDateConvenienceMethodsExposeEpochAndCalendarViews() throws { + let beforeEpoch = LocalDate.fromEpochDay(-1) + let leapDay = LocalDate.fromEpochDay(19_782) + let epoch = try LocalDate(date: Date(timeIntervalSince1970: 0)) + + #expect(beforeEpoch.toEpochDay() == -1) + #expect(beforeEpoch.year == 1969) + #expect(beforeEpoch.month == 12) + #expect(beforeEpoch.day == 31) + #expect(leapDay.year == 2024) + #expect(leapDay.month == 2) + #expect(leapDay.day == 29) + #expect(epoch == .fromEpochDay(0)) + #expect(beforeEpoch < epoch) + #expect(abs(epoch.toDate().timeIntervalSince1970) < 0.000_001) +} + +@Test +func dateAndTimestampContextHelpersUseExpectedWireProtocols() throws { + let xlangWriteBuffer = ByteBuffer() + let xlangTypeResolver = TypeResolver(trackRef: false) + let xlangWriteContext = WriteContext( + buffer: xlangWriteBuffer, + typeResolver: xlangTypeResolver, + xlang: true, + trackRef: false, + compatible: true, + checkClassVersion: true, + maxDepth: 5 + ) + + let xlangLocalDate = localDate(-1) + try xlangWriteContext.writeLocalDate(xlangLocalDate, refMode: .nullOnly, writeTypeInfo: true) + #expect( + Array(xlangWriteBuffer.copyToData()) == [ + UInt8(bitPattern: RefFlag.notNullValue.rawValue), + UInt8(LocalDate.staticTypeId.rawValue), + 0x01, + ] + ) + + let xlangReadContext = ReadContext( + buffer: ByteBuffer(data: xlangWriteBuffer.copyToData()), + typeResolver: xlangTypeResolver, + xlang: true, + trackRef: false, + compatible: true, + checkClassVersion: true, + maxCollectionSize: 1_000_000, + maxBinarySize: 64 * 1024 * 1024, + maxDepth: 5 + ) + let xlangLocalDateDecoded = try xlangReadContext.readLocalDate(refMode: RefMode.nullOnly, readTypeInfo: true) + #expect(xlangLocalDateDecoded == xlangLocalDate) + + let writeBuffer = ByteBuffer() + let typeResolver = TypeResolver(trackRef: false) + let writeContext = WriteContext( + buffer: writeBuffer, + typeResolver: typeResolver, + xlang: false, + trackRef: false, + compatible: true, + checkClassVersion: true, + maxDepth: 5 + ) + + let localDate = localDate(-1) + + try writeContext.writeLocalDate(localDate, refMode: .nullOnly, writeTypeInfo: true) + + let readContext = ReadContext( + buffer: ByteBuffer(data: writeBuffer.copyToData()), + typeResolver: typeResolver, + xlang: false, + trackRef: false, + compatible: true, + checkClassVersion: true, + maxCollectionSize: 1_000_000, + maxBinarySize: 64 * 1024 * 1024, + maxDepth: 5 + ) + + let localDateDecoded = try readContext.readLocalDate(refMode: RefMode.nullOnly, readTypeInfo: true) + + #expect(localDateDecoded == localDate) + #expect(Array(writeBuffer.copyToData()) == [ + UInt8(bitPattern: RefFlag.notNullValue.rawValue), + UInt8(LocalDate.staticTypeId.rawValue), + 0xFF, + 0xFF, + 0xFF, + 0xFF, + ]) + + let timestampBuffer = ByteBuffer() + let timestampWriteContext = WriteContext( + buffer: timestampBuffer, + typeResolver: typeResolver, + xlang: false, + trackRef: false, + compatible: true, + checkClassVersion: true, + maxDepth: 5 + ) + let instant = Date(timeIntervalSince1970: 123_456.000_001) + try timestampWriteContext.writeTimestamp(instant, refMode: .nullOnly, writeTypeInfo: true) + + let timestampReadContext = ReadContext( + buffer: ByteBuffer(data: timestampBuffer.copyToData()), + typeResolver: typeResolver, + xlang: false, + trackRef: false, + compatible: true, + checkClassVersion: true, + maxCollectionSize: 1_000_000, + maxBinarySize: 64 * 1024 * 1024, + maxDepth: 5 + ) + let timestampDecoded = try timestampReadContext.readTimestamp(refMode: RefMode.nullOnly, readTypeInfo: true) + #expect(abs(timestampDecoded.timeIntervalSince1970 - instant.timeIntervalSince1970) < 0.000_001) +} + @Test func dateAndTimestampMacroFieldRoundTrip() throws { let fory = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) fory.register(DateMacroHolder.self, id: 901) let value = DateMacroHolder( - day: .init(daysSinceEpoch: 20_001), + day: localDate(20_001), instant: Date(timeIntervalSince1970: 123_456.000_001), - timestamp: .init(seconds: 44, nanos: 12_345) + timestamp: Date(timeIntervalSince1970: 44.000_012_345) ) let data = try fory.serialize(value) let decoded: DateMacroHolder = try fory.deserialize(data) #expect(decoded.day == value.day) - #expect(decoded.timestamp == value.timestamp) - let diff = abs(decoded.instant.timeIntervalSince1970 - value.instant.timeIntervalSince1970) - #expect(diff < 0.000_001) + let instantDiff = abs(decoded.instant.timeIntervalSince1970 - value.instant.timeIntervalSince1970) + #expect(instantDiff < 0.000_001) + let timestampDiff = abs(decoded.timestamp.timeIntervalSince1970 - value.timestamp.timeIntervalSince1970) + #expect(timestampDiff < 0.000_001) } diff --git a/swift/Tests/ForyTests/DecimalTests.swift b/swift/Tests/ForyTests/DecimalTests.swift new file mode 100644 index 0000000000..19485f3554 --- /dev/null +++ b/swift/Tests/ForyTests/DecimalTests.swift @@ -0,0 +1,123 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import Foundation +import Testing +@testable import Fory + +@ForyObject +private struct DecimalEnvelope: Equatable { + var amount: Decimal = .zero + var note: String = "" +} + +private func makeDecimal(unscaled: String, scale: Int32) throws -> Decimal { + var digits = unscaled + var sign = "" + if digits.first == "-" { + sign = "-" + digits.removeFirst() + } + guard !digits.isEmpty, digits.allSatisfy(\.isNumber) else { + throw ForyError.invalidData("failed to create decimal \(unscaled) scale \(scale)") + } + if scale == 0 { + guard let value = Decimal(string: sign + digits, locale: Locale(identifier: "en_US_POSIX")) else { + throw ForyError.invalidData("failed to create decimal \(unscaled) scale \(scale)") + } + return value + } + let valueString: String + if scale > 0 { + let scaleInt = Int(scale) + if digits.count > scaleInt { + let split = digits.index(digits.endIndex, offsetBy: -scaleInt) + valueString = sign + String(digits[.. String { + guard !unscaled.isEmpty else { + throw PeerError.invalidFieldValue("decimal unscaled value must not be empty") + } + + var digits = unscaled + var sign = "" + if digits.first == "-" { + sign = "-" + digits.removeFirst() + } else if digits.first == "+" { + digits.removeFirst() + } + guard !digits.isEmpty, digits.allSatisfy(\.isNumber) else { + throw PeerError.invalidFieldValue("decimal unscaled value must contain only digits") + } + + if scale == 0 { + return sign + digits + } + if scale > 0 { + let scaleInt = Int(scale) + if digits.count > scaleInt { + let split = digits.index(digits.endIndex, offsetBy: -scaleInt) + return sign + String(digits[.. Decimal { + let valueString = try decimalValueString(unscaled: unscaled, scale: scale) + guard let value = Decimal(string: valueString, locale: Locale(identifier: "en_US_POSIX")) else { + throw PeerError.invalidFieldValue("failed to parse decimal \(valueString)") + } + return value +} + +private func decimalValues() throws -> [Decimal] { + try [ + decimal("0", scale: 0), + decimal("0", scale: 3), + decimal("1", scale: 0), + decimal("-1", scale: 0), + decimal("12345", scale: 2), + decimal("9223372036854775807", scale: 0), + decimal("-9223372036854775808", scale: 0), + decimal("4611686018427387903", scale: 0), + decimal("-4611686018427387904", scale: 0), + decimal("9223372036854775808", scale: 0), + decimal("-9223372036854775809", scale: 0), + decimal("123456789012345678901234567890123456789", scale: 37), + decimal("-123456789012345678901234567890123456789", scale: -17), + ] +} + private func verifyBufferCase(_ caseName: String, _ payload: [UInt8]) throws -> [UInt8] { let inputBuffer = ByteBuffer(bytes: payload) let outputBuffer = ByteBuffer(capacity: payload.count) @@ -508,8 +566,8 @@ private func handleCrossLanguageSerializer(_ bytes: [UInt8]) throws -> [UInt8] { let f32: Float = try fory.deserialize(from: buffer) let f64: Double = try fory.deserialize(from: buffer) let str: String = try fory.deserialize(from: buffer) - let day: ForyDate = try fory.deserialize(from: buffer) - let ts: ForyTimestamp = try fory.deserialize(from: buffer) + let day: LocalDate = try fory.deserialize(from: buffer) + let ts: Date = try fory.deserialize(from: buffer) let boolArray: [Bool] = try fory.deserialize(from: buffer) let byteArray: [UInt8] = try fory.deserialize(from: buffer) let shortArray: [Int16] = try fory.deserialize(from: buffer) @@ -651,6 +709,24 @@ private func handleColor(_ bytes: [UInt8]) throws -> [UInt8] { } } +private func handleDecimal(_ bytes: [UInt8]) throws -> [UInt8] { + let expectedValues = try decimalValues() + let fory = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) + let rewritten = try roundTripStream(bytes) { buffer, out in + for expected in expectedValues { + let value: Decimal = try fory.deserialize(from: buffer) + guard value == expected else { + throw PeerError.invalidFieldValue("unexpected decimal value \(value), expected \(expected)") + } + try fory.serialize(value, to: &out) + } + } + guard rewritten == bytes else { + throw ForyError.invalidData("decimal roundtrip bytes differ from the Java payload") + } + return rewritten +} + private func handleStructWithList(_ bytes: [UInt8]) throws -> [UInt8] { let fory = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) fory.register(StructWithList.self, id: 201) @@ -919,6 +995,8 @@ private func rewritePayload(caseName: String, bytes: [UInt8]) throws -> [UInt8] return try handleMap(bytes) case "test_integer": return try handleInteger(bytes) + case "test_decimal": + return try handleDecimal(bytes) case "test_item": return try handleItem(bytes) case "test_color":