diff --git a/AGENTS.md b/AGENTS.md index 5ac63b639e..f0a0c799b0 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -67,6 +67,7 @@ This is the entry point for AI guidance in Apache Fory. Read this file first, th - Preserve protocol compatibility across languages. - Read and respect `docs/specification/xlang_type_mapping.md` when changing cross-language type behavior. - Handle byte order correctly for cross-platform compatibility. +- If the reference implementation is not right, do not tweak another language's correct implementation to align with a wrong reference implementation just to make tests pass; fix the runtime that diverged from the spec. ## Git And Review Rules diff --git a/cpp/README.md b/cpp/README.md index 735fad2940..a26dd27d09 100644 --- a/cpp/README.md +++ b/cpp/README.md @@ -13,6 +13,7 @@ The C++ implementation provides high-performance serialization with compile-time - **Type-Safe**: Compile-time type checking with template specialization - **Shared References**: Automatic tracking of shared and circular references - **Schema Evolution**: Compatible mode for independent schema changes +- **Reduced-Precision Types**: `fory::float16_t` and `fory::bfloat16_t` scalars with dense `std::vector<...>` array carriers - **Two Formats**: Object graph serialization and zero-copy row-based format - **Modern C++17**: Clean API using modern C++ features diff --git a/cpp/fory/serialization/array_serializer.h b/cpp/fory/serialization/array_serializer.h index 344cf748e0..aef3851a49 100644 --- a/cpp/fory/serialization/array_serializer.h +++ b/cpp/fory/serialization/array_serializer.h @@ -402,5 +402,118 @@ template struct Serializer> { } }; +/// Serializer for std::array +template struct Serializer> { + static constexpr TypeId type_id = TypeId::BFLOAT16_ARRAY; + + static inline void write_type_info(WriteContext &ctx) { + ctx.write_uint8(static_cast(type_id)); + } + + static inline void read_type_info(ReadContext &ctx) { + uint32_t actual = ctx.read_uint8(ctx.error()); + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return; + } + if (!type_id_matches(actual, static_cast(type_id))) { + ctx.set_error( + Error::type_mismatch(actual, static_cast(type_id))); + } + } + + static inline void write(const std::array &arr, + WriteContext &ctx, RefMode ref_mode, bool write_type, + bool has_generics = false) { + write_not_null_ref_flag(ctx, ref_mode); + if (write_type) { + ctx.write_uint8(static_cast(type_id)); + } + write_data(arr, ctx); + } + + static inline void write_data(const std::array &arr, + WriteContext &ctx) { + Buffer &buffer = ctx.buffer(); + constexpr size_t max_size = 8 + N * sizeof(bfloat16_t); + buffer.grow(static_cast(max_size)); + uint32_t writer_index = buffer.writer_index(); + writer_index += buffer.put_var_uint32( + writer_index, static_cast(N * sizeof(bfloat16_t))); + if constexpr (N > 0) { + if constexpr (FORY_LITTLE_ENDIAN) { + buffer.unsafe_put(writer_index, arr.data(), N * sizeof(bfloat16_t)); + } else { + for (size_t i = 0; i < N; ++i) { + uint16_t bits = util::to_little_endian(arr[i].to_bits()); + buffer.unsafe_put(writer_index + i * sizeof(bfloat16_t), &bits, + sizeof(bfloat16_t)); + } + } + } + buffer.writer_index(writer_index + N * sizeof(bfloat16_t)); + } + + static inline void write_data_generic(const std::array &arr, + WriteContext &ctx, bool has_generics) { + write_data(arr, ctx); + } + + static inline std::array + read(ReadContext &ctx, RefMode ref_mode, bool read_type) { + bool has_value = read_null_only_flag(ctx, ref_mode); + if (ctx.has_error() || !has_value) { + return std::array(); + } + if (read_type) { + uint32_t type_id_read = ctx.read_uint8(ctx.error()); + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return std::array(); + } + if (type_id_read != static_cast(type_id)) { + ctx.set_error( + Error::type_mismatch(type_id_read, static_cast(type_id))); + return std::array(); + } + } + return read_data(ctx); + } + + static inline std::array read_data(ReadContext &ctx) { + uint32_t size_bytes = ctx.read_var_uint32(ctx.error()); + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return std::array(); + } + uint32_t length = size_bytes / sizeof(bfloat16_t); + if (length != N) { + ctx.set_error(Error::invalid_data("Array size mismatch: expected " + + std::to_string(N) + " but got " + + std::to_string(length))); + return std::array(); + } + std::array arr; + if constexpr (N > 0) { + if constexpr (FORY_LITTLE_ENDIAN) { + ctx.read_bytes(arr.data(), N * sizeof(bfloat16_t), ctx.error()); + } else { + for (size_t i = 0; i < N; ++i) { + uint16_t bits; + ctx.read_bytes(&bits, sizeof(bfloat16_t), ctx.error()); + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return arr; + } + arr[i] = bfloat16_t::from_bits(util::to_little_endian(bits)); + } + } + } + return arr; + } + + static inline std::array + read_with_type_info(ReadContext &ctx, RefMode ref_mode, + const TypeInfo &type_info) { + return read(ctx, ref_mode, false); + } +}; + } // namespace serialization } // namespace fory diff --git a/cpp/fory/serialization/basic_serializer.h b/cpp/fory/serialization/basic_serializer.h index fb4210fa9f..13ab2a3116 100644 --- a/cpp/fory/serialization/basic_serializer.h +++ b/cpp/fory/serialization/basic_serializer.h @@ -22,6 +22,7 @@ #include "fory/serialization/context.h" #include "fory/serialization/serializer_traits.h" #include "fory/type/type.h" +#include "fory/util/bfloat16.h" #include "fory/util/error.h" #include "fory/util/float16.h" #include @@ -603,6 +604,77 @@ template <> struct Serializer { } }; +/// bfloat16_t serializer +template <> struct Serializer { + static constexpr TypeId type_id = TypeId::BFLOAT16; + + static inline void write_type_info(WriteContext &ctx) { + ctx.write_uint8(static_cast(type_id)); + } + + static inline void read_type_info(ReadContext &ctx) { + uint32_t actual = ctx.read_uint8(ctx.error()); + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return; + } + if (actual != static_cast(type_id)) { + ctx.set_error( + Error::type_mismatch(actual, static_cast(type_id))); + } + } + + static inline void write(bfloat16_t value, WriteContext &ctx, + RefMode ref_mode, bool write_type, bool = false) { + write_not_null_ref_flag(ctx, ref_mode); + if (write_type) { + ctx.write_uint8(static_cast(type_id)); + } + write_data(value, ctx); + } + + static inline void write_data(bfloat16_t value, WriteContext &ctx) { + ctx.write_bytes(&value, sizeof(bfloat16_t)); + } + + static inline void write_data_generic(bfloat16_t value, WriteContext &ctx, + bool) { + write_data(value, ctx); + } + + static inline bfloat16_t read(ReadContext &ctx, RefMode ref_mode, + bool read_type) { + bool has_value = read_null_only_flag(ctx, ref_mode); + if (ctx.has_error() || !has_value) { + return bfloat16_t::from_bits(0); + } + if (read_type) { + uint32_t type_id_read = ctx.read_uint8(ctx.error()); + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return bfloat16_t::from_bits(0); + } + if (type_id_read != static_cast(type_id)) { + ctx.set_error( + Error::type_mismatch(type_id_read, static_cast(type_id))); + return bfloat16_t::from_bits(0); + } + } + return ctx.read_bf16(ctx.error()); + } + + static inline bfloat16_t read_data(ReadContext &ctx) { + return ctx.read_bf16(ctx.error()); + } + + static inline bfloat16_t read_data_generic(ReadContext &ctx, bool) { + return read_data(ctx); + } + + static inline bfloat16_t + read_with_type_info(ReadContext &ctx, RefMode ref_mode, const TypeInfo &) { + return read(ctx, ref_mode, false); + } +}; + // ============================================================================ // Character Type Serializers (C++ native only, not supported in xlang mode) // ============================================================================ diff --git a/cpp/fory/serialization/collection_serializer.h b/cpp/fory/serialization/collection_serializer.h index a53e53d858..0dadca10f3 100644 --- a/cpp/fory/serialization/collection_serializer.h +++ b/cpp/fory/serialization/collection_serializer.h @@ -751,6 +751,112 @@ template struct Serializer> { } }; +/// Vector serializer for bfloat16_t — typed array path (BFLOAT16_ARRAY). +template struct Serializer> { + static constexpr TypeId type_id = TypeId::BFLOAT16_ARRAY; + + static inline void write_type_info(WriteContext &ctx) { + ctx.write_uint8(static_cast(type_id)); + } + + static inline void read_type_info(ReadContext &ctx) { + uint32_t actual = ctx.read_uint8(ctx.error()); + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return; + } + if (!type_id_matches(actual, static_cast(type_id))) { + ctx.set_error( + Error::type_mismatch(actual, static_cast(type_id))); + } + } + + static inline void write(const std::vector &vec, + WriteContext &ctx, RefMode ref_mode, bool write_type, + bool has_generics = false) { + write_not_null_ref_flag(ctx, ref_mode); + if (write_type) { + ctx.write_uint8(static_cast(type_id)); + } + write_data(vec, ctx); + } + + static inline void write_data(const std::vector &vec, + WriteContext &ctx) { + uint64_t total_bytes = + static_cast(vec.size()) * sizeof(bfloat16_t); + if (total_bytes > std::numeric_limits::max()) { + ctx.set_error(Error::invalid("Vector byte size exceeds uint32_t range")); + return; + } + Buffer &buffer = ctx.buffer(); + size_t max_size = 8 + total_bytes; + buffer.grow(static_cast(max_size)); + uint32_t writer_index = buffer.writer_index(); + writer_index += + buffer.put_var_uint32(writer_index, static_cast(total_bytes)); + if (total_bytes > 0) { + buffer.unsafe_put(writer_index, vec.data(), + static_cast(total_bytes)); + } + buffer.writer_index(writer_index + static_cast(total_bytes)); + } + + static inline void + write_data_generic(const std::vector &vec, + WriteContext &ctx, bool has_generics) { + write_data(vec, ctx); + } + + static inline std::vector + read(ReadContext &ctx, RefMode ref_mode, bool read_type) { + bool has_value = read_null_only_flag(ctx, ref_mode); + if (ctx.has_error() || !has_value) { + return std::vector(); + } + if (read_type) { + uint32_t type_id_read = ctx.read_uint8(ctx.error()); + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return std::vector(); + } + if (type_id_read != static_cast(type_id)) { + ctx.set_error( + Error::type_mismatch(type_id_read, static_cast(type_id))); + return std::vector(); + } + } + return read_data(ctx); + } + + static inline std::vector + read_with_type_info(ReadContext &ctx, RefMode ref_mode, + const TypeInfo &type_info) { + return read(ctx, ref_mode, false); + } + + static inline std::vector read_data(ReadContext &ctx) { + uint32_t total_bytes_u32 = ctx.read_var_uint32(ctx.error()); + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return std::vector(); + } + if (FORY_PREDICT_FALSE(total_bytes_u32 > ctx.config().max_binary_size)) { + ctx.set_error(Error::invalid_data("Binary size exceeds max_binary_size")); + return std::vector(); + } + size_t elem_count = total_bytes_u32 / sizeof(bfloat16_t); + if (total_bytes_u32 % sizeof(bfloat16_t) != 0) { + ctx.set_error(Error::invalid_data( + "Vector byte size not aligned with bfloat16_t element size")); + return std::vector(); + } + std::vector result(elem_count); + if (total_bytes_u32 > 0) { + ctx.read_bytes(result.data(), static_cast(total_bytes_u32), + ctx.error()); + } + return result; + } +}; + /// Vector serializer for non-bool, non-arithmetic types template struct Serializer< diff --git a/cpp/fory/serialization/context.h b/cpp/fory/serialization/context.h index 1dd270f56f..9b80630552 100644 --- a/cpp/fory/serialization/context.h +++ b/cpp/fory/serialization/context.h @@ -563,6 +563,11 @@ class ReadContext { return buffer().read_f16(error); } + /// Read bfloat16_t from buffer. Sets error on failure. + FORY_ALWAYS_INLINE bfloat16_t read_bf16(Error &error) { + return buffer().read_bf16(error); + } + /// Read uint32_t value as varint from buffer. Sets error on failure. FORY_ALWAYS_INLINE uint32_t read_var_uint32(Error &error) { return buffer().read_var_uint32(error); diff --git a/cpp/fory/serialization/skip.cc b/cpp/fory/serialization/skip.cc index fdb06dd037..9835e0b453 100644 --- a/cpp/fory/serialization/skip.cc +++ b/cpp/fory/serialization/skip.cc @@ -630,33 +630,13 @@ void skip_field_value(ReadContext &ctx, const FieldType &field_type, return; } } - // Read array length - uint32_t len = ctx.read_var_uint32(ctx.error()); + // Typed primitive arrays encode payload size in bytes, not element count. + uint32_t payload_size = ctx.read_var_uint32(ctx.error()); if (FORY_PREDICT_FALSE(ctx.has_error())) { return; } - // Calculate element size - size_t elem_size = 1; - switch (tid) { - case TypeId::INT16_ARRAY: - case TypeId::FLOAT16_ARRAY: - case TypeId::BFLOAT16_ARRAY: - elem_size = 2; - break; - case TypeId::INT32_ARRAY: - case TypeId::FLOAT32_ARRAY: - elem_size = 4; - break; - case TypeId::INT64_ARRAY: - case TypeId::FLOAT64_ARRAY: - elem_size = 8; - break; - default: - break; - } - - ctx.buffer().increase_reader_index(len * elem_size, ctx.error()); + ctx.buffer().increase_reader_index(payload_size, ctx.error()); return; } diff --git a/cpp/fory/serialization/struct_serializer.h b/cpp/fory/serialization/struct_serializer.h index 71fcaff9e2..cdd4a5f40d 100644 --- a/cpp/fory/serialization/struct_serializer.h +++ b/cpp/fory/serialization/struct_serializer.h @@ -142,6 +142,9 @@ FORY_ALWAYS_INLINE uint32_t put_primitive_at(T value, Buffer &buffer, } else if constexpr (std::is_same_v) { buffer.unsafe_put(offset, value.to_bits()); return 2; + } else if constexpr (std::is_same_v) { + buffer.unsafe_put(offset, value.to_bits()); + return 2; } else if constexpr (std::is_same_v) { buffer.unsafe_put(offset, value); return 4; @@ -180,6 +183,8 @@ FORY_ALWAYS_INLINE void put_fixed_primitive_at(T value, Buffer &buffer, buffer.unsafe_put(offset, static_cast(value)); } else if constexpr (std::is_same_v) { buffer.unsafe_put(offset, value.to_bits()); + } else if constexpr (std::is_same_v) { + buffer.unsafe_put(offset, value.to_bits()); } else if constexpr (std::is_same_v) { buffer.unsafe_put(offset, value); } else if constexpr (std::is_same_v) { @@ -774,6 +779,7 @@ template struct CompileTimeFieldHelpers { std::is_same_v || std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || std::is_same_v; } @@ -814,7 +820,8 @@ template struct CompileTimeFieldHelpers { return 1; } else if constexpr (std::is_same_v || std::is_same_v || - std::is_same_v) { + std::is_same_v || + std::is_same_v) { return 2; } else if constexpr (is_configurable_int_v) { return configurable_int_fixed_size_bytes(); @@ -1212,7 +1219,7 @@ template struct CompileTimeFieldHelpers { if (sa != sb) return sa > sb; if (a_tid != b_tid) - return a_tid > b_tid; // type_id descending to match Java + return a_tid < b_tid; // type_id ascending int cmp = compare_identifier(a, b); if (cmp != 0) { return cmp < 0; @@ -2068,6 +2075,7 @@ template <> struct is_raw_primitive : std::true_type {}; template <> struct is_raw_primitive : std::true_type {}; template <> struct is_raw_primitive : std::true_type {}; template <> struct is_raw_primitive : std::true_type {}; +template <> struct is_raw_primitive : std::true_type {}; template <> struct is_raw_primitive : std::true_type {}; template <> struct is_raw_primitive : std::true_type {}; template @@ -2148,6 +2156,16 @@ FORY_ALWAYS_INLINE float16_t read_primitive_by_type_id( read_primitive_by_type_id(ctx, type_id, error)); } +template <> +FORY_ALWAYS_INLINE bfloat16_t read_primitive_by_type_id( + ReadContext &ctx, uint32_t type_id, Error &error) { + if (static_cast(type_id) == TypeId::BFLOAT16) { + return ctx.read_bf16(error); + } + return bfloat16_t::from_float( + read_primitive_by_type_id(ctx, type_id, error)); +} + /// Helper to read a primitive field directly using Error* pattern. /// This bypasses Serializer::read for better performance. /// Returns the read value; sets error on failure. @@ -2188,6 +2206,8 @@ FORY_ALWAYS_INLINE FieldType read_primitive_field_direct(ReadContext &ctx, return static_cast(ctx.read_int64(error)); } else if constexpr (std::is_same_v) { return ctx.read_f16(error); + } else if constexpr (std::is_same_v) { + return ctx.read_bf16(error); } else if constexpr (std::is_same_v) { return ctx.read_float(error); } else if constexpr (std::is_same_v) { @@ -2554,7 +2574,8 @@ template constexpr size_t fixed_primitive_size() { return 1; } else if constexpr (std::is_same_v || std::is_same_v || - std::is_same_v) { + std::is_same_v || + std::is_same_v) { return 2; } else if constexpr (std::is_same_v || std::is_same_v || std::is_same_v || @@ -2607,6 +2628,8 @@ FORY_ALWAYS_INLINE T read_fixed_primitive_at(Buffer &buffer, uint32_t offset) { return static_cast(buffer.unsafe_get(offset)); } else if constexpr (std::is_same_v) { return float16_t::from_bits(buffer.unsafe_get(offset)); + } else if constexpr (std::is_same_v) { + return bfloat16_t::from_bits(buffer.unsafe_get(offset)); } else if constexpr (std::is_same_v) { return buffer.unsafe_get(offset); } else if constexpr (std::is_same_v || diff --git a/cpp/fory/serialization/type_resolver.cc b/cpp/fory/serialization/type_resolver.cc index d945831993..d5038ddd64 100644 --- a/cpp/fory/serialization/type_resolver.cc +++ b/cpp/fory/serialization/type_resolver.cc @@ -799,7 +799,7 @@ bool numeric_sorter(const FieldInfo &a, const FieldInfo &b) { int32_t size_b = get_primitive_type_size(b_id); // Sort by: nullable (false first), compress (false first), size (larger - // first), type_id (descending to match Java), field_name + // first), type_id (ascending), field_name if (a_nullable != b_nullable) return !a_nullable; // non-nullable first if (compress_a != compress_b) @@ -807,7 +807,7 @@ bool numeric_sorter(const FieldInfo &a, const FieldInfo &b) { if (size_a != size_b) return size_a > size_b; // larger size first if (a_id != b_id) - return a_id > b_id; // type_id descending to match Java + return a_id < b_id; // type_id ascending std::string a_key = field_sort_key(a); std::string b_key = field_sort_key(b); if (a_key != b_key) { diff --git a/cpp/fory/serialization/xlang_test_main.cc b/cpp/fory/serialization/xlang_test_main.cc index e8f4247de6..db6d02a827 100644 --- a/cpp/fory/serialization/xlang_test_main.cc +++ b/cpp/fory/serialization/xlang_test_main.cc @@ -271,11 +271,8 @@ struct AnimalMapHolder { // ============================================================================ struct EmptyStructEvolution { - bool placeholder = false; // C++ templates require at least one field - bool operator==(const EmptyStructEvolution &other) const { - return placeholder == other.placeholder; - } - FORY_STRUCT(EmptyStructEvolution, placeholder); + bool operator==(const EmptyStructEvolution &) const { return true; } + FORY_STRUCT(EmptyStructEvolution); }; struct OneStringFieldStruct { @@ -295,6 +292,21 @@ struct TwoStringFieldStruct { FORY_STRUCT(TwoStringFieldStruct, f1, f2); }; +struct ReducedPrecisionFloatStruct { + fory::float16_t float16_value; + fory::bfloat16_t bfloat16_value; + std::vector float16_array; + std::vector bfloat16_array; + bool operator==(const ReducedPrecisionFloatStruct &other) const { + return float16_value == other.float16_value && + bfloat16_value == other.bfloat16_value && + float16_array == other.float16_array && + bfloat16_array == other.bfloat16_array; + } + FORY_STRUCT(ReducedPrecisionFloatStruct, float16_value, bfloat16_value, + float16_array, bfloat16_array); +}; + enum class TestEnum : int32_t { VALUE_A = 0, VALUE_B = 1, VALUE_C = 2 }; FORY_ENUM(TestEnum, VALUE_A, VALUE_B, VALUE_C); @@ -923,6 +935,9 @@ void run_test_one_string_field_compatible(const std::string &data_file); void run_test_two_string_field_compatible(const std::string &data_file); void run_test_schema_evolution_compatible(const std::string &data_file); void run_test_schema_evolution_compatible_reverse(const std::string &data_file); +void run_test_reduced_precision_float_struct(const std::string &data_file); +void run_test_reduced_precision_float_struct_compatible_skip( + const std::string &data_file); void run_test_one_enum_field_schema(const std::string &data_file); void run_test_one_enum_field_compatible(const std::string &data_file); void run_test_two_enum_field_compatible(const std::string &data_file); @@ -1021,6 +1036,11 @@ int main(int argc, char **argv) { run_test_schema_evolution_compatible(data_file); } else if (case_name == "test_schema_evolution_compatible_reverse") { run_test_schema_evolution_compatible_reverse(data_file); + } else if (case_name == "test_reduced_precision_float_struct") { + run_test_reduced_precision_float_struct(data_file); + } else if (case_name == + "test_reduced_precision_float_struct_compatible_skip") { + run_test_reduced_precision_float_struct_compatible_skip(data_file); } else if (case_name == "test_one_enum_field_schema") { run_test_one_enum_field_schema(data_file); } else if (case_name == "test_one_enum_field_compatible") { @@ -2191,6 +2211,54 @@ void run_test_schema_evolution_compatible_reverse( write_file(data_file, out); } +void run_test_reduced_precision_float_struct(const std::string &data_file) { + auto bytes = read_file(data_file); + auto fory = build_fory(false, true); + ensure_ok(fory.register_struct(213), + "register ReducedPrecisionFloatStruct"); + + Buffer buffer = make_buffer(bytes); + auto value = read_next(fory, buffer); + + if (value.float16_value.to_bits() != 0x3E00u) { + fail("ReducedPrecisionFloatStruct float16_value mismatch"); + } + if (value.bfloat16_value.to_bits() != 0x3FC0u) { + fail("ReducedPrecisionFloatStruct bfloat16_value mismatch"); + } + if (value.float16_array.size() != 3 || + value.float16_array[0].to_bits() != 0x0000u || + value.float16_array[1].to_bits() != 0x3C00u || + value.float16_array[2].to_bits() != 0xBC00u) { + fail("ReducedPrecisionFloatStruct float16_array mismatch"); + } + if (value.bfloat16_array.size() != 3 || + value.bfloat16_array[0].to_bits() != 0x0000u || + value.bfloat16_array[1].to_bits() != 0x3F80u || + value.bfloat16_array[2].to_bits() != 0xBF80u) { + fail("ReducedPrecisionFloatStruct bfloat16_array mismatch"); + } + + std::vector out; + append_serialized(fory, value, out); + write_file(data_file, out); +} + +void run_test_reduced_precision_float_struct_compatible_skip( + const std::string &data_file) { + auto bytes = read_file(data_file); + auto fory = build_fory(true, true); + ensure_ok(fory.register_struct(213), + "register EmptyStructEvolution for reduced precision skip"); + + Buffer buffer = make_buffer(bytes); + auto value = read_next(fory, buffer); + + std::vector out; + append_serialized(fory, value, out); + write_file(data_file, out); +} + // ============================================================================ // Schema Evolution Tests - Enum Fields // ============================================================================ diff --git a/cpp/fory/util/bfloat16.h b/cpp/fory/util/bfloat16.h new file mode 100644 index 0000000000..0a22a0c003 --- /dev/null +++ b/cpp/fory/util/bfloat16.h @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace fory { + +/// Public carrier for xlang bfloat16 values. +/// +/// Use `from_bits()` and `to_bits()` for exact wire control, or `from_float()` +/// and `to_float()` for numeric conversion. The canonical dense array carrier +/// is `std::vector`. +struct bfloat16_t { + uint16_t bits; + + [[nodiscard]] uint16_t to_bits() const noexcept { return bits; } + [[nodiscard]] static bfloat16_t from_bits(uint16_t bits) noexcept { + return bfloat16_t{bits}; + } + + [[nodiscard]] float to_float() const noexcept { + const uint32_t raw = static_cast(bits) << 16; + float value = 0.0f; + std::memcpy(&value, &raw, sizeof(value)); + return value; + } + + [[nodiscard]] static bfloat16_t from_float(float value) noexcept { + uint32_t raw = 0; + std::memcpy(&raw, &value, sizeof(raw)); + if ((raw & 0x7F800000u) == 0x7F800000u && (raw & 0x007FFFFFu) != 0u) { + return from_bits(0x7FC0u); + } + const uint32_t lsb = (raw >> 16) & 1u; + const uint32_t rounding_bias = 0x7FFFu + lsb; + return from_bits(static_cast((raw + rounding_bias) >> 16)); + } + + [[nodiscard]] static bool is_nan(bfloat16_t v) noexcept { + return (v.bits & 0x7F80u) == 0x7F80u && (v.bits & 0x007Fu) != 0u; + } + [[nodiscard]] static bool is_inf(bfloat16_t v) noexcept { + return (v.bits & 0x7FFFu) == 0x7F80u; + } + [[nodiscard]] static bool is_inf(bfloat16_t v, int sign) noexcept { + if (sign == 0) { + return is_inf(v); + } + return sign > 0 ? v.bits == 0x7F80u : v.bits == 0xFF80u; + } + [[nodiscard]] static bool is_zero(bfloat16_t v) noexcept { + return (v.bits & 0x7FFFu) == 0u; + } + [[nodiscard]] static bool signbit(bfloat16_t v) noexcept { + return (v.bits & 0x8000u) != 0u; + } + [[nodiscard]] static bool is_subnormal(bfloat16_t v) noexcept { + return (v.bits & 0x7F80u) == 0u && (v.bits & 0x007Fu) != 0u; + } + [[nodiscard]] static bool is_normal(bfloat16_t v) noexcept { + const uint16_t exp = v.bits & 0x7F80u; + return exp != 0u && exp != 0x7F80u; + } + [[nodiscard]] static bool is_finite(bfloat16_t v) noexcept { + return (v.bits & 0x7F80u) != 0x7F80u; + } + [[nodiscard]] static bool equal(bfloat16_t a, bfloat16_t b) noexcept { + if (is_nan(a) || is_nan(b)) { + return false; + } + if (is_zero(a) && is_zero(b)) { + return true; + } + return a.bits == b.bits; + } + [[nodiscard]] static bool less(bfloat16_t a, bfloat16_t b) noexcept { + if (is_nan(a) || is_nan(b)) { + return false; + } + if (is_zero(a) && is_zero(b)) { + return false; + } + const bool neg_a = signbit(a); + const bool neg_b = signbit(b); + if (neg_a != neg_b) { + return neg_a; + } + return neg_a ? a.bits > b.bits : a.bits < b.bits; + } + [[nodiscard]] static bool less_eq(bfloat16_t a, bfloat16_t b) noexcept { + return equal(a, b) || less(a, b); + } + [[nodiscard]] static bool greater(bfloat16_t a, bfloat16_t b) noexcept { + return less(b, a); + } + [[nodiscard]] static bool greater_eq(bfloat16_t a, bfloat16_t b) noexcept { + return equal(a, b) || greater(a, b); + } + [[nodiscard]] static int compare(bfloat16_t a, bfloat16_t b) noexcept { + if (is_nan(a) || is_nan(b)) { + return 0; + } + if (equal(a, b)) { + return 0; + } + return less(a, b) ? -1 : 1; + } + [[nodiscard]] static std::string to_string(bfloat16_t v) { + return std::to_string(v.to_float()); + } + [[nodiscard]] static bfloat16_t add(bfloat16_t a, bfloat16_t b) noexcept { + return from_float(a.to_float() + b.to_float()); + } + [[nodiscard]] static bfloat16_t sub(bfloat16_t a, bfloat16_t b) noexcept { + return from_float(a.to_float() - b.to_float()); + } + [[nodiscard]] static bfloat16_t mul(bfloat16_t a, bfloat16_t b) noexcept { + return from_float(a.to_float() * b.to_float()); + } + [[nodiscard]] static bfloat16_t div(bfloat16_t a, bfloat16_t b) noexcept { + return from_float(a.to_float() / b.to_float()); + } + [[nodiscard]] static bfloat16_t neg(bfloat16_t a) noexcept { + return from_bits(static_cast(a.bits ^ 0x8000u)); + } + [[nodiscard]] static bfloat16_t abs(bfloat16_t a) noexcept { + return from_bits(static_cast(a.bits & 0x7FFFu)); + } + + bfloat16_t &operator+=(bfloat16_t rhs) noexcept { + *this = add(*this, rhs); + return *this; + } + bfloat16_t &operator-=(bfloat16_t rhs) noexcept { + *this = sub(*this, rhs); + return *this; + } + bfloat16_t &operator*=(bfloat16_t rhs) noexcept { + *this = mul(*this, rhs); + return *this; + } + bfloat16_t &operator/=(bfloat16_t rhs) noexcept { + *this = div(*this, rhs); + return *this; + } +}; + +static_assert(sizeof(bfloat16_t) == 2); +static_assert(std::is_trivial_v); +static_assert(std::is_standard_layout_v); + +[[nodiscard]] inline bfloat16_t operator+(bfloat16_t a, bfloat16_t b) noexcept { + return bfloat16_t::add(a, b); +} +[[nodiscard]] inline bfloat16_t operator-(bfloat16_t a, bfloat16_t b) noexcept { + return bfloat16_t::sub(a, b); +} +[[nodiscard]] inline bfloat16_t operator*(bfloat16_t a, bfloat16_t b) noexcept { + return bfloat16_t::mul(a, b); +} +[[nodiscard]] inline bfloat16_t operator/(bfloat16_t a, bfloat16_t b) noexcept { + return bfloat16_t::div(a, b); +} +[[nodiscard]] inline bfloat16_t operator-(bfloat16_t a) noexcept { + return bfloat16_t::neg(a); +} +[[nodiscard]] inline bfloat16_t operator+(bfloat16_t a) noexcept { return a; } +[[nodiscard]] inline bool operator==(bfloat16_t a, bfloat16_t b) noexcept { + return bfloat16_t::equal(a, b); +} +[[nodiscard]] inline bool operator!=(bfloat16_t a, bfloat16_t b) noexcept { + return !bfloat16_t::equal(a, b); +} +[[nodiscard]] inline bool operator<(bfloat16_t a, bfloat16_t b) noexcept { + return bfloat16_t::less(a, b); +} +[[nodiscard]] inline bool operator<=(bfloat16_t a, bfloat16_t b) noexcept { + return bfloat16_t::less_eq(a, b); +} +[[nodiscard]] inline bool operator>(bfloat16_t a, bfloat16_t b) noexcept { + return bfloat16_t::greater(a, b); +} +[[nodiscard]] inline bool operator>=(bfloat16_t a, bfloat16_t b) noexcept { + return bfloat16_t::greater_eq(a, b); +} + +} // namespace fory + +namespace std { +template <> struct hash { + size_t operator()(fory::bfloat16_t v) const noexcept { + uint16_t bits = fory::bfloat16_t::is_zero(v) ? 0u : v.to_bits(); + return std::hash{}(bits); + } +}; +} // namespace std diff --git a/cpp/fory/util/buffer.h b/cpp/fory/util/buffer.h index 226ade1380..d64a87c9a6 100644 --- a/cpp/fory/util/buffer.h +++ b/cpp/fory/util/buffer.h @@ -27,6 +27,7 @@ #include #include +#include "fory/util/bfloat16.h" #include "fory/util/bit_util.h" #include "fory/util/error.h" #include "fory/util/float16.h" @@ -763,6 +764,14 @@ class Buffer { increase_writer_index(2); } + /// Write bfloat16_t as fixed 2 bytes (raw IEEE 754 bits, little-endian). + /// Automatically grows buffer and advances writer index. + FORY_ALWAYS_INLINE void write_bf16(bfloat16_t value) { + grow(2); + unsafe_put(writer_index_, value.to_bits()); + increase_writer_index(2); + } + /// write uint32_t value as varint to buffer at current writer index. /// Automatically grows buffer and advances writer index. FORY_ALWAYS_INLINE void write_var_uint32(uint32_t value) { @@ -976,6 +985,17 @@ class Buffer { return value; } + /// Read bfloat16_t from buffer. Sets error on bounds violation. + FORY_ALWAYS_INLINE bfloat16_t read_bf16(Error &error) { + if (FORY_PREDICT_FALSE(!ensure_readable(2, error))) { + return bfloat16_t::from_bits(0); + } + bfloat16_t value = + bfloat16_t::from_bits(unsafe_get(reader_index_)); + reader_index_ += 2; + return value; + } + /// Read uint32_t value as varint from buffer. Sets error on bounds violation. FORY_ALWAYS_INLINE uint32_t read_var_uint32(Error &error) { if (FORY_PREDICT_FALSE(!ensure_readable(1, error))) { diff --git a/csharp/README.md b/csharp/README.md index b7db2f445c..f310a625ee 100644 --- a/csharp/README.md +++ b/csharp/README.md @@ -13,6 +13,7 @@ The C# implementation provides high-performance object graph serialization for . - Source-generator-based serializers for `[ForyObject]` types - Optional shared/circular reference tracking (`TrackRef(true)`) - Compatible mode for schema evolution +- Reduced-precision carriers for `Half` / `BFloat16` scalars and `Half[]` / `List` / `BFloat16[]` / `List` array payloads - Thread-safe runtime wrapper (`ThreadSafeFory`) for concurrent workloads - Dynamic object serialization APIs for heterogeneous payloads diff --git a/csharp/src/Fory.Generator/ForyObjectGenerator.cs b/csharp/src/Fory.Generator/ForyObjectGenerator.cs index b25ac6a4e8..3f9894e62f 100644 --- a/csharp/src/Fory.Generator/ForyObjectGenerator.cs +++ b/csharp/src/Fory.Generator/ForyObjectGenerator.cs @@ -187,6 +187,8 @@ private static void EmitObjectSerializer(StringBuilder sb, TypeModel model) sb.AppendLine(" global::Apache.Fory.TypeId.UInt64 or"); sb.AppendLine(" global::Apache.Fory.TypeId.VarUInt64 or"); sb.AppendLine(" global::Apache.Fory.TypeId.TaggedUInt64 or"); + sb.AppendLine(" global::Apache.Fory.TypeId.Float16 or"); + sb.AppendLine(" global::Apache.Fory.TypeId.BFloat16 or"); sb.AppendLine(" global::Apache.Fory.TypeId.Float32 or"); sb.AppendLine(" global::Apache.Fory.TypeId.Float64 or"); sb.AppendLine(" global::Apache.Fory.TypeId.String => true,"); @@ -213,6 +215,8 @@ private static void EmitObjectSerializer(StringBuilder sb, TypeModel model) sb.AppendLine(" global::Apache.Fory.TypeId.UInt64 => context.Reader.ReadUInt64(),"); sb.AppendLine(" global::Apache.Fory.TypeId.VarUInt64 => context.Reader.ReadVarUInt64(),"); sb.AppendLine(" global::Apache.Fory.TypeId.TaggedUInt64 => context.Reader.ReadTaggedUInt64(),"); + sb.AppendLine(" global::Apache.Fory.TypeId.Float16 => global::System.BitConverter.UInt16BitsToHalf(context.Reader.ReadUInt16()),"); + sb.AppendLine(" global::Apache.Fory.TypeId.BFloat16 => global::Apache.Fory.BFloat16.FromBits(context.Reader.ReadUInt16()),"); sb.AppendLine(" global::Apache.Fory.TypeId.Float32 => context.Reader.ReadFloat32(),"); sb.AppendLine(" global::Apache.Fory.TypeId.Float64 => context.Reader.ReadFloat64(),"); sb.AppendLine(" global::Apache.Fory.TypeId.String => global::Apache.Fory.StringSerializer.ReadString(context),"); @@ -906,6 +910,12 @@ private static bool TryBuildDirectPayloadWrite(uint typeId, string valueExpr, ou case 15: writeCode = $"context.Writer.WriteTaggedUInt64({valueExpr});"; return true; + case 17: + writeCode = $"context.Writer.WriteUInt16(global::System.BitConverter.HalfToUInt16Bits({valueExpr}));"; + return true; + case 18: + writeCode = $"context.Writer.WriteUInt16({valueExpr}.ToBits());"; + return true; case 19: writeCode = $"context.Writer.WriteFloat32({valueExpr});"; return true; @@ -970,6 +980,12 @@ private static bool TryBuildDirectPayloadRead(uint typeId, out string? readExpr) case 15: readExpr = "context.Reader.ReadTaggedUInt64()"; return true; + case 17: + readExpr = "global::System.BitConverter.UInt16BitsToHalf(context.Reader.ReadUInt16())"; + return true; + case 18: + readExpr = "global::Apache.Fory.BFloat16.FromBits(context.Reader.ReadUInt16())"; + return true; case 19: readExpr = "context.Reader.ReadFloat32()"; return true; @@ -1053,7 +1069,7 @@ private static string BuildSchemaFingerprintExpression(ImmutableArray loc.IsInSource); @@ -1479,10 +1500,10 @@ private static ImmutableArray SortMembers(ImmutableArray 49, // ushort -> uint16 array 12 => 50, // uint -> uint32 array 14 => 51, // ulong -> uint64 array + 17 => 53, // Half -> float16 array + 18 => 54, // BFloat16 -> bfloat16 array 19 => 55, // float -> float32 array 20 => 56, // double -> float64 array _ => null, @@ -1926,7 +1963,8 @@ private static string ToSnakeCase(string name) { bool prevUpper = char.IsUpper(name[i - 1]); bool nextUpperOrEnd = i + 1 >= name.Length || char.IsUpper(name[i + 1]); - if (!prevUpper || !nextUpperOrEnd) + bool leadingPascalBoundary = i == 1 && prevUpper && !nextUpperOrEnd; + if ((!prevUpper || !nextUpperOrEnd) && !leadingPascalBoundary) { sb.Append('_'); } diff --git a/csharp/src/Fory/AnySerializer.cs b/csharp/src/Fory/AnySerializer.cs index 2be087ca45..88bb69cf88 100644 --- a/csharp/src/Fory/AnySerializer.cs +++ b/csharp/src/Fory/AnySerializer.cs @@ -255,6 +255,12 @@ private static bool TryWriteKnownTypeInfo(object value, WriteContext context) case ulong: context.Writer.WriteUInt8((byte)TypeId.VarUInt64); return true; + case Half: + context.Writer.WriteUInt8((byte)TypeId.Float16); + return true; + case BFloat16: + context.Writer.WriteUInt8((byte)TypeId.BFloat16); + return true; case float: context.Writer.WriteUInt8((byte)TypeId.Float32); return true; @@ -291,6 +297,12 @@ private static bool TryWriteKnownTypeInfo(object value, WriteContext context) case ulong[]: context.Writer.WriteUInt8((byte)TypeId.UInt64Array); return true; + case Half[]: + context.Writer.WriteUInt8((byte)TypeId.Float16Array); + return true; + case BFloat16[]: + context.Writer.WriteUInt8((byte)TypeId.BFloat16Array); + return true; case float[]: context.Writer.WriteUInt8((byte)TypeId.Float32Array); return true; @@ -343,6 +355,12 @@ private static bool TryWriteKnownPayload(object value, WriteContext context) case ulong v: context.Writer.WriteVarUInt64(v); return true; + case Half v: + context.Writer.WriteUInt16(BitConverter.HalfToUInt16Bits(v)); + return true; + case BFloat16 v: + context.Writer.WriteUInt16(v.ToBits()); + return true; case float v: context.Writer.WriteFloat32(v); return true; @@ -424,6 +442,20 @@ private static bool TryWriteKnownPayload(object value, WriteContext context) context.Writer.WriteUInt64(v[i]); } return true; + case Half[] v: + context.Writer.WriteVarUInt32((uint)(v.Length * 2)); + for (int i = 0; i < v.Length; i++) + { + context.Writer.WriteUInt16(BitConverter.HalfToUInt16Bits(v[i])); + } + return true; + case BFloat16[] v: + context.Writer.WriteVarUInt32((uint)(v.Length * 2)); + for (int i = 0; i < v.Length; i++) + { + context.Writer.WriteUInt16(v[i].ToBits()); + } + return true; case float[] v: context.Writer.WriteVarUInt32((uint)(v.Length * 4)); for (int i = 0; i < v.Length; i++) diff --git a/csharp/src/Fory/BFloat16.cs b/csharp/src/Fory/BFloat16.cs new file mode 100644 index 0000000000..36963cf058 --- /dev/null +++ b/csharp/src/Fory/BFloat16.cs @@ -0,0 +1,219 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +namespace Apache.Fory; + +/// +/// Represents one IEEE 754 bfloat16 value. +/// +/// +/// +/// is the public Fory carrier for xlang BFLOAT16 values. It stores +/// the canonical 16-bit wire representation and converts to and from using +/// round-to-nearest-even semantics. +/// +/// +/// Use and when you need exact bit-preserving +/// control. Use or the explicit conversion operators when you want +/// numeric conversion. +/// +/// +/// For xlang bfloat16_array payloads, use BFloat16[] or List<BFloat16>. +/// Both carriers map to the packed 16-bit array wire format, so a dedicated list wrapper is not +/// required. +/// +/// +public readonly struct BFloat16 : IEquatable, IComparable +{ + private const uint Float32ExpMask = 0x7F80_0000u; + private const uint Float32MantissaMask = 0x007F_FFFFu; + private const ushort CanonicalNaNBits = 0x7FC0; + + /// + /// Gets the canonical quiet NaN value. + /// + public static BFloat16 NaN => FromBits(CanonicalNaNBits); + + /// + /// Gets positive infinity. + /// + public static BFloat16 PositiveInfinity => FromBits(0x7F80); + + /// + /// Gets negative infinity. + /// + public static BFloat16 NegativeInfinity => FromBits(0xFF80); + + /// + /// Gets positive zero. + /// + public static BFloat16 Zero => default; + + /// + /// Gets negative zero. + /// + public static BFloat16 NegativeZero => FromBits(0x8000); + + /// + /// Gets the raw IEEE 754 bfloat16 bits. + /// + public ushort Bits { get; } + + /// + /// Initializes a new instance from the exact wire bits. + /// + /// Raw bfloat16 bits. + public BFloat16(ushort bits) + { + Bits = bits; + } + + /// + /// Creates a value from exact wire bits. + /// + public static BFloat16 FromBits(ushort bits) => new(bits); + + /// + /// Creates a value from a . + /// + public static BFloat16 FromSingle(float value) => new(Float32ToBFloat16Bits(value)); + + /// + /// Returns the exact wire bits. + /// + public ushort ToBits() => Bits; + + /// + /// Converts this value to . + /// + public float ToSingle() => BFloat16BitsToFloat32(Bits); + + /// + /// Converts this value to . + /// + public double ToDouble() => ToSingle(); + + /// + /// Returns whether the value is NaN. + /// + public bool IsNaN => (Bits & 0x7F80) == 0x7F80 && (Bits & 0x007F) != 0; + + /// + /// Returns whether the value is infinite. + /// + public bool IsInfinity => (Bits & 0x7F80) == 0x7F80 && (Bits & 0x007F) == 0; + + /// + /// Returns whether the value is finite. + /// + public bool IsFinite => (Bits & 0x7F80) != 0x7F80; + + /// + /// Returns whether the value is numerically zero, including signed zero. + /// + public bool IsZero => (Bits & 0x7FFF) == 0; + + /// + /// Returns whether the sign bit is set. + /// + public bool SignBit => (Bits & 0x8000) != 0; + + /// + /// Reinterprets a as bfloat16 bits using round-to-nearest-even. + /// + public static ushort ToBits(float value) => Float32ToBFloat16Bits(value); + + /// + /// Reinterprets bfloat16 bits as a . + /// + public static float ToSingle(ushort bits) => BFloat16BitsToFloat32(bits); + + /// + /// Converts from to . + /// + public static explicit operator BFloat16(float value) => FromSingle(value); + + /// + /// Converts from to . + /// + public static explicit operator BFloat16(double value) => FromSingle((float)value); + + /// + /// Converts to . + /// + public static implicit operator float(BFloat16 value) => value.ToSingle(); + + /// + /// Converts to . + /// + public static implicit operator double(BFloat16 value) => value.ToDouble(); + + public int CompareTo(BFloat16 other) + { + return ToSingle().CompareTo(other.ToSingle()); + } + + public bool Equals(BFloat16 other) + { + return Bits == other.Bits; + } + + public override bool Equals(object? obj) + { + return obj is BFloat16 other && Equals(other); + } + + public override int GetHashCode() + { + return Bits.GetHashCode(); + } + + public override string ToString() + { + return ToSingle().ToString(System.Globalization.CultureInfo.InvariantCulture); + } + + public static bool operator ==(BFloat16 left, BFloat16 right) => left.Equals(right); + + public static bool operator !=(BFloat16 left, BFloat16 right) => !left.Equals(right); + + public static bool operator <(BFloat16 left, BFloat16 right) => left.CompareTo(right) < 0; + + public static bool operator <=(BFloat16 left, BFloat16 right) => left.CompareTo(right) <= 0; + + public static bool operator >(BFloat16 left, BFloat16 right) => left.CompareTo(right) > 0; + + public static bool operator >=(BFloat16 left, BFloat16 right) => left.CompareTo(right) >= 0; + + internal static ushort Float32ToBFloat16Bits(float value) + { + uint bits32 = BitConverter.SingleToUInt32Bits(value); + if ((bits32 & Float32ExpMask) == Float32ExpMask && (bits32 & Float32MantissaMask) != 0) + { + return CanonicalNaNBits; + } + + uint lsb = (bits32 >> 16) & 1u; + uint rounded = bits32 + 0x7FFFu + lsb; + return unchecked((ushort)(rounded >> 16)); + } + + internal static float BFloat16BitsToFloat32(ushort bits) + { + return BitConverter.UInt32BitsToSingle((uint)bits << 16); + } +} diff --git a/csharp/src/Fory/FieldSkipper.cs b/csharp/src/Fory/FieldSkipper.cs index 30353bb66c..652d59cfac 100644 --- a/csharp/src/Fory/FieldSkipper.cs +++ b/csharp/src/Fory/FieldSkipper.cs @@ -66,6 +66,14 @@ public static void SkipFieldValue(ReadContext context, TypeMetaFieldType fieldTy return context.TypeResolver.GetSerializer().Read(context, refMode, false); case (uint)TypeId.VarInt64: return context.TypeResolver.GetSerializer().Read(context, refMode, false); + case (uint)TypeId.Float16: + return context.TypeResolver.GetSerializer().Read(context, refMode, false); + case (uint)TypeId.BFloat16: + return context.TypeResolver.GetSerializer().Read(context, refMode, false); + case (uint)TypeId.Float16Array: + return context.TypeResolver.GetSerializer().Read(context, refMode, false); + case (uint)TypeId.BFloat16Array: + return context.TypeResolver.GetSerializer().Read(context, refMode, false); case (uint)TypeId.Float32: return context.TypeResolver.GetSerializer().Read(context, refMode, false); case (uint)TypeId.Float64: diff --git a/csharp/src/Fory/PrimitiveArraySerializers.cs b/csharp/src/Fory/PrimitiveArraySerializers.cs index 505ec0dab4..0f618da9a5 100644 --- a/csharp/src/Fory/PrimitiveArraySerializers.cs +++ b/csharp/src/Fory/PrimitiveArraySerializers.cs @@ -295,6 +295,78 @@ public override ulong[] ReadData(ReadContext context) } } +internal sealed class Float16ArraySerializer : Serializer +{ + + + + public override Half[] DefaultValue => null!; + + public override void WriteData(WriteContext context, in Half[] value, bool hasGenerics) + { + _ = hasGenerics; + Half[] safe = value ?? []; + context.Writer.WriteVarUInt32((uint)(safe.Length * 2)); + for (int i = 0; i < safe.Length; i++) + { + context.Writer.WriteUInt16(BitConverter.HalfToUInt16Bits(safe[i])); + } + } + + public override Half[] ReadData(ReadContext context) + { + int payloadSize = checked((int)context.Reader.ReadVarUInt32()); + if ((payloadSize & 1) != 0) + { + throw new InvalidDataException("float16 array payload size mismatch"); + } + + Half[] values = new Half[payloadSize / 2]; + for (int i = 0; i < values.Length; i++) + { + values[i] = BitConverter.UInt16BitsToHalf(context.Reader.ReadUInt16()); + } + + return values; + } +} + +internal sealed class BFloat16ArraySerializer : Serializer +{ + + + + public override BFloat16[] DefaultValue => null!; + + public override void WriteData(WriteContext context, in BFloat16[] value, bool hasGenerics) + { + _ = hasGenerics; + BFloat16[] safe = value ?? []; + context.Writer.WriteVarUInt32((uint)(safe.Length * 2)); + for (int i = 0; i < safe.Length; i++) + { + context.Writer.WriteUInt16(safe[i].ToBits()); + } + } + + public override BFloat16[] ReadData(ReadContext context) + { + int payloadSize = checked((int)context.Reader.ReadVarUInt32()); + if ((payloadSize & 1) != 0) + { + throw new InvalidDataException("bfloat16 array payload size mismatch"); + } + + BFloat16[] values = new BFloat16[payloadSize / 2]; + for (int i = 0; i < values.Length; i++) + { + values[i] = BFloat16.FromBits(context.Reader.ReadUInt16()); + } + + return values; + } +} + internal sealed class Float32ArraySerializer : Serializer { diff --git a/csharp/src/Fory/PrimitiveCollectionSerializers.cs b/csharp/src/Fory/PrimitiveCollectionSerializers.cs index 3d41bd26e1..2eac41ea5c 100644 --- a/csharp/src/Fory/PrimitiveCollectionSerializers.cs +++ b/csharp/src/Fory/PrimitiveCollectionSerializers.cs @@ -274,6 +274,56 @@ public override List ReadData(ReadContext context) } } +internal sealed class ListHalfSerializer : Serializer> +{ + private static readonly ListSerializer Fallback = new(); + + + + + public override List DefaultValue => null!; + + public override void WriteData(WriteContext context, in List value, bool hasGenerics) + { + List list = value ?? []; + PrimitiveCollectionHeader.WriteListHeader(context, list.Count, hasGenerics, TypeId.Float16, false); + for (int i = 0; i < list.Count; i++) + { + context.Writer.WriteUInt16(BitConverter.HalfToUInt16Bits(list[i])); + } + } + + public override List ReadData(ReadContext context) + { + return Fallback.ReadData(context); + } +} + +internal sealed class ListBFloat16Serializer : Serializer> +{ + private static readonly ListSerializer Fallback = new(); + + + + + public override List DefaultValue => null!; + + public override void WriteData(WriteContext context, in List value, bool hasGenerics) + { + List list = value ?? []; + PrimitiveCollectionHeader.WriteListHeader(context, list.Count, hasGenerics, TypeId.BFloat16, false); + for (int i = 0; i < list.Count; i++) + { + context.Writer.WriteUInt16(list[i].ToBits()); + } + } + + public override List ReadData(ReadContext context) + { + return Fallback.ReadData(context); + } +} + internal sealed class ListFloatSerializer : Serializer> { private static readonly ListSerializer Fallback = new(); diff --git a/csharp/src/Fory/PrimitiveDictionarySerializers.cs b/csharp/src/Fory/PrimitiveDictionarySerializers.cs index d03ef70f94..05173961e9 100644 --- a/csharp/src/Fory/PrimitiveDictionarySerializers.cs +++ b/csharp/src/Fory/PrimitiveDictionarySerializers.cs @@ -245,6 +245,48 @@ public static ulong Read(ReadContext context) } } +internal readonly struct Float16PrimitiveDictionaryCodec : IPrimitiveDictionaryCodec +{ + public static TypeId WireTypeId => TypeId.Float16; + + public static bool IsNullable => false; + + public static Half DefaultValue => default; + + public static bool IsNone(Half value) => false; + + public static void Write(WriteContext context, Half value) + { + context.Writer.WriteUInt16(BitConverter.HalfToUInt16Bits(value)); + } + + public static Half Read(ReadContext context) + { + return BitConverter.UInt16BitsToHalf(context.Reader.ReadUInt16()); + } +} + +internal readonly struct BFloat16PrimitiveDictionaryCodec : IPrimitiveDictionaryCodec +{ + public static TypeId WireTypeId => TypeId.BFloat16; + + public static bool IsNullable => false; + + public static BFloat16 DefaultValue => default; + + public static bool IsNone(BFloat16 value) => false; + + public static void Write(WriteContext context, BFloat16 value) + { + context.Writer.WriteUInt16(value.ToBits()); + } + + public static BFloat16 Read(ReadContext context) + { + return BFloat16.FromBits(context.Reader.ReadUInt16()); + } +} + internal readonly struct Float32PrimitiveDictionaryCodec : IPrimitiveDictionaryCodec { public static TypeId WireTypeId => TypeId.Float32; diff --git a/csharp/src/Fory/PrimitiveSerializers.cs b/csharp/src/Fory/PrimitiveSerializers.cs index ad5e3e959d..2d3b639616 100644 --- a/csharp/src/Fory/PrimitiveSerializers.cs +++ b/csharp/src/Fory/PrimitiveSerializers.cs @@ -177,6 +177,40 @@ public override ulong ReadData(ReadContext context) } } +public sealed class Float16Serializer : Serializer +{ + + public override Half DefaultValue => default; + + public override void WriteData(WriteContext context, in Half value, bool hasGenerics) + { + _ = hasGenerics; + context.Writer.WriteUInt16(BitConverter.HalfToUInt16Bits(value)); + } + + public override Half ReadData(ReadContext context) + { + return BitConverter.UInt16BitsToHalf(context.Reader.ReadUInt16()); + } +} + +public sealed class BFloat16Serializer : Serializer +{ + + public override BFloat16 DefaultValue => default; + + public override void WriteData(WriteContext context, in BFloat16 value, bool hasGenerics) + { + _ = hasGenerics; + context.Writer.WriteUInt16(value.ToBits()); + } + + public override BFloat16 ReadData(ReadContext context) + { + return BFloat16.FromBits(context.Reader.ReadUInt16()); + } +} + public sealed class Float32Serializer : Serializer { diff --git a/csharp/src/Fory/TypeInfo.cs b/csharp/src/Fory/TypeInfo.cs index 31bc821c8d..c7952ceb52 100644 --- a/csharp/src/Fory/TypeInfo.cs +++ b/csharp/src/Fory/TypeInfo.cs @@ -305,6 +305,18 @@ private static bool TryResolveBuiltInTypeId(Type type, out TypeId typeId) return true; } + if (type == typeof(Half)) + { + typeId = TypeId.Float16; + return true; + } + + if (type == typeof(BFloat16)) + { + typeId = TypeId.BFloat16; + return true; + } + if (type == typeof(double)) { typeId = TypeId.Float64; @@ -383,6 +395,18 @@ private static bool TryResolveBuiltInTypeId(Type type, out TypeId typeId) return true; } + if (type == typeof(Half[])) + { + typeId = TypeId.Float16Array; + return true; + } + + if (type == typeof(BFloat16[])) + { + typeId = TypeId.BFloat16Array; + return true; + } + if (type == typeof(double[])) { typeId = TypeId.Float64Array; diff --git a/csharp/src/Fory/TypeMeta.cs b/csharp/src/Fory/TypeMeta.cs index ff05b507a7..4d8415260a 100644 --- a/csharp/src/Fory/TypeMeta.cs +++ b/csharp/src/Fory/TypeMeta.cs @@ -103,7 +103,7 @@ public static string LowerCamelToLowerUnderscore(string name) } } - return sb.ToString(); + return sb.ToString().Replace("b_float16", "bfloat16"); } } diff --git a/csharp/src/Fory/TypeResolver.cs b/csharp/src/Fory/TypeResolver.cs index 2c70e01286..6d69557c93 100644 --- a/csharp/src/Fory/TypeResolver.cs +++ b/csharp/src/Fory/TypeResolver.cs @@ -61,6 +61,8 @@ private static class GenericTypeCache (typeof(int), typeof(Int32PrimitiveDictionaryCodec)), (typeof(long), typeof(Int64PrimitiveDictionaryCodec)), (typeof(bool), typeof(BoolPrimitiveDictionaryCodec)), + (typeof(Half), typeof(Float16PrimitiveDictionaryCodec)), + (typeof(BFloat16), typeof(BFloat16PrimitiveDictionaryCodec)), (typeof(double), typeof(Float64PrimitiveDictionaryCodec)), (typeof(float), typeof(Float32PrimitiveDictionaryCodec)), (typeof(uint), typeof(UInt32PrimitiveDictionaryCodec)), @@ -84,6 +86,8 @@ private static class GenericTypeCache (typeof(ushort), typeof(UInt16PrimitiveDictionaryCodec)), (typeof(uint), typeof(UInt32PrimitiveDictionaryCodec)), (typeof(ulong), typeof(UInt64PrimitiveDictionaryCodec)), + (typeof(Half), typeof(Float16PrimitiveDictionaryCodec)), + (typeof(BFloat16), typeof(BFloat16PrimitiveDictionaryCodec)), (typeof(float), typeof(Float32PrimitiveDictionaryCodec)), (typeof(double), typeof(Float64PrimitiveDictionaryCodec))); @@ -95,6 +99,8 @@ private static class GenericTypeCache (typeof(ushort), typeof(UInt16PrimitiveDictionaryCodec)), (typeof(uint), typeof(UInt32PrimitiveDictionaryCodec)), (typeof(ulong), typeof(UInt64PrimitiveDictionaryCodec)), + (typeof(Half), typeof(Float16PrimitiveDictionaryCodec)), + (typeof(BFloat16), typeof(BFloat16PrimitiveDictionaryCodec)), (typeof(float), typeof(Float32PrimitiveDictionaryCodec)), (typeof(double), typeof(Float64PrimitiveDictionaryCodec))); @@ -986,6 +992,8 @@ private TypeInfo ResolveAnyBuiltInTypeInfo(TypeId wireTypeId) TypeId.UInt16 => GetTypeInfo().WithWireTypeInfo(wireTypeId), TypeId.UInt32 or TypeId.VarUInt32 => GetTypeInfo().WithWireTypeInfo(wireTypeId), TypeId.UInt64 or TypeId.VarUInt64 or TypeId.TaggedUInt64 => GetTypeInfo().WithWireTypeInfo(wireTypeId), + TypeId.Float16 => GetTypeInfo().WithWireTypeInfo(wireTypeId), + TypeId.BFloat16 => GetTypeInfo().WithWireTypeInfo(wireTypeId), TypeId.Float32 => GetTypeInfo().WithWireTypeInfo(wireTypeId), TypeId.Float64 => GetTypeInfo().WithWireTypeInfo(wireTypeId), TypeId.String => GetTypeInfo().WithWireTypeInfo(wireTypeId), @@ -1002,6 +1010,8 @@ private TypeInfo ResolveAnyBuiltInTypeInfo(TypeId wireTypeId) TypeId.UInt16Array => GetTypeInfo().WithWireTypeInfo(wireTypeId), TypeId.UInt32Array => GetTypeInfo().WithWireTypeInfo(wireTypeId), TypeId.UInt64Array => GetTypeInfo().WithWireTypeInfo(wireTypeId), + TypeId.Float16Array => GetTypeInfo().WithWireTypeInfo(wireTypeId), + TypeId.BFloat16Array => GetTypeInfo().WithWireTypeInfo(wireTypeId), TypeId.Float32Array => GetTypeInfo().WithWireTypeInfo(wireTypeId), TypeId.Float64Array => GetTypeInfo().WithWireTypeInfo(wireTypeId), TypeId.List => GetTypeInfo>().WithWireTypeInfo(wireTypeId), @@ -1443,6 +1453,16 @@ private TypeInfo CreateBindingCore(Type type) return TypeInfo.Create(type, new Float32Serializer()); } + if (type == typeof(Half)) + { + return TypeInfo.Create(type, new Float16Serializer()); + } + + if (type == typeof(BFloat16)) + { + return TypeInfo.Create(type, new BFloat16Serializer()); + } + if (type == typeof(double)) { return TypeInfo.Create(type, new Float64Serializer()); @@ -1508,6 +1528,16 @@ private TypeInfo CreateBindingCore(Type type) return TypeInfo.Create(type, new Float32ArraySerializer()); } + if (type == typeof(Half[])) + { + return TypeInfo.Create(type, new Float16ArraySerializer()); + } + + if (type == typeof(BFloat16[])) + { + return TypeInfo.Create(type, new BFloat16ArraySerializer()); + } + if (type == typeof(double[])) { return TypeInfo.Create(type, new Float64ArraySerializer()); @@ -1583,6 +1613,16 @@ private TypeInfo CreateBindingCore(Type type) return TypeInfo.Create(type, new ListFloatSerializer()); } + if (type == typeof(List)) + { + return TypeInfo.Create(type, new ListHalfSerializer()); + } + + if (type == typeof(List)) + { + return TypeInfo.Create(type, new ListBFloat16Serializer()); + } + if (type == typeof(List)) { return TypeInfo.Create(type, new ListDoubleSerializer()); diff --git a/csharp/tests/Fory.XlangPeer/Program.cs b/csharp/tests/Fory.XlangPeer/Program.cs index 11809c4448..80030eaf22 100644 --- a/csharp/tests/Fory.XlangPeer/Program.cs +++ b/csharp/tests/Fory.XlangPeer/Program.cs @@ -217,6 +217,8 @@ private static byte[] ExecuteCase(string caseName, byte[] input) "test_two_string_field_compatible" => CaseTwoStringFieldCompatible(input), "test_schema_evolution_compatible" => CaseSchemaEvolutionCompatible(input), "test_schema_evolution_compatible_reverse" => CaseSchemaEvolutionCompatibleReverse(input), + "test_reduced_precision_float_struct" => CaseReducedPrecisionFloatStruct(input), + "test_reduced_precision_float_struct_compatible_skip" => CaseReducedPrecisionFloatStructCompatibleSkip(input), "test_one_enum_field_schema" => CaseOneEnumFieldSchema(input), "test_one_enum_field_compatible" => CaseOneEnumFieldCompatible(input), "test_two_enum_field_compatible" => CaseTwoEnumFieldCompatible(input), @@ -781,6 +783,38 @@ private static byte[] CaseSchemaEvolutionCompatibleReverse(byte[] input) return RoundTripSingle(input, fory); } + private static byte[] CaseReducedPrecisionFloatStruct(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: false); + fory.Register(213); + + ReadOnlySequence sequence = new(input); + ReducedPrecisionFloatStruct value = fory.Deserialize(ref sequence); + EnsureConsumed(sequence, nameof(CaseReducedPrecisionFloatStruct)); + Ensure(BitConverter.HalfToUInt16Bits(value.Float16Value) == 0x3E00, "float16_value mismatch"); + Ensure(value.BFloat16Value.Bits == 0x3FC0, "bfloat16_value mismatch"); + Ensure( + value.Float16Array is { Length: 3 } + && BitConverter.HalfToUInt16Bits(value.Float16Array[0]) == 0x0000 + && BitConverter.HalfToUInt16Bits(value.Float16Array[1]) == 0x3C00 + && BitConverter.HalfToUInt16Bits(value.Float16Array[2]) == 0xBC00, + "float16_array mismatch"); + Ensure( + value.BFloat16Array is { Length: 3 } + && value.BFloat16Array[0].Bits == 0x0000 + && value.BFloat16Array[1].Bits == 0x3F80 + && value.BFloat16Array[2].Bits == 0xBF80, + "bfloat16_array mismatch"); + return fory.Serialize(value); + } + + private static byte[] CaseReducedPrecisionFloatStructCompatibleSkip(byte[] input) + { + ForyRuntime fory = BuildFory(compatible: true); + fory.Register(213); + return RoundTripSingle(input, fory); + } + private static byte[] CaseOneEnumFieldSchema(byte[] input) { ForyRuntime fory = BuildFory(compatible: false); @@ -1175,6 +1209,20 @@ public sealed class TwoStringFieldStruct public string F2 { get; set; } = string.Empty; } +[ForyObject] +public sealed class EmptyStruct +{ +} + +[ForyObject] +public sealed class ReducedPrecisionFloatStruct +{ + public Half Float16Value { get; set; } + public BFloat16 BFloat16Value { get; set; } + public Half[] Float16Array { get; set; } = []; + public BFloat16[] BFloat16Array { get; set; } = []; +} + [ForyObject] public enum TestEnum { diff --git a/dart/packages/fory-test/lib/entity/xlang_test_models.dart b/dart/packages/fory-test/lib/entity/xlang_test_models.dart index 52fd9303fc..366e61f2e9 100644 --- a/dart/packages/fory-test/lib/entity/xlang_test_models.dart +++ b/dart/packages/fory-test/lib/entity/xlang_test_models.dart @@ -278,6 +278,16 @@ class TwoStringFieldStruct { String f2 = ''; } +@ForyStruct() +class ReducedPrecisionFloatStruct { + ReducedPrecisionFloatStruct(); + + Float16 float16Value = const Float16.fromBits(0); + Bfloat16 bfloat16Value = const Bfloat16.fromBits(0); + Float16List float16Array = Float16List(0); + Bfloat16List bfloat16Array = Bfloat16List(0); +} + @ForyStruct() class OneEnumFieldStruct { OneEnumFieldStruct(); diff --git a/dart/packages/fory-test/test/cross_lang_test/xlang_test_main.dart b/dart/packages/fory-test/test/cross_lang_test/xlang_test_main.dart index 8d45d65380..0f66eadbd6 100644 --- a/dart/packages/fory-test/test/cross_lang_test/xlang_test_main.dart +++ b/dart/packages/fory-test/test/cross_lang_test/xlang_test_main.dart @@ -563,6 +563,16 @@ void _runCase(String caseName) { registerXlangType(fory, UnsignedSchemaCompatible, id: 502); _roundTripFory(fory); return; + case 'test_reduced_precision_float_struct': + final fory = _newFory(); + registerXlangType(fory, ReducedPrecisionFloatStruct, id: 213); + _roundTripFory(fory); + return; + case 'test_reduced_precision_float_struct_compatible_skip': + final fory = _newFory(compatible: true); + registerXlangType(fory, EmptyStruct, id: 213); + _roundTripFory(fory); + return; default: throw UnsupportedError('Unknown Dart xlang case: $caseName'); } diff --git a/dart/packages/fory/lib/src/codegen/fory_generator.dart b/dart/packages/fory/lib/src/codegen/fory_generator.dart index 49ef309b8c..3697cefd5b 100644 --- a/dart/packages/fory/lib/src/codegen/fory_generator.dart +++ b/dart/packages/fory/lib/src/codegen/fory_generator.dart @@ -1761,7 +1761,7 @@ GeneratedFieldType( if (sizeCompare != 0) { return sizeCompare; } - final typeCompare = right.fieldType.typeId - left.fieldType.typeId; + final typeCompare = left.fieldType.typeId - right.fieldType.typeId; if (typeCompare != 0) { return typeCompare; } diff --git a/dart/packages/fory/lib/src/serializer/scalar_serializers.dart b/dart/packages/fory/lib/src/serializer/scalar_serializers.dart index 9d715cc2ca..0fbfee8053 100644 --- a/dart/packages/fory/lib/src/serializer/scalar_serializers.dart +++ b/dart/packages/fory/lib/src/serializer/scalar_serializers.dart @@ -22,7 +22,7 @@ import 'dart:typed_data'; import 'package:fory/src/context/read_context.dart'; import 'package:fory/src/context/write_context.dart'; import 'package:fory/src/serializer/serializer.dart'; -import 'package:fory/src/string_encoding.dart'; +import 'package:fory/src/util/string_util.dart'; import 'package:fory/src/types/decimal.dart'; // The small form reserves the low header bit to distinguish small/big diff --git a/dart/packages/fory/lib/src/string_encoding.dart b/dart/packages/fory/lib/src/util/string_util.dart similarity index 100% rename from dart/packages/fory/lib/src/string_encoding.dart rename to dart/packages/fory/lib/src/util/string_util.dart diff --git a/dart/packages/fory/test/string_encoding_test.dart b/dart/packages/fory/test/string_encoding_test.dart index bc07b26ab7..4a5011cc02 100644 --- a/dart/packages/fory/test/string_encoding_test.dart +++ b/dart/packages/fory/test/string_encoding_test.dart @@ -18,7 +18,7 @@ */ import 'package:fory/src/buffer.dart'; -import 'package:fory/src/string_encoding.dart'; +import 'package:fory/src/util/string_util.dart'; import 'package:test/test.dart'; void main() { diff --git a/dart/packages/fory/test/string_serializer_test.dart b/dart/packages/fory/test/string_serializer_test.dart index 7ab15f8f79..5b5b3f1f26 100644 --- a/dart/packages/fory/test/string_serializer_test.dart +++ b/dart/packages/fory/test/string_serializer_test.dart @@ -21,7 +21,7 @@ import 'dart:convert'; import 'dart:typed_data'; import 'package:fory/fory.dart'; -import 'package:fory/src/string_encoding.dart'; +import 'package:fory/src/util/string_util.dart'; import 'package:test/test.dart'; void main() { diff --git a/docs/guide/cpp/cross-language.md b/docs/guide/cpp/cross-language.md index 5632d45b2c..b0f308b5b7 100644 --- a/docs/guide/cpp/cross-language.md +++ b/docs/guide/cpp/cross-language.md @@ -137,15 +137,17 @@ print(f"Timestamp: {msg.timestamp}") ### Primitive Types -| C++ Type | Java Type | Python Type | Go Type | Rust Type | -| --------- | --------- | ----------- | --------- | --------- | -| `bool` | `boolean` | `bool` | `bool` | `bool` | -| `int8_t` | `byte` | `int` | `int8` | `i8` | -| `int16_t` | `short` | `int` | `int16` | `i16` | -| `int32_t` | `int` | `int` | `int32` | `i32` | -| `int64_t` | `long` | `int` | `int64` | `i64` | -| `float` | `float` | `float` | `float32` | `f32` | -| `double` | `double` | `float` | `float64` | `f64` | +| C++ Type | Java Type | Python Type | Go Type | Rust Type | +| ------------------ | ---------- | ----------------- | ------------------- | ---------- | +| `bool` | `boolean` | `bool` | `bool` | `bool` | +| `int8_t` | `byte` | `int` | `int8` | `i8` | +| `int16_t` | `short` | `int` | `int16` | `i16` | +| `int32_t` | `int` | `int` | `int32` | `i32` | +| `int64_t` | `long` | `int` | `int64` | `i64` | +| `float` | `float` | `float` | `float32` | `f32` | +| `double` | `double` | `float` | `float64` | `f64` | +| `fory::float16_t` | `Float16` | `pyfory.float16` | `float16.Float16` | `Float16` | +| `fory::bfloat16_t` | `BFloat16` | `pyfory.bfloat16` | `bfloat16.BFloat16` | `BFloat16` | ### String Types @@ -155,11 +157,13 @@ print(f"Timestamp: {msg.timestamp}") ### Collection Types -| C++ Type | Java Type | Python Type | Go Type | -| ---------------- | ---------- | ----------- | ---------------- | -| `std::vector` | `List` | `list` | `[]T` | -| `std::set` | `Set` | `set` | `map[T]struct{}` | -| `std::map` | `Map` | `dict` | `map[K]V` | +| C++ Type | Java Type | Python Type | Go Type | Rust Type | +| ------------------------------- | -------------- | --------------- | --------------------- | --------------- | +| `std::vector` | `List` | `list` | `[]T` | `Vec` | +| `std::vector` | `Float16List` | `float16array` | `[]float16.Float16` | `Vec` | +| `std::vector` | `BFloat16List` | `bfloat16array` | `[]bfloat16.BFloat16` | `Vec` | +| `std::set` | `Set` | `set` | `map[T]struct{}` | `HashSet` | +| `std::map` | `Map` | `dict` | `map[K]V` | `HashMap` | ### Temporal Types diff --git a/docs/guide/cpp/supported-types.md b/docs/guide/cpp/supported-types.md index ab191a9ad0..c2ca398e09 100644 --- a/docs/guide/cpp/supported-types.md +++ b/docs/guide/cpp/supported-types.md @@ -25,22 +25,23 @@ This page documents all types supported by Fory C++ serialization. All C++ primitive types are supported with efficient binary encoding: -| Type | Size | Fory TypeId | Notes | -| ---------- | ------ | ----------- | --------------------- | -| `bool` | 1 byte | BOOL | True/false | -| `int8_t` | 1 byte | INT8 | Signed byte | -| `uint8_t` | 1 byte | INT8 | Unsigned byte | -| `int16_t` | 2 byte | INT16 | Signed short | -| `uint16_t` | 2 byte | INT16 | Unsigned short | -| `int32_t` | 4 byte | INT32 | Signed integer | -| `uint32_t` | 4 byte | INT32 | Unsigned integer | -| `int64_t` | 8 byte | INT64 | Signed long | -| `uint64_t` | 8 byte | INT64 | Unsigned long | -| `float` | 4 byte | FLOAT32 | IEEE 754 single | -| `double` | 8 byte | FLOAT64 | IEEE 754 double | -| `char` | 1 byte | INT8 | Character (as signed) | -| `char16_t` | 2 byte | INT16 | 16-bits characters | -| `char32_t` | 4 byte | INT32 | 32-bits characters | +| Type | Size | Fory TypeId | Notes | +| ------------------ | ------ | ----------- | --------------------- | +| `bool` | 1 byte | BOOL | True/false | +| `int8_t` | 1 byte | INT8 | Signed byte | +| `uint8_t` | 1 byte | INT8 | Unsigned byte | +| `int16_t` | 2 byte | INT16 | Signed short | +| `uint16_t` | 2 byte | INT16 | Unsigned short | +| `int32_t` | 4 byte | INT32 | Signed integer | +| `uint32_t` | 4 byte | INT32 | Unsigned integer | +| `int64_t` | 8 byte | INT64 | Signed long | +| `uint64_t` | 8 byte | INT64 | Unsigned long | +| `float` | 4 byte | FLOAT32 | IEEE 754 single | +| `double` | 8 byte | FLOAT64 | IEEE 754 double | +| `fory::bfloat16_t` | 2 byte | BFLOAT16 | IEEE 754 bfloat16 | +| `char` | 1 byte | INT8 | Character (as signed) | +| `char16_t` | 2 byte | INT16 | 16-bits characters | +| `char32_t` | 4 byte | INT32 | 32-bits characters | ```cpp int32_t value = 42; @@ -71,6 +72,8 @@ assert(text == decoded); `std::vector` for any serializable element type: +`std::vector` is the dense array carrier for xlang `bfloat16_array`. + ```cpp std::vector numbers{1, 2, 3, 4, 5}; auto bytes = fory.serialize(numbers).value(); diff --git a/docs/guide/csharp/cross-language.md b/docs/guide/csharp/cross-language.md index 2b11ceb70e..a1c8af9456 100644 --- a/docs/guide/csharp/cross-language.md +++ b/docs/guide/csharp/cross-language.md @@ -93,6 +93,8 @@ value = fory.deserialize(payload_from_csharp) See [xlang guide](../xlang/index.md) for complete mapping. +For reduced-precision numeric payloads, use `Half` / `Half[]` or `List` for xlang `float16`, and `BFloat16` / `BFloat16[]` or `List` for xlang `bfloat16`. + ## Best Practices 1. Keep type IDs stable and documented. diff --git a/docs/guide/csharp/supported-types.md b/docs/guide/csharp/supported-types.md index 0b63977887..e1a92ac0f4 100644 --- a/docs/guide/csharp/supported-types.md +++ b/docs/guide/csharp/supported-types.md @@ -29,6 +29,7 @@ This page summarizes built-in and generated type support in Apache Fory™ C#. | `sbyte`, `short`, `int`, `long` | Supported | | `byte`, `ushort`, `uint`, `ulong` | Supported | | `float`, `double` | Supported | +| `Half`, `BFloat16` | Supported | | `string` | Supported | | `byte[]` | Supported | | Nullable primitives (for example `int?`) | Supported | @@ -36,6 +37,8 @@ This page summarizes built-in and generated type support in Apache Fory™ C#. ## Arrays - Primitive numeric arrays (`bool[]`, `int[]`, `ulong[]`, etc.) +- `Half[]`, `List` for `float16_array` +- `BFloat16[]`, `List` for `bfloat16_array` - `byte[]` - General arrays (`T[]`) through collection serializers diff --git a/docs/guide/java/cross-language.md b/docs/guide/java/cross-language.md index f1dfd4eb65..6309d25a6a 100644 --- a/docs/guide/java/cross-language.md +++ b/docs/guide/java/cross-language.md @@ -158,6 +158,8 @@ Not all Java types have equivalents in other languages. When using xlang mode: - Use **primitive types** (`int`, `long`, `double`, `String`) for maximum compatibility - Use **standard collections** (`List`, `Map`, `Set`) instead of language-specific ones +- Use **reduced-precision carriers** (`Float16`, `BFloat16`, `Float16List`, `BFloat16List`) for 16-bit float payloads +- Treat `Float16[]` and `BFloat16[]` as xlang `list` carriers; use `Float16List` and `BFloat16List` when the wire type must be `float16_array` or `bfloat16_array` - Avoid **Java-specific types** like `Optional`, `BigDecimal` (unless the target language supports them) - See [Type Mapping Guide](../../specification/xlang_type_mapping.md) for complete compatibility matrix diff --git a/docs/guide/python/cross-language.md b/docs/guide/python/cross-language.md index b5b96f5caa..61f6d1dabd 100644 --- a/docs/guide/python/cross-language.md +++ b/docs/guide/python/cross-language.md @@ -109,18 +109,43 @@ class TypedData: double_value: pyfory.float64 # 64-bit float ``` +## Reduced-Precision Types + +`pyfory.serialization` exports Cython-only carrier types for xlang reduced-precision values: + +- `float16` and `float16array` +- `bfloat16` and `bfloat16array` + +These names are compiled into the `pyfory.serialization` extension and re-exported from `pyfory`. There is no pure-Python fallback module for them. + +The scalar wrappers behave like reduced-precision numeric value types. They support arithmetic and +ordering with Python numeric operands, and each operation quantizes the result back to the wrapper's +own format (`pyfory.float16` or `pyfory.bfloat16`). + +The array wrappers are value-oriented public APIs. Construct them from Python numeric values with +`pyfory.float16array([...])`, `pyfory.float16array.from_values([...])`, +`pyfory.bfloat16array([...])`, or `pyfory.bfloat16array.from_values([...])`. Use +`from_buffer(...)` and `to_buffer()` only when you already need packed little-endian `uint16` +storage and want the raw-buffer fast path. Both array carriers also implement the CPython buffer +protocol, so `memoryview(pyfory.float16array(...))` and `memoryview(pyfory.bfloat16array(...))` +expose the packed `uint16` storage directly. + ## Type Mapping -| Python | Java | Rust | Go | -| ---------------- | -------- | --------- | --------- | -| `str` | `String` | `String` | `string` | -| `int` | `long` | `i64` | `int64` | -| `pyfory.int32` | `int` | `i32` | `int32` | -| `pyfory.int64` | `long` | `i64` | `int64` | -| `float` | `double` | `f64` | `float64` | -| `pyfory.float32` | `float` | `f32` | `float32` | -| `list` | `List` | `Vec` | `[]T` | -| `dict` | `Map` | `HashMap` | `map[K]V` | +| Python | Java | Rust | Go | +| ---------------------- | -------------- | --------------- | --------------------- | +| `str` | `String` | `String` | `string` | +| `int` | `long` | `i64` | `int64` | +| `pyfory.int32` | `int` | `i32` | `int32` | +| `pyfory.int64` | `long` | `i64` | `int64` | +| `float` | `double` | `f64` | `float64` | +| `pyfory.float32` | `float` | `f32` | `float32` | +| `pyfory.float16` | `Float16` | `Float16` | `float16.Float16` | +| `pyfory.bfloat16` | `BFloat16` | `BFloat16` | `bfloat16.BFloat16` | +| `pyfory.float16array` | `Float16List` | `Vec` | `[]float16.Float16` | +| `pyfory.bfloat16array` | `BFloat16List` | `Vec` | `[]bfloat16.BFloat16` | +| `list` | `List` | `Vec` | `[]T` | +| `dict` | `Map` | `HashMap` | `map[K]V` | ## Differences from Python Native Mode diff --git a/docs/guide/rust/basic-serialization.md b/docs/guide/rust/basic-serialization.md index c12857a679..9ac5e5d5ab 100644 --- a/docs/guide/rust/basic-serialization.md +++ b/docs/guide/rust/basic-serialization.md @@ -81,12 +81,13 @@ assert_eq!(person, decoded); ### Primitive Types -| Rust Type | Description | -| ------------------------- | --------------- | -| `bool` | Boolean | -| `i8`, `i16`, `i32`, `i64` | Signed integers | -| `f32`, `f64` | Floating point | -| `String` | UTF-8 string | +| Rust Type | Description | +| ------------------------- | --------------------------- | +| `bool` | Boolean | +| `i8`, `i16`, `i32`, `i64` | Signed integers | +| `f32`, `f64` | Floating point | +| `BFloat16` | 16-bit brain floating point | +| `String` | UTF-8 string | ### Collections @@ -102,6 +103,8 @@ assert_eq!(person, decoded); | `BinaryHeap` | Binary heap | | `Option` | Optional value | +`Vec` is the dense xlang carrier for `bfloat16_array` payloads. + ### Smart Pointers | Rust Type | Description | diff --git a/docs/guide/rust/cross-language.md b/docs/guide/rust/cross-language.md index dac48ec6a0..245f17542d 100644 --- a/docs/guide/rust/cross-language.md +++ b/docs/guide/rust/cross-language.md @@ -133,16 +133,22 @@ See [xlang_type_mapping.md](../../specification/xlang_type_mapping.md) for compl ### Common Type Mappings -| Rust | Java | Python | -| -------------- | ------------ | ------------- | -| `i32` | `int` | `int32` | -| `i64` | `long` | `int64` | -| `f32` | `float` | `float32` | -| `f64` | `double` | `float64` | -| `String` | `String` | `str` | -| `Vec` | `List` | `List[T]` | -| `HashMap` | `Map` | `Dict[K,V]` | -| `Option` | nullable `T` | `Optional[T]` | +| Rust | Java | Python | +| --------------- | -------------- | --------------- | +| `i32` | `int` | `int32` | +| `i64` | `long` | `int64` | +| `f32` | `float` | `float32` | +| `f64` | `double` | `float64` | +| `Float16` | `Float16` | `float16` | +| `BFloat16` | `BFloat16` | `bfloat16` | +| `String` | `String` | `str` | +| `Vec` | `List` | `List[T]` | +| `Vec` | `Float16List` | `float16array` | +| `Vec` | `BFloat16List` | `bfloat16array` | +| `[Float16; N]` | `Float16List` | `float16array` | +| `[BFloat16; N]` | `BFloat16List` | `bfloat16array` | +| `HashMap` | `Map` | `Dict[K,V]` | +| `Option` | nullable `T` | `Optional[T]` | ## Best Practices diff --git a/docs/guide/xlang/serialization.md b/docs/guide/xlang/serialization.md index a646db6e57..17fbd5c5ce 100644 --- a/docs/guide/xlang/serialization.md +++ b/docs/guide/xlang/serialization.md @@ -25,6 +25,13 @@ This page demonstrates cross-language serialization patterns with examples in al Common types can be serialized automatically without registration: primitive numeric types, string, binary, array, list, map, and more. +Reduced-precision floating-point values are also part of the built-in xlang type system: + +- `float16` and `float16_array` +- `bfloat16` and `bfloat16_array` + +Use the language-specific carrier types documented in the type mapping reference. Python exposes the Cython-only `float16`, `float16array`, `bfloat16`, and `bfloat16array` names from `pyfory.serialization`; the Python array carriers are constructed from Python numeric values, while `from_buffer(...)` is reserved for packed raw storage. Go uses the `float16` and `bfloat16` packages for scalar, slice, and array carriers; JavaScript uses `number` / `number[]` for `float16` and `BFloat16` / `BFloat16Array` for `bfloat16`; Java uses `Float16List` / `BFloat16List` for xlang `*_array` payloads, while `Float16[]` / `BFloat16[]` stay on the general `list` path; C++, Rust, and C# provide their own dedicated scalar and array carriers. + ### Java ```java diff --git a/docs/specification/xlang_serialization_spec.md b/docs/specification/xlang_serialization_spec.md index 5019c70d83..c18c32decd 100644 --- a/docs/specification/xlang_serialization_spec.md +++ b/docs/specification/xlang_serialization_spec.md @@ -76,8 +76,10 @@ This specification defines the Fory xlang binary format. The format is dynamic r - date: a naive date without timezone, encoded as a signed varint64 count of days since the Unix epoch. - decimal: an exact decimal value encoded as a signed `scale` and an exact `unscaled` integer. - binary: an variable-length array of bytes. -- array: only allow 1d numeric components. Other arrays will be taken as List. The implementation should support the - interoperability between array and list. +- array: in current xlang, only one-dimensional primitive/numeric arrays have dedicated wire types. Other arrays are + taken as `list`, and implementations should support interoperability between array and list carriers. Internal type + ID `ARRAY (42)` is reserved for a future dedicated multi-dimensional array encoding and is not emitted by the current + xlang format. - bool_array: one dimensional bool array. - int8_array: one dimensional int8 array. - int16_array: one dimensional int16 array. @@ -162,65 +164,65 @@ Named types (`NAMED_*`) do not embed a user ID; their names are carried in metad #### Internal Type ID Table -| Type ID | Name | Description | -| ------- | ----------------------- | --------------------------------------------------- | -| 0 | UNKNOWN | Unknown type, used for dynamic typing | -| 1 | BOOL | Boolean value | -| 2 | INT8 | 8-bit signed integer | -| 3 | INT16 | 16-bit signed integer | -| 4 | INT32 | 32-bit signed integer | -| 5 | VARINT32 | Variable-length encoded 32-bit signed integer | -| 6 | INT64 | 64-bit signed integer | -| 7 | VARINT64 | Variable-length encoded 64-bit signed integer | -| 8 | TAGGED_INT64 | Hybrid encoded 64-bit signed integer | -| 9 | UINT8 | 8-bit unsigned integer | -| 10 | UINT16 | 16-bit unsigned integer | -| 11 | UINT32 | 32-bit unsigned integer | -| 12 | VAR_UINT32 | Variable-length encoded 32-bit unsigned integer | -| 13 | UINT64 | 64-bit unsigned integer | -| 14 | VAR_UINT64 | Variable-length encoded 64-bit unsigned integer | -| 15 | TAGGED_UINT64 | Hybrid encoded 64-bit unsigned integer | -| 16 | FLOAT8 | 8-bit floating point (float8) | -| 17 | FLOAT16 | 16-bit floating point (half precision) | -| 18 | BFLOAT16 | 16-bit brain floating point | -| 19 | FLOAT32 | 32-bit floating point (single precision) | -| 20 | FLOAT64 | 64-bit floating point (double precision) | -| 21 | STRING | UTF-8/UTF-16/Latin1 encoded string | -| 22 | LIST | Ordered collection (List, Array, Vector) | -| 23 | SET | Unordered collection of unique elements | -| 24 | MAP | Key-value mapping | -| 25 | ENUM | Enum registered by numeric ID | -| 26 | NAMED_ENUM | Enum registered by namespace + type name | -| 27 | STRUCT | Struct registered by numeric ID (schema consistent) | -| 28 | COMPATIBLE_STRUCT | Struct with schema evolution support (by ID) | -| 29 | NAMED_STRUCT | Struct registered by namespace + type name | -| 30 | NAMED_COMPATIBLE_STRUCT | Struct with schema evolution (by name) | -| 31 | EXT | Extension type registered by numeric ID | -| 32 | NAMED_EXT | Extension type registered by namespace + type name | -| 33 | UNION | Union value, schema identity not embedded | -| 34 | TYPED_UNION | Union value with registered numeric type ID | -| 35 | NAMED_UNION | Union value with embedded type name/TypeDef | -| 36 | NONE | Empty/unit type (no data) | -| 37 | DURATION | Time duration (seconds + nanoseconds) | -| 38 | TIMESTAMP | Point in time (seconds + nanoseconds since epoch) | -| 39 | DATE | Date without timezone (signed varint64 days) | -| 40 | DECIMAL | Arbitrary precision decimal (scale + unscaled) | -| 41 | BINARY | Raw binary data | -| 42 | ARRAY | Generic array type | -| 43 | BOOL_ARRAY | 1D boolean array | -| 44 | INT8_ARRAY | 1D int8 array | -| 45 | INT16_ARRAY | 1D int16 array | -| 46 | INT32_ARRAY | 1D int32 array | -| 47 | INT64_ARRAY | 1D int64 array | -| 48 | UINT8_ARRAY | 1D uint8 array | -| 49 | UINT16_ARRAY | 1D uint16 array | -| 50 | UINT32_ARRAY | 1D uint32 array | -| 51 | UINT64_ARRAY | 1D uint64 array | -| 52 | FLOAT8_ARRAY | 1D float8 array | -| 53 | FLOAT16_ARRAY | 1D float16 array | -| 54 | BFLOAT16_ARRAY | 1D bfloat16 array | -| 55 | FLOAT32_ARRAY | 1D float32 array | -| 56 | FLOAT64_ARRAY | 1D float64 array | +| Type ID | Name | Description | +| ------- | ----------------------- | ------------------------------------------------------ | +| 0 | UNKNOWN | Unknown type, used for dynamic typing | +| 1 | BOOL | Boolean value | +| 2 | INT8 | 8-bit signed integer | +| 3 | INT16 | 16-bit signed integer | +| 4 | INT32 | 32-bit signed integer | +| 5 | VARINT32 | Variable-length encoded 32-bit signed integer | +| 6 | INT64 | 64-bit signed integer | +| 7 | VARINT64 | Variable-length encoded 64-bit signed integer | +| 8 | TAGGED_INT64 | Hybrid encoded 64-bit signed integer | +| 9 | UINT8 | 8-bit unsigned integer | +| 10 | UINT16 | 16-bit unsigned integer | +| 11 | UINT32 | 32-bit unsigned integer | +| 12 | VAR_UINT32 | Variable-length encoded 32-bit unsigned integer | +| 13 | UINT64 | 64-bit unsigned integer | +| 14 | VAR_UINT64 | Variable-length encoded 64-bit unsigned integer | +| 15 | TAGGED_UINT64 | Hybrid encoded 64-bit unsigned integer | +| 16 | FLOAT8 | 8-bit floating point (float8) | +| 17 | FLOAT16 | 16-bit floating point (half precision) | +| 18 | BFLOAT16 | 16-bit brain floating point | +| 19 | FLOAT32 | 32-bit floating point (single precision) | +| 20 | FLOAT64 | 64-bit floating point (double precision) | +| 21 | STRING | UTF-8/UTF-16/Latin1 encoded string | +| 22 | LIST | Ordered collection (List, Array, Vector) | +| 23 | SET | Unordered collection of unique elements | +| 24 | MAP | Key-value mapping | +| 25 | ENUM | Enum registered by numeric ID | +| 26 | NAMED_ENUM | Enum registered by namespace + type name | +| 27 | STRUCT | Struct registered by numeric ID (schema consistent) | +| 28 | COMPATIBLE_STRUCT | Struct with schema evolution support (by ID) | +| 29 | NAMED_STRUCT | Struct registered by namespace + type name | +| 30 | NAMED_COMPATIBLE_STRUCT | Struct with schema evolution (by name) | +| 31 | EXT | Extension type registered by numeric ID | +| 32 | NAMED_EXT | Extension type registered by namespace + type name | +| 33 | UNION | Union value, schema identity not embedded | +| 34 | TYPED_UNION | Union value with registered numeric type ID | +| 35 | NAMED_UNION | Union value with embedded type name/TypeDef | +| 36 | NONE | Empty/unit type (no data) | +| 37 | DURATION | Time duration (seconds + nanoseconds) | +| 38 | TIMESTAMP | Point in time (seconds + nanoseconds since epoch) | +| 39 | DATE | Date without timezone (signed varint64 days) | +| 40 | DECIMAL | Arbitrary precision decimal (scale + unscaled) | +| 41 | BINARY | Raw binary data | +| 42 | ARRAY | Reserved for future dedicated multi-dimensional arrays | +| 43 | BOOL_ARRAY | 1D boolean array | +| 44 | INT8_ARRAY | 1D int8 array | +| 45 | INT16_ARRAY | 1D int16 array | +| 46 | INT32_ARRAY | 1D int32 array | +| 47 | INT64_ARRAY | 1D int64 array | +| 48 | UINT8_ARRAY | 1D uint8 array | +| 49 | UINT16_ARRAY | 1D uint16 array | +| 50 | UINT32_ARRAY | 1D uint32 array | +| 51 | UINT64_ARRAY | 1D uint64 array | +| 52 | FLOAT8_ARRAY | 1D float8 array | +| 53 | FLOAT16_ARRAY | 1D float16 array | +| 54 | BFLOAT16_ARRAY | 1D bfloat16 array | +| 55 | FLOAT32_ARRAY | 1D float32 array | +| 56 | FLOAT64_ARRAY | 1D float64 array | #### Type ID Encoding for User Types @@ -424,7 +426,10 @@ After the type ID: - If meta share is disabled, write `namespace` and `type_name` as meta strings. - If meta share is enabled, write a shared TypeDef entry (see below). - **UNION**: no extra bytes at this layer. -- **LIST / SET / MAP / ARRAY / primitives**: no extra bytes at this layer. +- **LIST / SET / MAP / primitives**: no extra bytes at this layer. + +`ARRAY (42)` is reserved for a future xlang extension for dedicated multi-dimensional arrays and +is not used in current xlang streams. Unregistered types are serialized as named types: @@ -482,7 +487,9 @@ The 8-byte header is a little-endian uint64: - If meta size >= 0xFF, the low 8 bits are set to 0xFF and an extra `varuint32(meta_size - 0xFF)` follows immediately after the header. - Bit 8: `HAS_FIELDS_META` (1 = fields metadata present). -- Bit 9: `COMPRESS_META` (1 = body is compressed; decompress before parsing). +- Bit 9: `COMPRESS_META` is reserved for a future xlang metadata-compression extension. + Current xlang writers MUST leave this bit unset and current xlang readers MUST treat a set bit + as unsupported. - Bits 10-13: reserved for future extension (must be zero). - High 50 bits: hash of the TypeDef body. @@ -1026,14 +1033,30 @@ Format: ``` - `seconds`: Number of seconds in the duration, encoded as a signed varint64. Can be positive or negative. -- `nanoseconds`: Nanosecond adjustment to the duration, encoded as a signed int32. Value range is [0, 999,999,999] for positive durations, and [-999,999,999, 0] for negative durations. +- `nanoseconds`: Nanosecond adjustment to the duration, encoded as a signed int32. Notes: - The duration is stored as two separate fields to maintain precision and avoid overflow issues. - Seconds are encoded using varint64 for compact representation of common duration values. - Nanoseconds are stored as a fixed int32 since the range is limited. -- The sign of the duration is determined by the seconds field. When seconds is 0, the sign is determined by nanoseconds. + +#### Canonical Rules + +- Writers MUST normalize durations so `nanoseconds` is always in `[0, 1_000_000_000)`. +- Zero MUST be encoded as `seconds = 0` and `nanoseconds = 0`. +- Negative sub-second durations MUST borrow one second and use a positive nanosecond adjustment. + Example: `-0.5s` is encoded as `seconds = -1`, `nanoseconds = 500_000_000`. +- More generally, the encoded pair MUST satisfy: + - `duration = seconds + nanoseconds / 1_000_000_000` + - `0 <= nanoseconds < 1_000_000_000` + +#### Final Value + +After decoding `seconds` and `nanoseconds`, the duration value is reconstructed as the exact +duration represented by: + +`seconds + nanoseconds / 1_000_000_000` ### collection/list @@ -1142,8 +1165,10 @@ Float array specifics: #### Multi-dimensional arrays -Xlang does not define a dedicated tensor encoding. Multi-dimensional arrays are serialized as -nested lists, while one-dimensional primitive arrays use the `*_ARRAY` type IDs. +Current xlang does not define a dedicated multi-dimensional array/tensor encoding. Multi-dimensional +arrays are serialized as nested lists, while one-dimensional primitive arrays use the `*_ARRAY` +type IDs. Internal type ID `ARRAY (42)` is reserved for a future dedicated multi-dimensional array +encoding and is not used in current xlang streams. #### object array @@ -1229,12 +1254,15 @@ The implementation can accumulate read count with map size to decide whether to ### enum -Enums are serialized as an unsigned var int tag. For plain enums, this tag is typically the -declaration ordinal. Some implementations or generated enum forms may instead use an explicit -stable enum value or variant ID. If the encoding relies on declaration order, reordering enum -values can change the deserialized result. In such cases, users should prefer an explicit stable -ID-based encoding or register a custom enum serializer that writes a stable string representation -with unique hash disabled. +Enums are serialized as an unsigned varint enum ID. + +- If the enum definition provides an explicit enum ID / variant ID / stable numeric tag for a + value, that ID MUST be used. +- If no explicit enum ID is specified, the declaration ordinal is used as the enum ID by default. + +This means the wire contract is always an enum ID. When the enum ID comes from declaration order, +reordering enum values changes the wire IDs and can change the deserialized result. For +cross-language or long-lived schemas, users should prefer explicit stable enum IDs. ### timestamp @@ -1373,7 +1401,7 @@ Within each group, apply the following sort keys in order until a difference is 1. **Compression category**: fixed-size numeric and boolean types first, then compressed numeric types (`VARINT32`, `VAR_UINT32`, `VARINT64`, `VAR_UINT64`, `TAGGED_INT64`, `TAGGED_UINT64`). 2. **Primitive size** (descending): 8-byte > 4-byte > 2-byte > 1-byte. -3. **Internal type ID** (descending) as a tie-breaker for equal sizes. +3. **Internal type ID** (ascending) as a tie-breaker for equal sizes. 4. **Field identifier** (lexicographic ascending). **Built-in / Collection / Map groups (3-5):** diff --git a/docs/specification/xlang_type_mapping.md b/docs/specification/xlang_type_mapping.md index ee62b1423d..b0185aad51 100644 --- a/docs/specification/xlang_type_mapping.md +++ b/docs/specification/xlang_type_mapping.md @@ -48,64 +48,72 @@ When reading type IDs: ## Type Mapping -| Fory Type | Fory Type ID | Java | Python | Javascript | C++ | Golang | Rust | -| ----------------------- | ------------ | --------------- | ------------------------ | ------------------- | ------------------------------ | ---------------- | ----------------- | -| bool | 1 | bool/Boolean | bool | Boolean | bool | bool | bool | -| int8 | 2 | byte/Byte | int/pyfory.int8 | Type.int8() | int8_t | int8 | i8 | -| int16 | 3 | short/Short | int/pyfory.int16 | Type.int16() | int16_t | int16 | i16 | -| int32 | 4 | int/Integer | int/pyfory.fixed_int32 | Type.int32() | int32_t | int32 | i32 | -| varint32 | 5 | int/Integer | int/pyfory.int32 | Type.varint32() | int32_t | int32 | i32 | -| int64 | 6 | long/Long | int/pyfory.fixed_int64 | Type.int64() | int64_t | int64 | i64 | -| varint64 | 7 | long/Long | int/pyfory.int64 | Type.varint64() | int64_t | int64 | i64 | -| tagged_int64 | 8 | long/Long | int/pyfory.tagged_int64 | Type.tagged_int64() | int64_t | int64 | i64 | -| uint8 | 9 | short/Short | int/pyfory.uint8 | Type.uint8() | uint8_t | uint8 | u8 | -| uint16 | 10 | int/Integer | int/pyfory.uint16 | Type.uint16() | uint16_t | uint16 | u16 | -| uint32 | 11 | long/Long | int/pyfory.fixed_uint32 | Type.uint32() | uint32_t | uint32 | u32 | -| var_uint32 | 12 | long/Long | int/pyfory.uint32 | Type.varUInt32() | uint32_t | uint32 | u32 | -| uint64 | 13 | long/Long | int/pyfory.fixed_uint64 | Type.uint64() | uint64_t | uint64 | u64 | -| var_uint64 | 14 | long/Long | int/pyfory.uint64 | Type.varUInt64() | uint64_t | uint64 | u64 | -| tagged_uint64 | 15 | long/Long | int/pyfory.tagged_uint64 | Type.taggedUInt64() | uint64_t | uint64 | u64 | -| float8 | 16 | / | / | / | / | / | / | -| float16 | 17 | Float16 | float/pyfory.float16 | Type.float16() | fory::float16_t | fory.float16 | fory::f16 | -| bfloat16 | 18 | Bfloat16 | / | / | / | / | / | -| float32 | 19 | float/Float | float/pyfory.float32 | Type.float32() | float | float32 | f32 | -| float64 | 20 | double/Double | float/pyfory.float64 | Type.float64() | double | float64 | f64 | -| string | 21 | String | str | String | string | string | String/str | -| list | 22 | List/Collection | list/tuple | array | vector | slice | Vec | -| set | 23 | Set | set | / | set | fory.Set | Set | -| map | 24 | Map | dict | Map | unordered_map | map | HashMap | -| enum | 25 | Enum subclasses | enum subclasses | / | enum | / | enum | -| named_enum | 26 | Enum subclasses | enum subclasses | / | enum | / | enum | -| struct | 27 | pojo/record | data class | object | struct/class | struct | struct | -| compatible_struct | 28 | pojo/record | data class | object | struct/class | struct | struct | -| named_struct | 29 | pojo/record | data class | object | struct/class | struct | struct | -| named_compatible_struct | 30 | pojo/record | data class | object | struct/class | struct | struct | -| ext | 31 | pojo/record | data class | object | struct/class | struct | struct | -| named_ext | 32 | pojo/record | data class | object | struct/class | struct | struct | -| union | 33 | Union | typing.Union | / | `std::variant` | / | tagged union enum | -| none | 36 | null | None | null | `std::monostate` | nil | `()` | -| duration | 37 | Duration | timedelta | Number | duration | Duration | Duration | -| timestamp | 38 | Instant | datetime | Number | std::chrono::nanoseconds | Time | DateTime | -| date | 39 | LocalDate | datetime.date | Date | fory::serialization::Date | fory.Date | chrono::NaiveDate | -| decimal | 40 | BigDecimal | Decimal | Decimal | / | fory.Decimal | fory::Decimal | -| binary | 41 | byte[] | bytes | / | `uint8_t[n]/vector` | `[n]uint8/[]T` | `Vec` | -| array | 42 | array | np.ndarray | / | / | array/slice | Vec | -| bool_array | 43 | bool[] | ndarray(np.bool\_) | / | `bool[n]` | `[n]bool/[]T` | `Vec` | -| int8_array | 44 | byte[] | ndarray(int8) | / | `int8_t[n]/vector` | `[n]int8/[]T` | `Vec` | -| int16_array | 45 | short[] | ndarray(int16) | / | `int16_t[n]/vector` | `[n]int16/[]T` | `Vec` | -| int32_array | 46 | int[] | ndarray(int32) | / | `int32_t[n]/vector` | `[n]int32/[]T` | `Vec` | -| int64_array | 47 | long[] | ndarray(int64) | / | `int64_t[n]/vector` | `[n]int64/[]T` | `Vec` | -| uint8_array | 48 | short[] | ndarray(uint8) | / | `uint8_t[n]/vector` | `[n]uint8/[]T` | `Vec` | -| uint16_array | 49 | int[] | ndarray(uint16) | / | `uint16_t[n]/vector` | `[n]uint16/[]T` | `Vec` | -| uint32_array | 50 | long[] | ndarray(uint32) | / | `uint32_t[n]/vector` | `[n]uint32/[]T` | `Vec` | -| uint64_array | 51 | long[] | ndarray(uint64) | / | `uint64_t[n]/vector` | `[n]uint64/[]T` | `Vec` | -| float8_array | 52 | / | / | / | / | / | / | -| float16_array | 53 | Float16List | ndarray(float16) | / | `fory::float16_t[n]/vector` | `[n]float16/[]T` | `Vec` | -| bfloat16_array | 54 | Bfloat16List | / | / | / | / | / | -| float32_array | 55 | float[] | ndarray(float32) | / | `float[n]/vector` | `[n]float32/[]T` | `Vec` | -| float64_array | 56 | double[] | ndarray(float64) | / | `double[n]/vector` | `[n]float64/[]T` | `Vec` | - -## Type info(not implemented currently) +| Fory Type | Fory Type ID | Java | Python | Javascript | C++ | Golang | Rust | +| ----------------------- | ------------ | --------------- | ------------------------ | ------------------------------ | --------------------------------------------------- | ---------------------------------------------- | --------------------------------- | +| bool | 1 | bool/Boolean | bool | Boolean | bool | bool | bool | +| int8 | 2 | byte/Byte | int/pyfory.int8 | Type.int8() | int8_t | int8 | i8 | +| int16 | 3 | short/Short | int/pyfory.int16 | Type.int16() | int16_t | int16 | i16 | +| int32 | 4 | int/Integer | int/pyfory.fixed_int32 | Type.int32() | int32_t | int32 | i32 | +| varint32 | 5 | int/Integer | int/pyfory.int32 | Type.varint32() | int32_t | int32 | i32 | +| int64 | 6 | long/Long | int/pyfory.fixed_int64 | Type.int64() | int64_t | int64 | i64 | +| varint64 | 7 | long/Long | int/pyfory.int64 | Type.varint64() | int64_t | int64 | i64 | +| tagged_int64 | 8 | long/Long | int/pyfory.tagged_int64 | Type.tagged_int64() | int64_t | int64 | i64 | +| uint8 | 9 | short/Short | int/pyfory.uint8 | Type.uint8() | uint8_t | uint8 | u8 | +| uint16 | 10 | int/Integer | int/pyfory.uint16 | Type.uint16() | uint16_t | uint16 | u16 | +| uint32 | 11 | long/Long | int/pyfory.fixed_uint32 | Type.uint32() | uint32_t | uint32 | u32 | +| var_uint32 | 12 | long/Long | int/pyfory.uint32 | Type.varUInt32() | uint32_t | uint32 | u32 | +| uint64 | 13 | long/Long | int/pyfory.fixed_uint64 | Type.uint64() | uint64_t | uint64 | u64 | +| var_uint64 | 14 | long/Long | int/pyfory.uint64 | Type.varUInt64() | uint64_t | uint64 | u64 | +| tagged_uint64 | 15 | long/Long | int/pyfory.tagged_uint64 | Type.taggedUInt64() | uint64_t | uint64 | u64 | +| float8 | 16 | / | / | / | / | / | / | +| float16 | 17 | Float16 | float/pyfory.float16 | `number` | `fory::float16_t` | `float16.Float16` | `fory::f16` | +| bfloat16 | 18 | BFloat16 | pyfory.bfloat16 | `BFloat16` / `number` | `fory::bfloat16_t` | `bfloat16.BFloat16` | `BFloat16` | +| float32 | 19 | float/Float | float/pyfory.float32 | Type.float32() | float | float32 | f32 | +| float64 | 20 | double/Double | float/pyfory.float64 | Type.float64() | double | float64 | f64 | +| string | 21 | String | str | String | string | string | String/str | +| list | 22 | List/Collection | list/tuple | array | vector | slice | Vec | +| set | 23 | Set | set | / | set | fory.Set | Set | +| map | 24 | Map | dict | Map | unordered_map | map | HashMap | +| enum | 25 | Enum subclasses | enum subclasses | / | enum | / | enum | +| named_enum | 26 | Enum subclasses | enum subclasses | / | enum | / | enum | +| struct | 27 | pojo/record | data class | object | struct/class | struct | struct | +| compatible_struct | 28 | pojo/record | data class | object | struct/class | struct | struct | +| named_struct | 29 | pojo/record | data class | object | struct/class | struct | struct | +| named_compatible_struct | 30 | pojo/record | data class | object | struct/class | struct | struct | +| ext | 31 | pojo/record | data class | object | struct/class | struct | struct | +| named_ext | 32 | pojo/record | data class | object | struct/class | struct | struct | +| union | 33 | Union | typing.Union | / | `std::variant` | / | tagged union enum | +| none | 36 | null | None | null | `std::monostate` | nil | `()` | +| duration | 37 | Duration | timedelta | Number | duration | Duration | Duration | +| timestamp | 38 | Instant | datetime | Number | std::chrono::nanoseconds | Time | DateTime | +| date | 39 | LocalDate | datetime.date | Date | fory::serialization::Date | fory.Date | chrono::NaiveDate | +| decimal | 40 | BigDecimal | Decimal | Decimal | / | fory.Decimal | fory::Decimal | +| binary | 41 | byte[] | bytes | / | `uint8_t[n]/vector` | `[n]uint8/[]T` | `Vec` | +| bool_array | 43 | bool[] | ndarray(np.bool\_) | / | `bool[n]` | `[n]bool/[]T` | `Vec` | +| int8_array | 44 | byte[] | ndarray(int8) | / | `int8_t[n]/vector` | `[n]int8/[]T` | `Vec` | +| int16_array | 45 | short[] | ndarray(int16) | / | `int16_t[n]/vector` | `[n]int16/[]T` | `Vec` | +| int32_array | 46 | int[] | ndarray(int32) | / | `int32_t[n]/vector` | `[n]int32/[]T` | `Vec` | +| int64_array | 47 | long[] | ndarray(int64) | / | `int64_t[n]/vector` | `[n]int64/[]T` | `Vec` | +| uint8_array | 48 | short[] | ndarray(uint8) | / | `uint8_t[n]/vector` | `[n]uint8/[]T` | `Vec` | +| uint16_array | 49 | int[] | ndarray(uint16) | / | `uint16_t[n]/vector` | `[n]uint16/[]T` | `Vec` | +| uint32_array | 50 | long[] | ndarray(uint32) | / | `uint32_t[n]/vector` | `[n]uint32/[]T` | `Vec` | +| uint64_array | 51 | long[] | ndarray(uint64) | / | `uint64_t[n]/vector` | `[n]uint64/[]T` | `Vec` | +| float8_array | 52 | / | / | / | / | / | / | +| float16_array | 53 | Float16List | ndarray(float16) | `number[]` | `fory::float16_t[n]/std::vector` | `[N]float16.Float16` / `[]float16.Float16` | `Vec` / `[Float16; N]` | +| bfloat16_array | 54 | BFloat16List | pyfory.bfloat16array | `BFloat16Array` / `BFloat16[]` | `fory::bfloat16_t[n]/std::vector` | `[N]bfloat16.BFloat16` / `[]bfloat16.BFloat16` | `Vec` / `[BFloat16; N]` | +| float32_array | 55 | float[] | ndarray(float32) | / | `float[n]/vector` | `[n]float32/[]T` | `Vec` | +| float64_array | 56 | double[] | ndarray(float64) | / | `double[n]/vector` | `[n]float64/[]T` | `Vec` | + +Notes: + +- `Float16List` and `BFloat16List` are the xlang `float16_array` and `bfloat16_array` carriers. +- `Float16[]` and `BFloat16[]` remain object arrays in xlang mode and serialize with the `list` wire type. +- `ARRAY (42)` is reserved for a future dedicated multi-dimensional array encoding and is not part + of the current xlang type-mapping surface. +- Current xlang uses `*_ARRAY` for one-dimensional primitive arrays and nested `list` for + multi-dimensional arrays. + +## Type info Due to differences between type systems of languages, those types can't be mapped one-to-one between languages. @@ -130,9 +138,9 @@ Here is en example: ```java class Foo { - @Int32Type(varint = true) + @Int32Type(compress = false) int f1; - List<@Int32Type(varint = true) Integer> f2; + List f2; } ``` @@ -140,6 +148,6 @@ Here is en example: ```python class Foo: - f1: pyfory.varint32 - f2: List[pyfory.varint32] + f1: pyfory.fixed_int32 + f2: List[pyfory.int32] ``` diff --git a/go/fory/codegen/utils.go b/go/fory/codegen/utils.go index 7062ee1bc3..4bdb094039 100644 --- a/go/fory/codegen/utils.go +++ b/go/fory/codegen/utils.go @@ -455,12 +455,12 @@ func sortFields(fields []*FieldInfo) { }) } -// Field group constants for sorting +// Field group constants for sorting. // This matches reflection's field ordering in field_info.go: -// primitives → internal built-in → list/set → map → other +// primitives → built-in non-container (including primitive arrays) → list/set → map → other const ( groupPrimitive = 0 // primitive and nullable primitive fields - groupInternalBuiltin = 1 // built-in types (STRING/BINARY/arrays/etc.) sorted by typeId then name + groupInternalBuiltin = 1 // built-in non-container types sorted by typeId then name groupListSet = 2 // LIST/SET sorted by typeId then name groupMap = 3 // MAP sorted by typeId then name groupOther = 4 // structs, enums, and unknown types - sorted by name diff --git a/go/fory/field_info.go b/go/fory/field_info.go index 246ba0e7f8..d5522d0903 100644 --- a/go/fory/field_info.go +++ b/go/fory/field_info.go @@ -174,14 +174,14 @@ func GroupFields(fields []FieldInfo) FieldGroup { } } - // Sort fixedFields: size desc, typeId desc, name asc + // Sort fixedFields: size desc, typeId asc, name asc sort.SliceStable(g.FixedFields, func(i, j int) bool { fi, fj := &g.FixedFields[i], &g.FixedFields[j] if fi.Meta.FixedSize != fj.Meta.FixedSize { return fi.Meta.FixedSize > fj.Meta.FixedSize // size descending } if fi.Meta.TypeId != fj.Meta.TypeId { - return fi.Meta.TypeId > fj.Meta.TypeId // typeId descending + return fi.Meta.TypeId < fj.Meta.TypeId // typeId ascending } return getFieldSortKey(fi) < getFieldSortKey(fj) // tag ID or name ascending }) @@ -200,8 +200,8 @@ func GroupFields(fields []FieldInfo) FieldGroup { g.FixedSize += g.FixedFields[i].Meta.FixedSize } - // Sort varintFields: underlying type size desc, typeId desc, name asc - // Note: Java uses primitive type size (8 for long, 4 for int), not encoding max size + // Sort varintFields: underlying type size desc, typeId asc, name asc + // Note: xlang uses primitive type size (8 for long, 4 for int), not encoding max size sort.SliceStable(g.VarintFields, func(i, j int) bool { fi, fj := &g.VarintFields[i], &g.VarintFields[j] sizeI := getUnderlyingTypeSize(fi.DispatchId) @@ -210,7 +210,7 @@ func GroupFields(fields []FieldInfo) FieldGroup { return sizeI > sizeJ // size descending } if fi.Meta.TypeId != fj.Meta.TypeId { - return fi.Meta.TypeId > fj.Meta.TypeId // typeId descending + return fi.Meta.TypeId < fj.Meta.TypeId // typeId ascending } return getFieldSortKey(fi) < getFieldSortKey(fj) // tag ID or name ascending }) @@ -229,7 +229,8 @@ func GroupFields(fields []FieldInfo) FieldGroup { } // Sort remainingFields: nullable primitives first (by primitiveComparator), - // then other internal types (typeId, name), then lists, sets, maps, other (by name) + // then built-in scalar types (typeId, name), then lists, sets, maps, + // then primitive arrays and other fields by sort key. sort.SliceStable(g.RemainingFields, func(i, j int) bool { fi, fj := &g.RemainingFields[i], &g.RemainingFields[j] catI, catJ := getFieldCategory(fi), getFieldCategory(fj) @@ -240,7 +241,7 @@ func GroupFields(fields []FieldInfo) FieldGroup { if catI == 0 { return comparePrimitiveFields(fi, fj) } - // Within internal/build-in or collection categories, sort by typeId then sort key. + // Within built-in scalar or collection categories, sort by typeId then sort key. if catI == 1 || catI == 2 || catI == 3 { if fi.Meta.TypeId != fj.Meta.TypeId { return fi.Meta.TypeId < fj.Meta.TypeId @@ -283,7 +284,7 @@ func isEnumField(field *FieldInfo) bool { // getFieldCategory returns the category for sorting remainingFields: // 0: nullable primitives (sorted by primitiveComparator) -// 1: internal build-in types (sorted by typeId, then sort key) +// 1: internal built-in non-container types, including primitive arrays (sorted by typeId, then sort key) // 2: list/set collections (sorted by typeId, then sort key) // 3: map collections (sorted by typeId, then sort key) // 4: struct, enum, and all other types (sorted by sort key) @@ -304,27 +305,30 @@ func getFieldCategory(field *FieldInfo) int { if typeId == MAP { return 3 } - // Internal build-in types: sorted by typeId, then sort key (matches Java build-in group) + if isPrimitiveArrayType(typeId) { + return 1 + } + // Internal built-in non-container types: sorted by typeId, then sort key. return 1 } // comparePrimitiveFields compares two nullable primitive fields using Java's primitiveComparator logic: -// fixed before varint, then underlying type size desc, typeId desc, name asc +// fixed before varint, then underlying type size desc, typeId asc, name asc func comparePrimitiveFields(fi, fj *FieldInfo) bool { iFixed := isNullableFixedSizePrimitive(fi.DispatchId) jFixed := isNullableFixedSizePrimitive(fj.DispatchId) if iFixed != jFixed { return iFixed // fixed before varint } - // Same category: compare by underlying type size desc, typeId desc, name asc - // Note: Java uses primitive type size (8, 4, 2, 1), not encoding size + // Same category: compare by underlying type size desc, typeId asc, name asc + // Note: xlang uses primitive type size (8, 4, 2, 1), not encoding size sizeI := getUnderlyingTypeSize(fi.DispatchId) sizeJ := getUnderlyingTypeSize(fj.DispatchId) if sizeI != sizeJ { return sizeI > sizeJ // size descending } if fi.Meta.TypeId != fj.Meta.TypeId { - return fi.Meta.TypeId > fj.Meta.TypeId // typeId descending + return fi.Meta.TypeId < fj.Meta.TypeId // typeId ascending } return getFieldSortKey(fi) < getFieldSortKey(fj) // tag ID or name ascending } @@ -686,7 +690,7 @@ func sortFields( primitives = append(primitives, t) } case isPrimitiveArrayType(t.typeID): - // Primitive arrays: built-in non-container types (sorted by typeId then name) + // Primitive arrays are built-in non-container types in xlang field ordering. otherInternalTypeFields = append(otherInternalTypeFields, t) case isListType(t.typeID), isSetType(t.typeID): // LIST, SET: collection group @@ -704,7 +708,7 @@ func sortFields( } } // Sort primitives (non-nullable) - same logic as boxed - // Java sorts by: compressed (varint) types last, then by size (largest first), then by type ID (descending) + // Xlang sorts by: compressed (varint) types last, then by size (largest first), then by type ID (ascending) // Fixed types: BOOL, INT8, UINT8, INT16, UINT16, INT32, UINT32, INT64, UINT64, FLOAT32, FLOAT64 // Varint types: VARINT32, VARINT64, VAR_UINT32, VAR_UINT64, TAGGED_INT64, TAGGED_UINT64 isVarintTypeId := func(typeID TypeId) bool { @@ -724,9 +728,9 @@ func sortFields( if szI != szJ { return szI > szJ } - // Tie-breaker: type ID descending (higher type ID first), then field name + // Tie-breaker: type ID ascending (lower type ID first), then field name if ai.typeID != aj.typeID { - return ai.typeID > aj.typeID + return ai.typeID < aj.typeID } return ai.getSortKey() < aj.getSortKey() }) @@ -751,11 +755,11 @@ func sortFields( sortByTypeIDThenName(otherInternalTypeFields) sortByTypeIDThenName(listSet) sortByTypeIDThenName(maps) - // Merge all category 2 fields (primitive arrays, userDefined, others) and sort by name - // This matches GroupFields' getFieldCategory which sorts all category 4 fields together + // Merge primitive arrays, user-defined types, and unknown types into the same + // sort-key-ordered tail group. otherGroup := make([]triple, 0, len(userDefined)+len(others)) - otherGroup = append(otherGroup, userDefined...) // structs, enums, ext - otherGroup = append(otherGroup, others...) // unknown types + otherGroup = append(otherGroup, userDefined...) + otherGroup = append(otherGroup, others...) sortTuple(otherGroup) // Order: primitives, boxed, built-in non-container, list/set, map, other (by name) @@ -763,7 +767,7 @@ func sortFields( all := make([]triple, 0, len(fieldNames)) all = append(all, primitives...) all = append(all, boxed...) - all = append(all, otherInternalTypeFields...) // STRING, BINARY, primitive arrays, time, unions, etc. + all = append(all, otherInternalTypeFields...) // STRING, BINARY, time, unions, etc. all = append(all, listSet...) all = append(all, maps...) all = append(all, otherGroup...) @@ -922,6 +926,12 @@ func typeIdFromKind(type_ reflect.Type) TypeId { case reflect.Int16: return INT16_ARRAY case reflect.Uint16: + if type_.Elem() == float16Type { + return FLOAT16_ARRAY + } + if type_.Elem() == bfloat16Type { + return BFLOAT16_ARRAY + } return UINT16_ARRAY case reflect.Int32: return INT32_ARRAY @@ -952,6 +962,12 @@ func typeIdFromKind(type_ reflect.Type) TypeId { case reflect.Int16: return INT16_ARRAY case reflect.Uint16: + if type_.Elem() == float16Type { + return FLOAT16_ARRAY + } + if type_.Elem() == bfloat16Type { + return BFLOAT16_ARRAY + } return UINT16_ARRAY case reflect.Int32: return INT32_ARRAY diff --git a/go/fory/skip.go b/go/fory/skip.go index 3fde813532..9d99f65683 100644 --- a/go/fory/skip.go +++ b/go/fory/skip.go @@ -613,23 +613,23 @@ func skipValue(ctx *ReadContext, fieldDef FieldDef, readRefFlag bool, isField bo } _ = ctx.buffer.ReadBinary(length, err) case INT16_ARRAY, UINT16_ARRAY, FLOAT16_ARRAY, BFLOAT16_ARRAY: - length := ctx.ReadBinaryLength() + size := ctx.ReadBinaryLength() if ctx.HasError() { return } - _ = ctx.buffer.ReadBinary(length*2, err) + _ = ctx.buffer.ReadBinary(size, err) case INT32_ARRAY, UINT32_ARRAY, FLOAT32_ARRAY: - length := ctx.ReadBinaryLength() + size := ctx.ReadBinaryLength() if ctx.HasError() { return } - _ = ctx.buffer.ReadBinary(length*4, err) + _ = ctx.buffer.ReadBinary(size, err) case INT64_ARRAY, UINT64_ARRAY, FLOAT64_ARRAY: - length := ctx.ReadBinaryLength() + size := ctx.ReadBinaryLength() if ctx.HasError() { return } - _ = ctx.buffer.ReadBinary(length*8, err) + _ = ctx.buffer.ReadBinary(size, err) // Date/Time types case DATE: diff --git a/go/fory/struct_test.go b/go/fory/struct_test.go index 0dfbb1a1c5..d4d42c130a 100644 --- a/go/fory/struct_test.go +++ b/go/fory/struct_test.go @@ -21,6 +21,7 @@ import ( "reflect" "testing" + "github.com/apache/fory/go/fory/bfloat16" "github.com/apache/fory/go/fory/float16" "github.com/apache/fory/go/fory/optional" "github.com/stretchr/testify/require" @@ -484,6 +485,69 @@ func TestSetFieldTypeId(t *testing.T) { } } +func TestReducedPrecisionPrimitiveArraysUseBuiltInOrdering(t *testing.T) { + type TestStruct struct { + Float16Value float16.Float16 + Bfloat16Value bfloat16.BFloat16 + Bfloat16Array []bfloat16.BFloat16 + Float16Array []float16.Float16 + } + + f := New(WithXlang(true), WithCompatible(false)) + require.NoError(t, f.RegisterStruct(TestStruct{}, 1004)) + + typeInfo, err := f.typeResolver.getTypeInfo(reflect.ValueOf(TestStruct{}), false) + require.NoError(t, err) + + structSer, ok := typeInfo.Serializer.(*structSerializer) + require.True(t, ok) + require.NoError(t, structSer.initialize(f.typeResolver)) + + require.Len(t, structSer.fieldGroup.FixedFields, 2) + require.Equal(t, "float16_value", structSer.fieldGroup.FixedFields[0].Meta.Name) + require.Equal(t, "bfloat16_value", structSer.fieldGroup.FixedFields[1].Meta.Name) + + require.Len(t, structSer.fieldGroup.RemainingFields, 2) + require.Equal(t, "float16_array", structSer.fieldGroup.RemainingFields[0].Meta.Name) + require.Equal(t, "bfloat16_array", structSer.fieldGroup.RemainingFields[1].Meta.Name) +} + +func TestCompatibleSkipReducedPrecisionArrays(t *testing.T) { + type Source struct { + Float16Value float16.Float16 + Bfloat16Value bfloat16.BFloat16 + Bfloat16Array []bfloat16.BFloat16 + Float16Array []float16.Float16 + } + type Empty struct{} + + writer := New(WithXlang(true), WithCompatible(true)) + require.NoError(t, writer.RegisterStruct(Source{}, 1005)) + + data, err := writer.Serialize(&Source{ + Float16Value: float16.Float16FromBits(0x3E00), + Bfloat16Value: bfloat16.BFloat16FromBits(0x3FC0), + Bfloat16Array: []bfloat16.BFloat16{ + bfloat16.BFloat16FromBits(0x0000), + bfloat16.BFloat16FromBits(0x3F80), + bfloat16.BFloat16FromBits(0xBF80), + }, + Float16Array: []float16.Float16{ + float16.Float16FromBits(0x0000), + float16.Float16FromBits(0x3C00), + float16.Float16FromBits(0xBC00), + }, + }) + require.NoError(t, err) + + reader := New(WithXlang(true), WithCompatible(true)) + require.NoError(t, reader.RegisterStruct(Empty{}, 1005)) + + var out any + require.NoError(t, testDeserialize(t, reader, data, &out)) + require.NotNil(t, out) +} + func TestSkipAnyValueReadsSharedTypeMeta(t *testing.T) { type First struct { ID int diff --git a/go/fory/tests/structs_fory_gen.go b/go/fory/tests/structs_fory_gen.go index e26cf7608d..3971e5d364 100644 --- a/go/fory/tests/structs_fory_gen.go +++ b/go/fory/tests/structs_fory_gen.go @@ -1,6 +1,6 @@ // Code generated by forygen. DO NOT EDIT. // source: structs.go -// generated at: 2026-01-26T23:14:39+08:00 +// generated at: 2026-04-22T20:36:47+08:00 package fory @@ -56,7 +56,12 @@ func (g *DynamicSliceDemo_ForyGenSerializer) Write(ctx *fory.WriteContext, refMo ctx.Buffer().WriteInt8(-1) // NotNullValueFlag } if writeType { - ctx.Buffer().WriteVarUint32(uint32(fory.NAMED_STRUCT)) + typeInfo, err := ctx.TypeResolver().GetTypeInfo(value, true) + if err != nil { + ctx.SetError(fory.FromError(err)) + return + } + ctx.TypeResolver().WriteTypeInfo(ctx.Buffer(), typeInfo, ctx.Err()) } g.WriteData(ctx, value) } @@ -180,7 +185,7 @@ func (g *DynamicSliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v isXlang := ctx.TypeResolver().IsXlang() if isXlang { // xlang mode: slices are not nullable, read directly without null flag - sliceLen := int(buf.ReadVarUint32(err)) + sliceLen := ctx.ReadCollectionLength() if sliceLen == 0 { v.DynamicSlice = make([]any, 0) } else { @@ -199,7 +204,7 @@ func (g *DynamicSliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v if nullFlag == -3 { v.DynamicSlice = nil } else { - sliceLen := int(buf.ReadVarUint32(err)) + sliceLen := ctx.ReadCollectionLength() if sliceLen == 0 { v.DynamicSlice = make([]any, 0) } else { @@ -289,7 +294,12 @@ func (g *MapDemo_ForyGenSerializer) Write(ctx *fory.WriteContext, refMode fory.R ctx.Buffer().WriteInt8(-1) // NotNullValueFlag } if writeType { - ctx.Buffer().WriteVarUint32(uint32(fory.NAMED_STRUCT)) + typeInfo, err := ctx.TypeResolver().GetTypeInfo(value, true) + if err != nil { + ctx.SetError(fory.FromError(err)) + return + } + ctx.TypeResolver().WriteTypeInfo(ctx.Buffer(), typeInfo, ctx.Err()) } g.WriteData(ctx, value) } @@ -633,7 +643,7 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) isXlang := ctx.TypeResolver().IsXlang() if isXlang { // xlang mode: maps are not nullable, read directly without null flag - mapLen := int(buf.ReadVarUint32(err)) + mapLen := ctx.ReadCollectionLength() if mapLen == 0 { v.IntMap = make(map[int]int) } else { @@ -667,7 +677,7 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if nullFlag == -3 { v.IntMap = nil } else { - mapLen := int(buf.ReadVarUint32(err)) + mapLen := ctx.ReadCollectionLength() if mapLen == 0 { v.IntMap = make(map[int]int) } else { @@ -703,7 +713,7 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) isXlang := ctx.TypeResolver().IsXlang() if isXlang { // xlang mode: maps are not nullable, read directly without null flag - mapLen := int(buf.ReadVarUint32(err)) + mapLen := ctx.ReadCollectionLength() if mapLen == 0 { v.MixedMap = make(map[string]int) } else { @@ -737,7 +747,7 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if nullFlag == -3 { v.MixedMap = nil } else { - mapLen := int(buf.ReadVarUint32(err)) + mapLen := ctx.ReadCollectionLength() if mapLen == 0 { v.MixedMap = make(map[string]int) } else { @@ -773,7 +783,7 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) isXlang := ctx.TypeResolver().IsXlang() if isXlang { // xlang mode: maps are not nullable, read directly without null flag - mapLen := int(buf.ReadVarUint32(err)) + mapLen := ctx.ReadCollectionLength() if mapLen == 0 { v.StringMap = make(map[string]string) } else { @@ -807,7 +817,7 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if nullFlag == -3 { v.StringMap = nil } else { - mapLen := int(buf.ReadVarUint32(err)) + mapLen := ctx.ReadCollectionLength() if mapLen == 0 { v.StringMap = make(map[string]string) } else { @@ -912,7 +922,12 @@ func (g *SliceDemo_ForyGenSerializer) Write(ctx *fory.WriteContext, refMode fory ctx.Buffer().WriteInt8(-1) // NotNullValueFlag } if writeType { - ctx.Buffer().WriteVarUint32(uint32(fory.NAMED_STRUCT)) + typeInfo, err := ctx.TypeResolver().GetTypeInfo(value, true) + if err != nil { + ctx.SetError(fory.FromError(err)) + return + } + ctx.TypeResolver().WriteTypeInfo(ctx.Buffer(), typeInfo, ctx.Err()) } g.WriteData(ctx, value) } @@ -1138,7 +1153,7 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD isXlang := ctx.TypeResolver().IsXlang() if isXlang { // xlang mode: slices are not nullable, read directly without null flag - sliceLen := int(buf.ReadVarUint32(err)) + sliceLen := ctx.ReadCollectionLength() if sliceLen == 0 { v.StringSlice = make([]string, 0) } else { @@ -1176,7 +1191,7 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD if nullFlag == -3 { v.StringSlice = nil } else { - sliceLen := int(buf.ReadVarUint32(err)) + sliceLen := ctx.ReadCollectionLength() if sliceLen == 0 { v.StringSlice = make([]string, 0) } else { @@ -1285,7 +1300,12 @@ func (g *ValidationDemo_ForyGenSerializer) Write(ctx *fory.WriteContext, refMode ctx.Buffer().WriteInt8(-1) // NotNullValueFlag } if writeType { - ctx.Buffer().WriteVarUint32(uint32(fory.NAMED_STRUCT)) + typeInfo, err := ctx.TypeResolver().GetTypeInfo(value, true) + if err != nil { + ctx.SetError(fory.FromError(err)) + return + } + ctx.TypeResolver().WriteTypeInfo(ctx.Buffer(), typeInfo, ctx.Err()) } g.WriteData(ctx, value) } diff --git a/go/fory/tests/xlang/xlang_test_main.go b/go/fory/tests/xlang/xlang_test_main.go index b2600fce98..8935502312 100644 --- a/go/fory/tests/xlang/xlang_test_main.go +++ b/go/fory/tests/xlang/xlang_test_main.go @@ -26,6 +26,8 @@ import ( "runtime" "github.com/apache/fory/go/fory" + "github.com/apache/fory/go/fory/bfloat16" + "github.com/apache/fory/go/fory/float16" ) // ============================================================================ @@ -116,6 +118,17 @@ func getTwoEnumFieldStruct(obj any) TwoEnumFieldStruct { } } +func getReducedPrecisionFloatStruct(obj any) ReducedPrecisionFloatStruct { + switch v := obj.(type) { + case ReducedPrecisionFloatStruct: + return v + case *ReducedPrecisionFloatStruct: + return *v + default: + panic(fmt.Sprintf("expected ReducedPrecisionFloatStruct, got %T", obj)) + } +} + func getNullableComprehensiveSchemaConsistent(obj any) NullableComprehensiveSchemaConsistent { switch v := obj.(type) { case NullableComprehensiveSchemaConsistent: @@ -277,6 +290,13 @@ func (FixedOverrideStruct) ForyEvolving() bool { return false } +type ReducedPrecisionFloatStruct struct { + Float16Value float16.Float16 + Bfloat16Value bfloat16.BFloat16 + Float16Array []float16.Float16 + Bfloat16Array []bfloat16.BFloat16 +} + type StructWithList struct { Items []string } @@ -1483,6 +1503,70 @@ func testSchemaEvolutionCompatibleReverse() { writeFile(dataFile, serialized) } +func testReducedPrecisionFloatStruct() { + dataFile := getDataFile() + data := readFile(dataFile) + + f := fory.New(fory.WithXlang(true), fory.WithCompatible(false)) + f.RegisterStruct(ReducedPrecisionFloatStruct{}, 213) + + buf := fory.NewByteBuffer(data) + var obj any + err := f.DeserializeWithCallbackBuffers(buf, &obj, nil) + if err != nil { + panic(fmt.Sprintf("Failed to deserialize: %v", err)) + } + + result := getReducedPrecisionFloatStruct(obj) + if result.Float16Value.Bits() != 0x3E00 { + panic(fmt.Sprintf("float16_value mismatch: expected 0x3E00, got 0x%04x", result.Float16Value.Bits())) + } + if result.Bfloat16Value.Bits() != 0x3FC0 { + panic(fmt.Sprintf("bfloat16_value mismatch: expected 0x3FC0, got 0x%04x", result.Bfloat16Value.Bits())) + } + if len(result.Float16Array) != 3 || + result.Float16Array[0].Bits() != 0x0000 || + result.Float16Array[1].Bits() != 0x3C00 || + result.Float16Array[2].Bits() != 0xBC00 { + panic(fmt.Sprintf("float16_array mismatch: got %#v", result.Float16Array)) + } + if len(result.Bfloat16Array) != 3 || + result.Bfloat16Array[0].Bits() != 0x0000 || + result.Bfloat16Array[1].Bits() != 0x3F80 || + result.Bfloat16Array[2].Bits() != 0xBF80 { + panic(fmt.Sprintf("bfloat16_array mismatch: got %#v", result.Bfloat16Array)) + } + + serialized, err := f.Serialize(&result) + if err != nil { + panic(fmt.Sprintf("Failed to serialize: %v", err)) + } + + writeFile(dataFile, serialized) +} + +func testReducedPrecisionFloatStructCompatibleSkip() { + dataFile := getDataFile() + data := readFile(dataFile) + + f := fory.New(fory.WithXlang(true), fory.WithCompatible(true)) + f.RegisterStruct(EmptyStruct{}, 213) + + buf := fory.NewByteBuffer(data) + var obj any + err := f.DeserializeWithCallbackBuffers(buf, &obj, nil) + if err != nil { + panic(fmt.Sprintf("Failed to deserialize as EmptyStruct: %v", err)) + } + + serialized, err := f.Serialize(obj) + if err != nil { + panic(fmt.Sprintf("Failed to serialize: %v", err)) + } + + writeFile(dataFile, serialized) +} + // Enum field tests func testOneEnumFieldSchemaConsistent() { dataFile := getDataFile() @@ -2601,6 +2685,10 @@ func main() { testSchemaEvolutionCompatible() case "test_schema_evolution_compatible_reverse": testSchemaEvolutionCompatibleReverse() + case "test_reduced_precision_float_struct": + testReducedPrecisionFloatStruct() + case "test_reduced_precision_float_struct_compatible_skip": + testReducedPrecisionFloatStructCompatibleSkip() case "test_one_enum_field_schema": testOneEnumFieldSchemaConsistent() case "test_one_enum_field_compatible": diff --git a/go/fory/type_resolver.go b/go/fory/type_resolver.go index 540114735f..4d8b1abbf1 100644 --- a/go/fory/type_resolver.go +++ b/go/fory/type_resolver.go @@ -1219,8 +1219,16 @@ func (r *TypeResolver) getTypeInfo(value reflect.Value, create bool) (*TypeInfo, serializer = int32ArraySerializer{arrayType: type_} } case reflect.Uint16: - arrayTypeID = UINT16_ARRAY - serializer = uint16ArraySerializer{arrayType: type_} + if type_.Elem() == float16Type { + arrayTypeID = FLOAT16_ARRAY + serializer = float16ArraySerializer{arrayType: type_} + } else if type_.Elem() == bfloat16Type { + arrayTypeID = BFLOAT16_ARRAY + serializer = bfloat16ArraySerializer{arrayType: type_} + } else { + arrayTypeID = UINT16_ARRAY + serializer = uint16ArraySerializer{arrayType: type_} + } case reflect.Uint32: arrayTypeID = UINT32_ARRAY serializer = uint32ArraySerializer{arrayType: type_} @@ -1900,6 +1908,14 @@ func (r *TypeResolver) GetSliceSerializer(sliceType reflect.Type) (Serializer, e return int64SliceSerializer{}, nil case reflect.Uint8: return byteSliceSerializer{}, nil + case reflect.Uint16: + if elemType == float16Type { + return float16SliceSerializer{}, nil + } + if elemType == bfloat16Type { + return bfloat16SliceSerializer{}, nil + } + return uint16SliceSerializer{}, nil case reflect.Float32: return float32SliceSerializer{}, nil case reflect.Float64: @@ -1948,6 +1964,14 @@ func (r *TypeResolver) GetArraySerializer(arrayType reflect.Type) (Serializer, e return int64ArraySerializer{arrayType: arrayType}, nil case reflect.Uint8: return uint8ArraySerializer{arrayType: arrayType}, nil + case reflect.Uint16: + if elemType == float16Type { + return float16ArraySerializer{arrayType: arrayType}, nil + } + if elemType == bfloat16Type { + return bfloat16ArraySerializer{arrayType: arrayType}, nil + } + return uint16ArraySerializer{arrayType: arrayType}, nil case reflect.Float32: return float32ArraySerializer{arrayType: arrayType}, nil case reflect.Float64: diff --git a/go/fory/type_test.go b/go/fory/type_test.go index e8978e5930..a0206d0747 100644 --- a/go/fory/type_test.go +++ b/go/fory/type_test.go @@ -156,3 +156,44 @@ func TestCreateSerializerArrayTypes(t *testing.T) { } } } + +func TestGetSliceSerializerReducedPrecisionTypes(t *testing.T) { + fory := NewFory() + r := newTypeResolver(fory) + + serializer, err := r.GetSliceSerializer(reflect.TypeOf([]float16.Float16{})) + require.NoError(t, err) + require.IsType(t, float16SliceSerializer{}, serializer) + + serializer, err = r.GetSliceSerializer(reflect.TypeOf([]bfloat16.BFloat16{})) + require.NoError(t, err) + require.IsType(t, bfloat16SliceSerializer{}, serializer) +} + +func TestGetArraySerializerReducedPrecisionTypes(t *testing.T) { + fory := NewFory() + r := newTypeResolver(fory) + + serializer, err := r.GetArraySerializer(reflect.TypeOf([4]float16.Float16{})) + require.NoError(t, err) + require.IsType(t, float16ArraySerializer{}, serializer) + + serializer, err = r.GetArraySerializer(reflect.TypeOf([4]bfloat16.BFloat16{})) + require.NoError(t, err) + require.IsType(t, bfloat16ArraySerializer{}, serializer) +} + +func TestGetTypeInfoReducedPrecisionArrayTypeIDs(t *testing.T) { + fory := NewFory(WithXlang(true)) + r := newTypeResolver(fory) + + float16Info, err := r.GetTypeInfo(reflect.ValueOf([2]float16.Float16{}), true) + require.NoError(t, err) + require.Equal(t, uint32(FLOAT16_ARRAY), float16Info.TypeID) + require.IsType(t, float16ArraySerializer{}, float16Info.Serializer) + + bfloat16Info, err := r.GetTypeInfo(reflect.ValueOf([2]bfloat16.BFloat16{}), true) + require.NoError(t, err) + require.Equal(t, uint32(BFLOAT16_ARRAY), bfloat16Info.TypeID) + require.IsType(t, bfloat16ArraySerializer{}, bfloat16Info.Serializer) +} diff --git a/go/fory/types.go b/go/fory/types.go index 38b1b71b4b..50e34145ad 100644 --- a/go/fory/types.go +++ b/go/fory/types.go @@ -381,6 +381,7 @@ const ( Float32SliceDispatchId Float64SliceDispatchId Float16SliceDispatchId + BFloat16SliceDispatchId BoolSliceDispatchId StringSliceDispatchId @@ -462,6 +463,9 @@ func GetDispatchId(t reflect.Type) DispatchId { if t.Elem().Name() == "Float16" && (t.Elem().PkgPath() == "github.com/apache/fory/go/fory/float16" || strings.HasSuffix(t.Elem().PkgPath(), "/float16")) { return Float16SliceDispatchId } + if t.Elem().Name() == "BFloat16" && (t.Elem().PkgPath() == "github.com/apache/fory/go/fory/bfloat16" || strings.HasSuffix(t.Elem().PkgPath(), "/bfloat16")) { + return BFloat16SliceDispatchId + } return Uint16SliceDispatchId case reflect.Uint32: return Uint32SliceDispatchId diff --git a/java/fory-core/src/main/java/org/apache/fory/builder/BaseObjectCodecBuilder.java b/java/fory-core/src/main/java/org/apache/fory/builder/BaseObjectCodecBuilder.java index 557e5635c7..e7aee5ee69 100644 --- a/java/fory-core/src/main/java/org/apache/fory/builder/BaseObjectCodecBuilder.java +++ b/java/fory-core/src/main/java/org/apache/fory/builder/BaseObjectCodecBuilder.java @@ -138,6 +138,7 @@ import org.apache.fory.serializer.collection.CollectionFlags; import org.apache.fory.serializer.collection.CollectionLikeSerializer; import org.apache.fory.serializer.collection.MapLikeSerializer; +import org.apache.fory.type.BFloat16; import org.apache.fory.type.Descriptor; import org.apache.fory.type.DispatchId; import org.apache.fory.type.Float16; @@ -551,6 +552,8 @@ private Expression serializePrimitiveField( return new Invoke(buffer, "writeFloat64", inputObject); case DispatchId.FLOAT16: return new Invoke(buffer, "writeInt16", new Invoke(inputObject, "toBits", SHORT_TYPE)); + case DispatchId.BFLOAT16: + return new Invoke(buffer, "writeInt16", new Invoke(inputObject, "toBits", SHORT_TYPE)); default: throw new IllegalStateException("Unsupported dispatchId: " + dispatchId); } @@ -714,7 +717,9 @@ protected int getNumericDescriptorDispatchId(Descriptor descriptor) { descriptor, d -> DispatchId.getDispatchId(typeResolver, d)); Class rawType = descriptor.getRawType(); Preconditions.checkArgument( - TypeUtils.unwrap(rawType).isPrimitive() || dispatchId == DispatchId.FLOAT16); + TypeUtils.unwrap(rawType).isPrimitive() + || dispatchId == DispatchId.FLOAT16 + || dispatchId == DispatchId.BFLOAT16); return dispatchId; } @@ -722,7 +727,7 @@ private boolean isPrimitiveLikeDescriptor(Descriptor descriptor, Class rawTyp if (isPrimitive(rawType) || isBoxed(rawType)) { return true; } - return rawType == Float16.class; + return rawType == Float16.class || rawType == BFloat16.class; } /** @@ -2149,6 +2154,9 @@ private Expression deserializePrimitiveField(Expression buffer, Descriptor descr case DispatchId.FLOAT16: return new StaticInvoke( Float16.class, "fromBits", TypeRef.of(Float16.class), readInt16(buffer)); + case DispatchId.BFLOAT16: + return new StaticInvoke( + BFloat16.class, "fromBits", TypeRef.of(BFloat16.class), readInt16(buffer)); default: throw new IllegalStateException("Unsupported dispatchId: " + dispatchId); } diff --git a/java/fory-core/src/main/java/org/apache/fory/builder/ObjectCodecBuilder.java b/java/fory-core/src/main/java/org/apache/fory/builder/ObjectCodecBuilder.java index 574ee61ea0..cca88079c4 100644 --- a/java/fory-core/src/main/java/org/apache/fory/builder/ObjectCodecBuilder.java +++ b/java/fory-core/src/main/java/org/apache/fory/builder/ObjectCodecBuilder.java @@ -29,8 +29,8 @@ import static org.apache.fory.type.TypeUtils.PRIMITIVE_INT_TYPE; import static org.apache.fory.type.TypeUtils.PRIMITIVE_LONG_TYPE; import static org.apache.fory.type.TypeUtils.PRIMITIVE_VOID_TYPE; +import static org.apache.fory.type.TypeUtils.SHORT_TYPE; import static org.apache.fory.type.TypeUtils.getRawType; -import static org.apache.fory.type.TypeUtils.getSizeOfPrimitiveType; import java.util.ArrayList; import java.util.Collection; @@ -60,10 +60,13 @@ import org.apache.fory.reflect.ObjectCreators; import org.apache.fory.reflect.TypeRef; import org.apache.fory.serializer.ObjectSerializer; +import org.apache.fory.type.BFloat16; import org.apache.fory.type.Descriptor; import org.apache.fory.type.DescriptorGrouper; import org.apache.fory.type.DispatchId; +import org.apache.fory.type.Float16; import org.apache.fory.type.TypeUtils; +import org.apache.fory.type.Types; import org.apache.fory.util.StringUtils; import org.apache.fory.util.function.SerializableSupplier; import org.apache.fory.util.record.RecordUtils; @@ -287,6 +290,13 @@ private List serializePrimitivesUnCompressed( } else if (dispatchId == DispatchId.INT16 || dispatchId == DispatchId.UINT16) { groupExpressions.add(unsafePutShort(base, getWriterPos(writerAddr, acc), fieldValue)); acc += 2; + } else if (dispatchId == DispatchId.FLOAT16 || dispatchId == DispatchId.BFLOAT16) { + groupExpressions.add( + unsafePutShort( + base, + getWriterPos(writerAddr, acc), + new Invoke(fieldValue, "toBits", SHORT_TYPE))); + acc += 2; } else if (dispatchId == DispatchId.INT32 || dispatchId == DispatchId.UINT32) { groupExpressions.add(unsafePutInt(base, getWriterPos(writerAddr, acc), fieldValue)); acc += 4; @@ -377,6 +387,13 @@ private List serializePrimitivesCompressed( } else if (dispatchId == DispatchId.INT16 || dispatchId == DispatchId.UINT16) { groupExpressions.add(unsafePutShort(base, getWriterPos(writerAddr, acc), fieldValue)); acc += 2; + } else if (dispatchId == DispatchId.FLOAT16 || dispatchId == DispatchId.BFLOAT16) { + groupExpressions.add( + unsafePutShort( + base, + getWriterPos(writerAddr, acc), + new Invoke(fieldValue, "toBits", SHORT_TYPE))); + acc += 2; } else if (dispatchId == DispatchId.FLOAT32) { groupExpressions.add(unsafePutFloat(base, getWriterPos(writerAddr, acc), fieldValue)); acc += 4; @@ -453,7 +470,14 @@ private void addIncWriterIndexExpr(ListExpression expressions, Expression buffer private int getTotalSizeOfPrimitives(List> primitiveGroups) { return primitiveGroups.stream() .flatMap(Collection::stream) - .mapToInt(d -> getSizeOfPrimitiveType(TypeUtils.unwrap(d.getRawType()))) + .mapToInt( + d -> { + Class rawType = d.getRawType(); + if (TypeUtils.isPrimitive(rawType) || TypeUtils.isBoxed(rawType)) { + return TypeUtils.getSizeOfPrimitiveType(TypeUtils.unwrap(rawType)); + } + return Types.getPrimitiveTypeSize(Types.getDescriptorTypeId(typeResolver, d)); + }) .sum(); } @@ -680,6 +704,22 @@ private List deserializeUnCompressedPrimitives( } else if (dispatchId == DispatchId.INT16 || dispatchId == DispatchId.UINT16) { fieldValue = unsafeGetShort(heapBuffer, getReaderAddress(readerAddr, acc)); acc += 2; + } else if (dispatchId == DispatchId.FLOAT16) { + fieldValue = + new StaticInvoke( + Float16.class, + "fromBits", + TypeRef.of(Float16.class), + unsafeGetShort(heapBuffer, getReaderAddress(readerAddr, acc))); + acc += 2; + } else if (dispatchId == DispatchId.BFLOAT16) { + fieldValue = + new StaticInvoke( + BFloat16.class, + "fromBits", + TypeRef.of(BFloat16.class), + unsafeGetShort(heapBuffer, getReaderAddress(readerAddr, acc))); + acc += 2; } else if (dispatchId == DispatchId.INT32 || dispatchId == DispatchId.UINT32) { fieldValue = unsafeGetInt(heapBuffer, getReaderAddress(readerAddr, acc)); acc += 4; @@ -745,6 +785,22 @@ private List deserializeCompressedPrimitives( } else if (dispatchId == DispatchId.INT16 || dispatchId == DispatchId.UINT16) { fieldValue = unsafeGetShort(heapBuffer, getReaderAddress(readerAddr, acc)); acc += 2; + } else if (dispatchId == DispatchId.FLOAT16) { + fieldValue = + new StaticInvoke( + Float16.class, + "fromBits", + TypeRef.of(Float16.class), + unsafeGetShort(heapBuffer, getReaderAddress(readerAddr, acc))); + acc += 2; + } else if (dispatchId == DispatchId.BFLOAT16) { + fieldValue = + new StaticInvoke( + BFloat16.class, + "fromBits", + TypeRef.of(BFloat16.class), + unsafeGetShort(heapBuffer, getReaderAddress(readerAddr, acc))); + acc += 2; } else if (dispatchId == DispatchId.FLOAT32) { fieldValue = unsafeGetFloat(heapBuffer, getReaderAddress(readerAddr, acc)); acc += 4; diff --git a/java/fory-core/src/main/java/org/apache/fory/collection/BFloat16List.java b/java/fory-core/src/main/java/org/apache/fory/collection/BFloat16List.java new file mode 100644 index 0000000000..1255cf6590 --- /dev/null +++ b/java/fory-core/src/main/java/org/apache/fory/collection/BFloat16List.java @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.fory.collection; + +import java.util.AbstractList; +import java.util.Arrays; +import java.util.Objects; +import java.util.RandomAccess; +import org.apache.fory.type.BFloat16; + +/** + * Dense {@link java.util.List} carrier for xlang {@code bfloat16_array} payloads. + * + *

