@@ -402,5 +402,118 @@ template <size_t N> struct Serializer<std::array<float16_t, N>> {
402402 }
403403};
404404
405+ // / Serializer for std::array<bfloat16_t, N>
406+ template <size_t N> struct Serializer <std::array<bfloat16_t , N>> {
407+ static constexpr TypeId type_id = TypeId::BFLOAT16_ARRAY;
408+
409+ static inline void write_type_info (WriteContext &ctx) {
410+ ctx.write_uint8 (static_cast <uint8_t >(type_id));
411+ }
412+
413+ static inline void read_type_info (ReadContext &ctx) {
414+ uint32_t actual = ctx.read_uint8 (ctx.error ());
415+ if (FORY_PREDICT_FALSE (ctx.has_error ())) {
416+ return ;
417+ }
418+ if (!type_id_matches (actual, static_cast <uint32_t >(type_id))) {
419+ ctx.set_error (
420+ Error::type_mismatch (actual, static_cast <uint32_t >(type_id)));
421+ }
422+ }
423+
424+ static inline void write (const std::array<bfloat16_t , N> &arr,
425+ WriteContext &ctx, RefMode ref_mode, bool write_type,
426+ bool has_generics = false ) {
427+ write_not_null_ref_flag (ctx, ref_mode);
428+ if (write_type) {
429+ ctx.write_uint8 (static_cast <uint8_t >(type_id));
430+ }
431+ write_data (arr, ctx);
432+ }
433+
434+ static inline void write_data (const std::array<bfloat16_t , N> &arr,
435+ WriteContext &ctx) {
436+ Buffer &buffer = ctx.buffer ();
437+ constexpr size_t max_size = 8 + N * sizeof (bfloat16_t );
438+ buffer.grow (static_cast <uint32_t >(max_size));
439+ uint32_t writer_index = buffer.writer_index ();
440+ writer_index += buffer.put_var_uint32 (
441+ writer_index, static_cast <uint32_t >(N * sizeof (bfloat16_t )));
442+ if constexpr (N > 0 ) {
443+ if constexpr (FORY_LITTLE_ENDIAN) {
444+ buffer.unsafe_put (writer_index, arr.data (), N * sizeof (bfloat16_t ));
445+ } else {
446+ for (size_t i = 0 ; i < N; ++i) {
447+ uint16_t bits = util::to_little_endian (arr[i].to_bits ());
448+ buffer.unsafe_put (writer_index + i * sizeof (bfloat16_t ), &bits,
449+ sizeof (bfloat16_t ));
450+ }
451+ }
452+ }
453+ buffer.writer_index (writer_index + N * sizeof (bfloat16_t ));
454+ }
455+
456+ static inline void write_data_generic (const std::array<bfloat16_t , N> &arr,
457+ WriteContext &ctx, bool has_generics) {
458+ write_data (arr, ctx);
459+ }
460+
461+ static inline std::array<bfloat16_t , N>
462+ read (ReadContext &ctx, RefMode ref_mode, bool read_type) {
463+ bool has_value = read_null_only_flag (ctx, ref_mode);
464+ if (ctx.has_error () || !has_value) {
465+ return std::array<bfloat16_t , N>();
466+ }
467+ if (read_type) {
468+ uint32_t type_id_read = ctx.read_uint8 (ctx.error ());
469+ if (FORY_PREDICT_FALSE (ctx.has_error ())) {
470+ return std::array<bfloat16_t , N>();
471+ }
472+ if (type_id_read != static_cast <uint32_t >(type_id)) {
473+ ctx.set_error (
474+ Error::type_mismatch (type_id_read, static_cast <uint32_t >(type_id)));
475+ return std::array<bfloat16_t , N>();
476+ }
477+ }
478+ return read_data (ctx);
479+ }
480+
481+ static inline std::array<bfloat16_t , N> read_data (ReadContext &ctx) {
482+ uint32_t size_bytes = ctx.read_var_uint32 (ctx.error ());
483+ if (FORY_PREDICT_FALSE (ctx.has_error ())) {
484+ return std::array<bfloat16_t , N>();
485+ }
486+ uint32_t length = size_bytes / sizeof (bfloat16_t );
487+ if (length != N) {
488+ ctx.set_error (Error::invalid_data (" Array size mismatch: expected " +
489+ std::to_string (N) + " but got " +
490+ std::to_string (length)));
491+ return std::array<bfloat16_t , N>();
492+ }
493+ std::array<bfloat16_t , N> arr;
494+ if constexpr (N > 0 ) {
495+ if constexpr (FORY_LITTLE_ENDIAN) {
496+ ctx.read_bytes (arr.data (), N * sizeof (bfloat16_t ), ctx.error ());
497+ } else {
498+ for (size_t i = 0 ; i < N; ++i) {
499+ uint16_t bits;
500+ ctx.read_bytes (&bits, sizeof (bfloat16_t ), ctx.error ());
501+ if (FORY_PREDICT_FALSE (ctx.has_error ())) {
502+ return arr;
503+ }
504+ arr[i] = bfloat16_t::from_bits (util::to_little_endian (bits));
505+ }
506+ }
507+ }
508+ return arr;
509+ }
510+
511+ static inline std::array<bfloat16_t , N>
512+ read_with_type_info (ReadContext &ctx, RefMode ref_mode,
513+ const TypeInfo &type_info) {
514+ return read (ctx, ref_mode, false );
515+ }
516+ };
517+
405518} // namespace serialization
406519} // namespace fory
0 commit comments