Skip to content

Commit 1ad7199

Browse files
authored
Fix handling of attention_mask in encoders (#14)
1 parent 039b5d0 commit 1ad7199

3 files changed

Lines changed: 11 additions & 8 deletions

File tree

mmlearn/modules/encoders/clip_encoders.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -327,8 +327,8 @@ def forward(self, inputs: Dict[Union[str, Modality], Any]) -> Tuple[torch.Tensor
327327
The text embeddings. Will be a tuple with a single element.
328328
"""
329329
input_ids = inputs[Modalities.TEXT]
330-
attention_mask = inputs.get("attention_mask") or inputs.get(
331-
Modalities.TEXT.attention_mask
330+
attention_mask: Optional[torch.Tensor] = inputs.get(
331+
"attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
332332
)
333333
position_ids = inputs.get("position_ids")
334334

@@ -568,8 +568,9 @@ def forward(self, inputs: Dict[Union[str, Modality], Any]) -> BaseModelOutput:
568568
"""
569569
output = self.model(
570570
input_ids=inputs[Modalities.TEXT],
571-
attention_mask=inputs.get("attention_mask")
572-
or inputs.get(Modalities.TEXT.attention_mask),
571+
attention_mask=inputs.get(
572+
"attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
573+
),
573574
inputs_embeds=inputs.get("inputs_embeds"),
574575
output_attentions=inputs.get("output_attentions"),
575576
output_hidden_states=True,

mmlearn/modules/encoders/hf_text_encoders.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,9 @@ def forward(self, inputs: Dict[Union[str, Modality], Any]) -> BaseModelOutput:
156156
"""
157157
outputs = self.model(
158158
input_ids=inputs[Modalities.TEXT],
159-
attention_mask=inputs.get("attention_mask")
160-
or inputs.get(Modalities.TEXT.attention_mask),
159+
attention_mask=inputs.get(
160+
"attention_mask", inputs.get(Modalities.TEXT.attention_mask, None)
161+
),
161162
position_ids=inputs.get("position_ids"),
162163
output_attentions=inputs.get("output_attentions"),
163164
return_dict=True,

projects/bioscan_clip/encoders.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,9 @@ def forward(self, inputs: Dict[Union[str, Modality], Any]) -> BaseModelOutput:
140140
"""Run the forward pass."""
141141
outputs = self.model(
142142
input_ids=inputs[Modalities.DNA],
143-
attention_mask=inputs.get("attention_mask")
144-
or inputs.get(Modalities.DNA.attention_mask),
143+
attention_mask=inputs.get(
144+
"attention_mask", inputs.get(Modalities.DNA.attention_mask, None)
145+
),
145146
position_ids=inputs.get("position_ids"),
146147
output_attentions=inputs.get("output_attentions"),
147148
return_dict=True,

0 commit comments

Comments
 (0)