Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlx/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,12 @@ void array::init(It src) {
case int64:
std::copy(src, src + size(), data<int64_t>());
break;
case float8:
std::copy(src, src + size(), data<float8_t>());
break;
case bfloat8:
std::copy(src, src + size(), data<bfloat8_t>());
break;
case float16:
std::copy(src, src + size(), data<float16_t>());
break;
Expand Down
5 changes: 5 additions & 0 deletions mlx/dtype.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "mlx/api.h"
#include "mlx/types/complex.h"
#include "mlx/types/half_types.h"
#include "mlx/types/quarter_types.h"

namespace mlx::core {

Expand All @@ -22,9 +23,11 @@ struct Dtype {
int16,
int32,
int64,
float8,
float16,
float32,
float64,
bfloat8,
bfloat16,
complex64,
};
Expand Down Expand Up @@ -78,9 +81,11 @@ inline constexpr Dtype int16{Dtype::Val::int16, sizeof(int16_t)};
inline constexpr Dtype int32{Dtype::Val::int32, sizeof(int32_t)};
inline constexpr Dtype int64{Dtype::Val::int64, sizeof(int64_t)};

inline constexpr Dtype float8{Dtype::Val::float8, sizeof(uint8_t)};
inline constexpr Dtype float16{Dtype::Val::float16, sizeof(uint16_t)};
inline constexpr Dtype float32{Dtype::Val::float32, sizeof(float)};
inline constexpr Dtype float64{Dtype::Val::float64, sizeof(double)};
inline constexpr Dtype bfloat8{Dtype::Val::bfloat8, sizeof(uint8_t)};
inline constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)};
inline constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)};

Expand Down
2 changes: 2 additions & 0 deletions mlx/dtype_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ const char* dtype_to_string(Dtype arg);
MLX_INTERNAL_DTYPE_SWITCH_CASE(uint64, uint64_t)

#define MLX_INTERNAL_DTYPE_SWITCH_FLOATS() \
MLX_INTERNAL_DTYPE_SWITCH_CASE(float8, float8_t); \
MLX_INTERNAL_DTYPE_SWITCH_CASE(bfloat8, bfloat8_t); \
MLX_INTERNAL_DTYPE_SWITCH_CASE(float16, float16_t); \
MLX_INTERNAL_DTYPE_SWITCH_CASE(bfloat16, bfloat16_t); \
MLX_INTERNAL_DTYPE_SWITCH_CASE(float32, float); \
Expand Down
241 changes: 241 additions & 0 deletions mlx/types/bf8.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
// Copyright © 2023 Apple Inc.

#pragma once

#include <cmath>
#include <cstdint>
#include <vector>

// BFloat8 (E5M2): 1 sign, 5 exponent (bias 15), 2 mantissa bits
#define __MLX_BF8_NAN__ 0x7F
#define __MLX_BF8_ONE__ 0x3C

