Skip to content

Commit bc13ba5

Browse files
committed
add xlang bfloat16 support
1 parent 0870613 commit bc13ba5

86 files changed

Lines changed: 3433 additions & 93 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

cpp/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ The C++ implementation provides high-performance serialization with compile-time
1313
- **Type-Safe**: Compile-time type checking with template specialization
1414
- **Shared References**: Automatic tracking of shared and circular references
1515
- **Schema Evolution**: Compatible mode for independent schema changes
16+
- **Reduced-Precision Types**: `fory::float16_t` and `fory::bfloat16_t` scalars with dense `std::vector<...>` array carriers
1617
- **Two Formats**: Object graph serialization and zero-copy row-based format
1718
- **Modern C++17**: Clean API using modern C++ features
1819

cpp/fory/serialization/array_serializer.h

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,5 +402,118 @@ template <size_t N> struct Serializer<std::array<float16_t, N>> {
402402
}
403403
};
404404

405+
/// Serializer for std::array<bfloat16_t, N>
406+
template <size_t N> struct Serializer<std::array<bfloat16_t, N>> {
407+
static constexpr TypeId type_id = TypeId::BFLOAT16_ARRAY;
408+
409+
static inline void write_type_info(WriteContext &ctx) {
410+
ctx.write_uint8(static_cast<uint8_t>(type_id));
411+
}
412+
413+
static inline void read_type_info(ReadContext &ctx) {
414+
uint32_t actual = ctx.read_uint8(ctx.error());
415+
if (FORY_PREDICT_FALSE(ctx.has_error())) {
416+
return;
417+
}
418+
if (!type_id_matches(actual, static_cast<uint32_t>(type_id))) {
419+
ctx.set_error(
420+
Error::type_mismatch(actual, static_cast<uint32_t>(type_id)));
421+
}
422+
}
423+
424+
static inline void write(const std::array<bfloat16_t, N> &arr,
425+
WriteContext &ctx, RefMode ref_mode, bool write_type,
426+
bool has_generics = false) {
427+
write_not_null_ref_flag(ctx, ref_mode);
428+
if (write_type) {
429+
ctx.write_uint8(static_cast<uint8_t>(type_id));
430+
}
431+
write_data(arr, ctx);
432+
}
433+
434+
static inline void write_data(const std::array<bfloat16_t, N> &arr,
435+
WriteContext &ctx) {
436+
Buffer &buffer = ctx.buffer();
437+
constexpr size_t max_size = 8 + N * sizeof(bfloat16_t);
438+
buffer.grow(static_cast<uint32_t>(max_size));
439+
uint32_t writer_index = buffer.writer_index();
440+
writer_index += buffer.put_var_uint32(
441+
writer_index, static_cast<uint32_t>(N * sizeof(bfloat16_t)));
442+
if constexpr (N > 0) {
443+
if constexpr (FORY_LITTLE_ENDIAN) {
444+
buffer.unsafe_put(writer_index, arr.data(), N * sizeof(bfloat16_t));
445+
} else {
446+
for (size_t i = 0; i < N; ++i) {
447+
uint16_t bits = util::to_little_endian(arr[i].to_bits());
448+
buffer.unsafe_put(writer_index + i * sizeof(bfloat16_t), &bits,
449+
sizeof(bfloat16_t));
450+
}
451+
}
452+
}
453+
buffer.writer_index(writer_index + N * sizeof(bfloat16_t));
454+
}
455+
456+
static inline void write_data_generic(const std::array<bfloat16_t, N> &arr,
457+
WriteContext &ctx, bool has_generics) {
458+
write_data(arr, ctx);
459+
}
460+
461+
static inline std::array<bfloat16_t, N>
462+
read(ReadContext &ctx, RefMode ref_mode, bool read_type) {
463+
bool has_value = read_null_only_flag(ctx, ref_mode);
464+
if (ctx.has_error() || !has_value) {
465+
return std::array<bfloat16_t, N>();
466+
}
467+
if (read_type) {
468+
uint32_t type_id_read = ctx.read_uint8(ctx.error());
469+
if (FORY_PREDICT_FALSE(ctx.has_error())) {
470+
return std::array<bfloat16_t, N>();
471+
}
472+
if (type_id_read != static_cast<uint32_t>(type_id)) {
473+
ctx.set_error(
474+
Error::type_mismatch(type_id_read, static_cast<uint32_t>(type_id)));
475+
return std::array<bfloat16_t, N>();
476+
}
477+
}
478+
return read_data(ctx);
479+
}
480+
481+
static inline std::array<bfloat16_t, N> read_data(ReadContext &ctx) {
482+
uint32_t size_bytes = ctx.read_var_uint32(ctx.error());
483+
if (FORY_PREDICT_FALSE(ctx.has_error())) {
484+
return std::array<bfloat16_t, N>();
485+
}
486+
uint32_t length = size_bytes / sizeof(bfloat16_t);
487+
if (length != N) {
488+
ctx.set_error(Error::invalid_data("Array size mismatch: expected " +
489+
std::to_string(N) + " but got " +
490+
std::to_string(length)));
491+
return std::array<bfloat16_t, N>();
492+
}
493+
std::array<bfloat16_t, N> arr;
494+
if constexpr (N > 0) {
495+
if constexpr (FORY_LITTLE_ENDIAN) {
496+
ctx.read_bytes(arr.data(), N * sizeof(bfloat16_t), ctx.error());
497+
} else {
498+
for (size_t i = 0; i < N; ++i) {
499+
uint16_t bits;
500+
ctx.read_bytes(&bits, sizeof(bfloat16_t), ctx.error());
501+
if (FORY_PREDICT_FALSE(ctx.has_error())) {
502+
return arr;
503+
}
504+
arr[i] = bfloat16_t::from_bits(util::to_little_endian(bits));
505+
}
506+
}
507+
}
508+
return arr;
509+
}
510+
511+
static inline std::array<bfloat16_t, N>
512+
read_with_type_info(ReadContext &ctx, RefMode ref_mode,
513+
const TypeInfo &type_info) {
514+
return read(ctx, ref_mode, false);
515+
}
516+
};
517+
405518
} // namespace serialization
406519
} // namespace fory

