Add basic fp8/bfp8 support, future implementations of 8-bit floating …#3374
Add basic fp8/bfp8 support, future implementations of 8-bit floating …#3374Geramy wants to merge 1 commit intoml-explore:mainfrom
Conversation
…point functions next.
|
Rather than JAX's float8 types: https://github.com/jax-ml/ml_dtypes I think we should have a consensus on the type names with the team first. |
Sounds good, I actually think that is a lot better since we have to think about compatibility between all possible backends. It could turn into a big problem later I would assume. Let me know what you guys decide on I 100% agree on changing the names upfront. |
…point functions next.
Proposed changes
The proposed changes are to include fp8 and bfp8 support.
Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes