Hi @stas00
Based on our discussion on X, I looked around and someone on the EleutherAI discord pointed out the scaled_mm API - https://github.com/pytorch/pytorch/blob/88c77db9c862573f9d7a8eda58ae735415bc740d/torch/nn/functional.py#L6752 which is used under the hood with TorchAO. They also pointed to these benchmark functions in the TorchAO repo which seem to imply that post Torch 2.10.0 you can use MXPF4 with scaled_mm - https://github.com/pytorch/ao/blob/fe986580eaafc87f532534a8f222c7d11af18702/benchmarks/float8/bench_matmul.py#L166
I'm thinking of bringing both MXFP4 and NVP4 support in using the function params in the second link as reference.
Hi @stas00
Based on our discussion on X, I looked around and someone on the EleutherAI discord pointed out the
scaled_mmAPI - https://github.com/pytorch/pytorch/blob/88c77db9c862573f9d7a8eda58ae735415bc740d/torch/nn/functional.py#L6752 which is used under the hood with TorchAO. They also pointed to these benchmark functions in the TorchAO repo which seem to imply that post Torch 2.10.0 you can use MXPF4 with scaled_mm - https://github.com/pytorch/ao/blob/fe986580eaafc87f532534a8f222c7d11af18702/benchmarks/float8/bench_matmul.py#L166I'm thinking of bringing both MXFP4 and NVP4 support in using the function params in the second link as reference.