cpp/fory/serialization/basic_serializer.h

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "fory/serialization/context.h"
2323
#include "fory/serialization/serializer_traits.h"
2424
#include "fory/type/type.h"
25+
#include "fory/util/bfloat16.h"
2526
#include "fory/util/error.h"
2627
#include "fory/util/float16.h"
2728
#include <cstdint>
@@ -603,6 +604,77 @@ template <> struct Serializer<float16_t> {
603604
}
604605
};
605606

607+
/// bfloat16_t serializer
608+
template <> struct Serializer<bfloat16_t> {
609+
static constexpr TypeId type_id = TypeId::BFLOAT16;
610+
611+
static inline void write_type_info(WriteContext &ctx) {
612+
ctx.write_uint8(static_cast<uint8_t>(type_id));
613+
}
614+
615+
static inline void read_type_info(ReadContext &ctx) {
616+
uint32_t actual = ctx.read_uint8(ctx.error());
617+
if (FORY_PREDICT_FALSE(ctx.has_error())) {
618+
return;
619+
}
620+
if (actual != static_cast<uint32_t>(type_id)) {
621+
ctx.set_error(
622+
Error::type_mismatch(actual, static_cast<uint32_t>(type_id)));
623+
}
624+
}
625+
626+
static inline void write(bfloat16_t value, WriteContext &ctx,
627+
RefMode ref_mode, bool write_type, bool = false) {
628+
write_not_null_ref_flag(ctx, ref_mode);
629+
if (write_type) {
630+
ctx.write_uint8(static_cast<uint8_t>(type_id));
631+
}
632+
write_data(value, ctx);
633+
}
634+
635+
static inline void write_data(bfloat16_t value, WriteContext &ctx) {
636+
ctx.write_bytes(&value, sizeof(bfloat16_t));
637+
}
638+
639+
static inline void write_data_generic(bfloat16_t value, WriteContext &ctx,
640+
bool) {
641+
write_data(value, ctx);
642+
}
643+
644+
static inline bfloat16_t read(ReadContext &ctx, RefMode ref_mode,
645+
bool read_type) {
646+
bool has_value = read_null_only_flag(ctx, ref_mode);
647+
if (ctx.has_error() || !has_value) {
648+
return bfloat16_t::from_bits(0);
649+
}
650+
if (read_type) {
651+
uint32_t type_id_read = ctx.read_uint8(ctx.error());
652+
if (FORY_PREDICT_FALSE(ctx.has_error())) {
653+
return bfloat16_t::from_bits(0);
654+
}
655+
if (type_id_read != static_cast<uint32_t>(type_id)) {
656+
ctx.set_error(
657+
Error::type_mismatch(type_id_read, static_cast<uint32_t>(type_id)));
658+
return bfloat16_t::from_bits(0);
659+
}
660+
}
661+
return ctx.read_bf16(ctx.error());
662+
}
663+
664+
static inline bfloat16_t read_data(ReadContext &ctx) {
665+
return ctx.read_bf16(ctx.error());
666+
}
667+
668+
static inline bfloat16_t read_data_generic(ReadContext &ctx, bool) {
669+
return read_data(ctx);
670+
}
671+
672+
static inline bfloat16_t
673+
read_with_type_info(ReadContext &ctx, RefMode ref_mode, const TypeInfo &) {
674+
return read(ctx, ref_mode, false);
675+
}
676+
};
677+
606678
// ============================================================================
607679
// Character Type Serializers (C++ native only, not supported in xlang mode)
608680
// ============================================================================

