Skip to content

Commit 60a1a72

Browse files
re-implement support for cnlpt models
1 parent 2bd994a commit 60a1a72

1 file changed

Lines changed: 2 additions & 15 deletions

File tree

src/cnlpt/train_system.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -478,25 +478,13 @@ def main(
478478
# in this case we're looking at a fine-tuned model (?)
479479
character_level=data_args.character_level,
480480
)
481-
482481
if training_args.do_train:
483482
# Setting 1) only load weights from the encoder
484-
raise NotImplementedError(
485-
"This functionality has not been restored yet"
486-
)
487483
model = CnlpModelForClassification(
488-
model_path=model_args.encoder_name,
489484
config=config,
490-
cache_dir=model_args.cache_dir,
491-
tagger=tagger,
492-
relations=relations,
493485
class_weights=dataset.class_weights,
494486
final_task_weight=training_args.final_task_weight,
495-
use_prior_tasks=model_args.use_prior_tasks,
496-
argument_regularization=model_args.arg_reg,
497487
)
498-
delattr(model, "classifiers")
499-
delattr(model, "feature_extractors")
500488
if training_args.do_train:
501489
tempmodel = tempfile.NamedTemporaryFile(dir=model_args.cache_dir)
502490
torch.save(model.state_dict(), tempmodel)
@@ -511,7 +499,6 @@ def main(
511499
freeze=training_args.freeze,
512500
bias_fit=training_args.bias_fit,
513501
)
514-
515502
else:
516503
# This only works when model_args.encoder_name is one of the
517504
# model card from https://huggingface.co/models
@@ -675,7 +662,7 @@ def compute_metrics_fn(p: EvalPrediction):
675662
model.best_eval_results = metrics
676663
if trainer.is_world_process_zero():
677664
if training_args.do_train:
678-
trainer.save_model()
665+
trainer.save_model() # NOTE: a RobertaConfig is loaded here. why?
679666
tokenizer.save_pretrained(training_args.output_dir)
680667
if model_name == "cnn" or model_name == "lstm":
681668
with open(
@@ -884,7 +871,7 @@ def compute_metrics_fn(p: EvalPrediction):
884871

885872
out_table = process_prediction(
886873
task_names=dataset.tasks,
887-
error_analysis=False,
874+
error_analysis=training_args.error_analysis,
888875
output_prob=training_args.output_prob,
889876
character_level=data_args.character_level,
890877
task_to_label_packet=task_to_label_packet,

0 commit comments

Comments
 (0)