Skip to content

Commit 8a145e7

Browse files
committed
addressing review comments
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 3b02068 commit 8a145e7

2 files changed

Lines changed: 35 additions & 21 deletions

File tree

models/esm2/src/esm/collator.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,17 @@ def __call__(self, features, return_tensors=None):
171171
batch = self.flattening_collator(features, return_tensors)
172172

173173
special_tokens_mask = batch.pop("special_tokens_mask", None)
174-
batch["input_ids"], batch["labels"] = self.mlm_collator.torch_mask_tokens(
175-
batch["input_ids"], special_tokens_mask=special_tokens_mask
176-
)
174+
175+
if return_tensors == "pt":
176+
batch["input_ids"], batch["labels"] = self.mlm_collator.torch_mask_tokens(
177+
batch["input_ids"], special_tokens_mask=special_tokens_mask
178+
)
179+
elif return_tensors == "np":
180+
batch["input_ids"], batch["labels"] = self.mlm_collator.numpy_mask_tokens(
181+
batch["input_ids"], special_tokens_mask=special_tokens_mask
182+
)
183+
else:
184+
raise ValueError(f'return_tensors must be one of ("pt", "np"), {return_tensors=} not suported')
177185

178186
return batch
179187

models/esm2/src/esm/modeling_esm_te.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -176,33 +176,39 @@ def forward(
176176
raise ValueError(
177177
"cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k must be provided when using THD inputs."
178178
)
179+
assert hidden_states.dim() == 3 and hidden_states.size(0) == 1, (
180+
"THD expects embeddings shaped [1, total_tokens, hidden_size]."
181+
)
182+
hidden_states = hidden_states.squeeze(0)
179183

180184
elif self.config.attn_input_format == "bshd":
181185
if any(x is not None for x in [cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k]):
182186
raise ValueError(
183187
"cu_seq_lens_q, cu_seq_lens_k, max_length_q, and max_length_k are not allowed when using BSHD inputs."
184188
)
185189

186-
if self.config.attn_input_format == "bshd" and self.te_rope_emb is not None:
187-
te_rope_emb = self.te_rope_emb.to(
188-
device=hidden_states.device, dtype=hidden_states.dtype, non_blocking=True
189-
)
190-
seq_len = hidden_states.shape[1]
191-
if te_rope_emb.size(0) < seq_len:
192-
raise RuntimeError(
193-
f"ROPE length {te_rope_emb.size(0)} < input seq length {seq_len}. "
194-
f"Increase max_position_embeddings."
190+
te_rope_emb = None
191+
if self.config.position_embedding_type == "rotary":
192+
if self.config.attn_input_format == "bshd":
193+
te_rope_emb = self.te_rope_emb.to(
194+
device=hidden_states.device, dtype=hidden_states.dtype, non_blocking=True
195+
)
196+
seq_len = hidden_states.shape[1]
197+
if te_rope_emb.size(0) < seq_len:
198+
raise RuntimeError(
199+
f"ROPE length {te_rope_emb.size(0)} < input seq length {seq_len}. "
200+
f"Increase max_position_embeddings."
201+
)
202+
te_rope_emb = te_rope_emb[:seq_len]
203+
204+
elif self.config.attn_input_format == "thd":
205+
assert cu_seq_lens_q is not None
206+
te_rope_emb = self.rotary_embeddings(max_seq_len=cu_seq_lens_q[-1]).to(
207+
device=hidden_states.device, dtype=hidden_states.dtype, non_blocking=True
195208
)
196-
te_rope_emb = te_rope_emb[:seq_len]
197209

198-
elif self.config.attn_input_format == "thd":
199-
assert cu_seq_lens_q is not None
200-
te_rope_emb = self.rotary_embeddings(max_seq_len=cu_seq_lens_q[-1]).to(
201-
device=hidden_states.device, dtype=hidden_states.dtype, non_blocking=True
202-
)
203-
hidden_states = hidden_states.squeeze(0)
204-
else:
205-
te_rope_emb = None
210+
else:
211+
raise ValueError(f"Unsupported attention input format: {self.config.attn_input_format}")
206212

207213
for layer_module in self.layers:
208214
if output_hidden_states:

0 commit comments

Comments
 (0)