cpp/fory/serialization/collection_serializer.h

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,112 @@ template <typename Alloc> struct Serializer<std::vector<float16_t, Alloc>> {
751751
}
752752
};
753753

754+
/// Vector serializer for bfloat16_t — typed array path (BFLOAT16_ARRAY).
755+
template <typename Alloc> struct Serializer<std::vector<bfloat16_t, Alloc>> {
756+
static constexpr TypeId type_id = TypeId::BFLOAT16_ARRAY;
757+
758+
static inline void write_type_info(WriteContext &ctx) {
759+
ctx.write_uint8(static_cast<uint8_t>(type_id));
760+
}
761+
762+
static inline void read_type_info(ReadContext &ctx) {
763+
uint32_t actual = ctx.read_uint8(ctx.error());
764+
if (FORY_PREDICT_FALSE(ctx.has_error())) {
765+
return;
766+
}
767+
if (!type_id_matches(actual, static_cast<uint32_t>(type_id))) {
768+
ctx.set_error(
769+
Error::type_mismatch(actual, static_cast<uint32_t>(type_id)));
770+
}
771+
}
772+
773+
static inline void write(const std::vector<bfloat16_t, Alloc> &vec,
774+
WriteContext &ctx, RefMode ref_mode, bool write_type,
775+
bool has_generics = false) {
776+
write_not_null_ref_flag(ctx, ref_mode);
777+
if (write_type) {
778+
ctx.write_uint8(static_cast<uint8_t>(type_id));
779+
}
780+
write_data(vec, ctx);
781+
}
782+
783+
static inline void write_data(const std::vector<bfloat16_t, Alloc> &vec,
784+
WriteContext &ctx) {
785+
uint64_t total_bytes =
786+
static_cast<uint64_t>(vec.size()) * sizeof(bfloat16_t);
787+
if (total_bytes > std::numeric_limits<uint32_t>::max()) {
788+
ctx.set_error(Error::invalid("Vector byte size exceeds uint32_t range"));
789+
return;
790+
}
791+
Buffer &buffer = ctx.buffer();
792+
size_t max_size = 8 + total_bytes;
793+
buffer.grow(static_cast<uint32_t>(max_size));
794+
uint32_t writer_index = buffer.writer_index();
795+
writer_index +=
796+
buffer.put_var_uint32(writer_index, static_cast<uint32_t>(total_bytes));
797+
if (total_bytes > 0) {
798+
buffer.unsafe_put(writer_index, vec.data(),
799+
static_cast<uint32_t>(total_bytes));
800+
}
801+
buffer.writer_index(writer_index + static_cast<uint32_t>(total_bytes));
802+
}
803+
804+
static inline void
805+
write_data_generic(const std::vector<bfloat16_t, Alloc> &vec,
806+
WriteContext &ctx, bool has_generics) {
807+
write_data(vec, ctx);
808+
}
809+
810+
static inline std::vector<bfloat16_t, Alloc>
811+
read(ReadContext &ctx, RefMode ref_mode, bool read_type) {
812+
bool has_value = read_null_only_flag(ctx, ref_mode);
813+
if (ctx.has_error() || !has_value) {
814+
return std::vector<bfloat16_t, Alloc>();
815+
}
816+
if (read_type) {
817+
uint32_t type_id_read = ctx.read_uint8(ctx.error());
818+
if (FORY_PREDICT_FALSE(ctx.has_error())) {
819+
return std::vector<bfloat16_t, Alloc>();
820+
}
821+
if (type_id_read != static_cast<uint32_t>(type_id)) {
822+
ctx.set_error(
823+
Error::type_mismatch(type_id_read, static_cast<uint32_t>(type_id)));
824+
return std::vector<bfloat16_t, Alloc>();
825+
}
826+
}
827+
return read_data(ctx);
828+
}
829+
830+
static inline std::vector<bfloat16_t, Alloc>
831+
read_with_type_info(ReadContext &ctx, RefMode ref_mode,
832+
const TypeInfo &type_info) {
833+
return read(ctx, ref_mode, false);
834+
}
835+
836+
static inline std::vector<bfloat16_t, Alloc> read_data(ReadContext &ctx) {
837+
uint32_t total_bytes_u32 = ctx.read_var_uint32(ctx.error());
838+
if (FORY_PREDICT_FALSE(ctx.has_error())) {
839+
return std::vector<bfloat16_t, Alloc>();
840+
}
841+
if (FORY_PREDICT_FALSE(total_bytes_u32 > ctx.config().max_binary_size)) {
842+
ctx.set_error(Error::invalid_data("Binary size exceeds max_binary_size"));
843+
return std::vector<bfloat16_t, Alloc>();
844+
}
845+
size_t elem_count = total_bytes_u32 / sizeof(bfloat16_t);
846+
if (total_bytes_u32 % sizeof(bfloat16_t) != 0) {
847+
ctx.set_error(Error::invalid_data(
848+
"Vector byte size not aligned with bfloat16_t element size"));
849+
return std::vector<bfloat16_t, Alloc>();
850+
}
851+
std::vector<bfloat16_t, Alloc> result(elem_count);
852+
if (total_bytes_u32 > 0) {
853+
ctx.read_bytes(result.data(), static_cast<uint32_t>(total_bytes_u32),
854+
ctx.error());
855+
}
856+
return result;
857+
}
858+
};
859+
754860
/// Vector serializer for non-bool, non-arithmetic types
755861
template <typename T, typename Alloc>
756862
struct Serializer<

