File tree Expand file tree Collapse file tree
models/amplify/src/amplify Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -212,7 +212,6 @@ def forward(
212212 output_hidden_states = False ,
213213 output_attentions = False ,
214214 labels = None ,
215- ** kwargs ,
216215 ) -> BaseModelOutput :
217216 """Forward pass of the AMPLIFY model.
218217
@@ -222,7 +221,6 @@ def forward(
222221 output_hidden_states (bool): Whether to output the hidden states.
223222 output_attentions (bool): Whether to output the attention weights.
224223 labels (torch.Tensor): The labels.
225- **kwargs: Additional arguments.
226224
227225 Returns:
228226 BaseModelOutput: The output of the model.
@@ -296,7 +294,6 @@ def forward(
296294 output_hidden_states = False ,
297295 output_attentions = False ,
298296 labels = None ,
299- ** kwargs ,
300297 ) -> MaskedLMOutput :
301298 """Forward pass of the AMPLIFYForMaskedLM model.
302299
@@ -306,7 +303,6 @@ def forward(
306303 output_hidden_states (bool): Whether to output the hidden states.
307304 output_attentions (bool): Whether to output the attention weights.
308305 labels (torch.Tensor): The labels.
309- **kwargs: Additional arguments.
310306
311307 Returns:
312308 MaskedLMOutput: The output of the model.
@@ -317,7 +313,6 @@ def forward(
317313 output_hidden_states ,
318314 output_attentions ,
319315 labels ,
320- ** kwargs ,
321316 )
322317
323318 # Classification head with layer norm
Original file line number Diff line number Diff line change @@ -46,7 +46,7 @@ def main(args: DictConfig):
4646 )
4747
4848 config = AutoConfig .from_pretrained (args .model_tag , trust_remote_code = True )
49- model = AutoModelForMaskedLM .from_config (config , trust_remote_code = True , torch_dtype = torch .bfloat16 )
49+ model = AutoModelForMaskedLM .from_config (config , trust_remote_code = True , dtype = torch .bfloat16 )
5050
5151 train_dataset , eval_dataset , data_collator = create_datasets_and_collator (
5252 tokenizer_name = args .model_tag ,
@@ -72,6 +72,8 @@ def main(args: DictConfig):
7272 logger .info ("Resuming from checkpoint: %s" , last_checkpoint )
7373 else :
7474 logger .info ("No checkpoint found, starting from scratch" )
75+ if state .is_main_process :
76+ breakpoint ()
7577 train_result = trainer .train (resume_from_checkpoint = last_checkpoint )
7678 logger .info ("Training complete. Metrics: %s" , train_result .metrics )
7779 trainer .save_metrics ("train" , train_result .metrics )
You can’t perform that action at this time.
0 commit comments