Skip to content

Commit b51d21e

Browse files
committed
remove kwargs from amplify
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent f40e4b0 commit b51d21e

2 files changed

Lines changed: 3 additions & 6 deletions

File tree

models/amplify/src/amplify/amplify_te.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,6 @@ def forward(
212212
output_hidden_states=False,
213213
output_attentions=False,
214214
labels=None,
215-
**kwargs,
216215
) -> BaseModelOutput:
217216
"""Forward pass of the AMPLIFY model.
218217
@@ -222,7 +221,6 @@ def forward(
222221
output_hidden_states (bool): Whether to output the hidden states.
223222
output_attentions (bool): Whether to output the attention weights.
224223
labels (torch.Tensor): The labels.
225-
**kwargs: Additional arguments.
226224
227225
Returns:
228226
BaseModelOutput: The output of the model.
@@ -296,7 +294,6 @@ def forward(
296294
output_hidden_states=False,
297295
output_attentions=False,
298296
labels=None,
299-
**kwargs,
300297
) -> MaskedLMOutput:
301298
"""Forward pass of the AMPLIFYForMaskedLM model.
302299
@@ -306,7 +303,6 @@ def forward(
306303
output_hidden_states (bool): Whether to output the hidden states.
307304
output_attentions (bool): Whether to output the attention weights.
308305
labels (torch.Tensor): The labels.
309-
**kwargs: Additional arguments.
310306
311307
Returns:
312308
MaskedLMOutput: The output of the model.
@@ -317,7 +313,6 @@ def forward(
317313
output_hidden_states,
318314
output_attentions,
319315
labels,
320-
**kwargs,
321316
)
322317

323318
# Classification head with layer norm

recipes/esm2_accelerate/train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def main(args: DictConfig):
4646
)
4747

4848
config = AutoConfig.from_pretrained(args.model_tag, trust_remote_code=True)
49-
model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True, torch_dtype=torch.bfloat16)
49+
model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True, dtype=torch.bfloat16)
5050

5151
train_dataset, eval_dataset, data_collator = create_datasets_and_collator(
5252
tokenizer_name=args.model_tag,
@@ -72,6 +72,8 @@ def main(args: DictConfig):
7272
logger.info("Resuming from checkpoint: %s", last_checkpoint)
7373
else:
7474
logger.info("No checkpoint found, starting from scratch")
75+
if state.is_main_process:
76+
breakpoint()
7577
train_result = trainer.train(resume_from_checkpoint=last_checkpoint)
7678
logger.info("Training complete. Metrics: %s", train_result.metrics)
7779
trainer.save_metrics("train", train_result.metrics)

0 commit comments

Comments
 (0)