cpp/fory/serialization/context.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,11 @@ class ReadContext {
563563
return buffer().read_f16(error);
564564
}
565565

566+
/// Read bfloat16_t from buffer. Sets error on failure.
567+
FORY_ALWAYS_INLINE bfloat16_t read_bf16(Error &error) {
568+
return buffer().read_bf16(error);
569+
}
570+
566571
/// Read uint32_t value as varint from buffer. Sets error on failure.
567572
FORY_ALWAYS_INLINE uint32_t read_var_uint32(Error &error) {
568573
return buffer().read_var_uint32(error);

cpp/fory/serialization/skip.cc

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -630,33 +630,13 @@ void skip_field_value(ReadContext &ctx, const FieldType &field_type,
630630
return;
631631
}
632632
}
633-
// Read array length
634-
uint32_t len = ctx.read_var_uint32(ctx.error());
633+
// Typed primitive arrays encode payload size in bytes, not element count.
634+
uint32_t payload_size = ctx.read_var_uint32(ctx.error());
635635
if (FORY_PREDICT_FALSE(ctx.has_error())) {
636636
return;
637637
}
638638

639-
// Calculate element size
640-
size_t elem_size = 1;
641-
switch (tid) {
642-
case TypeId::INT16_ARRAY:
643-
case TypeId::FLOAT16_ARRAY:
644-
case TypeId::BFLOAT16_ARRAY:
645-
elem_size = 2;
646-
break;
647-
case TypeId::INT32_ARRAY:
648-
case TypeId::FLOAT32_ARRAY:
649-
elem_size = 4;
650-
break;
651-
case TypeId::INT64_ARRAY:
652-
case TypeId::FLOAT64_ARRAY:
653-
elem_size = 8;
654-
break;
655-
default:
656-
break;
657-
}
658-
659-
ctx.buffer().increase_reader_index(len * elem_size, ctx.error());
639+
ctx.buffer().increase_reader_index(payload_size, ctx.error());
660640
return;
661641
}
662642

0 commit comments

Comments
 (0)