The list stores packed 16-bit values in a primitive {@code short[]} so the runtime can + * serialize and deserialize {@code bfloat16_array} without per-element boxing overhead. + */ +public final class BFloat16List extends AbstractList implements RandomAccess { + private static final int DEFAULT_CAPACITY = 10; + + private short[] array; + private int size; + + public BFloat16List() { + this(DEFAULT_CAPACITY); + } + + public BFloat16List(int initialCapacity) { + if (initialCapacity < 0) { + throw new IllegalArgumentException("Illegal capacity: " + initialCapacity); + } + this.array = new short[initialCapacity]; + this.size = 0; + } + + public BFloat16List(short[] array) { + this.array = array; + this.size = array.length; + } + + @Override + public BFloat16 get(int index) { + checkIndex(index); + return BFloat16.fromBits(array[index]); + } + + @Override + public int size() { + return size; + } + + @Override + public BFloat16 set(int index, BFloat16 element) { + checkIndex(index); + Objects.requireNonNull(element, "element"); + short prev = array[index]; + array[index] = element.toBits(); + return BFloat16.fromBits(prev); + } + + public void set(int index, short bits) { + checkIndex(index); + array[index] = bits; + } + + public void set(int index, float value) { + checkIndex(index); + array[index] = BFloat16.toBits(value); + } + + @Override + public void add(int index, BFloat16 element) { + checkPositionIndex(index); + ensureCapacity(size + 1); + System.arraycopy(array, index, array, index + 1, size - index); + array[index] = element.toBits(); + size++; + modCount++; + } + + @Override + public boolean add(BFloat16 element) { + Objects.requireNonNull(element, "element"); + ensureCapacity(size + 1); + array[size++] = element.toBits(); + modCount++; + return true; + } + + public boolean add(short bits) { + ensureCapacity(size + 1); + array[size++] = bits; + modCount++; + return true; + } + + public boolean add(float value) { + ensureCapacity(size + 1); + array[size++] = BFloat16.toBits(value); + modCount++; + return true; + } + + public float getFloat(int index) { + checkIndex(index); + return BFloat16.toFloat(array[index]); + } + + public short getShort(int index) { + checkIndex(index); + return array[index]; + } + + public boolean hasArray() { + return array != null; + } + + public short[] getArray() { + return array; + } + + public short[] copyArray() { + return Arrays.copyOf(array, size); + } + + private void ensureCapacity(int minCapacity) { + if (array.length >= minCapacity) { + return; + } + int newCapacity = array.length + (array.length >> 1) + 1; + if (newCapacity < minCapacity) { + newCapacity = minCapacity; + } + array = Arrays.copyOf(array, newCapacity); + } + + private void checkIndex(int index) { + if (index < 0 || index >= size) { + throw new IndexOutOfBoundsException("Index: " + index + ", Size: " + size); + } + } + + private void checkPositionIndex(int index) { + if (index < 0 || index > size) { + throw new IndexOutOfBoundsException("Index: " + index + ", Size: " + size); + } + } +} diff --git a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java index d62696dcb4..e0a0e4fdcc 100644 --- a/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java +++ b/java/fory-core/src/main/java/org/apache/fory/context/ReadContext.java @@ -33,6 +33,7 @@ import org.apache.fory.serializer.PrimitiveSerializers.LongSerializer; import org.apache.fory.serializer.Serializer; import org.apache.fory.serializer.StringSerializer; +import org.apache.fory.type.BFloat16; import org.apache.fory.type.Float16; import org.apache.fory.type.Generics; import org.apache.fory.type.Types; @@ -191,6 +192,16 @@ public Float16 readFloat16() { return Float16.fromBits(buffer.readInt16()); } + /** + * Reads a 16-bit bfloat16 value encoded through its raw IEEE 754 bfloat16 bits. + * + *

If a caller needs multiple primitive reads, fetch the buffer once through {@link + * #getBuffer()} and invoke {@link MemoryBuffer#readInt16()} directly for better performance. + */ + public BFloat16 readBFloat16() { + return BFloat16.fromBits(buffer.readInt16()); + } + /** * Reads a 32-bit floating-point value directly from the current buffer. * diff --git a/java/fory-core/src/main/java/org/apache/fory/context/WriteContext.java b/java/fory-core/src/main/java/org/apache/fory/context/WriteContext.java index 696294f58e..f4b74dabed 100644 --- a/java/fory-core/src/main/java/org/apache/fory/context/WriteContext.java +++ b/java/fory-core/src/main/java/org/apache/fory/context/WriteContext.java @@ -35,6 +35,7 @@ import org.apache.fory.serializer.Serializer; import org.apache.fory.serializer.StringSerializer; import org.apache.fory.serializer.UnknownClass.UnknownStruct; +import org.apache.fory.type.BFloat16; import org.apache.fory.type.Float16; import org.apache.fory.type.Generics; import org.apache.fory.type.Types; @@ -188,6 +189,17 @@ public void writeFloat16(Float16 value) { buffer.writeInt16(value.toBits()); } + /** + * Writes a 16-bit bfloat16 value encoded through its raw IEEE 754 bfloat16 bits. + * + *

If a caller needs multiple primitive writes, fetch the buffer once through {@link + * #getBuffer()} and invoke {@link MemoryBuffer#writeInt16(short)} directly for better + * performance. + */ + public void writeBFloat16(BFloat16 value) { + buffer.writeInt16(value.toBits()); + } + /** * Writes a 32-bit floating-point value directly to the current buffer. * diff --git a/java/fory-core/src/main/java/org/apache/fory/meta/FieldTypes.java b/java/fory-core/src/main/java/org/apache/fory/meta/FieldTypes.java index eed70a9286..c44d76c6b3 100644 --- a/java/fory-core/src/main/java/org/apache/fory/meta/FieldTypes.java +++ b/java/fory-core/src/main/java/org/apache/fory/meta/FieldTypes.java @@ -32,6 +32,7 @@ import java.lang.reflect.Field; import java.util.Objects; import org.apache.fory.annotation.ForyField; +import org.apache.fory.collection.BFloat16List; import org.apache.fory.collection.BoolList; import org.apache.fory.collection.Float16List; import org.apache.fory.collection.Float32List; @@ -55,7 +56,6 @@ import org.apache.fory.resolver.XtypeResolver; import org.apache.fory.serializer.UnknownClass; import org.apache.fory.type.Descriptor; -import org.apache.fory.type.Float16; import org.apache.fory.type.GenericType; import org.apache.fory.type.TypeUtils; import org.apache.fory.type.Types; @@ -647,7 +647,9 @@ private static Class getPrimitiveArrayClass(int typeId) { case Types.FLOAT32_ARRAY: return float[].class; case Types.FLOAT16_ARRAY: - return Float16[].class; + return Float16List.class; + case Types.BFLOAT16_ARRAY: + return BFloat16List.class; case Types.FLOAT64_ARRAY: return double[].class; default: @@ -679,6 +681,8 @@ private static Class getPrimitiveListClass(int typeId) { return Float32List.class; case Types.FLOAT16_ARRAY: return Float16List.class; + case Types.BFLOAT16_ARRAY: + return BFloat16List.class; case Types.FLOAT64_ARRAY: return Float64List.class; default: diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java b/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java index af4ec7f73d..e53910260a 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/ClassResolver.java @@ -72,6 +72,7 @@ import org.apache.fory.annotation.Internal; import org.apache.fory.builder.CodecUtils; import org.apache.fory.builder.JITContext; +import org.apache.fory.collection.BFloat16List; import org.apache.fory.collection.BoolList; import org.apache.fory.collection.Float16List; import org.apache.fory.collection.Float32List; @@ -151,6 +152,7 @@ import org.apache.fory.serializer.scala.SingletonObjectSerializer; import org.apache.fory.serializer.shim.ProtobufDispatcher; import org.apache.fory.serializer.shim.ShimDispatcher; +import org.apache.fory.type.BFloat16; import org.apache.fory.type.Descriptor; import org.apache.fory.type.Float16; import org.apache.fory.type.GenericType; @@ -271,6 +273,7 @@ public void initialize() { registerInternal(Long.class, Types.INT64); registerInternal(Double.class, Types.FLOAT64); registerInternal(Float16.class, Types.FLOAT16); + registerInternal(BFloat16.class, Types.BFLOAT16); registerInternal(String.class, Types.STRING); registerInternal(Uint8.class, Types.UINT8); registerInternal(Uint16.class, Types.UINT16); @@ -285,6 +288,7 @@ public void initialize() { registerInternal(long[].class, PRIMITIVE_LONG_ARRAY_ID); registerInternal(double[].class, PRIMITIVE_DOUBLE_ARRAY_ID); registerInternal(Float16[].class); + registerInternal(BFloat16[].class); registerInternal(String[].class, STRING_ARRAY_ID); registerInternal(Object[].class, OBJECT_ARRAY_ID); registerInternal(BoolList.class, Types.BOOL_ARRAY); @@ -299,6 +303,7 @@ public void initialize() { registerInternal(Float32List.class, Types.FLOAT32_ARRAY); registerInternal(Float64List.class, Types.FLOAT64_ARRAY); registerInternal(Float16List.class, Types.FLOAT16_ARRAY); + registerInternal(BFloat16List.class, Types.BFLOAT16_ARRAY); registerInternal(ArrayList.class, ARRAYLIST_ID); registerInternal(HashMap.class, HASHMAP_ID); registerInternal(HashSet.class, HASHSET_ID); diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java b/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java index ef42e563ab..0117e879e1 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/TypeResolver.java @@ -19,7 +19,6 @@ package org.apache.fory.resolver; -import static org.apache.fory.type.TypeUtils.getSizeOfPrimitiveType; import static org.apache.fory.type.Types.INVALID_USER_TYPE_ID; import java.lang.reflect.AnnotatedType; @@ -401,6 +400,11 @@ protected DescriptorGrouper configureDescriptorGrouper(DescriptorGrouper descrip return descriptorGrouper; } + public boolean usesPrimitiveFieldOrdering(Descriptor descriptor) { + Class rawType = descriptor.getRawType(); + return TypeUtils.isPrimitive(rawType) || TypeUtils.isBoxed(rawType); + } + public abstract boolean isMonomorphic(Descriptor descriptor); public abstract boolean isMonomorphic(Class clz); @@ -1355,6 +1359,7 @@ private DescriptorGrouper buildDescriptorGrouper( Function descriptorUpdator) { return configureDescriptorGrouper( DescriptorGrouper.createDescriptorGrouper( + this::usesPrimitiveFieldOrdering, this::isBuildIn, descriptors, descriptorsGroupedOrdered, @@ -1446,16 +1451,14 @@ protected static String getFieldSortKey(Descriptor descriptor) { */ public Comparator getPrimitiveComparator() { return (d1, d2) -> { - Class t1 = TypeUtils.unwrap(d1.getRawType()); - Class t2 = TypeUtils.unwrap(d2.getRawType()); int typeId1 = Types.getDescriptorTypeId(this, d1); int typeId2 = Types.getDescriptorTypeId(this, d2); boolean t1Compress = Types.isCompressedType(typeId1); boolean t2Compress = Types.isCompressedType(typeId2); if ((t1Compress && t2Compress) || (!t1Compress && !t2Compress)) { - int c = getSizeOfPrimitiveType(t2) - getSizeOfPrimitiveType(t1); + int c = getPrimitiveFieldSize(d2) - getPrimitiveFieldSize(d1); if (c == 0) { - c = typeId2 - typeId1; + c = isCrossLanguage() ? typeId1 - typeId2 : typeId2 - typeId1; // noinspection Duplicates if (c == 0) { c = getFieldSortKey(d1).compareTo(getFieldSortKey(d2)); @@ -1481,6 +1484,14 @@ public Comparator getPrimitiveComparator() { }; } + private int getPrimitiveFieldSize(Descriptor descriptor) { + Class rawType = descriptor.getRawType(); + if (TypeUtils.isPrimitive(rawType) || TypeUtils.isBoxed(rawType)) { + return TypeUtils.getSizeOfPrimitiveType(TypeUtils.unwrap(rawType)); + } + return Types.getPrimitiveTypeSize(Types.getDescriptorTypeId(this, descriptor)); + } + /** * Get the nullable flag for a field, respecting xlang mode. * diff --git a/java/fory-core/src/main/java/org/apache/fory/resolver/XtypeResolver.java b/java/fory-core/src/main/java/org/apache/fory/resolver/XtypeResolver.java index 06737b2c97..d1864e91fa 100644 --- a/java/fory-core/src/main/java/org/apache/fory/resolver/XtypeResolver.java +++ b/java/fory-core/src/main/java/org/apache/fory/resolver/XtypeResolver.java @@ -51,6 +51,7 @@ import org.apache.fory.annotation.ForyField; import org.apache.fory.annotation.Internal; import org.apache.fory.builder.JITContext; +import org.apache.fory.collection.BFloat16List; import org.apache.fory.collection.BoolList; import org.apache.fory.collection.Float16List; import org.apache.fory.collection.Float32List; @@ -107,6 +108,7 @@ import org.apache.fory.serializer.collection.MapSerializer; import org.apache.fory.serializer.collection.MapSerializers.XlangMapSerializer; import org.apache.fory.serializer.collection.PrimitiveListSerializers; +import org.apache.fory.type.BFloat16; import org.apache.fory.type.Descriptor; import org.apache.fory.type.DescriptorGrouper; import org.apache.fory.type.Float16; @@ -521,9 +523,15 @@ private int determineTypeIdForClass(Class type) { if (type == Float16.class) { return Types.FLOAT16; } - if (type == Float16[].class || type == Float16List.class) { + if (type == BFloat16.class) { + return Types.BFLOAT16; + } + if (type == Float16List.class) { return Types.FLOAT16_ARRAY; } + if (type == BFloat16List.class) { + return Types.BFLOAT16_ARRAY; + } if (type.isArray()) { Class componentType = type.getComponentType(); if (componentType.isPrimitive()) { @@ -925,6 +933,10 @@ private void registerDefaultTypes() { Types.FLOAT16, Float16.class, new PrimitiveSerializers.Float16Serializer(config, Float16.class)); + registerType( + Types.BFLOAT16, + BFloat16.class, + new PrimitiveSerializers.BFloat16Serializer(config, BFloat16.class)); registerType( Types.FLOAT64, Double.class, @@ -992,9 +1004,6 @@ private void registerDefaultTypes() { Types.FLOAT32_ARRAY, float[].class, new ArraySerializers.FloatArraySerializer(this)); registerType( Types.FLOAT64_ARRAY, double[].class, new ArraySerializers.DoubleArraySerializer(this)); - registerType( - Types.FLOAT16_ARRAY, Float16[].class, new ArraySerializers.Float16ArraySerializer(this)); - // Primitive lists registerType( Types.BOOL_ARRAY, BoolList.class, new PrimitiveListSerializers.BoolListSerializer(this)); @@ -1032,6 +1041,10 @@ private void registerDefaultTypes() { Types.FLOAT16_ARRAY, Float16List.class, new PrimitiveListSerializers.Float16ListSerializer(this)); + registerType( + Types.BFLOAT16_ARRAY, + BFloat16List.class, + new PrimitiveListSerializers.BFloat16ListSerializer(this)); // Collections registerType(Types.LIST, ArrayList.class, new ArrayListSerializer(this)); @@ -1279,6 +1292,15 @@ protected DescriptorGrouper configureDescriptorGrouper(DescriptorGrouper descrip Comparator.comparing(TypeResolver::getFieldSortKey)); } + @Override + public boolean usesPrimitiveFieldOrdering(Descriptor descriptor) { + if (super.usesPrimitiveFieldOrdering(descriptor)) { + return true; + } + int typeId = Types.getDescriptorTypeId(this, descriptor); + return typeId == Types.FLOAT16 || typeId == Types.BFLOAT16; + } + private byte getInternalTypeId(Descriptor descriptor) { Class cls = descriptor.getRawType(); if (cls.isArray() && cls.getComponentType().isPrimitive()) { diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/AbstractObjectSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/AbstractObjectSerializer.java index 048bdb38d4..68d3ce48a8 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/AbstractObjectSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/AbstractObjectSerializer.java @@ -45,6 +45,7 @@ import org.apache.fory.resolver.TypeInfo; import org.apache.fory.resolver.TypeResolver; import org.apache.fory.serializer.FieldGroups.SerializationFieldInfo; +import org.apache.fory.type.BFloat16; import org.apache.fory.type.Descriptor; import org.apache.fory.type.DescriptorGrouper; import org.apache.fory.type.DispatchId; @@ -466,6 +467,9 @@ static void writeNotPrimitiveFieldValue( case DispatchId.FLOAT16: buffer.writeInt16(((Float16) fieldValue).toBits()); return; + case DispatchId.BFLOAT16: + buffer.writeInt16(((BFloat16) fieldValue).toBits()); + return; default: writeField( writeContext, typeResolver, refWriter, fieldInfo, RefMode.NONE, buffer, fieldValue); @@ -726,6 +730,8 @@ private static Object readNotNullBuildInFieldValue( return buffer.readFloat64(); case DispatchId.FLOAT16: return Float16.fromBits(buffer.readInt16()); + case DispatchId.BFLOAT16: + return BFloat16.fromBits(buffer.readInt16()); case DispatchId.STRING: return readContext.readString(); default: @@ -949,6 +955,9 @@ private static void readNotPrimitiveFieldValue( case DispatchId.FLOAT16: fieldAccessor.putObject(targetObject, Float16.fromBits(buffer.readInt16())); return; + case DispatchId.BFLOAT16: + fieldAccessor.putObject(targetObject, BFloat16.fromBits(buffer.readInt16())); + return; case DispatchId.STRING: fieldAccessor.putObject(targetObject, readContext.readString()); return; @@ -1109,6 +1118,7 @@ private static void copySetNotPrimitiveField( case DispatchId.FLOAT32: case DispatchId.FLOAT64: case DispatchId.FLOAT16: + case DispatchId.BFLOAT16: case DispatchId.STRING: Platform.putObject(newObj, fieldOffset, Platform.getObject(originObj, fieldOffset)); break; @@ -1175,6 +1185,7 @@ private Object copyNotPrimitiveField( case DispatchId.TAGGED_UINT64: case DispatchId.FLOAT64: case DispatchId.FLOAT16: + case DispatchId.BFLOAT16: case DispatchId.STRING: return Platform.getObject(targetObject, fieldOffset); default: diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java index 9350471376..31816c388e 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java @@ -35,6 +35,7 @@ import org.apache.fory.resolver.TypeResolver; import org.apache.fory.serializer.collection.CollectionFlags; import org.apache.fory.serializer.collection.ForyArrayAsListSerializer; +import org.apache.fory.type.BFloat16; import org.apache.fory.type.Float16; import org.apache.fory.type.GenericType; import org.apache.fory.type.TypeUtils; @@ -944,6 +945,69 @@ public Float16[] read(ReadContext readContext) { } } + public static final class BFloat16ArraySerializer extends PrimitiveArraySerializer { + public BFloat16ArraySerializer(TypeResolver typeResolver) { + super(typeResolver, BFloat16[].class); + } + + @Override + public void write(WriteContext writeContext, BFloat16[] value) { + MemoryBuffer buffer = writeContext.getBuffer(); + int length = value.length; + for (int i = 0; i < length; i++) { + if (value[i] == null) { + throw new IllegalArgumentException( + "BFloat16[] doesn't support null elements at index " + i); + } + } + writeNonNull(buffer, value, length); + } + + private void writeNonNull(MemoryBuffer buffer, BFloat16[] value, int length) { + int size = length * 2; + buffer.writeVarUint32Small7(size); + + if (Platform.IS_LITTLE_ENDIAN) { + int writerIndex = buffer.writerIndex(); + buffer.ensure(writerIndex + size); + for (int i = 0; i < length; i++) { + buffer._unsafePutInt16(writerIndex + i * 2, value[i].toBits()); + } + buffer._unsafeWriterIndex(writerIndex + size); + } else { + for (int i = 0; i < length; i++) { + buffer.writeInt16(value[i].toBits()); + } + } + } + + @Override + public BFloat16[] copy(CopyContext copyContext, BFloat16[] originArray) { + return Arrays.copyOf(originArray, originArray.length); + } + + @Override + public BFloat16[] read(ReadContext readContext) { + MemoryBuffer buffer = readContext.getBuffer(); + int size = buffer.readVarUint32Small7(); + int numElements = size / 2; + BFloat16[] values = new BFloat16[numElements]; + if (Platform.IS_LITTLE_ENDIAN) { + int readerIndex = buffer.readerIndex(); + buffer.checkReadableBytes(size); + for (int i = 0; i < numElements; i++) { + values[i] = BFloat16.fromBits(buffer._unsafeGetInt16(readerIndex + i * 2)); + } + buffer._increaseReaderIndexUnsafe(size); + } else { + for (int i = 0; i < numElements; i++) { + values[i] = BFloat16.fromBits(buffer.readInt16()); + } + } + return values; + } + } + public static final class StringArraySerializer extends Serializer { private final Config config; private final ForyArrayAsListSerializer collectionSerializer; @@ -1072,6 +1136,7 @@ public static void registerDefaultSerializers(TypeResolver resolver) { resolver.registerInternalSerializer( Double[].class, new ObjectArraySerializer<>(resolver, Double[].class)); resolver.registerInternalSerializer(Float16[].class, new Float16ArraySerializer(resolver)); + resolver.registerInternalSerializer(BFloat16[].class, new BFloat16ArraySerializer(resolver)); resolver.registerInternalSerializer(boolean[].class, new BooleanArraySerializer(resolver)); resolver.registerInternalSerializer( Boolean[].class, new ObjectArraySerializer<>(resolver, Boolean[].class)); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/FieldGroups.java b/java/fory-core/src/main/java/org/apache/fory/serializer/FieldGroups.java index 13e0d6c880..2eb96e10e7 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/FieldGroups.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/FieldGroups.java @@ -85,6 +85,7 @@ static DescriptorGrouper buildDescriptorGrouper( boolean descriptorsGroupedOrdered, Function descriptorUpdator) { return DescriptorGrouper.createDescriptorGrouper( + typeResolver::usesPrimitiveFieldOrdering, typeResolver::isBuildIn, descriptors, descriptorsGroupedOrdered, @@ -102,7 +103,8 @@ public static FieldGroups buildFieldInfos(TypeResolver typeResolver, DescriptorG Collection buildIn = grouper.getBuildInDescriptors(); List regularBuildIn = new ArrayList<>(buildIn.size()); for (Descriptor d : buildIn) { - if (DispatchId.getDispatchId(typeResolver, d) == DispatchId.FLOAT16) { + int dispatchId = DispatchId.getDispatchId(typeResolver, d); + if (dispatchId == DispatchId.FLOAT16 || dispatchId == DispatchId.BFLOAT16) { if (d.isNullable()) { boxed.add(d); } else { diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/FieldSkipper.java b/java/fory-core/src/main/java/org/apache/fory/serializer/FieldSkipper.java index a659f24cdc..a4daed090f 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/FieldSkipper.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/FieldSkipper.java @@ -79,6 +79,7 @@ static void skipField( case DispatchId.UINT16: case DispatchId.EXT_UINT16: case DispatchId.FLOAT16: + case DispatchId.BFLOAT16: buffer.increaseReaderIndex(2); break; case DispatchId.INT32: diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/PrimitiveSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/PrimitiveSerializers.java index 8537456e2d..b3454aac80 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/PrimitiveSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/PrimitiveSerializers.java @@ -30,6 +30,7 @@ import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.memory.Platform; import org.apache.fory.resolver.TypeResolver; +import org.apache.fory.type.BFloat16; import org.apache.fory.type.Float16; import org.apache.fory.util.Preconditions; @@ -325,6 +326,23 @@ public Float16 read(ReadContext readContext) { } } + public static final class BFloat16Serializer extends ImmutableSerializer + implements Shareable { + public BFloat16Serializer(Config config, Class cls) { + super(config, (Class) cls, false); + } + + @Override + public void write(WriteContext writeContext, BFloat16 value) { + writeContext.getBuffer().writeInt16(value.toBits()); + } + + @Override + public BFloat16 read(ReadContext readContext) { + return BFloat16.fromBits(readContext.getBuffer().readInt16()); + } + } + public static void registerDefaultSerializers(TypeResolver resolver) { // primitive types will be boxed. Config config = resolver.getConfig(); @@ -349,5 +367,7 @@ public static void registerDefaultSerializers(TypeResolver resolver) { resolver.registerInternalSerializer(Double.class, new DoubleSerializer(config, Double.class)); resolver.registerInternalSerializer( Float16.class, new Float16Serializer(config, Float16.class)); + resolver.registerInternalSerializer( + BFloat16.class, new BFloat16Serializer(config, BFloat16.class)); } } diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/PrimitiveListSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/PrimitiveListSerializers.java index a0143717c7..c4ca51c65b 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/PrimitiveListSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/PrimitiveListSerializers.java @@ -20,6 +20,7 @@ package org.apache.fory.serializer.collection; import java.util.Collection; +import org.apache.fory.collection.BFloat16List; import org.apache.fory.collection.BoolList; import org.apache.fory.collection.Float16List; import org.apache.fory.collection.Float32List; @@ -665,6 +666,49 @@ public Float16List copy(CopyContext copyContext, Float16List value) { } } + public static final class BFloat16ListSerializer extends PrimitiveListSerializer { + public BFloat16ListSerializer(TypeResolver typeResolver) { + super(typeResolver, BFloat16List.class); + } + + @Override + public void write(WriteContext writeContext, BFloat16List value) { + MemoryBuffer buffer = writeContext.getBuffer(); + int size = value.size(); + int byteSize = size * 2; + buffer.writeVarUint32Small7(byteSize); + short[] array = value.getArray(); + if (Platform.IS_LITTLE_ENDIAN) { + buffer.writePrimitiveArray(array, Platform.SHORT_ARRAY_OFFSET, byteSize); + } else { + for (int i = 0; i < size; i++) { + buffer.writeInt16(array[i]); + } + } + } + + @Override + public BFloat16List read(ReadContext readContext) { + MemoryBuffer buffer = readContext.getBuffer(); + int byteSize = buffer.readVarUint32Small7(); + int size = byteSize / 2; + short[] array = new short[size]; + if (Platform.IS_LITTLE_ENDIAN) { + buffer.readToUnsafe(array, Platform.SHORT_ARRAY_OFFSET, byteSize); + } else { + for (int i = 0; i < size; i++) { + array[i] = buffer.readInt16(); + } + } + return new BFloat16List(array); + } + + @Override + public BFloat16List copy(CopyContext copyContext, BFloat16List value) { + return new BFloat16List(value.copyArray()); + } + } + public static void registerDefaultSerializers(TypeResolver resolver) { resolver.registerInternalSerializer(BoolList.class, new BoolListSerializer(resolver)); resolver.registerInternalSerializer(Int8List.class, new Int8ListSerializer(resolver)); @@ -678,5 +722,6 @@ public static void registerDefaultSerializers(TypeResolver resolver) { resolver.registerInternalSerializer(Float32List.class, new Float32ListSerializer(resolver)); resolver.registerInternalSerializer(Float64List.class, new Float64ListSerializer(resolver)); resolver.registerInternalSerializer(Float16List.class, new Float16ListSerializer(resolver)); + resolver.registerInternalSerializer(BFloat16List.class, new BFloat16ListSerializer(resolver)); } } diff --git a/java/fory-core/src/main/java/org/apache/fory/type/BFloat16.java b/java/fory-core/src/main/java/org/apache/fory/type/BFloat16.java new file mode 100644 index 0000000000..fb0d905ffb --- /dev/null +++ b/java/fory-core/src/main/java/org/apache/fory/type/BFloat16.java @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.fory.type; + +import java.io.Serializable; + +/** + * Public carrier for xlang {@code bfloat16} values. + * + *

This type stores the exact 16-bit wire representation and converts to and from {@code float} + * using round-to-nearest-even semantics. Use {@link #fromBits(short)} and {@link #toBits()} when + * you need bit-preserving behavior, or {@link #valueOf(float)} and {@link #toFloat()} when you want + * numeric conversion. + */ +public final class BFloat16 extends Number implements Comparable, Serializable { + private static final long serialVersionUID = 1L; + + private static final int SIGN_MASK = 0x8000; + private static final int EXP_MASK = 0x7F80; + private static final int MANT_MASK = 0x007F; + + private static final short BITS_NAN = (short) 0x7FC0; + private static final short BITS_POS_INF = (short) 0x7F80; + private static final short BITS_NEG_INF = (short) 0xFF80; + private static final short BITS_NEG_ZERO = (short) 0x8000; + private static final short BITS_MAX = (short) 0x7F7F; + private static final short BITS_ONE = (short) 0x3F80; + private static final short BITS_MIN_NORMAL = (short) 0x0080; + private static final short BITS_MIN_VALUE = (short) 0x0001; + + public static final BFloat16 NaN = new BFloat16(BITS_NAN); + + public static final BFloat16 POSITIVE_INFINITY = new BFloat16(BITS_POS_INF); + + public static final BFloat16 NEGATIVE_INFINITY = new BFloat16(BITS_NEG_INF); + + public static final BFloat16 ZERO = new BFloat16((short) 0); + + public static final BFloat16 NEGATIVE_ZERO = new BFloat16(BITS_NEG_ZERO); + + public static final BFloat16 ONE = new BFloat16(BITS_ONE); + + public static final BFloat16 MAX_VALUE = new BFloat16(BITS_MAX); + + public static final BFloat16 MIN_NORMAL = new BFloat16(BITS_MIN_NORMAL); + + public static final BFloat16 MIN_VALUE = new BFloat16(BITS_MIN_VALUE); + + public static final int SIZE_BITS = 16; + + public static final int SIZE_BYTES = 2; + + private final short bits; + + private BFloat16(short bits) { + this.bits = bits; + } + + public static BFloat16 fromBits(short bits) { + return new BFloat16(bits); + } + + public static BFloat16 valueOf(float value) { + return new BFloat16(floatToBFloat16Bits(value)); + } + + public static short toBits(float value) { + return floatToBFloat16Bits(value); + } + + public short toBits() { + return bits; + } + + public static float toFloat(short bits) { + return bfloat16BitsToFloat(bits); + } + + public float toFloat() { + return floatValue(); + } + + private static short floatToBFloat16Bits(float f32) { + int bits32 = Float.floatToRawIntBits(f32); + if ((bits32 & 0x7F800000) == 0x7F800000 && (bits32 & 0x007FFFFF) != 0) { + return BITS_NAN; + } + int lsb = (bits32 >>> 16) & 1; + int roundingBias = 0x7FFF + lsb; + return (short) ((bits32 + roundingBias) >>> 16); + } + + private static float bfloat16BitsToFloat(short bits16) { + return Float.intBitsToFloat((bits16 & 0xFFFF) << 16); + } + + public boolean isNaN() { + return (bits & EXP_MASK) == EXP_MASK && (bits & MANT_MASK) != 0; + } + + public boolean isInfinite() { + return (bits & EXP_MASK) == EXP_MASK && (bits & MANT_MASK) == 0; + } + + public boolean isFinite() { + return (bits & EXP_MASK) != EXP_MASK; + } + + public boolean isZero() { + return (bits & (EXP_MASK | MANT_MASK)) == 0; + } + + public boolean isNormal() { + int exp = bits & EXP_MASK; + return exp != 0 && exp != EXP_MASK; + } + + public boolean isSubnormal() { + return (bits & EXP_MASK) == 0 && (bits & MANT_MASK) != 0; + } + + public boolean signbit() { + return (bits & SIGN_MASK) != 0; + } + + public BFloat16 add(BFloat16 other) { + return valueOf(floatValue() + other.floatValue()); + } + + public BFloat16 subtract(BFloat16 other) { + return valueOf(floatValue() - other.floatValue()); + } + + public BFloat16 multiply(BFloat16 other) { + return valueOf(floatValue() * other.floatValue()); + } + + public BFloat16 divide(BFloat16 other) { + return valueOf(floatValue() / other.floatValue()); + } + + public BFloat16 negate() { + return fromBits((short) (bits ^ SIGN_MASK)); + } + + public BFloat16 abs() { + return fromBits((short) (bits & ~SIGN_MASK)); + } + + @Override + public float floatValue() { + return bfloat16BitsToFloat(bits); + } + + @Override + public double doubleValue() { + return floatValue(); + } + + @Override + public int intValue() { + return (int) floatValue(); + } + + @Override + public long longValue() { + return (long) floatValue(); + } + + @Override + public byte byteValue() { + return (byte) floatValue(); + } + + @Override + public short shortValue() { + return (short) floatValue(); + } + + public boolean isNumericEqual(BFloat16 other) { + if (isNaN() || other.isNaN()) { + return false; + } + if (isZero() && other.isZero()) { + return true; + } + return bits == other.bits; + } + + public boolean equalsValue(BFloat16 other) { + return isNumericEqual(other); + } + + public boolean lessThan(BFloat16 other) { + if (isNaN() || other.isNaN()) { + return false; + } + return floatValue() < other.floatValue(); + } + + public boolean lessThanOrEqual(BFloat16 other) { + if (isNaN() || other.isNaN()) { + return false; + } + return floatValue() <= other.floatValue(); + } + + public boolean greaterThan(BFloat16 other) { + if (isNaN() || other.isNaN()) { + return false; + } + return floatValue() > other.floatValue(); + } + + public boolean greaterThanOrEqual(BFloat16 other) { + if (isNaN() || other.isNaN()) { + return false; + } + return floatValue() >= other.floatValue(); + } + + public static int compare(BFloat16 a, BFloat16 b) { + if (a.bits == b.bits) { + return 0; + } + int compare = Float.compare(a.floatValue(), b.floatValue()); + if (compare != 0) { + return compare; + } + return Integer.compare(a.bits & 0xFFFF, b.bits & 0xFFFF); + } + + public static BFloat16 parse(String s) { + return valueOf(Float.parseFloat(s)); + } + + @Override + public int compareTo(BFloat16 other) { + return compare(this, other); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof BFloat16)) { + return false; + } + BFloat16 other = (BFloat16) obj; + return bits == other.bits; + } + + @Override + public int hashCode() { + return Short.hashCode(bits); + } + + @Override + public String toString() { + return Float.toString(floatValue()); + } +} diff --git a/java/fory-core/src/main/java/org/apache/fory/type/DescriptorGrouper.java b/java/fory-core/src/main/java/org/apache/fory/type/DescriptorGrouper.java index c499c91f8f..a4a1341bef 100644 --- a/java/fory-core/src/main/java/org/apache/fory/type/DescriptorGrouper.java +++ b/java/fory-core/src/main/java/org/apache/fory/type/DescriptorGrouper.java @@ -56,6 +56,7 @@ public class DescriptorGrouper { private final Collection descriptors; + private final Predicate usesPrimitiveFieldOrdering; private final Predicate isBuildIn; private final Function descriptorUpdater; private final boolean descriptorsGroupedOrdered; @@ -81,12 +82,14 @@ public class DescriptorGrouper { * @param comparator comparator for non-primitive fields. */ private DescriptorGrouper( + Predicate usesPrimitiveFieldOrdering, Predicate isBuildIn, Collection descriptors, boolean descriptorsGroupedOrdered, Function descriptorUpdater, Comparator primitiveComparator, Comparator comparator) { + this.usesPrimitiveFieldOrdering = usesPrimitiveFieldOrdering; this.descriptors = descriptors; this.isBuildIn = isBuildIn; this.descriptorUpdater = descriptorUpdater; @@ -116,13 +119,7 @@ public DescriptorGrouper sort() { return this; } for (Descriptor descriptor : descriptors) { - if (TypeUtils.isPrimitive(descriptor.getRawType())) { - if (!descriptor.isNullable()) { - primitiveDescriptors.add(descriptorUpdater.apply(descriptor)); - } else { - boxedDescriptors.add(descriptorUpdater.apply(descriptor)); - } - } else if (TypeUtils.isBoxed(descriptor.getRawType())) { + if (usesPrimitiveFieldOrdering.test(descriptor)) { if (!descriptor.isNullable()) { primitiveDescriptors.add(descriptorUpdater.apply(descriptor)); } else { @@ -203,7 +200,28 @@ public static DescriptorGrouper createDescriptorGrouper( Function descriptorUpdator, Comparator primitiveComparator, Comparator comparator) { + return createDescriptorGrouper( + descriptor -> + TypeUtils.isPrimitive(descriptor.getRawType()) + || TypeUtils.isBoxed(descriptor.getRawType()), + isBuildIn, + descriptors, + descriptorsGroupedOrdered, + descriptorUpdator, + primitiveComparator, + comparator); + } + + public static DescriptorGrouper createDescriptorGrouper( + Predicate usesPrimitiveFieldOrdering, + Predicate isBuildIn, + Collection descriptors, + boolean descriptorsGroupedOrdered, + Function descriptorUpdator, + Comparator primitiveComparator, + Comparator comparator) { return new DescriptorGrouper( + usesPrimitiveFieldOrdering, isBuildIn, descriptors, descriptorsGroupedOrdered, diff --git a/java/fory-core/src/main/java/org/apache/fory/type/DispatchId.java b/java/fory-core/src/main/java/org/apache/fory/type/DispatchId.java index 6ed9fa05d0..284004632b 100644 --- a/java/fory-core/src/main/java/org/apache/fory/type/DispatchId.java +++ b/java/fory-core/src/main/java/org/apache/fory/type/DispatchId.java @@ -50,15 +50,16 @@ public class DispatchId { public static final int VAR_UINT32 = 15; public static final int UINT64 = 16; public static final int FLOAT16 = 17; - public static final int VAR_UINT64 = 18; - public static final int TAGGED_UINT64 = 19; - public static final int EXT_UINT8 = 20; - public static final int EXT_UINT16 = 21; - public static final int EXT_UINT32 = 22; - public static final int EXT_VAR_UINT32 = 23; - public static final int EXT_UINT64 = 24; - public static final int EXT_VAR_UINT64 = 25; - public static final int STRING = 26; + public static final int BFLOAT16 = 18; + public static final int VAR_UINT64 = 19; + public static final int TAGGED_UINT64 = 20; + public static final int EXT_UINT8 = 21; + public static final int EXT_UINT16 = 22; + public static final int EXT_UINT32 = 23; + public static final int EXT_VAR_UINT32 = 24; + public static final int EXT_UINT64 = 25; + public static final int EXT_VAR_UINT64 = 26; + public static final int STRING = 27; public static int getDispatchId(TypeResolver resolver, Descriptor d) { int typeId = Types.getDescriptorTypeId(resolver, d); @@ -66,6 +67,9 @@ public static int getDispatchId(TypeResolver resolver, Descriptor d) { if (rawType == Float16.class) { return FLOAT16; } + if (rawType == BFloat16.class) { + return BFLOAT16; + } if (resolver.isCrossLanguage()) { return adjustUnsignedDispatchId(typeId, rawType, xlangTypeIdToDispatchId(typeId)); } else { @@ -107,6 +111,8 @@ private static int xlangTypeIdToDispatchId(int typeId) { return TAGGED_UINT64; case Types.FLOAT16: return FLOAT16; + case Types.BFLOAT16: + return BFLOAT16; case Types.FLOAT32: return FLOAT32; case Types.FLOAT64: diff --git a/java/fory-core/src/main/java/org/apache/fory/type/TypeUtils.java b/java/fory-core/src/main/java/org/apache/fory/type/TypeUtils.java index aa941e8824..7366fd71dd 100644 --- a/java/fory-core/src/main/java/org/apache/fory/type/TypeUtils.java +++ b/java/fory-core/src/main/java/org/apache/fory/type/TypeUtils.java @@ -67,6 +67,7 @@ import java.util.WeakHashMap; import java.util.stream.Collectors; import org.apache.fory.annotation.Ref; +import org.apache.fory.collection.BFloat16List; import org.apache.fory.collection.BoolList; import org.apache.fory.collection.Float16List; import org.apache.fory.collection.Float32List; @@ -712,6 +713,7 @@ public static boolean isPrimitiveListClass(Class cls) { || cls == Uint32List.class || cls == Uint64List.class || cls == Float16List.class + || cls == BFloat16List.class || cls == Float32List.class || cls == Float64List.class; } diff --git a/java/fory-core/src/main/java/org/apache/fory/type/Types.java b/java/fory-core/src/main/java/org/apache/fory/type/Types.java index 322fbed48a..afe8d7eadf 100644 --- a/java/fory-core/src/main/java/org/apache/fory/type/Types.java +++ b/java/fory-core/src/main/java/org/apache/fory/type/Types.java @@ -84,7 +84,7 @@ public class Types { /** float16: a 16-bit floating point number. */ public static final int FLOAT16 = 17; - /** bfloat16: a 16-bit brain floating point number. */ + /** BFloat16: a 16-bit brain floating point number. */ public static final int BFLOAT16 = 18; /** float32: a 32-bit floating point number. */ @@ -211,7 +211,7 @@ public class Types { /** One dimensional float16 array. */ public static final int FLOAT16_ARRAY = 53; - /** One dimensional bfloat16 array. */ + /** One dimensional BFloat16 array. */ public static final int BFLOAT16_ARRAY = 54; /** One dimensional float32 array. */ @@ -349,6 +349,37 @@ public static int getPrimitiveArrayTypeId(int typeId) { } } + public static int getPrimitiveTypeSize(int typeId) { + switch (typeId) { + case BOOL: + case INT8: + case UINT8: + case FLOAT8: + return 1; + case INT16: + case UINT16: + case FLOAT16: + case BFLOAT16: + return 2; + case INT32: + case VARINT32: + case UINT32: + case VAR_UINT32: + case FLOAT32: + return 4; + case INT64: + case VARINT64: + case TAGGED_INT64: + case UINT64: + case VAR_UINT64: + case TAGGED_UINT64: + case FLOAT64: + return 8; + default: + throw new IllegalArgumentException("Type id " + typeId + " must be primitive"); + } + } + public static int getDescriptorTypeId(TypeResolver resolver, Field field) { Annotation annotation = Descriptor.getAnnotation(field); Class rawType = field.getType(); @@ -448,8 +479,9 @@ public static Class getClassForTypeId(int typeId) { case TAGGED_UINT64: return Long.class; case FLOAT8: - case BFLOAT16: return Float.class; + case BFLOAT16: + return BFloat16.class; case FLOAT16: return Float16.class; case FLOAT32: diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/Float16SerializerTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/Float16SerializerTest.java index f0d6842f9a..201fd223e4 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/Float16SerializerTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/Float16SerializerTest.java @@ -229,7 +229,7 @@ public void testFloat16XlangTopLevelSerialization() { Float16[] array = new Float16[] {Float16.ONE, Float16.valueOf(-0.5f), Float16.MIN_VALUE}; bytes = fory.serialize(array); - Float16[] arrayResult = (Float16[]) fory.deserialize(bytes); + Float16[] arrayResult = fory.deserialize(bytes, Float16[].class); assertEquals(arrayResult.length, array.length); for (int i = 0; i < array.length; i++) { assertEquals(arrayResult[i].toBits(), array[i].toBits(), "Index " + i + " should match"); @@ -237,17 +237,8 @@ public void testFloat16XlangTopLevelSerialization() { Float16List list = buildFloat16List(); bytes = fory.serialize(list); - Object listResult = fory.deserialize(bytes); - if (listResult instanceof Float16List) { - assertFloat16ListBits(list, (Float16List) listResult); - } else { - Float16[] arrayResultFromList = (Float16[]) listResult; - assertEquals(arrayResultFromList.length, list.size()); - for (int i = 0; i < arrayResultFromList.length; i++) { - assertEquals( - arrayResultFromList[i].toBits(), list.getShort(i), "Index " + i + " should match"); - } - } + Float16List listResult = (Float16List) fory.deserialize(bytes); + assertFloat16ListBits(list, listResult); } @Test diff --git a/java/fory-core/src/test/java/org/apache/fory/type/DescriptorGrouperTest.java b/java/fory-core/src/test/java/org/apache/fory/type/DescriptorGrouperTest.java index c769dad2b5..440abe02c9 100644 --- a/java/fory-core/src/test/java/org/apache/fory/type/DescriptorGrouperTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/type/DescriptorGrouperTest.java @@ -37,6 +37,7 @@ import java.util.stream.Collectors; import org.apache.fory.Fory; import org.apache.fory.ForyTestBase; +import org.apache.fory.config.Language; import org.apache.fory.reflect.ReflectionUtils; import org.apache.fory.reflect.TypeRef; import org.apache.fory.resolver.ClassResolver; @@ -163,6 +164,25 @@ public void testPrimitiveCompressedComparator() { assertEquals(classes, expected); } + @Test + public void testXlangPrimitiveComparatorUsesAscendingTypeIdTieBreaker() { + Fory fory = Fory.builder().withLanguage(Language.XLANG).build(); + List descriptors = new ArrayList<>(); + descriptors.add( + createDescriptor(TypeRef.of(Short.class), "shortValue", -1, "TestClass", false)); + descriptors.add( + createDescriptor(TypeRef.of(Float16.class), "float16Value", -1, "TestClass", false)); + descriptors.add( + createDescriptor(TypeRef.of(BFloat16.class), "bfloat16Value", -1, "TestClass", false)); + + Collections.shuffle(descriptors, new Random(11)); + descriptors.sort(fory.getTypeResolver().getPrimitiveComparator()); + + List> classes = + descriptors.stream().map(Descriptor::getRawType).collect(Collectors.toList()); + assertEquals(classes, Arrays.asList(Short.class, Float16.class, BFloat16.class)); + } + @Test public void testGrouper() { Fory fory = Fory.builder().build(); diff --git a/java/fory-core/src/test/java/org/apache/fory/type/TypesTest.java b/java/fory-core/src/test/java/org/apache/fory/type/TypesTest.java index b59b35ebf3..d1553f9c2c 100644 --- a/java/fory-core/src/test/java/org/apache/fory/type/TypesTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/type/TypesTest.java @@ -28,6 +28,6 @@ public class TypesTest { public void testGetClassForFloatTypeIds() { assertSame(Types.getClassForTypeId(Types.FLOAT8), Float.class); assertSame(Types.getClassForTypeId(Types.FLOAT16), Float16.class); - assertSame(Types.getClassForTypeId(Types.BFLOAT16), Float.class); + assertSame(Types.getClassForTypeId(Types.BFLOAT16), BFloat16.class); } } diff --git a/java/fory-core/src/test/java/org/apache/fory/xlang/CPPXlangTest.java b/java/fory-core/src/test/java/org/apache/fory/xlang/CPPXlangTest.java index 884a0af119..34eb466622 100644 --- a/java/fory-core/src/test/java/org/apache/fory/xlang/CPPXlangTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/xlang/CPPXlangTest.java @@ -237,6 +237,17 @@ public void testStructVersionCheck(boolean enableCodegen) throws java.io.IOExcep super.testStructVersionCheck(enableCodegen); } + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testReducedPrecisionFloatStruct(boolean enableCodegen) throws java.io.IOException { + super.testReducedPrecisionFloatStruct(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testReducedPrecisionFloatStructCompatibleFieldSkip(boolean enableCodegen) + throws java.io.IOException { + super.testReducedPrecisionFloatStructCompatibleFieldSkip(enableCodegen); + } + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") public void testPolymorphicList(boolean enableCodegen) throws java.io.IOException { super.testPolymorphicList(enableCodegen); diff --git a/java/fory-core/src/test/java/org/apache/fory/xlang/CSharpXlangTest.java b/java/fory-core/src/test/java/org/apache/fory/xlang/CSharpXlangTest.java index 0618680a0f..42f1a4f5ab 100644 --- a/java/fory-core/src/test/java/org/apache/fory/xlang/CSharpXlangTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/xlang/CSharpXlangTest.java @@ -231,6 +231,17 @@ public void testStructVersionCheck(boolean enableCodegen) throws java.io.IOExcep super.testStructVersionCheck(enableCodegen); } + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testReducedPrecisionFloatStruct(boolean enableCodegen) throws java.io.IOException { + super.testReducedPrecisionFloatStruct(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testReducedPrecisionFloatStructCompatibleFieldSkip(boolean enableCodegen) + throws java.io.IOException { + super.testReducedPrecisionFloatStructCompatibleFieldSkip(enableCodegen); + } + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") public void testPolymorphicList(boolean enableCodegen) throws java.io.IOException { super.testPolymorphicList(enableCodegen); diff --git a/java/fory-core/src/test/java/org/apache/fory/xlang/GoXlangTest.java b/java/fory-core/src/test/java/org/apache/fory/xlang/GoXlangTest.java index 96ed9aeaf2..45cfb9f876 100644 --- a/java/fory-core/src/test/java/org/apache/fory/xlang/GoXlangTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/xlang/GoXlangTest.java @@ -245,6 +245,17 @@ public void testSchemaEvolutionCompatible(boolean enableCodegen) throws java.io. super.testSchemaEvolutionCompatible(enableCodegen); } + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testReducedPrecisionFloatStruct(boolean enableCodegen) throws java.io.IOException { + super.testReducedPrecisionFloatStruct(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testReducedPrecisionFloatStructCompatibleFieldSkip(boolean enableCodegen) + throws java.io.IOException { + super.testReducedPrecisionFloatStructCompatibleFieldSkip(enableCodegen); + } + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") public void testOneEnumFieldSchemaConsistent(boolean enableCodegen) throws java.io.IOException { super.testOneEnumFieldSchemaConsistent(enableCodegen); diff --git a/java/fory-core/src/test/java/org/apache/fory/xlang/PythonXlangTest.java b/java/fory-core/src/test/java/org/apache/fory/xlang/PythonXlangTest.java index ebfedf37a4..d501642baa 100644 --- a/java/fory-core/src/test/java/org/apache/fory/xlang/PythonXlangTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/xlang/PythonXlangTest.java @@ -198,6 +198,19 @@ public void testStructVersionCheck(boolean enableCodegen) throws IOException { super.testStructVersionCheck(enableCodegen); } + @Override + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testReducedPrecisionFloatStruct(boolean enableCodegen) throws IOException { + super.testReducedPrecisionFloatStruct(enableCodegen); + } + + @Override + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testReducedPrecisionFloatStructCompatibleFieldSkip(boolean enableCodegen) + throws IOException { + super.testReducedPrecisionFloatStructCompatibleFieldSkip(enableCodegen); + } + @Override @Test(groups = "xlang", dataProvider = "enableCodegenParallel") public void testPolymorphicList(boolean enableCodegen) throws IOException { diff --git a/java/fory-core/src/test/java/org/apache/fory/xlang/RustXlangTest.java b/java/fory-core/src/test/java/org/apache/fory/xlang/RustXlangTest.java index 9e2e4f9e2b..1e5974dc56 100644 --- a/java/fory-core/src/test/java/org/apache/fory/xlang/RustXlangTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/xlang/RustXlangTest.java @@ -199,6 +199,17 @@ public void testStructVersionCheck(boolean enableCodegen) throws java.io.IOExcep super.testStructVersionCheck(enableCodegen); } + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testReducedPrecisionFloatStruct(boolean enableCodegen) throws java.io.IOException { + super.testReducedPrecisionFloatStruct(enableCodegen); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testReducedPrecisionFloatStructCompatibleFieldSkip(boolean enableCodegen) + throws java.io.IOException { + super.testReducedPrecisionFloatStructCompatibleFieldSkip(enableCodegen); + } + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") public void testPolymorphicList(boolean enableCodegen) throws java.io.IOException { super.testPolymorphicList(enableCodegen); diff --git a/java/fory-core/src/test/java/org/apache/fory/xlang/XlangTestBase.java b/java/fory-core/src/test/java/org/apache/fory/xlang/XlangTestBase.java index 459b7885e6..d01ba21d67 100644 --- a/java/fory-core/src/test/java/org/apache/fory/xlang/XlangTestBase.java +++ b/java/fory-core/src/test/java/org/apache/fory/xlang/XlangTestBase.java @@ -44,6 +44,8 @@ import org.apache.fory.annotation.Uint32Type; import org.apache.fory.annotation.Uint64Type; import org.apache.fory.annotation.Uint8Type; +import org.apache.fory.collection.BFloat16List; +import org.apache.fory.collection.Float16List; import org.apache.fory.config.CompatibleMode; import org.apache.fory.config.Language; import org.apache.fory.config.LongEncoding; @@ -56,6 +58,8 @@ import org.apache.fory.resolver.TypeResolver; import org.apache.fory.serializer.Serializer; import org.apache.fory.test.TestUtils; +import org.apache.fory.type.BFloat16; +import org.apache.fory.type.Float16; import org.apache.fory.type.Types; import org.apache.fory.util.MurmurHash3; import org.testng.Assert; @@ -1520,6 +1524,43 @@ static class TwoStringFieldStruct { String f2; } + @Data + static class ReducedPrecisionFloatStruct { + Float16 float16Value; + BFloat16 bfloat16Value; + Float16List float16Array; + BFloat16List bfloat16Array; + } + + protected static ReducedPrecisionFloatStruct newReducedPrecisionFloatStruct() { + ReducedPrecisionFloatStruct value = new ReducedPrecisionFloatStruct(); + value.float16Value = Float16.fromBits((short) 0x3E00); + value.bfloat16Value = BFloat16.fromBits((short) 0x3FC0); + value.float16Array = + new Float16List(new short[] {(short) 0x0000, (short) 0x3C00, (short) 0xBC00}); + value.bfloat16Array = + new BFloat16List(new short[] {(short) 0x0000, (short) 0x3F80, (short) 0xBF80}); + return value; + } + + protected static void assertReducedPrecisionFloatStruct(ReducedPrecisionFloatStruct value) { + Assert.assertNotNull(value); + Assert.assertNotNull(value.float16Value); + Assert.assertNotNull(value.bfloat16Value); + Assert.assertEquals(value.float16Value.toBits(), (short) 0x3E00); + Assert.assertEquals(value.bfloat16Value.toBits(), (short) 0x3FC0); + Assert.assertNotNull(value.float16Array); + Assert.assertNotNull(value.bfloat16Array); + Assert.assertEquals(value.float16Array.size(), 3); + Assert.assertEquals(value.bfloat16Array.size(), 3); + Assert.assertEquals(value.float16Array.getShort(0), (short) 0x0000); + Assert.assertEquals(value.float16Array.getShort(1), (short) 0x3C00); + Assert.assertEquals(value.float16Array.getShort(2), (short) 0xBC00); + Assert.assertEquals(value.bfloat16Array.getShort(0), (short) 0x0000); + Assert.assertEquals(value.bfloat16Array.getShort(1), (short) 0x3F80); + Assert.assertEquals(value.bfloat16Array.getShort(2), (short) 0xBF80); + } + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") public void testOneStringFieldSchemaConsistent(boolean enableCodegen) throws java.io.IOException { String caseName = "test_one_string_field_schema"; @@ -1668,6 +1709,65 @@ public void testSchemaEvolutionCompatible(boolean enableCodegen) throws java.io. "Expected null or empty string but got: " + result2.f2); } + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testReducedPrecisionFloatStruct(boolean enableCodegen) throws java.io.IOException { + String caseName = "test_reduced_precision_float_struct"; + Fory fory = + Fory.builder() + .withLanguage(Language.XLANG) + .withCompatibleMode(CompatibleMode.SCHEMA_CONSISTENT) + .withCodegen(enableCodegen) + .build(); + fory.register(ReducedPrecisionFloatStruct.class, 213); + + ReducedPrecisionFloatStruct obj = newReducedPrecisionFloatStruct(); + assertReducedPrecisionFloatStruct((ReducedPrecisionFloatStruct) xserDe(fory, obj)); + + MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(128); + fory.serialize(buffer, obj); + + ExecutionContext ctx = prepareExecution(caseName, buffer.getBytes(0, buffer.writerIndex())); + runPeer(ctx); + + MemoryBuffer buffer2 = readBuffer(ctx.dataFile()); + ReducedPrecisionFloatStruct result = (ReducedPrecisionFloatStruct) fory.deserialize(buffer2); + assertReducedPrecisionFloatStruct(result); + } + + @Test(groups = "xlang", dataProvider = "enableCodegenParallel") + public void testReducedPrecisionFloatStructCompatibleFieldSkip(boolean enableCodegen) + throws java.io.IOException { + String caseName = "test_reduced_precision_float_struct_compatible_skip"; + Fory fory = + Fory.builder() + .withLanguage(Language.XLANG) + .withCompatibleMode(CompatibleMode.COMPATIBLE) + .withCodegen(enableCodegen) + .withMetaCompressor(new NoOpMetaCompressor()) + .build(); + fory.register(ReducedPrecisionFloatStruct.class, 213); + + Fory foryEmpty = + Fory.builder() + .withLanguage(Language.XLANG) + .withCompatibleMode(CompatibleMode.COMPATIBLE) + .withCodegen(enableCodegen) + .withMetaCompressor(new NoOpMetaCompressor()) + .build(); + foryEmpty.register(EmptyStruct.class, 213); + + ReducedPrecisionFloatStruct obj = newReducedPrecisionFloatStruct(); + MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(128); + fory.serialize(buffer, obj); + + ExecutionContext ctx = prepareExecution(caseName, buffer.getBytes(0, buffer.writerIndex())); + runPeer(ctx); + + MemoryBuffer buffer2 = readBuffer(ctx.dataFile()); + EmptyStruct result = (EmptyStruct) foryEmpty.deserialize(buffer2); + Assert.assertNotNull(result); + } + // Enum field structs for testing enum TestEnum { VALUE_A, diff --git a/javascript/packages/core/lib/meta/TypeMeta.ts b/javascript/packages/core/lib/meta/TypeMeta.ts index 78110c8ca7..b04d38aa58 100644 --- a/javascript/packages/core/lib/meta/TypeMeta.ts +++ b/javascript/packages/core/lib/meta/TypeMeta.ts @@ -852,7 +852,7 @@ export class TypeMeta { // Sort functions const primitiveComparator = (a: T, b: T) => { - // Sort by type_id descending, then by name ascending + // Sort by type_id ascending, then by name ascending const t1Compress = TypeId.isCompressedType(a.typeId); const t2Compress = TypeId.isCompressedType(b.typeId); @@ -863,7 +863,7 @@ export class TypeMeta { let c = sizeb - sizea; if (c === 0) { - c = b.typeId - a.typeId; + c = a.typeId - b.typeId; // noinspection Duplicates if (c == 0) { return nameSorter(a, b); diff --git a/python/README.md b/python/README.md index ee5215f0ec..8fe89dd4f5 100644 --- a/python/README.md +++ b/python/README.md @@ -25,6 +25,7 @@ - **Polymorphism support** for customized types with automatic type dispatching - **Schema evolution** support for backward/forward compatibility when using dataclasses in cross-language mode - **Out-of-band buffer support** for zero-copy serialization of large data structures like NumPy arrays and Pandas DataFrames, compatible with pickle protocol 5 +- **Cython-only reduced-precision carriers** for `float16`, `float16array`, `bfloat16`, and `bfloat16array` via the compiled `pyfory.serialization` extension; there is no pure-Python fallback ### ⚡ **Blazing Fast Performance** diff --git a/python/pyfory/__init__.py b/python/pyfory/__init__.py index 2b614f8296..7d4ada2001 100644 --- a/python/pyfory/__init__.py +++ b/python/pyfory/__init__.py @@ -31,7 +31,7 @@ if ENABLE_FORY_CYTHON_SERIALIZATION: from pyfory.serialization import Fory, TypeInfo # noqa: F401,F811 -from pyfory.serialization import Buffer # noqa: F401 # pylint: disable=unused-import +from pyfory.serialization import Buffer, bfloat16, bfloat16array, float16, float16array # noqa: F401 # pylint: disable=unused-import from pyfory.serializer import ( # noqa: F401 # pylint: disable=unused-import Serializer, @@ -50,6 +50,8 @@ Uint64Serializer, VarUint64Serializer, TaggedUint64Serializer, + Float16Serializer, + Float16ArraySerializer, Float32Serializer, Float64Serializer, StringSerializer, @@ -68,6 +70,8 @@ MethodSerializer, ReduceSerializer, StatefulSerializer, + BFloat16Serializer, + BFloat16ArraySerializer, ) from pyfory.struct import DataClassSerializer from pyfory.field import dataclass, field # noqa: F401 # pylint: disable=unused-import @@ -127,6 +131,14 @@ "ThreadSafeFory", "TypeInfo", "Buffer", + "float16", + "float16array", + "bfloat16", + "bfloat16array", + "Float16Serializer", + "Float16ArraySerializer", + "BFloat16Serializer", + "BFloat16ArraySerializer", "DeserializationPolicy", # Field metadata "field", diff --git a/python/pyfory/_serializer.py b/python/pyfory/_serializer.py index 4d0beadf56..3fc55e6ec0 100644 --- a/python/pyfory/_serializer.py +++ b/python/pyfory/_serializer.py @@ -19,11 +19,14 @@ import logging import platform import time +import array from abc import ABC from pyfory._fory import NOT_NULL_INT64_FLAG from pyfory.resolver import NOT_NULL_VALUE_FLAG, NULL_FLAG +from pyfory.serialization import bfloat16, bfloat16array, float16, float16array from pyfory.types import is_primitive_type +from pyfory.utils import is_little_endian try: import numpy as np @@ -239,6 +242,72 @@ def read(self, read_context): return read_context.read_float64() +def _coerce_float16_bits(value): + if isinstance(value, float16): + return value.to_bits() + return float16(value).to_bits() + + +def _coerce_bfloat16_bits(value): + if isinstance(value, bfloat16): + return value.to_bits() + return bfloat16(value).to_bits() + + +class Float16Serializer(Serializer): + def write(self, write_context, value): + write_context.write_uint16(_coerce_float16_bits(value)) + + def read(self, read_context): + return float16.from_bits(read_context.read_uint16()) + + +class Float16ArraySerializer(Serializer): + def write(self, write_context, value): + safe = float16array() if value is None else value + buffer = safe.to_buffer() + write_context.write_var_uint32(len(buffer) * 2) + if is_little_endian: + write_context.buffer.write_buffer(buffer) + else: + swapped = array.array("H", buffer) + swapped.byteswap() + write_context.buffer.write_buffer(swapped) + + def read(self, read_context): + payload_size = read_context.read_var_uint32() + if payload_size & 1: + raise ValueError("float16 array payload size mismatch") + return float16array.from_buffer(read_context.read_bytes(payload_size)) + + +class BFloat16Serializer(Serializer): + def write(self, write_context, value): + write_context.write_uint16(_coerce_bfloat16_bits(value)) + + def read(self, read_context): + return bfloat16.from_bits(read_context.read_uint16()) + + +class BFloat16ArraySerializer(Serializer): + def write(self, write_context, value): + safe = bfloat16array() if value is None else value + buffer = safe.to_buffer() + write_context.write_var_uint32(len(buffer) * 2) + if is_little_endian: + write_context.buffer.write_buffer(buffer) + else: + swapped = array.array("H", buffer) + swapped.byteswap() + write_context.buffer.write_buffer(swapped) + + def read(self, read_context): + payload_size = read_context.read_var_uint32() + if payload_size & 1: + raise ValueError("bfloat16 array payload size mismatch") + return bfloat16array.from_buffer(read_context.read_bytes(payload_size)) + + class StringSerializer(Serializer): def __init__(self, type_resolver, type_): super().__init__(type_resolver, type_) diff --git a/python/pyfory/bfloat16.pxi b/python/pyfory/bfloat16.pxi new file mode 100644 index 0000000000..3eb5e7d359 --- /dev/null +++ b/python/pyfory/bfloat16.pxi @@ -0,0 +1,304 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Cython-only public carrier types for xlang bfloat16 values and one-dimensional arrays. + +The compiled ``pyfory.serialization`` extension provides these names directly. There is no +pure-Python fallback module for the bfloat16 API surface. +""" + +import array as _py_array +from cpython.buffer cimport Py_buffer, PyBuffer_Release, PyObject_GetBuffer +from libc.string cimport memcpy +from pyfory.utils import is_little_endian + + +cdef inline uint16_t _bfloat16_float_to_bits(float value): + cdef uint32_t bits32 + cdef uint32_t lsb + memcpy(&bits32, &value, sizeof(float)) + if (bits32 & 0x7F800000) == 0x7F800000 and (bits32 & 0x007FFFFF) != 0: + return 0x7FC0 + lsb = (bits32 >> 16) & 1 + return (((bits32 + 0x7FFF + lsb) >> 16) & 0xFFFF) + + +cdef inline float _bfloat16_bits_to_float(uint16_t bits): + cdef uint32_t bits32 = (bits) << 16 + cdef float value + memcpy(&value, &bits32, sizeof(float)) + return value + + +cdef inline bint _bfloat16_try_coerce_float(object value, float* out_value): + try: + out_value[0] = float(value) + return True + except (TypeError, ValueError, OverflowError): + return False + + +@cython.final +cdef class bfloat16: + """Exact IEEE 754 bfloat16 value carrier with reduced-precision arithmetic operators.""" + + cdef uint16_t bits + + def __cinit__(self, value=0.0): + if isinstance(value, bfloat16): + self.bits = (value).bits + else: + self.bits = _bfloat16_float_to_bits(value) + + @staticmethod + def from_bits(bits): + cdef bfloat16 value = bfloat16.__new__(bfloat16) + value.bits = bits + return value + + @staticmethod + def from_float(value): + return bfloat16(value) + + def to_bits(self): + return self.bits + + def to_float(self): + return _bfloat16_bits_to_float(self.bits) + + def __float__(self): + return self.to_float() + + def __int__(self): + return int(self.to_float()) + + def __repr__(self): + return f"bfloat16(bits=0x{self.bits:04x}, value={self.to_float()!r})" + + def __hash__(self): + return hash(self.bits) + + def __eq__(self, other): + if isinstance(other, bfloat16): + return self.bits == (other).bits + return self.to_float() == float(other) + + def __lt__(self, other): + cdef float rhs + if not _bfloat16_try_coerce_float(other, &rhs): + return NotImplemented + return _bfloat16_bits_to_float(self.bits) < rhs + + def __le__(self, other): + cdef float rhs + if not _bfloat16_try_coerce_float(other, &rhs): + return NotImplemented + return _bfloat16_bits_to_float(self.bits) <= rhs + + def __gt__(self, other): + cdef float rhs + if not _bfloat16_try_coerce_float(other, &rhs): + return NotImplemented + return _bfloat16_bits_to_float(self.bits) > rhs + + def __ge__(self, other): + cdef float rhs + if not _bfloat16_try_coerce_float(other, &rhs): + return NotImplemented + return _bfloat16_bits_to_float(self.bits) >= rhs + + def __add__(self, other): + cdef float rhs + if not _bfloat16_try_coerce_float(other, &rhs): + return NotImplemented + return bfloat16(_bfloat16_bits_to_float(self.bits) + rhs) + + def __radd__(self, other): + return self.__add__(other) + + def __sub__(self, other): + cdef float rhs + if not _bfloat16_try_coerce_float(other, &rhs): + return NotImplemented + return bfloat16(_bfloat16_bits_to_float(self.bits) - rhs) + + def __rsub__(self, other): + cdef float lhs + if not _bfloat16_try_coerce_float(other, &lhs): + return NotImplemented + return bfloat16(lhs - _bfloat16_bits_to_float(self.bits)) + + def __mul__(self, other): + cdef float rhs + if not _bfloat16_try_coerce_float(other, &rhs): + return NotImplemented + return bfloat16(_bfloat16_bits_to_float(self.bits) * rhs) + + def __rmul__(self, other): + return self.__mul__(other) + + def __truediv__(self, other): + cdef float rhs + if not _bfloat16_try_coerce_float(other, &rhs): + return NotImplemented + return bfloat16(_bfloat16_bits_to_float(self.bits) / rhs) + + def __rtruediv__(self, other): + cdef float lhs + if not _bfloat16_try_coerce_float(other, &lhs): + return NotImplemented + return bfloat16(lhs / _bfloat16_bits_to_float(self.bits)) + + def __neg__(self): + return bfloat16.from_bits(self.bits ^ 0x8000) + + def __pos__(self): + return bfloat16.from_bits(self.bits) + + def __abs__(self): + return bfloat16.from_bits(self.bits & 0x7FFF) + + +cdef inline uint16_t _coerce_bfloat16_bits(value): + if isinstance(value, bfloat16): + return (value).bits + return _bfloat16_float_to_bits(value) + + +@cython.final +cdef class bfloat16array: + """Packed one-dimensional carrier for xlang ``bfloat16_array`` payloads.""" + + cdef object _data + + def __cinit__(self, values=()): + self._data = _py_array.array("H") + self.extend(values) + + @staticmethod + def from_values(values): + return bfloat16array(values) + + @staticmethod + def from_buffer(buffer): + cdef bfloat16array value = bfloat16array.__new__(bfloat16array) + cdef object data = buffer + value._data = _py_array.array("H") + if isinstance(data, memoryview): + if data.itemsize == 2 and data.format in ("H", "@H"): + value._data = _py_array.array("H", data) + return value + data = data.tobytes() + if isinstance(data, (bytes, bytearray)): + if len(data) & 1: + raise ValueError("bfloat16 bits payload size mismatch") + value._data.frombytes(data) + if not is_little_endian and len(data) > 0: + value._data.byteswap() + return value + if isinstance(data, _py_array.array) and data.typecode == "H": + value._data = _py_array.array("H", data) + return value + raise TypeError( + f"bfloat16array.from_buffer expects a buffer-compatible value, got {type(buffer)!r}" + ) + + def append(self, value): + self._data.append(_coerce_bfloat16_bits(value)) + + def extend(self, values): + for value in values: + self.append(value) + + def to_buffer(self): + return _py_array.array("H", self._data) + + def tolist(self): + return [bfloat16.from_bits(bits) for bits in self._data] + + def __len__(self): + return len(self._data) + + def __iter__(self): + for bits in self._data: + yield bfloat16.from_bits(bits) + + def __getitem__(self, index): + if isinstance(index, slice): + return bfloat16array.from_buffer(self._data[index]) + return bfloat16.from_bits(self._data[index]) + + def __getbuffer__(self, Py_buffer *buffer, int flags): + if PyObject_GetBuffer(self._data, buffer, flags) < 0: + raise BufferError("bfloat16array failed to export buffer") + + def __releasebuffer__(self, Py_buffer *buffer): + PyBuffer_Release(buffer) + + def __repr__(self): + return f"bfloat16array({self.tolist()!r})" + + def __eq__(self, other): + if isinstance(other, bfloat16array): + return self._data == (other)._data + try: + return self.tolist() == list(other) + except TypeError: + return False + + +@cython.final +cdef class BFloat16Serializer(Serializer): + """Serializer for xlang ``bfloat16`` scalar values.""" + + cpdef inline write(self, WriteContext write_context, value): + write_context.write_uint16(_coerce_bfloat16_bits(value)) + + cpdef inline read(self, ReadContext read_context): + return bfloat16.from_bits(read_context.read_uint16()) + + +@cython.final +cdef class BFloat16ArraySerializer(Serializer): + """Serializer for xlang ``bfloat16_array`` payloads.""" + + cpdef write(self, WriteContext write_context, value): + cdef bfloat16array safe + cdef object swapped + if value is None: + safe = bfloat16array() + else: + safe = value + write_context.write_var_uint32(len(safe) * 2) + if is_little_endian: + write_context.write_buffer(safe._data) + else: + swapped = _py_array.array("H", safe._data) + swapped.byteswap() + write_context.write_buffer(swapped) + + cpdef read(self, ReadContext read_context): + cdef uint32_t payload_size = read_context.read_var_uint32() + cdef bfloat16array values + if payload_size & 1: + raise ValueError("bfloat16 array payload size mismatch") + values = bfloat16array() + values._data.frombytes(read_context.read_bytes(payload_size)) + if not is_little_endian and payload_size > 0: + values._data.byteswap() + return values diff --git a/python/pyfory/float16.pxi b/python/pyfory/float16.pxi new file mode 100644 index 0000000000..fe2f76f91a --- /dev/null +++ b/python/pyfory/float16.pxi @@ -0,0 +1,366 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import array as _py_array +from cpython.buffer cimport Py_buffer, PyBuffer_Release, PyObject_GetBuffer +from libc.string cimport memcpy +from pyfory.utils import is_little_endian + + +cdef inline uint16_t _float16_float_to_bits(float value): + cdef uint32_t bits32 + cdef uint32_t sign + cdef uint32_t exp + cdef uint32_t mant + cdef int32_t new_exp + cdef uint32_t out_exp + cdef uint32_t out_mant + cdef uint32_t full_mant + cdef int32_t shift + cdef int32_t net_shift + cdef uint32_t round_bit + cdef uint32_t sticky + memcpy(&bits32, &value, sizeof(float)) + sign = (bits32 >> 16) & 0x8000 + exp = (bits32 >> 23) & 0xFF + mant = bits32 & 0x7FFFFF + + if exp == 0xFF: + out_exp = 0x1F + if mant != 0: + out_mant = 0x200 | ((mant >> 13) & 0x1FF) + if out_mant == 0x200: + out_mant = 0x201 + else: + out_mant = 0 + elif exp == 0: + out_exp = 0 + out_mant = 0 + else: + new_exp = exp - 127 + 15 + if new_exp >= 31: + out_exp = 0x1F + out_mant = 0 + elif new_exp <= 0: + full_mant = mant | 0x800000 + shift = 1 - new_exp + net_shift = 13 + shift + if net_shift >= 24: + out_exp = 0 + out_mant = 0 + else: + out_exp = 0 + round_bit = (full_mant >> (net_shift - 1)) & 1 + sticky = full_mant & ((1 << (net_shift - 1)) - 1) + out_mant = full_mant >> net_shift + if round_bit == 1 and (sticky != 0 or (out_mant & 1) == 1): + out_mant += 1 + else: + out_exp = new_exp + out_mant = mant >> 13 + round_bit = (mant >> 12) & 1 + sticky = mant & 0xFFF + if round_bit == 1 and (sticky != 0 or (out_mant & 1) == 1): + out_mant += 1 + if out_mant > 0x3FF: + out_mant = 0 + out_exp += 1 + if out_exp >= 31: + out_exp = 0x1F + + return (sign | (out_exp << 10) | out_mant) + + +cdef inline float _float16_bits_to_float(uint16_t bits): + cdef uint32_t sign = (((bits >> 15) & 0x1)) << 31 + cdef uint32_t exp = (bits >> 10) & 0x1F + cdef uint32_t mant = bits & 0x3FF + cdef uint32_t out_bits = sign + cdef int32_t shift = 0 + cdef float value + + if exp == 0x1F: + out_bits |= 0xFF << 23 + if mant != 0: + out_bits |= mant << 13 + elif exp == 0: + if mant != 0: + while (mant & 0x400) == 0: + mant <<= 1 + shift += 1 + mant &= 0x3FF + out_bits |= (1 - 15 - shift + 127) << 23 + out_bits |= mant << 13 + else: + out_bits |= (exp - 15 + 127) << 23 + out_bits |= mant << 13 + + memcpy(&value, &out_bits, sizeof(float)) + return value + + +cdef inline bint _float16_try_coerce_float(object value, float* out_value): + try: + out_value[0] = float(value) + return True + except (TypeError, ValueError, OverflowError): + return False + + +@cython.final +cdef class float16: + """Exact IEEE 754 binary16 value carrier with reduced-precision arithmetic operators.""" + + cdef uint16_t bits + + def __cinit__(self, value=0.0): + if isinstance(value, float16): + self.bits = (value).bits + else: + self.bits = _float16_float_to_bits(value) + + @staticmethod + def from_bits(bits): + cdef float16 value = float16.__new__(float16) + value.bits = bits + return value + + @staticmethod + def from_float(value): + return float16(value) + + def to_bits(self): + return self.bits + + def to_float(self): + return _float16_bits_to_float(self.bits) + + def __float__(self): + return self.to_float() + + def __int__(self): + return int(self.to_float()) + + def __repr__(self): + return f"float16(bits=0x{self.bits:04x}, value={self.to_float()!r})" + + def __hash__(self): + return hash(self.bits) + + def __eq__(self, other): + if isinstance(other, float16): + return self.bits == (other).bits + return self.to_float() == float(other) + + def __lt__(self, other): + cdef float rhs + if not _float16_try_coerce_float(other, &rhs): + return NotImplemented + return _float16_bits_to_float(self.bits) < rhs + + def __le__(self, other): + cdef float rhs + if not _float16_try_coerce_float(other, &rhs): + return NotImplemented + return _float16_bits_to_float(self.bits) <= rhs + + def __gt__(self, other): + cdef float rhs + if not _float16_try_coerce_float(other, &rhs): + return NotImplemented + return _float16_bits_to_float(self.bits) > rhs + + def __ge__(self, other): + cdef float rhs + if not _float16_try_coerce_float(other, &rhs): + return NotImplemented + return _float16_bits_to_float(self.bits) >= rhs + + def __add__(self, other): + cdef float rhs + if not _float16_try_coerce_float(other, &rhs): + return NotImplemented + return float16(_float16_bits_to_float(self.bits) + rhs) + + def __radd__(self, other): + return self.__add__(other) + + def __sub__(self, other): + cdef float rhs + if not _float16_try_coerce_float(other, &rhs): + return NotImplemented + return float16(_float16_bits_to_float(self.bits) - rhs) + + def __rsub__(self, other): + cdef float lhs + if not _float16_try_coerce_float(other, &lhs): + return NotImplemented + return float16(lhs - _float16_bits_to_float(self.bits)) + + def __mul__(self, other): + cdef float rhs + if not _float16_try_coerce_float(other, &rhs): + return NotImplemented + return float16(_float16_bits_to_float(self.bits) * rhs) + + def __rmul__(self, other): + return self.__mul__(other) + + def __truediv__(self, other): + cdef float rhs + if not _float16_try_coerce_float(other, &rhs): + return NotImplemented + return float16(_float16_bits_to_float(self.bits) / rhs) + + def __rtruediv__(self, other): + cdef float lhs + if not _float16_try_coerce_float(other, &lhs): + return NotImplemented + return float16(lhs / _float16_bits_to_float(self.bits)) + + def __neg__(self): + return float16.from_bits(self.bits ^ 0x8000) + + def __pos__(self): + return float16.from_bits(self.bits) + + def __abs__(self): + return float16.from_bits(self.bits & 0x7FFF) + + +cdef inline uint16_t _coerce_float16_bits(value): + if isinstance(value, float16): + return (value).bits + return _float16_float_to_bits(value) + + +@cython.final +cdef class float16array: + """Packed one-dimensional carrier for xlang ``float16_array`` payloads.""" + + cdef object _data + + def __cinit__(self, values=()): + self._data = _py_array.array("H") + self.extend(values) + + @staticmethod + def from_values(values): + return float16array(values) + + @staticmethod + def from_buffer(buffer): + cdef float16array value = float16array.__new__(float16array) + cdef object data = buffer + value._data = _py_array.array("H") + if isinstance(data, memoryview): + if data.itemsize == 2 and data.format in ("H", "@H"): + value._data = _py_array.array("H", data) + return value + data = data.tobytes() + if isinstance(data, (bytes, bytearray)): + if len(data) & 1: + raise ValueError("float16 bits payload size mismatch") + value._data.frombytes(data) + if not is_little_endian and len(data) > 0: + value._data.byteswap() + return value + if isinstance(data, _py_array.array) and data.typecode == "H": + value._data = _py_array.array("H", data) + return value + raise TypeError(f"float16array.from_buffer expects a buffer-compatible value, got {type(buffer)!r}") + + def append(self, value): + self._data.append(_coerce_float16_bits(value)) + + def extend(self, values): + for value in values: + self.append(value) + + def to_buffer(self): + return _py_array.array("H", self._data) + + def tolist(self): + return [float16.from_bits(bits) for bits in self._data] + + def __len__(self): + return len(self._data) + + def __iter__(self): + for bits in self._data: + yield float16.from_bits(bits) + + def __getitem__(self, index): + if isinstance(index, slice): + return float16array.from_buffer(self._data[index]) + return float16.from_bits(self._data[index]) + + def __getbuffer__(self, Py_buffer *buffer, int flags): + if PyObject_GetBuffer(self._data, buffer, flags) < 0: + raise BufferError("float16array failed to export buffer") + + def __releasebuffer__(self, Py_buffer *buffer): + PyBuffer_Release(buffer) + + def __repr__(self): + return f"float16array({self.tolist()!r})" + + def __eq__(self, other): + if isinstance(other, float16array): + return self._data == (other)._data + try: + return self.tolist() == list(other) + except TypeError: + return False + + +@cython.final +cdef class Float16Serializer(Serializer): + cpdef inline write(self, WriteContext write_context, value): + write_context.write_uint16(_coerce_float16_bits(value)) + + cpdef inline read(self, ReadContext read_context): + return float16.from_bits(read_context.read_uint16()) + + +@cython.final +cdef class Float16ArraySerializer(Serializer): + cpdef write(self, WriteContext write_context, value): + cdef float16array safe + cdef object swapped + if value is None: + safe = float16array() + else: + safe = value + write_context.write_var_uint32(len(safe) * 2) + if is_little_endian: + write_context.write_buffer(safe._data) + else: + swapped = _py_array.array("H", safe._data) + swapped.byteswap() + write_context.write_buffer(swapped) + + cpdef read(self, ReadContext read_context): + cdef uint32_t payload_size = read_context.read_var_uint32() + cdef float16array values + if payload_size & 1: + raise ValueError("float16 array payload size mismatch") + values = float16array() + values._data.frombytes(read_context.read_bytes(payload_size)) + if not is_little_endian and payload_size > 0: + values._data.byteswap() + return values diff --git a/python/pyfory/registry.py b/python/pyfory/registry.py index da1fab51ce..6b3bee2ccd 100644 --- a/python/pyfory/registry.py +++ b/python/pyfory/registry.py @@ -56,8 +56,12 @@ Uint64Serializer, VarUint64Serializer, TaggedUint64Serializer, + Float16Serializer, + Float16ArraySerializer, Float32Serializer, Float64Serializer, + BFloat16Serializer, + BFloat16ArraySerializer, StringSerializer, DecimalSerializer, DateSerializer, @@ -83,6 +87,13 @@ PickleBufferSerializer, UnionSerializer, ) +from pyfory.serialization import ( + Serializer as CythonSerializer, + bfloat16, + bfloat16array, + float16, + float16array, +) from pyfory.meta.metastring import MetaStringEncoder, MetaStringDecoder from pyfory.meta.meta_compressor import DeflaterMetaCompressor from pyfory.context import EncodedMetaString @@ -146,8 +157,10 @@ TypeId.UINT64, TypeId.VAR_UINT64, TypeId.TAGGED_UINT64, + TypeId.FLOAT16, TypeId.FLOAT32, TypeId.FLOAT64, + TypeId.BFLOAT16, } ) @@ -161,7 +174,7 @@ def _accepts_n_positional_args(factory, nargs: int) -> bool: signature = inspect.signature(factory.__init__) parameters = tuple(signature.parameters.values())[1:] except (AttributeError, TypeError, ValueError): - if inspect.isclass(factory) and issubclass(factory, Serializer): + if inspect.isclass(factory) and issubclass(factory, (Serializer, CythonSerializer)): return nargs == 2 raise TypeError(f"Unable to inspect serializer constructor for {factory!r}") min_args = 0 @@ -443,11 +456,27 @@ def _initialize_common(self): serializer=Float64Serializer, ) register(float, type_id=TypeId.FLOAT64, serializer=Float64Serializer) + register( + float16, + type_id=TypeId.FLOAT16, + serializer=Float16Serializer, + ) + register(bfloat16, type_id=TypeId.BFLOAT16, serializer=BFloat16Serializer) register(str, type_id=TypeId.STRING, serializer=StringSerializer) register(decimal.Decimal, type_id=TypeId.DECIMAL, serializer=DecimalSerializer) register(datetime.datetime, type_id=TypeId.TIMESTAMP, serializer=TimestampSerializer) register(datetime.date, type_id=TypeId.DATE, serializer=DateSerializer) register(bytes, type_id=TypeId.BINARY, serializer=BytesSerializer) + register( + float16array, + type_id=TypeId.FLOAT16_ARRAY, + serializer=Float16ArraySerializer, + ) + register( + bfloat16array, + type_id=TypeId.BFLOAT16_ARRAY, + serializer=BFloat16ArraySerializer, + ) for itemsize, ftype, typeid in PyArraySerializer.typecode_dict.values(): register( ftype, diff --git a/python/pyfory/serialization.pyx b/python/pyfory/serialization.pyx index 0ccea8963b..717104b364 100644 --- a/python/pyfory/serialization.pyx +++ b/python/pyfory/serialization.pyx @@ -1056,6 +1056,8 @@ cdef class Fory: self.reset_read() include "primitive.pxi" +include "float16.pxi" +include "bfloat16.pxi" include "collection.pxi" include "struct.pxi" diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py index 30101757f6..a50fa9a38e 100644 --- a/python/pyfory/serializer.py +++ b/python/pyfory/serializer.py @@ -39,6 +39,7 @@ NOT_NULL_INT64_FLAG, BufferObject, ) +from pyfory.serialization import bfloat16, bfloat16array, float16, float16array _WINDOWS = os.name == "nt" @@ -71,6 +72,8 @@ TaggedUint64Serializer, Float32Serializer, Float64Serializer, + Float16Serializer, + Float16ArraySerializer, StringSerializer, DateSerializer, TimestampSerializer, @@ -82,6 +85,10 @@ MapSerializer, EnumSerializer, SliceSerializer, + bfloat16, + bfloat16array, + BFloat16Serializer, + BFloat16ArraySerializer, ) from pyfory.union import UnionSerializer # noqa: F401 else: @@ -106,6 +113,10 @@ TaggedUint64Serializer, Float32Serializer, Float64Serializer, + Float16Serializer, + Float16ArraySerializer, + BFloat16Serializer, + BFloat16ArraySerializer, StringSerializer, DateSerializer, TimestampSerializer, @@ -1419,8 +1430,16 @@ def read(self, read_context): "Uint64Serializer", "VarUint64Serializer", "TaggedUint64Serializer", + "float16", + "float16array", + "Float16Serializer", + "Float16ArraySerializer", "Float32Serializer", "Float64Serializer", + "bfloat16", + "bfloat16array", + "BFloat16Serializer", + "BFloat16ArraySerializer", "StringSerializer", "DateSerializer", "TimestampSerializer", diff --git a/python/pyfory/struct.py b/python/pyfory/struct.py index 06141d8468..1658c28fc0 100644 --- a/python/pyfory/struct.py +++ b/python/pyfory/struct.py @@ -64,6 +64,7 @@ ) from pyfory.serialization import Buffer from pyfory.serialization import ENABLE_FORY_CYTHON_SERIALIZATION +from pyfory.serialization import bfloat16, float16 from pyfory.error import TypeNotCompatibleError from pyfory.resolver import NULL_FLAG, NOT_NULL_VALUE_FLAG from pyfory.field import ( @@ -615,6 +616,8 @@ def _replace(self): fixed_uint64, tagged_uint64, # Floats + float16, + bfloat16, float32, float64, # Python native types @@ -779,16 +782,15 @@ def numeric_sorter(item): TypeId.VAR_UINT64, TypeId.TAGGED_UINT64, } - # Sort by: compress flag, -size (largest first), -type_id (higher type ID first), field_name - # Java sorts by size (largest first), then by primitive type ID (descending) - return int(compress), -get_primitive_type_size(id_), -id_, item[3] + # Sort by: compress flag, -size (largest first), type_id (lower first), field_name + return int(compress), -get_primitive_type_size(id_), id_, item[3] boxed_types = sorted(boxed_types, key=numeric_sorter) nullable_boxed_types = sorted(nullable_boxed_types, key=numeric_sorter) - collection_types = sorted(collection_types, key=sorter) - set_types = sorted(set_types, key=sorter) + collection_types = sorted(collection_types, key=lambda item: item[3]) + set_types = sorted(set_types, key=lambda item: item[3]) internal_types = sorted(internal_types, key=sorter) - map_types = sorted(map_types, key=sorter) + map_types = sorted(map_types, key=lambda item: item[3]) other_types = sorted(other_types, key=lambda item: item[3]) return (boxed_types, nullable_boxed_types, internal_types, collection_types, set_types, map_types, other_types) diff --git a/python/pyfory/tests/test_serializer.py b/python/pyfory/tests/test_serializer.py index 04da47ca27..1beed93e80 100644 --- a/python/pyfory/tests/test_serializer.py +++ b/python/pyfory/tests/test_serializer.py @@ -172,6 +172,134 @@ def test_basic_serializer(xlang): assert ser_de(fory, set_) == set_ +def test_float16_round_trip(): + fory = Fory(xlang=True, ref=False) + value = pyfory.float16.from_bits(0x3C00) + typeinfo = fory.type_resolver.get_type_info(pyfory.float16) + assert isinstance(typeinfo.serializer, pyfory.Float16Serializer) + assert typeinfo.type_id == TypeId.FLOAT16 + decoded = ser_de(fory, value) + assert isinstance(decoded, pyfory.float16) + assert decoded.to_bits() == 0x3C00 + + +def test_bfloat16_round_trip(): + fory = Fory(xlang=True, ref=False) + value = pyfory.bfloat16.from_bits(0x3FC0) + typeinfo = fory.type_resolver.get_type_info(pyfory.bfloat16) + assert isinstance(typeinfo.serializer, pyfory.BFloat16Serializer) + assert typeinfo.type_id == TypeId.BFLOAT16 + decoded = ser_de(fory, value) + assert isinstance(decoded, pyfory.bfloat16) + assert decoded.to_bits() == 0x3FC0 + + +def test_float16_arithmetic(): + value = pyfory.float16(1.5) + half = pyfory.float16(0.5) + + assert isinstance(value + half, pyfory.float16) + assert (value + half).to_bits() == 0x4000 + assert (value - half).to_bits() == 0x3C00 + assert (value * 2).to_bits() == 0x4200 + assert (3 + half).to_bits() == 0x4300 + assert (value / half).to_bits() == 0x4200 + with pytest.raises(ZeroDivisionError): + _ = value / 0.0 + assert (-value).to_bits() == 0xBE00 + assert (+value).to_bits() == value.to_bits() + assert abs(pyfory.float16(-1.5)).to_bits() == 0x3E00 + assert pyfory.float16(1.5) < 2.0 + assert pyfory.float16(1.5) <= pyfory.float16(1.5) + assert pyfory.float16(1.5) > 1.0 + assert pyfory.float16(1.5) >= pyfory.float16(1.5) + assert not (pyfory.float16.from_bits(0x7E00) < 1.0) + with pytest.raises(TypeError): + _ = value + "x" + + +def test_bfloat16_arithmetic(): + value = pyfory.bfloat16(1.5) + half = pyfory.bfloat16(0.5) + + assert isinstance(value + half, pyfory.bfloat16) + assert (value + half).to_bits() == 0x4000 + assert (value - half).to_bits() == 0x3F80 + assert (value * 2).to_bits() == 0x4040 + assert (3 + half).to_bits() == 0x4060 + assert (value / half).to_bits() == 0x4040 + with pytest.raises(ZeroDivisionError): + _ = value / 0.0 + assert (-value).to_bits() == 0xBFC0 + assert (+value).to_bits() == value.to_bits() + assert abs(pyfory.bfloat16(-1.5)).to_bits() == 0x3FC0 + assert pyfory.bfloat16(1.5) < 2.0 + assert pyfory.bfloat16(1.5) <= pyfory.bfloat16(1.5) + assert pyfory.bfloat16(1.5) > 1.0 + assert pyfory.bfloat16(1.5) >= pyfory.bfloat16(1.5) + assert not (pyfory.bfloat16.from_bits(0x7FC0) < 1.0) + with pytest.raises(TypeError): + _ = value + "x" + + +def test_float16_array_round_trip(): + fory = Fory(xlang=True, ref=False) + values = pyfory.float16array.from_values([0.0, 1.0, -2.0]) + typeinfo = fory.type_resolver.get_type_info(pyfory.float16array) + assert isinstance(typeinfo.serializer, pyfory.Float16ArraySerializer) + assert typeinfo.type_id == TypeId.FLOAT16_ARRAY + decoded = ser_de(fory, values) + assert isinstance(decoded, pyfory.float16array) + assert list(decoded.to_buffer()) == [0x0000, 0x3C00, 0xC000] + + +def test_float16_array_from_values(): + values = pyfory.float16array.from_values([0.0, 1.0, -2.0]) + assert list(values.to_buffer()) == [0x0000, 0x3C00, 0xC000] + + +def test_float16_array_from_buffer(): + values = pyfory.float16array.from_buffer(memoryview(bytes.fromhex("0000003c00c0"))) + assert list(values.to_buffer()) == [0x0000, 0x3C00, 0xC000] + + +def test_float16_array_buffer_protocol(): + values = pyfory.float16array.from_values([0.0, 1.0, -2.0]) + view = memoryview(values) + assert view.format in ("H", "@H") + assert view.itemsize == 2 + assert list(pyfory.float16array.from_buffer(view).to_buffer()) == [0x0000, 0x3C00, 0xC000] + + +def test_bfloat16_array_round_trip(): + fory = Fory(xlang=True, ref=False) + values = pyfory.bfloat16array.from_values([0.0, 1.0, -2.0]) + typeinfo = fory.type_resolver.get_type_info(pyfory.bfloat16array) + assert isinstance(typeinfo.serializer, pyfory.BFloat16ArraySerializer) + assert typeinfo.type_id == TypeId.BFLOAT16_ARRAY + decoded = ser_de(fory, values) + assert isinstance(decoded, pyfory.bfloat16array) + assert list(decoded.to_buffer()) == [0x0000, 0x3F80, 0xC000] + + +def test_bfloat16_array_from_values(): + values = pyfory.bfloat16array.from_values([0.0, 1.0, -2.0]) + assert list(values.to_buffer()) == [0x0000, 0x3F80, 0xC000] + + +def test_bfloat16_array_from_buffer(): + values = pyfory.bfloat16array.from_buffer(memoryview(bytes.fromhex("0000803f00c0"))) + assert list(values.to_buffer()) == [0x0000, 0x3F80, 0xC000] + + +def test_bfloat16_array_buffer_protocol(): + values = pyfory.bfloat16array.from_values([0.0, 1.0, -2.0]) + view = memoryview(values) + assert view.format in ("H", "@H") + assert view.itemsize == 2 + assert list(pyfory.bfloat16array.from_buffer(view).to_buffer()) == [0x0000, 0x3F80, 0xC000] + + @pytest.mark.parametrize("xlang", [True, False]) def test_date_serializer_uses_xlang_varint64_and_native_int32(xlang): fory = Fory(xlang=xlang, ref=False) diff --git a/python/pyfory/tests/test_struct.py b/python/pyfory/tests/test_struct.py index 92b00c3507..23e24e7f90 100644 --- a/python/pyfory/tests/test_struct.py +++ b/python/pyfory/tests/test_struct.py @@ -186,9 +186,8 @@ class TestClass: fory = Fory(xlang=True, ref=True) serializer = DataClassSerializer(fory.type_resolver, TestClass) # Sorting order: - # 1. Non-compressed primitives (compress=0) by -size, then name: - # float64(8), float32(4), bool(1), int8(1) => f13, f5, f11, f7 - # (f11 < f7 alphabetically since '1' < '7') + # 1. Non-compressed primitives (compress=0) by -size, then ascending type_id, then name: + # float64(8), float32(4), bool(1), int8(1) => f13, f5, f7, f11 # 2. Compressed primitives (compress=1) by -size, then name: # int64(8), int32(4) => f12, f1 # 3. Internal types by type_id, then name: str, datetime, bytes => f4, f15, f6 @@ -196,7 +195,7 @@ class TestClass: # 5. Set types by type_id, then name: set => f14 # 6. Map types by type_id, then name: dict => f3, f9 # 7. Other types (polymorphic/any) by name: any => f8 - assert serializer._field_names == ["f13", "f5", "f11", "f7", "f12", "f1", "f4", "f15", "f6", "f10", "f2", "f14", "f3", "f9", "f8"] + assert serializer._field_names == ["f13", "f5", "f7", "f11", "f12", "f1", "f4", "f15", "f6", "f10", "f2", "f14", "f3", "f9", "f8"] @pytest.mark.parametrize( diff --git a/python/pyfory/tests/xlang_test_main.py b/python/pyfory/tests/xlang_test_main.py index a8e5a526e3..4e58da65b6 100644 --- a/python/pyfory/tests/xlang_test_main.py +++ b/python/pyfory/tests/xlang_test_main.py @@ -187,6 +187,14 @@ class TwoStringFieldStruct: f2: str = "" +@dataclass +class ReducedPrecisionFloatStruct: + float16_value: pyfory.float16 = None + bfloat16_value: pyfory.bfloat16 = None + float16_array: pyfory.float16array = None + bfloat16_array: pyfory.bfloat16array = None + + class TestEnum(enum.Enum): VALUE_A = 0 VALUE_B = 1 @@ -810,6 +818,47 @@ def test_schema_evolution_compatible_reverse(): f.write(new_bytes) +def _assert_reduced_precision_float_struct(obj: ReducedPrecisionFloatStruct): + assert obj.float16_value.to_bits() == 0x3E00, f"float16_value bits: {obj.float16_value.to_bits():#06x}" + assert obj.bfloat16_value.to_bits() == 0x3FC0, f"bfloat16_value bits: {obj.bfloat16_value.to_bits():#06x}" + assert list(obj.float16_array.to_buffer()) == [0x0000, 0x3C00, 0xBC00], f"float16_array bits: {list(obj.float16_array.to_buffer())}" + assert list(obj.bfloat16_array.to_buffer()) == [0x0000, 0x3F80, 0xBF80], f"bfloat16_array bits: {list(obj.bfloat16_array.to_buffer())}" + + +def test_reduced_precision_float_struct(): + data_file = get_data_file() + with open(data_file, "rb") as f: + data_bytes = f.read() + + fory = pyfory.Fory(xlang=True, compatible=False) + fory.register_type(ReducedPrecisionFloatStruct, type_id=213) + + obj = fory.deserialize(data_bytes) + debug_print(f"Deserialized: {obj}") + _assert_reduced_precision_float_struct(obj) + + new_bytes = fory.serialize(obj) + with open(data_file, "wb") as f: + f.write(new_bytes) + + +def test_reduced_precision_float_struct_compatible_skip(): + data_file = get_data_file() + with open(data_file, "rb") as f: + data_bytes = f.read() + + fory = pyfory.Fory(xlang=True, compatible=True, meta_compressor=NoOpMetaCompressor()) + fory.register_type(EmptyStruct, type_id=213) + + obj = fory.deserialize(data_bytes) + debug_print(f"Deserialized empty struct: {obj}") + assert isinstance(obj, EmptyStruct) + + new_bytes = fory.serialize(obj) + with open(data_file, "wb") as f: + f.write(new_bytes) + + def test_one_enum_field_schema(): """Test one enum field struct with schema consistent mode.""" data_file = get_data_file() diff --git a/python/pyfory/types.py b/python/pyfory/types.py index b680d393c6..919d323b8f 100644 --- a/python/pyfory/types.py +++ b/python/pyfory/types.py @@ -243,6 +243,13 @@ def __class_getitem__(cls, params): float64, } + +def _is_special_compiled_primitive_type(type_) -> bool: + from pyfory.serialization import bfloat16, float16 + + return type_ is float16 or type_ is bfloat16 + + _primitive_types_ids = { TypeId.BOOL, # Signed integers @@ -276,7 +283,7 @@ def __class_getitem__(cls, params): def is_primitive_type(type_) -> bool: if type(type_) is int: return type_ in _primitive_types_ids - return type_ in _primitive_types + return type_ in _primitive_types or _is_special_compiled_primitive_type(type_) _primitive_type_sizes = { @@ -375,6 +382,12 @@ def get_primitive_type_size(type_id) -> int: _primitive_array_types = _py_array_types.union(_np_array_types) +def _is_special_compiled_primitive_array_type(type_) -> bool: + from pyfory.serialization import bfloat16array, float16array + + return type_ is float16array or type_ is bfloat16array + + def is_py_array_type(type_) -> bool: return type_ in _py_array_types @@ -389,6 +402,8 @@ def is_py_array_type(type_) -> bool: TypeId.UINT16_ARRAY, TypeId.UINT32_ARRAY, TypeId.UINT64_ARRAY, + TypeId.FLOAT16_ARRAY, + TypeId.BFLOAT16_ARRAY, TypeId.FLOAT32_ARRAY, TypeId.FLOAT64_ARRAY, } @@ -397,7 +412,7 @@ def is_py_array_type(type_) -> bool: def is_primitive_array_type(type_) -> bool: if type(type_) is int: return type_ in _primitive_array_type_ids - return type_ in _primitive_array_types + return type_ in _primitive_array_types or _is_special_compiled_primitive_array_type(type_) def is_list_type(type_): diff --git a/rust/README.md b/rust/README.md index 34648dc922..897e244085 100644 --- a/rust/README.md +++ b/rust/README.md @@ -16,6 +16,7 @@ The Rust implementation provides versatile and high-performance serialization wi - **🔄 Circular References**: Automatic tracking of shared and circular references with `Rc`/`Arc` and weak pointers - **🧬 Polymorphic**: Serialize trait objects with `Box`, `Rc`, and `Arc` - **📦 Schema Evolution**: Compatible mode for independent schema changes +- **🔢 Reduced-Precision Types**: `Float16` and `BFloat16` scalars with `Vec` / `Vec` arrays - **⚡ Two Formats**: Object graph serialization and zero-copy row-based format ## 📦 Crates diff --git a/rust/fory-core/src/buffer.rs b/rust/fory-core/src/buffer.rs index 9bd7310476..2d4a6b4666 100644 --- a/rust/fory-core/src/buffer.rs +++ b/rust/fory-core/src/buffer.rs @@ -16,6 +16,7 @@ // under the License. use crate::error::Error; +use crate::types::bfloat16::bfloat16; use crate::types::float16::float16; use crate::util::buffer_rw_string::read_latin1_simd; use byteorder::{ByteOrder, LittleEndian}; @@ -397,6 +398,12 @@ impl<'a> Writer<'a> { self.write_u16(value.to_bits()); } + // ============ BFLOAT16 (TypeId = 18) ============ + #[inline(always)] + pub fn write_bf16(&mut self, value: bfloat16) { + self.write_u16(value.to_bits()); + } + // ============ FLOAT64 (TypeId = 18) ============ #[inline(always)] @@ -876,6 +883,14 @@ impl<'a> Reader<'a> { Ok(float16::from_bits(bits)) } + #[inline(always)] + pub fn read_bf16(&mut self) -> Result { + self.check_bound(2)?; + let bits = LittleEndian::read_u16(&self.bf[self.cursor..self.cursor + 2]); + self.cursor += 2; + Ok(bfloat16::from_bits(bits)) + } + pub fn read_f64(&mut self) -> Result { self.check_bound(8)?; let result = LittleEndian::read_f64(&self.bf[self.cursor..self.cursor + 8]); diff --git a/rust/fory-core/src/lib.rs b/rust/fory-core/src/lib.rs index eb8083097d..272cbd0973 100644 --- a/rust/fory-core/src/lib.rs +++ b/rust/fory-core/src/lib.rs @@ -32,7 +32,7 @@ //! - **`serializer`**: Type-specific serialization implementations //! - **`resolver`**: Type resolution and metadata management //! - **`meta`**: Metadata handling for schema evolution -//! - **`types`**: Runtime value carriers such as decimal, float16, and weak refs +//! - **`types`**: Runtime value carriers such as decimal, Float16, BFloat16, and weak refs //! - **`type_id`**: Type IDs and protocol header helpers //! - **`error`**: Error handling and result types //! - **`util`**: Utility functions and helpers @@ -203,5 +203,6 @@ pub use crate::meta::{compute_field_hash, compute_struct_hash}; pub use crate::resolver::{RefFlag, RefMode, TypeInfo, TypeResolver}; pub use crate::serializer::{read_data, write_data, ForyDefault, Serializer, StructSerializer}; pub use crate::type_id::TypeId; +pub use crate::types::bfloat16::bfloat16 as BFloat16; pub use crate::types::float16::float16 as Float16; pub use crate::types::{ArcWeak, Decimal, RcWeak}; diff --git a/rust/fory-core/src/meta/type_meta.rs b/rust/fory-core/src/meta/type_meta.rs index 00f896020d..035e7cb328 100644 --- a/rust/fory-core/src/meta/type_meta.rs +++ b/rust/fory-core/src/meta/type_meta.rs @@ -796,7 +796,7 @@ impl TypeMeta { .cmp(&b_nullable) // non-nullable first .then_with(|| compress_a.cmp(&compress_b)) // fixed-size (false) first, then variable-size (true) last .then_with(|| size_b.cmp(&size_a)) // when same compress status: larger size first - .then_with(|| b_id.cmp(&a_id)) // when same size: larger type id first + .then_with(|| a_id.cmp(&b_id)) // when same size: smaller type id first .then_with(|| a_field_name.cmp(b_field_name)) // when same id: lexicographic name } fn type_then_name_sorter(a: &FieldInfo, b: &FieldInfo) -> std::cmp::Ordering { diff --git a/rust/fory-core/src/resolver/type_resolver.rs b/rust/fory-core/src/resolver/type_resolver.rs index bbd6b3a547..6045e3c825 100644 --- a/rust/fory-core/src/resolver/type_resolver.rs +++ b/rust/fory-core/src/resolver/type_resolver.rs @@ -733,6 +733,7 @@ impl TypeResolver { self.register_internal_serializer::(TypeId::FLOAT32)?; self.register_internal_serializer::(TypeId::FLOAT64)?; self.register_internal_serializer::(TypeId::FLOAT16)?; + self.register_internal_serializer::(TypeId::BFLOAT16)?; self.register_internal_serializer::(TypeId::UINT8)?; self.register_internal_serializer::(TypeId::UINT16)?; self.register_internal_serializer::(TypeId::VAR_UINT32)?; @@ -754,6 +755,9 @@ impl TypeResolver { self.register_internal_serializer::>( TypeId::FLOAT16_ARRAY, )?; + self.register_internal_serializer::>( + TypeId::BFLOAT16_ARRAY, + )?; self.register_internal_serializer::>(TypeId::BINARY)?; self.register_internal_serializer::>(TypeId::UINT16_ARRAY)?; self.register_internal_serializer::>(TypeId::UINT32_ARRAY)?; diff --git a/rust/fory-core/src/serializer/list.rs b/rust/fory-core/src/serializer/list.rs index 36ec760a4f..435ec97e9e 100644 --- a/rust/fory-core/src/serializer/list.rs +++ b/rust/fory-core/src/serializer/list.rs @@ -44,6 +44,7 @@ pub(super) fn get_primitive_type_id() -> TypeId { // Handle INT64, VARINT64, and TAGGED_INT64 (i64 uses VARINT64 in xlang mode) TypeId::INT64 | TypeId::VARINT64 | TypeId::TAGGED_INT64 => TypeId::INT64_ARRAY, TypeId::FLOAT16 => TypeId::FLOAT16_ARRAY, + TypeId::BFLOAT16 => TypeId::BFLOAT16_ARRAY, TypeId::FLOAT32 => TypeId::FLOAT32_ARRAY, TypeId::FLOAT64 => TypeId::FLOAT64_ARRAY, TypeId::UINT8 => TypeId::BINARY, @@ -77,6 +78,7 @@ pub(super) fn is_primitive_type() -> bool { | TypeId::TAGGED_INT64 | TypeId::INT128 | TypeId::FLOAT16 + | TypeId::BFLOAT16 | TypeId::FLOAT32 | TypeId::FLOAT64 | TypeId::UINT8 diff --git a/rust/fory-core/src/serializer/number.rs b/rust/fory-core/src/serializer/number.rs index 50f83d0958..d42d4b219d 100644 --- a/rust/fory-core/src/serializer/number.rs +++ b/rust/fory-core/src/serializer/number.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::types::bfloat16::bfloat16; use crate::types::float16::float16; use crate::buffer::{Reader, Writer}; @@ -150,6 +151,54 @@ impl ForyDefault for float16 { float16::ZERO } } + +impl Serializer for bfloat16 { + #[inline(always)] + fn fory_write_data(&self, context: &mut WriteContext) -> Result<(), Error> { + Writer::write_bf16(&mut context.writer, *self); + Ok(()) + } + #[inline(always)] + fn fory_read_data(context: &mut ReadContext) -> Result { + Reader::read_bf16(&mut context.reader) + } + #[inline(always)] + fn fory_reserved_space() -> usize { + std::mem::size_of::() + } + #[inline(always)] + fn fory_get_type_id(_: &TypeResolver) -> Result { + Ok(TypeId::BFLOAT16) + } + #[inline(always)] + fn fory_type_id_dyn(&self, _: &TypeResolver) -> Result { + Ok(TypeId::BFLOAT16) + } + #[inline(always)] + fn fory_static_type_id() -> TypeId { + TypeId::BFLOAT16 + } + #[inline(always)] + fn as_any(&self) -> &dyn std::any::Any { + self + } + #[inline(always)] + fn fory_write_type_info(context: &mut WriteContext) -> Result<(), Error> { + context.writer.write_var_u32(TypeId::BFLOAT16 as u32); + Ok(()) + } + #[inline(always)] + fn fory_read_type_info(context: &mut ReadContext) -> Result<(), Error> { + read_basic_type_info::(context) + } +} + +impl ForyDefault for bfloat16 { + #[inline(always)] + fn fory_default() -> Self { + bfloat16::ZERO + } +} impl_num_serializer!(i128, Writer::write_i128, Reader::read_i128, TypeId::INT128); impl_num_serializer!( isize, diff --git a/rust/fory-core/src/serializer/skip.rs b/rust/fory-core/src/serializer/skip.rs index 9c51dd60f3..9e139b0e2f 100644 --- a/rust/fory-core/src/serializer/skip.rs +++ b/rust/fory-core/src/serializer/skip.rs @@ -544,6 +544,11 @@ fn skip_value( ::fory_read_data(context)?; } + // ============ BFLOAT16 (TypeId = 18) ============ + types::BFLOAT16 => { + ::fory_read_data(context)?; + } + // ============ FLOAT32 (TypeId = 17) ============ types::FLOAT32 => { ::fory_read_data(context)?; @@ -708,6 +713,11 @@ fn skip_value( as Serializer>::fory_read_data(context)?; } + // ============ BFLOAT16_ARRAY (TypeId = 54) ============ + types::BFLOAT16_ARRAY => { + as Serializer>::fory_read_data(context)?; + } + // ============ FLOAT32_ARRAY (TypeId = 51) ============ types::FLOAT32_ARRAY => { as Serializer>::fory_read_data(context)?; diff --git a/rust/fory-core/src/type_id.rs b/rust/fory-core/src/type_id.rs index ecb51628c3..597ddfae77 100644 --- a/rust/fory-core/src/type_id.rs +++ b/rust/fory-core/src/type_id.rs @@ -272,7 +272,7 @@ pub static PRIMITIVE_ARRAY_TYPES: [u32; 19] = [ TypeId::USIZE_ARRAY as u32, TypeId::ISIZE_ARRAY as u32, ]; -pub static BASIC_TYPE_NAMES: [&str; 19] = [ +pub static BASIC_TYPE_NAMES: [&str; 20] = [ "bool", "i8", "i16", @@ -289,6 +289,7 @@ pub static BASIC_TYPE_NAMES: [&str; 19] = [ "u32", "u64", "float16", + "bfloat16", "u128", "usize", "isize", @@ -309,6 +310,7 @@ pub static PRIMITIVE_ARRAY_TYPE_MAP: &[(&str, u32, &str)] = &[ ("u32", TypeId::UINT32_ARRAY as u32, "Vec"), ("u64", TypeId::UINT64_ARRAY as u32, "Vec"), ("float16", TypeId::FLOAT16_ARRAY as u32, "Vec"), + ("bfloat16", TypeId::BFLOAT16_ARRAY as u32, "Vec"), ("f32", TypeId::FLOAT32_ARRAY as u32, "Vec"), ("f64", TypeId::FLOAT64_ARRAY as u32, "Vec"), // Rust-specific diff --git a/rust/fory-core/src/types/bfloat16.rs b/rust/fory-core/src/types/bfloat16.rs new file mode 100644 index 0000000000..ec649d4660 --- /dev/null +++ b/rust/fory-core/src/types/bfloat16.rs @@ -0,0 +1,271 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! IEEE 754 bfloat16 floating-point type. +//! +//! This module provides a `bfloat16` type that stores an IEEE 754 bfloat16 +//! payload exactly. The type is a transparent wrapper around `u16` and +//! provides round-to-nearest-even conversion from `f32`, exact expansion back +//! to `f32`, classification helpers, and arithmetic through `f32`. +//! +//! The type is re-exported from `fory_core` as `BFloat16`, and `Vec` +//! / `Vec` is the canonical dense carrier for xlang `bfloat16_array` +//! payloads. + +use std::cmp::Ordering; +use std::fmt; +use std::hash::{Hash, Hasher}; +use std::ops::{Add, Div, Mul, Neg, Sub}; + +#[repr(transparent)] +#[derive(Copy, Clone, Default)] +#[allow(non_camel_case_types)] +pub struct bfloat16(u16); + +const SIGN_MASK: u16 = 0x8000; +const EXP_MASK: u16 = 0x7F80; +const MANTISSA_MASK: u16 = 0x007F; +const INFINITY_BITS: u16 = 0x7F80; +const NEG_INFINITY_BITS: u16 = 0xFF80; +const QUIET_NAN_BITS: u16 = 0x7FC0; + +impl bfloat16 { + #[inline(always)] + pub const fn from_bits(bits: u16) -> Self { + Self(bits) + } + + #[inline(always)] + pub const fn to_bits(self) -> u16 { + self.0 + } + + pub const ZERO: Self = Self(0x0000); + pub const NEG_ZERO: Self = Self(0x8000); + pub const INFINITY: Self = Self(INFINITY_BITS); + pub const NEG_INFINITY: Self = Self(NEG_INFINITY_BITS); + pub const NAN: Self = Self(QUIET_NAN_BITS); + pub const MAX: Self = Self(0x7F7F); + pub const MIN_POSITIVE: Self = Self(0x0080); + pub const MIN_POSITIVE_SUBNORMAL: Self = Self(0x0001); + + #[inline(always)] + pub fn from_f32(value: f32) -> Self { + let bits = value.to_bits(); + if (bits & 0x7F80_0000) == 0x7F80_0000 && (bits & 0x007F_FFFF) != 0 { + return Self(QUIET_NAN_BITS); + } + let lsb = (bits >> 16) & 1; + let rounding_bias = 0x7FFF + lsb; + Self(((bits + rounding_bias) >> 16) as u16) + } + + #[inline(always)] + pub fn to_f32(self) -> f32 { + f32::from_bits((self.0 as u32) << 16) + } + + #[inline(always)] + pub fn is_nan(self) -> bool { + (self.0 & EXP_MASK) == EXP_MASK && (self.0 & MANTISSA_MASK) != 0 + } + + #[inline(always)] + pub fn is_infinite(self) -> bool { + (self.0 & EXP_MASK) == EXP_MASK && (self.0 & MANTISSA_MASK) == 0 + } + + #[inline(always)] + pub fn is_finite(self) -> bool { + (self.0 & EXP_MASK) != EXP_MASK + } + + #[inline(always)] + pub fn is_zero(self) -> bool { + (self.0 & !SIGN_MASK) == 0 + } + + #[inline(always)] + pub fn is_normal(self) -> bool { + let exp = self.0 & EXP_MASK; + exp != 0 && exp != EXP_MASK + } + + #[inline(always)] + pub fn is_subnormal(self) -> bool { + (self.0 & EXP_MASK) == 0 && (self.0 & MANTISSA_MASK) != 0 + } + + #[inline(always)] + pub fn is_sign_negative(self) -> bool { + (self.0 & SIGN_MASK) != 0 + } + + #[inline(always)] + pub fn eq_value(self, other: Self) -> bool { + if self.is_nan() || other.is_nan() { + return false; + } + if self.is_zero() && other.is_zero() { + return true; + } + self.0 == other.0 + } + + /// Add two `bfloat16` values (via `f32`). + #[inline(always)] + #[allow(clippy::should_implement_trait)] + pub fn add(self, rhs: Self) -> Self { + Self::from_f32(self.to_f32() + rhs.to_f32()) + } + + /// Subtract two `bfloat16` values (via `f32`). + #[inline(always)] + #[allow(clippy::should_implement_trait)] + pub fn sub(self, rhs: Self) -> Self { + Self::from_f32(self.to_f32() - rhs.to_f32()) + } + + /// Multiply two `bfloat16` values (via `f32`). + #[inline(always)] + #[allow(clippy::should_implement_trait)] + pub fn mul(self, rhs: Self) -> Self { + Self::from_f32(self.to_f32() * rhs.to_f32()) + } + + /// Divide two `bfloat16` values (via `f32`). + #[inline(always)] + #[allow(clippy::should_implement_trait)] + pub fn div(self, rhs: Self) -> Self { + Self::from_f32(self.to_f32() / rhs.to_f32()) + } + + /// Negate this `bfloat16` value. + #[inline(always)] + #[allow(clippy::should_implement_trait)] + pub fn neg(self) -> Self { + Self(self.0 ^ SIGN_MASK) + } + + /// Absolute value. + #[inline(always)] + pub fn abs(self) -> Self { + Self(self.0 & !SIGN_MASK) + } +} + +impl PartialEq for bfloat16 { + #[inline] + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Eq for bfloat16 {} + +impl Hash for bfloat16 { + #[inline] + fn hash(&self, state: &mut H) { + self.0.hash(state); + } +} + +impl PartialOrd for bfloat16 { + #[inline] + fn partial_cmp(&self, other: &Self) -> Option { + self.to_f32().partial_cmp(&other.to_f32()) + } +} + +impl Add for bfloat16 { + type Output = Self; + #[inline] + fn add(self, rhs: Self) -> Self { + Self::add(self, rhs) + } +} + +impl Sub for bfloat16 { + type Output = Self; + #[inline] + fn sub(self, rhs: Self) -> Self { + Self::sub(self, rhs) + } +} + +impl Mul for bfloat16 { + type Output = Self; + #[inline] + fn mul(self, rhs: Self) -> Self { + Self::mul(self, rhs) + } +} + +impl Div for bfloat16 { + type Output = Self; + #[inline] + fn div(self, rhs: Self) -> Self { + Self::div(self, rhs) + } +} + +impl Neg for bfloat16 { + type Output = Self; + #[inline] + fn neg(self) -> Self { + Self::neg(self) + } +} + +impl fmt::Display for bfloat16 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.to_f32()) + } +} + +impl fmt::Debug for bfloat16 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "bfloat16({})", self.to_f32()) + } +} + +#[cfg(test)] +mod tests { + use super::bfloat16; + + #[test] + fn test_basic_conversion() { + assert_eq!(bfloat16::from_f32(1.0).to_bits(), 0x3F80); + assert_eq!(bfloat16::from_f32(-1.0).to_bits(), 0xBF80); + assert_eq!(bfloat16::from_f32(0.0), bfloat16::ZERO); + assert_eq!(bfloat16::from_f32(-0.0), bfloat16::NEG_ZERO); + } + + #[test] + fn test_special_values() { + assert_eq!(bfloat16::INFINITY.to_bits(), 0x7F80); + assert_eq!(bfloat16::NEG_INFINITY.to_bits(), 0xFF80); + assert_eq!(bfloat16::NAN.to_bits(), 0x7FC0); + assert!(bfloat16::from_f32(f32::NAN).is_nan()); + } + + #[test] + fn test_round_to_nearest_even() { + assert_eq!(bfloat16::from_f32(1.0 + 1.0 / 256.0).to_bits(), 0x3F80); + assert_eq!(bfloat16::from_f32(1.0 + 3.0 / 256.0).to_bits(), 0x3F82); + } +} diff --git a/rust/fory-core/src/types/mod.rs b/rust/fory-core/src/types/mod.rs index a73f92366b..c204d91f18 100644 --- a/rust/fory-core/src/types/mod.rs +++ b/rust/fory-core/src/types/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +pub mod bfloat16; pub mod decimal; pub mod float16; pub mod weak; diff --git a/rust/fory-derive/src/object/read.rs b/rust/fory-derive/src/object/read.rs index dc16783ce5..1e37eb6ad1 100644 --- a/rust/fory-derive/src/object/read.rs +++ b/rust/fory-derive/src/object/read.rs @@ -168,6 +168,14 @@ fn need_declared_by_option(field: &Field) -> bool { type_name == "Option" || !is_primitive_type(type_name.as_str()) } +fn is_float16_like(type_name: &str) -> bool { + type_name == "float16" || type_name == "Float16" +} + +fn is_bfloat16_like(type_name: &str) -> bool { + type_name == "bfloat16" || type_name == "BFloat16" +} + pub(crate) fn declare_var(source_fields: &[SourceField<'_>]) -> Vec { source_fields .iter() @@ -186,10 +194,14 @@ pub(crate) fn declare_var(source_fields: &[SourceField<'_>]) -> Vec quote! { let mut #var_name: Option<#ty> = None; } - } else if extract_type_name(&field.ty) == "float16" { + } else if is_float16_like(extract_type_name(&field.ty).as_str()) { quote! { let mut #var_name: fory_core::types::float16::float16 = fory_core::types::float16::float16::ZERO; } + } else if is_bfloat16_like(extract_type_name(&field.ty).as_str()) { + quote! { + let mut #var_name: fory_core::types::bfloat16::bfloat16 = fory_core::types::bfloat16::bfloat16::ZERO; + } } else if extract_type_name(&field.ty) == "bool" { quote! { let mut #var_name: bool = false; diff --git a/rust/fory-derive/src/object/util.rs b/rust/fory-derive/src/object/util.rs index cb9adc4e9e..53047951f5 100644 --- a/rust/fory-derive/src/object/util.rs +++ b/rust/fory-derive/src/object/util.rs @@ -531,7 +531,12 @@ pub(super) fn generic_tree_to_tokens(node: &TypeNode) -> TokenStream { "i32" => quote! { fory_core::type_id::TypeId::INT32_ARRAY as u32 }, "i64" => quote! { fory_core::type_id::TypeId::INT64_ARRAY as u32 }, "i128" => quote! { fory_core::type_id::TypeId::INT128_ARRAY as u32 }, - "float16" => quote! { fory_core::type_id::TypeId::FLOAT16_ARRAY as u32 }, + "float16" | "Float16" => { + quote! { fory_core::type_id::TypeId::FLOAT16_ARRAY as u32 } + } + "bfloat16" | "BFloat16" => { + quote! { fory_core::type_id::TypeId::BFLOAT16_ARRAY as u32 } + } "f32" => quote! { fory_core::type_id::TypeId::FLOAT32_ARRAY as u32 }, "f64" => quote! { fory_core::type_id::TypeId::FLOAT64_ARRAY as u32 }, "u8" => quote! { fory_core::type_id::TypeId::BINARY as u32 }, @@ -697,9 +702,9 @@ fn extract_option_inner(s: &str) -> Option<&str> { s.strip_prefix("Option<")?.strip_suffix(">") } -const PRIMITIVE_TYPE_NAMES: [&str; 14] = [ - "bool", "i8", "i16", "i32", "i64", "i128", "float16", "f32", "f64", "u8", "u16", "u32", "u64", - "u128", +const PRIMITIVE_TYPE_NAMES: [&str; 17] = [ + "bool", "i8", "i16", "i32", "i64", "i128", "float16", "Float16", "bfloat16", "BFloat16", "f32", + "f64", "u8", "u16", "u32", "u64", "u128", ]; fn get_primitive_type_id(ty: &str) -> u32 { @@ -711,7 +716,8 @@ fn get_primitive_type_id(ty: &str) -> u32 { "i32" => TypeId::VARINT32 as u32, // Use VARINT64 for i64 to match Java xlang mode and Rust type resolver registration "i64" => TypeId::VARINT64 as u32, - "float16" => TypeId::FLOAT16 as u32, + "float16" | "Float16" => TypeId::FLOAT16 as u32, + "bfloat16" | "BFloat16" => TypeId::BFLOAT16 as u32, "f32" => TypeId::FLOAT32 as u32, "f64" => TypeId::FLOAT64 as u32, "u8" => TypeId::UINT8 as u32, @@ -986,7 +992,8 @@ pub(crate) fn get_type_id_by_name(ty: &str) -> u32 { "Vec" => return TypeId::INT32_ARRAY as u32, "Vec" => return TypeId::INT64_ARRAY as u32, "Vec" => return TypeId::INT128_ARRAY as u32, - "Vec" => return TypeId::FLOAT16_ARRAY as u32, + "Vec" | "Vec" => return TypeId::FLOAT16_ARRAY as u32, + "Vec" | "Vec" => return TypeId::BFLOAT16_ARRAY as u32, "Vec" => return TypeId::FLOAT32_ARRAY as u32, "Vec" => return TypeId::FLOAT64_ARRAY as u32, "Vec" => return TypeId::UINT16_ARRAY as u32, @@ -1008,7 +1015,8 @@ pub(crate) fn get_type_id_by_name(ty: &str) -> u32 { "i32" => return TypeId::INT32_ARRAY as u32, "i64" => return TypeId::INT64_ARRAY as u32, "i128" => return TypeId::INT128_ARRAY as u32, - "float16" => return TypeId::FLOAT16_ARRAY as u32, + "float16" | "Float16" => return TypeId::FLOAT16_ARRAY as u32, + "bfloat16" | "BFloat16" => return TypeId::BFLOAT16_ARRAY as u32, "f32" => return TypeId::FLOAT32_ARRAY as u32, "f64" => return TypeId::FLOAT64_ARRAY as u32, "u16" => return TypeId::UINT16_ARRAY as u32, @@ -1223,8 +1231,7 @@ fn group_fields_by_type(fields: &[&Field]) -> FieldGroups { compress_a .cmp(&compress_b) .then_with(|| size_b.cmp(&size_a)) - // Use descending type_id order to match Java's COMPARATOR_BY_PRIMITIVE_TYPE_ID - .then_with(|| b.2.cmp(&a.2)) + .then_with(|| a.2.cmp(&b.2)) // Field identifier (tag ID or name) as tie-breaker .then_with(|| a.1.cmp(&b.1)) // Deterministic fallback for duplicate identifiers diff --git a/rust/tests/tests/test_array.rs b/rust/tests/tests/test_array.rs index 54076eeca6..b8ca692579 100644 --- a/rust/tests/tests/test_array.rs +++ b/rust/tests/tests/test_array.rs @@ -407,3 +407,41 @@ fn test_array_float16_special_values() { assert_eq!(obj[2].to_bits(), float16::MAX.to_bits()); assert!(obj[4].is_subnormal()); } + +#[test] +fn test_array_bfloat16() { + use fory_core::types::bfloat16::bfloat16; + let fory = fory_core::fory::Fory::default(); + let arr = [ + bfloat16::from_f32(1.0), + bfloat16::from_f32(2.5), + bfloat16::from_f32(-1.5), + bfloat16::ZERO, + ]; + let bin = fory.serialize(&arr).unwrap(); + let obj: [bfloat16; 4] = fory.deserialize(&bin).expect("deserialize bfloat16 array"); + for (a, b) in arr.iter().zip(obj.iter()) { + assert_eq!(a.to_bits(), b.to_bits()); + } +} + +#[test] +fn test_array_bfloat16_special_values() { + use fory_core::types::bfloat16::bfloat16; + let fory = fory_core::fory::Fory::default(); + let arr = [ + bfloat16::INFINITY, + bfloat16::NEG_INFINITY, + bfloat16::MAX, + bfloat16::MIN_POSITIVE, + bfloat16::MIN_POSITIVE_SUBNORMAL, + ]; + let bin = fory.serialize(&arr).unwrap(); + let obj: [bfloat16; 5] = fory + .deserialize(&bin) + .expect("deserialize bfloat16 array specials"); + assert!(obj[0].is_infinite() && !obj[0].is_sign_negative()); + assert!(obj[1].is_infinite() && obj[1].is_sign_negative()); + assert_eq!(obj[2].to_bits(), bfloat16::MAX.to_bits()); + assert!(obj[4].is_subnormal()); +} diff --git a/rust/tests/tests/test_cross_language.rs b/rust/tests/tests/test_cross_language.rs index f3da1c030a..59d4a94cf9 100644 --- a/rust/tests/tests/test_cross_language.rs +++ b/rust/tests/tests/test_cross_language.rs @@ -22,7 +22,7 @@ use fory_core::resolver::TypeResolver; use fory_core::serializer::{ForyDefault, Serializer}; use fory_core::type_id::TypeId; use fory_core::util::murmurhash3_x64_128; -use fory_core::{read_data, write_data, Decimal, Fory}; +use fory_core::{read_data, write_data, BFloat16, Decimal, Float16, Fory}; use fory_core::{ReadContext, WriteContext}; use fory_derive::ForyObject; use num_bigint::BigInt; @@ -878,6 +878,14 @@ struct TwoStringFieldStruct { f2: String, } +#[derive(ForyObject, Debug, PartialEq)] +struct ReducedPrecisionFloatStruct { + float16_value: Float16, + bfloat16_value: BFloat16, + float16_array: Vec, + bfloat16_array: Vec, +} + #[allow(non_camel_case_types)] #[derive(ForyObject, Debug, PartialEq, Default, Clone)] enum TestEnum { @@ -1282,6 +1290,54 @@ fn test_schema_evolution_compatible_reverse() { fs::write(&data_file_path, new_bytes).unwrap(); } +#[test] +#[ignore] +fn test_reduced_precision_float_struct() { + let data_file_path = get_data_file(); + let bytes = fs::read(&data_file_path).unwrap(); + + let mut fory = Fory::builder().compatible(false).xlang(true).build(); + fory.register::(213).unwrap(); + + let value: ReducedPrecisionFloatStruct = fory.deserialize(&bytes).unwrap(); + assert_eq!(value.float16_value.to_bits(), 0x3E00); + assert_eq!(value.bfloat16_value.to_bits(), 0x3FC0); + assert_eq!( + value + .float16_array + .iter() + .map(|v| v.to_bits()) + .collect::>(), + vec![0x0000, 0x3C00, 0xBC00] + ); + assert_eq!( + value + .bfloat16_array + .iter() + .map(|v| v.to_bits()) + .collect::>(), + vec![0x0000, 0x3F80, 0xBF80] + ); + + let new_bytes = fory.serialize(&value).unwrap(); + fs::write(&data_file_path, new_bytes).unwrap(); +} + +#[test] +#[ignore] +fn test_reduced_precision_float_struct_compatible_skip() { + let data_file_path = get_data_file(); + let bytes = fs::read(&data_file_path).unwrap(); + + let mut fory = Fory::builder().compatible(true).xlang(true).build(); + fory.register::(213).unwrap(); + + let value: EmptyStructEvolution = fory.deserialize(&bytes).unwrap(); + + let new_bytes = fory.serialize(&value).unwrap(); + fs::write(&data_file_path, new_bytes).unwrap(); +} + // ============================================================================ // Schema Evolution Tests - Enum Fields // ============================================================================ diff --git a/rust/tests/tests/test_list.rs b/rust/tests/tests/test_list.rs index 6edc8faff5..de29e5f7a3 100644 --- a/rust/tests/tests/test_list.rs +++ b/rust/tests/tests/test_list.rs @@ -191,6 +191,48 @@ fn test_vec_float16_empty() { assert_eq!(obj.len(), 0); } +#[test] +fn test_vec_bfloat16_basic() { + use fory_core::types::bfloat16::bfloat16; + let fory = fory_core::fory::Fory::default(); + let vec: Vec = vec![ + bfloat16::from_f32(1.0), + bfloat16::from_f32(2.5), + bfloat16::from_f32(-3.0), + bfloat16::ZERO, + ]; + let bin = fory.serialize(&vec).unwrap(); + let obj: Vec = fory.deserialize(&bin).expect("deserialize bfloat16 vec"); + assert_eq!(vec.len(), obj.len()); + for (a, b) in vec.iter().zip(obj.iter()) { + assert_eq!(a.to_bits(), b.to_bits()); + } +} + +#[test] +fn test_vec_bfloat16_special_values() { + use fory_core::types::bfloat16::bfloat16; + let fory = fory_core::fory::Fory::default(); + let vec: Vec = vec![ + bfloat16::INFINITY, + bfloat16::NEG_INFINITY, + bfloat16::NAN, + bfloat16::MAX, + bfloat16::MIN_POSITIVE, + bfloat16::MIN_POSITIVE_SUBNORMAL, + ]; + let bin = fory.serialize(&vec).unwrap(); + let obj: Vec = fory + .deserialize(&bin) + .expect("deserialize bfloat16 special"); + assert_eq!(vec.len(), obj.len()); + assert!(obj[0].is_infinite() && !obj[0].is_sign_negative()); + assert!(obj[1].is_infinite() && obj[1].is_sign_negative()); + assert!(obj[2].is_nan()); + assert_eq!(obj[3].to_bits(), bfloat16::MAX.to_bits()); + assert!(obj[5].is_subnormal()); +} + #[test] fn test_vec_max_collection_size_guardrail() { let fory = Fory::default(); diff --git a/rust/tests/tests/test_simple_struct.rs b/rust/tests/tests/test_simple_struct.rs index 583ba1a92a..02ff9725a2 100644 --- a/rust/tests/tests/test_simple_struct.rs +++ b/rust/tests/tests/test_simple_struct.rs @@ -259,3 +259,45 @@ fn test_struct_with_float16_fields() { assert_eq!(obj2.arr_field[1].to_bits(), float16::MAX.to_bits()); assert_eq!(obj2.arr_field[2].to_bits(), float16::ZERO.to_bits()); } + +#[test] +fn test_struct_with_bfloat16_fields() { + use fory_core::types::bfloat16::bfloat16; + + #[derive(ForyObject, Debug)] + struct BFloat16Data { + scalar: bfloat16, + vec_field: Vec, + arr_field: [bfloat16; 3], + } + + let mut fory = Fory::default(); + fory.register::(201).unwrap(); + + let obj = BFloat16Data { + scalar: bfloat16::from_f32(1.5), + vec_field: vec![ + bfloat16::from_f32(1.0), + bfloat16::from_f32(2.0), + bfloat16::INFINITY, + ], + arr_field: [bfloat16::from_f32(-1.0), bfloat16::MAX, bfloat16::ZERO], + }; + + let bin = fory.serialize(&obj).unwrap(); + let obj2: BFloat16Data = fory.deserialize(&bin).expect("deserialize BFloat16Data"); + + assert_eq!(obj2.scalar.to_bits(), bfloat16::from_f32(1.5).to_bits()); + assert_eq!(obj2.vec_field.len(), 3); + assert_eq!( + obj2.vec_field[0].to_bits(), + bfloat16::from_f32(1.0).to_bits() + ); + assert!(obj2.vec_field[2].is_infinite() && !obj2.vec_field[2].is_sign_negative()); + assert_eq!( + obj2.arr_field[0].to_bits(), + bfloat16::from_f32(-1.0).to_bits() + ); + assert_eq!(obj2.arr_field[1].to_bits(), bfloat16::MAX.to_bits()); + assert_eq!(obj2.arr_field[2].to_bits(), bfloat16::ZERO.to_bits()); +} diff --git a/swift/Package.swift b/swift/Package.swift index c5a7f8a617..163c3776ff 100644 --- a/swift/Package.swift +++ b/swift/Package.swift @@ -35,8 +35,7 @@ let package = Package( .target( name: "Fory", dependencies: ["ForyMacro"], - path: "Sources", - exclude: ["ForyMacro"] + path: "Sources/Fory" ), .executableTarget( name: "ForyXlangTests", diff --git a/swift/Sources/TypeId.swift b/swift/Sources/Fory/TypeId.swift similarity index 100% rename from swift/Sources/TypeId.swift rename to swift/Sources/Fory/TypeId.swift diff --git a/swift/Sources/ForyMacro/ForyObjectMacro.swift b/swift/Sources/ForyMacro/ForyObjectMacro.swift index f1b022a8ec..47ba880d56 100644 --- a/swift/Sources/ForyMacro/ForyObjectMacro.swift +++ b/swift/Sources/ForyMacro/ForyObjectMacro.swift @@ -966,7 +966,7 @@ private func sortFields(_ fields: [ParsedField]) -> [ParsedField] { return lhs.primitiveSize > rhs.primitiveSize } if lhs.typeID != rhs.typeID { - return lhs.typeID > rhs.typeID + return lhs.typeID < rhs.typeID } if let identifierOrder = compareFieldIdentifier(lhs, rhs) { return identifierOrder @@ -1591,6 +1591,10 @@ private func classifyType( return .init(typeID: 12, isPrimitive: true, isBuiltIn: true, isCollection: false, isMap: false, isCompressedNumeric: true, primitiveSize: 4) case "UInt64", "UInt": return .init(typeID: 14, isPrimitive: true, isBuiltIn: true, isCollection: false, isMap: false, isCompressedNumeric: true, primitiveSize: 8) + case "Float16": + return .init(typeID: 17, isPrimitive: true, isBuiltIn: true, isCollection: false, isMap: false, isCompressedNumeric: false, primitiveSize: 2) + case "BFloat16": + return .init(typeID: 18, isPrimitive: true, isBuiltIn: true, isCollection: false, isMap: false, isCompressedNumeric: false, primitiveSize: 2) case "Float": return .init(typeID: 19, isPrimitive: true, isBuiltIn: true, isCollection: false, isMap: false, isCompressedNumeric: false, primitiveSize: 4) case "Double": @@ -1670,6 +1674,18 @@ private func classifyType( isCompressedNumeric: false, primitiveSize: 0 ) } + if elem.typeID == 17 { + return .init( + typeID: 53, isPrimitive: false, isBuiltIn: true, isCollection: false, isMap: false, + isCompressedNumeric: false, primitiveSize: 0 + ) + } + if elem.typeID == 18 { + return .init( + typeID: 54, isPrimitive: false, isBuiltIn: true, isCollection: false, isMap: false, + isCompressedNumeric: false, primitiveSize: 0 + ) + } if elem.typeID == 19 { return .init( typeID: 55, isPrimitive: false, isBuiltIn: true, isCollection: false, isMap: false, diff --git a/swift/Tests/ForyTests/ForySwiftTests.swift b/swift/Tests/ForyTests/ForySwiftTests.swift index 60cf231646..aee2bad787 100644 --- a/swift/Tests/ForyTests/ForySwiftTests.swift +++ b/swift/Tests/ForyTests/ForySwiftTests.swift @@ -53,6 +53,14 @@ struct EncodedNumberFields: Equatable { var u64Tagged: UInt64 } +@ForyObject +struct ReducedPrecisionMacroFields: Equatable { + var float16Value: Float16 + var bfloat16Value: BFloat16 + var float16Array: [Float16] + var bfloat16Array: [BFloat16] +} + @ForyObject struct FieldIdConfigured: Equatable { @ForyField(id: 2) @@ -895,6 +903,19 @@ func macroFieldEncodingOverridesCompatibleTypeMeta() throws { #expect(fields[1].fieldType.typeID == TypeId.taggedUInt64.rawValue) } +@Test +func macroReducedPrecisionFieldsUseXlangTypeIDs() { + let fields = ReducedPrecisionMacroFields.foryFieldsInfo(trackRef: false) + #expect(fields.count == 4) + #expect(fields.map(\.fieldName) == ["float16Value", "bfloat16Value", "float16Array", "bfloat16Array"]) + #expect(fields.map(\.fieldType.typeID) == [ + TypeId.float16.rawValue, + TypeId.bfloat16.rawValue, + TypeId.float16Array.rawValue, + TypeId.bfloat16Array.rawValue + ]) +} + @Test func macroFieldIDsPopulateCompatibleTypeMeta() { let fields = FieldIdConfigured.foryFieldsInfo(trackRef: false) diff --git a/swift/Tests/ForyXlangTests/main.swift b/swift/Tests/ForyXlangTests/main.swift index 89579c9c56..b884cf40b0 100644 --- a/swift/Tests/ForyXlangTests/main.swift +++ b/swift/Tests/ForyXlangTests/main.swift @@ -104,6 +104,14 @@ private struct TwoStringFieldStruct { var f2: String = "" } +@ForyObject +private struct ReducedPrecisionFloatStruct { + var float16Value: Float16 = 0 + var bfloat16Value: BFloat16 = .init() + var float16Array: [Float16] = [] + var bfloat16Array: [BFloat16] = [] +} + @ForyObject private struct OneEnumFieldStruct { var f1: PeerTestEnum = .valueA @@ -973,6 +981,18 @@ private func handleUnsignedSchemaCompatible(_ bytes: [UInt8]) throws -> [UInt8] return try roundTripSingle(bytes, fory: fory, as: UnsignedSchemaCompatible.self) } +private func handleReducedPrecisionFloatStruct(_ bytes: [UInt8]) throws -> [UInt8] { + let fory = Fory(config: .init(xlang: true, trackRef: false, compatible: false)) + fory.register(ReducedPrecisionFloatStruct.self, id: 213) + return try roundTripSingle(bytes, fory: fory, as: ReducedPrecisionFloatStruct.self) +} + +private func handleReducedPrecisionFloatStructCompatibleSkip(_ bytes: [UInt8]) throws -> [UInt8] { + let fory = Fory(config: .init(xlang: true, trackRef: false, compatible: true)) + fory.register(EmptyStructEvolution.self, id: 213) + return try roundTripSingle(bytes, fory: fory, as: EmptyStructEvolution.self) +} + private func rewritePayload(caseName: String, bytes: [UInt8]) throws -> [UInt8] { switch caseName { case "test_buffer", "test_buffer_var": @@ -1059,6 +1079,10 @@ private func rewritePayload(caseName: String, bytes: [UInt8]) throws -> [UInt8] return try handleUnsignedSchemaConsistent(bytes) case "test_unsigned_schema_compatible": return try handleUnsignedSchemaCompatible(bytes) + case "test_reduced_precision_float_struct": + return try handleReducedPrecisionFloatStruct(bytes) + case "test_reduced_precision_float_struct_compatible_skip": + return try handleReducedPrecisionFloatStructCompatibleSkip(bytes) default: throw PeerError.unsupportedCase(caseName) }