diff --git a/mlx/array.h b/mlx/array.h index 60d5e50bbd..4255f1041b 100644 --- a/mlx/array.h +++ b/mlx/array.h @@ -615,6 +615,12 @@ void array::init(It src) { case int64: std::copy(src, src + size(), data()); break; + case float8: + std::copy(src, src + size(), data()); + break; + case bfloat8: + std::copy(src, src + size(), data()); + break; case float16: std::copy(src, src + size(), data()); break; diff --git a/mlx/dtype.h b/mlx/dtype.h index 744ca5879e..7630df549c 100644 --- a/mlx/dtype.h +++ b/mlx/dtype.h @@ -8,6 +8,7 @@ #include "mlx/api.h" #include "mlx/types/complex.h" #include "mlx/types/half_types.h" +#include "mlx/types/quarter_types.h" namespace mlx::core { @@ -22,9 +23,11 @@ struct Dtype { int16, int32, int64, + float8, float16, float32, float64, + bfloat8, bfloat16, complex64, }; @@ -78,9 +81,11 @@ inline constexpr Dtype int16{Dtype::Val::int16, sizeof(int16_t)}; inline constexpr Dtype int32{Dtype::Val::int32, sizeof(int32_t)}; inline constexpr Dtype int64{Dtype::Val::int64, sizeof(int64_t)}; +inline constexpr Dtype float8{Dtype::Val::float8, sizeof(uint8_t)}; inline constexpr Dtype float16{Dtype::Val::float16, sizeof(uint16_t)}; inline constexpr Dtype float32{Dtype::Val::float32, sizeof(float)}; inline constexpr Dtype float64{Dtype::Val::float64, sizeof(double)}; +inline constexpr Dtype bfloat8{Dtype::Val::bfloat8, sizeof(uint8_t)}; inline constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)}; inline constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)}; diff --git a/mlx/dtype_utils.h b/mlx/dtype_utils.h index 47c6ed6612..0c8d51b1f1 100644 --- a/mlx/dtype_utils.h +++ b/mlx/dtype_utils.h @@ -28,6 +28,8 @@ const char* dtype_to_string(Dtype arg); MLX_INTERNAL_DTYPE_SWITCH_CASE(uint64, uint64_t) #define MLX_INTERNAL_DTYPE_SWITCH_FLOATS() \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(float8, float8_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(bfloat8, bfloat8_t); \ MLX_INTERNAL_DTYPE_SWITCH_CASE(float16, float16_t); \ MLX_INTERNAL_DTYPE_SWITCH_CASE(bfloat16, bfloat16_t); \ MLX_INTERNAL_DTYPE_SWITCH_CASE(float32, float); \ diff --git a/mlx/types/bf8.h b/mlx/types/bf8.h new file mode 100644 index 0000000000..5c70805bd5 --- /dev/null +++ b/mlx/types/bf8.h @@ -0,0 +1,241 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include +#include +#include + +// BFloat8 (E5M2): 1 sign, 5 exponent (bias 15), 2 mantissa bits +#define __MLX_BF8_NAN__ 0x7F +#define __MLX_BF8_ONE__ 0x3C + +namespace mlx::core { + +namespace { +union float_bits_bf8 { + float f; + uint32_t u; +}; +} // namespace + +struct _MLX_BFloat8 { + uint8_t bits_; + + // Default constructor + _MLX_BFloat8() = default; + + // Default copy constructor + _MLX_BFloat8(_MLX_BFloat8 const&) = default; + + // Appease std::vector for being special + _MLX_BFloat8& operator=(std::vector::reference x) { + bits_ = (x) ? __MLX_BF8_ONE__ : 0; + return (*this); + } + + _MLX_BFloat8& operator=(const float& x) { + return (*this = _MLX_BFloat8(x)); + } + + // From float32 + _MLX_BFloat8(const float& x) : bits_(0) { + if (std::isnan(x)) { + bits_ = __MLX_BF8_NAN__; + } else if (std::isinf(x)) { + // Infinity: exp=11111, mantissa=00, preserve sign + bits_ = (x > 0) ? 0x7C : 0xFC; + } else { + float_bits_bf8 in; + in.f = x; + + // Extract sign + uint8_t sign = (in.u >> 24) & 0x80; + + // Extract float32 exponent and re-bias for E5M2 + int32_t exp = (int32_t)((in.u >> 23) & 0xFF) - 127 + 15; + + // Extract top 2 mantissa bits + uint32_t mantissa = (in.u >> 21) & 0x03; + + // Round to nearest even + uint32_t round_bit = (in.u >> 20) & 1; + uint32_t sticky = (in.u & 0x000FFFFF) != 0; + if (round_bit && (sticky || (mantissa & 1))) { + mantissa++; + if (mantissa > 3) { + mantissa = 0; + exp++; + } + } + + if (exp >= 31) { + // Overflow to infinity + bits_ = sign | 0x7C; + } else if (exp <= 0) { + if (exp >= -2) { + // Denormalized: shift mantissa with implicit 1 + mantissa = (mantissa | 0x04) >> (1 - exp); + bits_ = sign | (uint8_t)(mantissa & 0x03); + } else { + // Underflow to zero + bits_ = sign; + } + } else { + bits_ = sign | ((uint8_t)exp << 2) | (uint8_t)mantissa; + } + } + } + + // To float32 + operator float() const { + float_bits_bf8 out; + + uint32_t sign = (bits_ >> 7) & 1; + uint32_t exp = (bits_ >> 2) & 0x1F; + uint32_t mantissa = bits_ & 0x03; + + if (exp == 0x1F) { + // Inf (mantissa==0) or NaN (mantissa!=0) + out.u = (sign << 31) | 0x7F800000 | (mantissa << 21); + } else if (exp == 0) { + if (mantissa == 0) { + // Zero (signed) + out.u = sign << 31; + } else { + // Denormalized: normalize + while ((mantissa & 0x04) == 0) { + mantissa <<= 1; + exp--; + } + mantissa &= 0x03; // remove implicit 1 + uint32_t f_exp = (uint32_t)((int32_t)exp + 127 - 15 + 1); + out.u = (sign << 31) | (f_exp << 23) | (mantissa << 21); + } + } else { + // Normalized + uint32_t f_exp = exp - 15 + 127; + out.u = (sign << 31) | (f_exp << 23) | (mantissa << 21); + } + + return out.f; + } +}; + +#define bf8_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ + inline otype __operator__(atype lhs, btype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +#define bf8_binop_helper(__op__, __operator__, otype, itype, ctype) \ + inline otype __operator__(_MLX_BFloat8 lhs, itype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + inline otype __operator__(itype lhs, _MLX_BFloat8 rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +// Operators +#define bf8_binop(__op__, __operator__) \ + bf8_binop_base( \ + __op__, __operator__, _MLX_BFloat8, _MLX_BFloat8, _MLX_BFloat8, float); \ + bf8_binop_helper(__op__, __operator__, float, float, float); \ + bf8_binop_helper(__op__, __operator__, double, double, double); \ + bf8_binop_helper(__op__, __operator__, _MLX_BFloat8, bool, float); \ + bf8_binop_helper(__op__, __operator__, _MLX_BFloat8, int32_t, float); \ + bf8_binop_helper(__op__, __operator__, _MLX_BFloat8, uint32_t, float); \ + bf8_binop_helper(__op__, __operator__, _MLX_BFloat8, int64_t, float); \ + bf8_binop_helper(__op__, __operator__, _MLX_BFloat8, uint64_t, float); + +bf8_binop(+, operator+); +bf8_binop(-, operator-); +bf8_binop(*, operator*); +bf8_binop(/, operator/); + +#undef bf8_binop + +// Comparison ops +#define bf8_compop(__op__, __operator__) \ + bf8_binop_base( \ + __op__, __operator__, bool, _MLX_BFloat8, _MLX_BFloat8, float); \ + bf8_binop_helper(__op__, __operator__, bool, float, float); \ + bf8_binop_helper(__op__, __operator__, bool, double, double); \ + bf8_binop_helper(__op__, __operator__, bool, int32_t, float); \ + bf8_binop_helper(__op__, __operator__, bool, uint32_t, float); \ + bf8_binop_helper(__op__, __operator__, bool, int64_t, float); \ + bf8_binop_helper(__op__, __operator__, bool, uint64_t, float); + +bf8_compop(>, operator>); +bf8_compop(<, operator<); +bf8_compop(>=, operator>=); +bf8_compop(<=, operator<=); +bf8_compop(==, operator==); +bf8_compop(!=, operator!=); + +#undef bf8_compop + +// Negative +inline _MLX_BFloat8 operator-(_MLX_BFloat8 lhs) { + return -static_cast(lhs); +} + +// Inplace ops +#define bf8_inplace_op(__op__, __operator__) \ + inline _MLX_BFloat8& __operator__(_MLX_BFloat8& lhs, const float& rhs) { \ + lhs = lhs __op__ rhs; \ + return lhs; \ + } \ + inline float& __operator__(float& lhs, _MLX_BFloat8 rhs) { \ + lhs = lhs __op__ rhs; \ + return lhs; \ + } + +bf8_inplace_op(+, operator+=); +bf8_inplace_op(-, operator-=); +bf8_inplace_op(*, operator*=); +bf8_inplace_op(/, operator/=); + +#undef bf8_inplace_op + +// Bitwise ops + +#define bf8_bitop(__op__, __operator__) \ + inline _MLX_BFloat8 __operator__(_MLX_BFloat8 lhs, _MLX_BFloat8 rhs) { \ + _MLX_BFloat8 out; \ + out.bits_ = lhs.bits_ __op__ rhs.bits_; \ + return out; \ + } \ + inline _MLX_BFloat8 __operator__(_MLX_BFloat8 lhs, uint8_t rhs) { \ + _MLX_BFloat8 out; \ + out.bits_ = lhs.bits_ __op__ rhs; \ + return out; \ + } \ + inline _MLX_BFloat8 __operator__(uint8_t lhs, _MLX_BFloat8 rhs) { \ + _MLX_BFloat8 out; \ + out.bits_ = lhs __op__ rhs.bits_; \ + return out; \ + } + +bf8_bitop(|, operator|); +bf8_bitop(&, operator&); +bf8_bitop(^, operator^); + +#undef bf8_bitop + +#define bf8_inplace_bitop(__op__, __operator__) \ + inline _MLX_BFloat8& __operator__(_MLX_BFloat8& lhs, _MLX_BFloat8 rhs) { \ + lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \ + return lhs; \ + } \ + inline _MLX_BFloat8& __operator__(_MLX_BFloat8& lhs, uint8_t rhs) { \ + lhs.bits_ = lhs.bits_ __op__ rhs; \ + return lhs; \ + } + +bf8_inplace_bitop(|, operator|=); +bf8_inplace_bitop(&, operator&=); +bf8_inplace_bitop(^, operator^=); + +#undef bf8_inplace_bitop + +} // namespace mlx::core diff --git a/mlx/types/fp8.h b/mlx/types/fp8.h new file mode 100644 index 0000000000..d748149d0a --- /dev/null +++ b/mlx/types/fp8.h @@ -0,0 +1,248 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include +#include +#include + +// Float8 E4M3FN: 1 sign, 4 exponent (bias 7), 3 mantissa bits +// No infinities; NaN is exp=1111, mantissa=111 only +#define __MLX_FP8_NAN__ 0x7F +#define __MLX_FP8_ONE__ 0x38 + +namespace mlx::core { + +namespace { +union float_bits_fp8 { + float f; + uint32_t u; +}; +} // namespace + +struct _MLX_Float8 { + uint8_t bits_; + + // Default constructor + _MLX_Float8() = default; + + // Default copy constructor + _MLX_Float8(_MLX_Float8 const&) = default; + + // Appease std::vector for being special + _MLX_Float8& operator=(std::vector::reference x) { + bits_ = (x) ? __MLX_FP8_ONE__ : 0; + return (*this); + } + + _MLX_Float8& operator=(const float& x) { + return (*this = _MLX_Float8(x)); + } + + // From float32 + _MLX_Float8(const float& x) : bits_(0) { + if (std::isnan(x)) { + bits_ = __MLX_FP8_NAN__; + } else { + float_bits_fp8 in; + in.f = x; + + // Extract sign + uint8_t sign = (in.u >> 24) & 0x80; + + // Handle infinity and values beyond E4M3 max (448.0) + float abs_x = std::abs(x); + if (abs_x > 448.0f) { + // Clamp to max representable: exp=1111, mantissa=110 = 0x7E + bits_ = sign | 0x7E; + return; + } + + // Extract float32 exponent and re-bias for E4M3 (bias 7) + int32_t exp = (int32_t)((in.u >> 23) & 0xFF) - 127 + 7; + + // Extract top 3 mantissa bits + uint32_t mantissa = (in.u >> 20) & 0x07; + + // Round to nearest even + uint32_t round_bit = (in.u >> 19) & 1; + uint32_t sticky = (in.u & 0x0007FFFF) != 0; + if (round_bit && (sticky || (mantissa & 1))) { + mantissa++; + if (mantissa > 7) { + mantissa = 0; + exp++; + } + } + + // Check for overflow after rounding (max normal is exp=15, mantissa=110) + if (exp >= 15 && (exp > 15 || mantissa > 6)) { + // Clamp to max (not NaN, which is exp=15 mantissa=7) + bits_ = sign | 0x7E; + } else if (exp <= 0) { + if (exp >= -3) { + // Denormalized: shift mantissa with implicit 1 + mantissa = (mantissa | 0x08) >> (1 - exp); + bits_ = sign | (uint8_t)(mantissa & 0x07); + } else { + // Underflow to zero + bits_ = sign; + } + } else { + bits_ = sign | ((uint8_t)exp << 3) | (uint8_t)mantissa; + } + } + } + + // To float32 + operator float() const { + float_bits_fp8 out; + + uint32_t sign = (bits_ >> 7) & 1; + uint32_t exp = (bits_ >> 3) & 0x0F; + uint32_t mantissa = bits_ & 0x07; + + // NaN check: exp=1111, mantissa=111 + if (exp == 0x0F && mantissa == 0x07) { + out.u = (sign << 31) | 0x7F800000 | (1 << 22); // quiet NaN + } else if (exp == 0) { + if (mantissa == 0) { + // Zero (signed) + out.u = sign << 31; + } else { + // Denormalized: normalize + while ((mantissa & 0x08) == 0) { + mantissa <<= 1; + exp--; + } + mantissa &= 0x07; // remove implicit 1 + uint32_t f_exp = (uint32_t)((int32_t)exp + 127 - 7 + 1); + out.u = (sign << 31) | (f_exp << 23) | (mantissa << 20); + } + } else { + // Normalized (includes exp=15 mantissa=0..6 which are valid values) + uint32_t f_exp = exp - 7 + 127; + out.u = (sign << 31) | (f_exp << 23) | (mantissa << 20); + } + + return out.f; + } +}; + +#define fp8_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ + inline otype __operator__(atype lhs, btype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +#define fp8_binop_helper(__op__, __operator__, otype, itype, ctype) \ + inline otype __operator__(_MLX_Float8 lhs, itype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + inline otype __operator__(itype lhs, _MLX_Float8 rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +// Operators +#define fp8_binop(__op__, __operator__) \ + fp8_binop_base( \ + __op__, __operator__, _MLX_Float8, _MLX_Float8, _MLX_Float8, float); \ + fp8_binop_helper(__op__, __operator__, float, float, float); \ + fp8_binop_helper(__op__, __operator__, double, double, double); \ + fp8_binop_helper(__op__, __operator__, _MLX_Float8, bool, float); \ + fp8_binop_helper(__op__, __operator__, _MLX_Float8, int32_t, float); \ + fp8_binop_helper(__op__, __operator__, _MLX_Float8, uint32_t, float); \ + fp8_binop_helper(__op__, __operator__, _MLX_Float8, int64_t, float); \ + fp8_binop_helper(__op__, __operator__, _MLX_Float8, uint64_t, float); + +fp8_binop(+, operator+); +fp8_binop(-, operator-); +fp8_binop(*, operator*); +fp8_binop(/, operator/); + +#undef fp8_binop + +// Comparison ops +#define fp8_compop(__op__, __operator__) \ + fp8_binop_base( \ + __op__, __operator__, bool, _MLX_Float8, _MLX_Float8, float); \ + fp8_binop_helper(__op__, __operator__, bool, float, float); \ + fp8_binop_helper(__op__, __operator__, bool, double, double); \ + fp8_binop_helper(__op__, __operator__, bool, int32_t, float); \ + fp8_binop_helper(__op__, __operator__, bool, uint32_t, float); \ + fp8_binop_helper(__op__, __operator__, bool, int64_t, float); \ + fp8_binop_helper(__op__, __operator__, bool, uint64_t, float); + +fp8_compop(>, operator>); +fp8_compop(<, operator<); +fp8_compop(>=, operator>=); +fp8_compop(<=, operator<=); +fp8_compop(==, operator==); +fp8_compop(!=, operator!=); + +#undef fp8_compop + +// Negative +inline _MLX_Float8 operator-(_MLX_Float8 lhs) { + return -static_cast(lhs); +} + +// Inplace ops +#define fp8_inplace_op(__op__, __operator__) \ + inline _MLX_Float8& __operator__(_MLX_Float8& lhs, const float& rhs) { \ + lhs = lhs __op__ rhs; \ + return lhs; \ + } \ + inline float& __operator__(float& lhs, _MLX_Float8 rhs) { \ + lhs = lhs __op__ rhs; \ + return lhs; \ + } + +fp8_inplace_op(+, operator+=); +fp8_inplace_op(-, operator-=); +fp8_inplace_op(*, operator*=); +fp8_inplace_op(/, operator/=); + +#undef fp8_inplace_op + +// Bitwise ops + +#define fp8_bitop(__op__, __operator__) \ + inline _MLX_Float8 __operator__(_MLX_Float8 lhs, _MLX_Float8 rhs) { \ + _MLX_Float8 out; \ + out.bits_ = lhs.bits_ __op__ rhs.bits_; \ + return out; \ + } \ + inline _MLX_Float8 __operator__(_MLX_Float8 lhs, uint8_t rhs) { \ + _MLX_Float8 out; \ + out.bits_ = lhs.bits_ __op__ rhs; \ + return out; \ + } \ + inline _MLX_Float8 __operator__(uint8_t lhs, _MLX_Float8 rhs) { \ + _MLX_Float8 out; \ + out.bits_ = lhs __op__ rhs.bits_; \ + return out; \ + } + +fp8_bitop(|, operator|); +fp8_bitop(&, operator&); +fp8_bitop(^, operator^); + +#undef fp8_bitop + +#define fp8_inplace_bitop(__op__, __operator__) \ + inline _MLX_Float8& __operator__(_MLX_Float8& lhs, _MLX_Float8 rhs) { \ + lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \ + return lhs; \ + } \ + inline _MLX_Float8& __operator__(_MLX_Float8& lhs, uint8_t rhs) { \ + lhs.bits_ = lhs.bits_ __op__ rhs; \ + return lhs; \ + } + +fp8_inplace_bitop(|, operator|=); +fp8_inplace_bitop(&, operator&=); +fp8_inplace_bitop(^, operator^=); + +#undef fp8_inplace_bitop + +} // namespace mlx::core diff --git a/mlx/types/quarter_types.h b/mlx/types/quarter_types.h new file mode 100644 index 0000000000..b485ad68a3 --- /dev/null +++ b/mlx/types/quarter_types.h @@ -0,0 +1,90 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include "mlx/types/half_types.h" + +#include "mlx/types/fp8.h" +namespace mlx::core { +typedef struct _MLX_Float8 float8_t; +} // namespace mlx::core + +#include "mlx/types/bf8.h" +namespace mlx::core { +typedef struct _MLX_BFloat8 bfloat8_t; +} // namespace mlx::core + +namespace mlx::core { + +// clang-format off +#define fp8_bf8_binop_helper(__op__, __operator__) \ + inline float __operator__(float8_t lhs, bfloat8_t rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + inline float __operator__(bfloat8_t lhs, float8_t rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +fp8_bf8_binop_helper(+, operator+) +fp8_bf8_binop_helper(-, operator-) +fp8_bf8_binop_helper(*, operator*) +fp8_bf8_binop_helper(/, operator/) + +// Cross-type ops: float8_t <-> float16_t +#define fp8_fp16_binop_helper(__op__, __operator__) \ + inline float __operator__(float8_t lhs, float16_t rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + inline float __operator__(float16_t lhs, float8_t rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +fp8_fp16_binop_helper(+, operator+) +fp8_fp16_binop_helper(-, operator-) +fp8_fp16_binop_helper(*, operator*) +fp8_fp16_binop_helper(/, operator/) + +// Cross-type ops: float8_t <-> bfloat16_t +#define fp8_bf16_binop_helper(__op__, __operator__) \ + inline float __operator__(float8_t lhs, bfloat16_t rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + inline float __operator__(bfloat16_t lhs, float8_t rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +fp8_bf16_binop_helper(+, operator+) +fp8_bf16_binop_helper(-, operator-) +fp8_bf16_binop_helper(*, operator*) +fp8_bf16_binop_helper(/, operator/) + +// Cross-type ops: bfloat8_t <-> float16_t +#define bf8_fp16_binop_helper(__op__, __operator__) \ + inline float __operator__(bfloat8_t lhs, float16_t rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + inline float __operator__(float16_t lhs, bfloat8_t rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +bf8_fp16_binop_helper(+, operator+) +bf8_fp16_binop_helper(-, operator-) +bf8_fp16_binop_helper(*, operator*) +bf8_fp16_binop_helper(/, operator/) + +// Cross-type ops: bfloat8_t <-> bfloat16_t +#define bf8_bf16_binop_helper(__op__, __operator__) \ + inline float __operator__(bfloat8_t lhs, bfloat16_t rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + inline float __operator__(bfloat16_t lhs, bfloat8_t rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +bf8_bf16_binop_helper(+, operator+) +bf8_bf16_binop_helper(-, operator-) +bf8_bf16_binop_helper(*, operator*) +bf8_bf16_binop_helper(/, operator/) +// clang-format on + +} // namespace mlx::core diff --git a/mlx/utils.cpp b/mlx/utils.cpp index cf0e0f38db..8de3b051bf 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -56,6 +56,12 @@ inline void PrintFormatter::print(std::ostream& os, int64_t val) { inline void PrintFormatter::print(std::ostream& os, uint64_t val) { os << val; } +inline void PrintFormatter::print(std::ostream& os, float8_t val) { + os << static_cast(val); +} +inline void PrintFormatter::print(std::ostream& os, bfloat8_t val) { + os << static_cast(val); +} inline void PrintFormatter::print(std::ostream& os, float16_t val) { os << val; } diff --git a/mlx/utils.h b/mlx/utils.h index 62aa82b658..ef85eb62b4 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -46,6 +46,8 @@ struct PrintFormatter { inline void print(std::ostream& os, uint32_t val); inline void print(std::ostream& os, int64_t val); inline void print(std::ostream& os, uint64_t val); + inline void print(std::ostream& os, float8_t val); + inline void print(std::ostream& os, bfloat8_t val); inline void print(std::ostream& os, float16_t val); inline void print(std::ostream& os, bfloat16_t val); inline void print(std::ostream& os, float val);