Skip to content

Commit b6020e3

Browse files
authored
[JAX] Fix bug with pre scale bias (NVIDIA#2300)
* fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> * fix Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> --------- Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
1 parent 77a0063 commit b6020e3

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

transformer_engine/jax/flax/transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def __call__(
197197
fused_scale_factor = scale_factor
198198
if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS:
199199
attn_weights += bias
200+
bias = None
200201

201202
def apply_swa_mask(original_mask: Array) -> Array:
202203
"""Apply the sliding window mask to a given mask"""

0 commit comments

Comments
 (0)