Skip to content

Commit 13becf1

Browse files
committed
Allow passing existing casual attention masks
Since we create them in the T5 data loader, why not use them?
1 parent b2fc665 commit 13becf1

1 file changed

Lines changed: 1 addition & 2 deletions

File tree

megatron/model/fused_softmax.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,7 @@ def forward_torch_softmax(self, input, mask):
214214
if self.scale is not None:
215215
input = input * self.scale
216216

217-
if self.attn_mask_type == AttnMaskType.causal:
218-
assert mask is None
217+
if self.attn_mask_type == AttnMaskType.causal and mask is None:
219218
assert input.shape[2] == input.shape[3]
220219
mask = self.get_causal_mask(input.shape[2])
221220

0 commit comments

Comments
 (0)