[JAX] Return max_logit and softmax aux stats from TE JAX fused attn#3112
Draft
KshitijLakhani wants to merge 4 commits into
Draft
[JAX] Return max_logit and softmax aux stats from TE JAX fused attn#3112KshitijLakhani wants to merge 4 commits into
KshitijLakhani wants to merge 4 commits into
Loading