namespace mlx::core {

namespace {
union float_bits_bf8 {
float f;
uint32_t u;
};
} // namespace

struct _MLX_BFloat8 {
uint8_t bits_;

// Default constructor
_MLX_BFloat8() = default;

// Default copy constructor
_MLX_BFloat8(_MLX_BFloat8 const&) = default;

// Appease std::vector<bool> for being special
_MLX_BFloat8& operator=(std::vector<bool>::reference x) {
bits_ = (x) ? __MLX_BF8_ONE__ : 0;
return (*this);
}

_MLX_BFloat8& operator=(const float& x) {
return (*this = _MLX_BFloat8(x));
}

// From float32
_MLX_BFloat8(const float& x) : bits_(0) {
if (std::isnan(x)) {
bits_ = __MLX_BF8_NAN__;
} else if (std::isinf(x)) {
// Infinity: exp=11111, mantissa=00, preserve sign
bits_ = (x > 0) ? 0x7C : 0xFC;
} else {
float_bits_bf8 in;
in.f = x;

// Extract sign
uint8_t sign = (in.u >> 24) & 0x80;

// Extract float32 exponent and re-bias for E5M2
int32_t exp = (int32_t)((in.u >> 23) & 0xFF) - 127 + 15;

// Extract top 2 mantissa bits
uint32_t mantissa = (in.u >> 21) & 0x03;

// Round to nearest even
uint32_t round_bit = (in.u >> 20) & 1;
uint32_t sticky = (in.u & 0x000FFFFF) != 0;
if (round_bit && (sticky || (mantissa & 1))) {
mantissa++;
if (mantissa > 3) {
mantissa = 0;
exp++;
}
}

if (exp >= 31) {
// Overflow to infinity
bits_ = sign | 0x7C;
} else if (exp <= 0) {
if (exp >= -2) {
// Denormalized: shift mantissa with implicit 1
mantissa = (mantissa | 0x04) >> (1 - exp);
bits_ = sign | (uint8_t)(mantissa & 0x03);
} else {
// Underflow to zero
bits_ = sign;
}
} else {
bits_ = sign | ((uint8_t)exp << 2) | (uint8_t)mantissa;
}
}
}

// To float32
operator float() const {
float_bits_bf8 out;

uint32_t sign = (bits_ >> 7) & 1;
uint32_t exp = (bits_ >> 2) & 0x1F;
uint32_t mantissa = bits_ & 0x03;

if (exp == 0x1F) {
// Inf (mantissa==0) or NaN (mantissa!=0)
out.u = (sign << 31) | 0x7F800000 | (mantissa << 21);
} else if (exp == 0) {
if (mantissa == 0) {
// Zero (signed)
out.u = sign << 31;
} else {
// Denormalized: normalize
while ((mantissa & 0x04) == 0) {
mantissa <<= 1;
exp--;
}
mantissa &= 0x03; // remove implicit 1
uint32_t f_exp = (uint32_t)((int32_t)exp + 127 - 15 + 1);
out.u = (sign << 31) | (f_exp << 23) | (mantissa << 21);
}
} else {
// Normalized
uint32_t f_exp = exp - 15 + 127;
out.u = (sign << 31) | (f_exp << 23) | (mantissa << 21);
}

return out.f;
}
};

#define bf8_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
inline otype __operator__(atype lhs, btype rhs) { \
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
}

#define bf8_binop_helper(__op__, __operator__, otype, itype, ctype) \
inline otype __operator__(_MLX_BFloat8 lhs, itype rhs) { \
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
} \
inline otype __operator__(itype lhs, _MLX_BFloat8 rhs) { \
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
}

// Operators
#define bf8_binop(__op__, __operator__) \
bf8_binop_base( \
__op__, __operator__, _MLX_BFloat8, _MLX_BFloat8, _MLX_BFloat8, float); \
bf8_binop_helper(__op__, __operator__, float, float, float); \
bf8_binop_helper(__op__, __operator__, double, double, double); \
bf8_binop_helper(__op__, __operator__, _MLX_BFloat8, bool, float); \
bf8_binop_helper(__op__, __operator__, _MLX_BFloat8, int32_t, float); \
bf8_binop_helper(__op__, __operator__, _MLX_BFloat8, uint32_t, float); \
bf8_binop_helper(__op__, __operator__, _MLX_BFloat8, int64_t, float); \
bf8_binop_helper(__op__, __operator__, _MLX_BFloat8, uint64_t, float);

bf8_binop(+, operator+);
bf8_binop(-, operator-);
bf8_binop(*, operator*);
bf8_binop(/, operator/);

#undef bf8_binop

// Comparison ops
#define bf8_compop(__op__, __operator__) \
bf8_binop_base( \
__op__, __operator__, bool, _MLX_BFloat8, _MLX_BFloat8, float); \
bf8_binop_helper(__op__, __operator__, bool, float, float); \
bf8_binop_helper(__op__, __operator__, bool, double, double); \
bf8_binop_helper(__op__, __operator__, bool, int32_t, float); \
bf8_binop_helper(__op__, __operator__, bool, uint32_t, float); \
bf8_binop_helper(__op__, __operator__, bool, int64_t, float); \
bf8_binop_helper(__op__, __operator__, bool, uint64_t, float);

bf8_compop(>, operator>);
bf8_compop(<, operator<);
bf8_compop(>=, operator>=);
bf8_compop(<=, operator<=);
bf8_compop(==, operator==);
bf8_compop(!=, operator!=);

#undef bf8_compop

// Negative
inline _MLX_BFloat8 operator-(_MLX_BFloat8 lhs) {
return -static_cast<float>(lhs);
}

// Inplace ops
#define bf8_inplace_op(__op__, __operator__) \
inline _MLX_BFloat8& __operator__(_MLX_BFloat8& lhs, const float& rhs) { \
lhs = lhs __op__ rhs; \
return lhs; \
} \
inline float& __operator__(float& lhs, _MLX_BFloat8 rhs) { \
lhs = lhs __op__ rhs; \
return lhs; \
}

bf8_inplace_op(+, operator+=);
bf8_inplace_op(-, operator-=);
bf8_inplace_op(*, operator*=);
bf8_inplace_op(/, operator/=);

#undef bf8_inplace_op

// Bitwise ops

#define bf8_bitop(__op__, __operator__) \
inline _MLX_BFloat8 __operator__(_MLX_BFloat8 lhs, _MLX_BFloat8 rhs) { \
_MLX_BFloat8 out; \
out.bits_ = lhs.bits_ __op__ rhs.bits_; \
return out; \
} \
inline _MLX_BFloat8 __operator__(_MLX_BFloat8 lhs, uint8_t rhs) { \
_MLX_BFloat8 out; \
out.bits_ = lhs.bits_ __op__ rhs; \
return out; \
} \
inline _MLX_BFloat8 __operator__(uint8_t lhs, _MLX_BFloat8 rhs) { \
_MLX_BFloat8 out; \
out.bits_ = lhs __op__ rhs.bits_; \
return out; \
}

bf8_bitop(|, operator|);
bf8_bitop(&, operator&);
bf8_bitop(^, operator^);

#undef bf8_bitop

#define bf8_inplace_bitop(__op__, __operator__) \
inline _MLX_BFloat8& __operator__(_MLX_BFloat8& lhs, _MLX_BFloat8 rhs) { \
lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \
return lhs; \
} \
inline _MLX_BFloat8& __operator__(_MLX_BFloat8& lhs, uint8_t rhs) { \
lhs.bits_ = lhs.bits_ __op__ rhs; \
return lhs; \
}

bf8_inplace_bitop(|, operator|=);
bf8_inplace_bitop(&, operator&=);
bf8_inplace_bitop(^, operator^=);

#undef bf8_inplace_bitop

} // namespace mlx::core
Loading