@@ -478,25 +478,13 @@ def main(
478478 # in this case we're looking at a fine-tuned model (?)
479479 character_level = data_args .character_level ,
480480 )
481-
482481 if training_args .do_train :
483482 # Setting 1) only load weights from the encoder
484- raise NotImplementedError (
485- "This functionality has not been restored yet"
486- )
487483 model = CnlpModelForClassification (
488- model_path = model_args .encoder_name ,
489484 config = config ,
490- cache_dir = model_args .cache_dir ,
491- tagger = tagger ,
492- relations = relations ,
493485 class_weights = dataset .class_weights ,
494486 final_task_weight = training_args .final_task_weight ,
495- use_prior_tasks = model_args .use_prior_tasks ,
496- argument_regularization = model_args .arg_reg ,
497487 )
498- delattr (model , "classifiers" )
499- delattr (model , "feature_extractors" )
500488 if training_args .do_train :
501489 tempmodel = tempfile .NamedTemporaryFile (dir = model_args .cache_dir )
502490 torch .save (model .state_dict (), tempmodel )
@@ -511,7 +499,6 @@ def main(
511499 freeze = training_args .freeze ,
512500 bias_fit = training_args .bias_fit ,
513501 )
514-
515502 else :
516503 # This only works when model_args.encoder_name is one of the
517504 # model card from https://huggingface.co/models
@@ -675,7 +662,7 @@ def compute_metrics_fn(p: EvalPrediction):
675662 model .best_eval_results = metrics
676663 if trainer .is_world_process_zero ():
677664 if training_args .do_train :
678- trainer .save_model ()
665+ trainer .save_model () # NOTE: a RobertaConfig is loaded here. why?
679666 tokenizer .save_pretrained (training_args .output_dir )
680667 if model_name == "cnn" or model_name == "lstm" :
681668 with open (
@@ -884,7 +871,7 @@ def compute_metrics_fn(p: EvalPrediction):
884871
885872 out_table = process_prediction (
886873 task_names = dataset .tasks ,
887- error_analysis = False ,
874+ error_analysis = training_args . error_analysis ,
888875 output_prob = training_args .output_prob ,
889876 character_level = data_args .character_level ,
890877 task_to_label_packet = task_to_label_packet ,
0 commit comments