@@ -112,15 +112,21 @@ def __init__(self, *args, class_weights=None, **kwargs):
112112 self .class_weights = class_weights
113113 print ("CustomTrainer initialized with class weights:" , self .class_weights )
114114
115- def compute_loss (self , model , inputs , return_outputs = False , ** kwargs ): # ty: ignore[invalid-method-override]
115+ def compute_loss (
116+ self , model , inputs , return_outputs = False , ** kwargs
117+ ): # ty: ignore[invalid-method-override]
116118 labels = inputs .get ("labels" )
117119 outputs = model (** inputs )
118120 logits = outputs .get ("logits" )
119121
120122 # Convert one-hot labels to class indices for CrossEntropyLoss
121123 label_indices = labels .argmax (dim = - 1 )
122- loss_fct = nn .CrossEntropyLoss (weight = self .class_weights .to (logits .device )) # ty: ignore[unresolved-attribute]
123- loss = loss_fct (logits .view (- 1 , self .model .config .num_labels ), label_indices .view (- 1 )) # ty: ignore[unresolved-attribute]
124+ loss_fct = nn .CrossEntropyLoss (
125+ weight = self .class_weights .to (logits .device )
126+ ) # ty: ignore[unresolved-attribute]
127+ loss = loss_fct (
128+ logits .view (- 1 , self .model .config .num_labels ), label_indices .view (- 1 )
129+ ) # ty: ignore[unresolved-attribute]
124130 return (loss , outputs ) if return_outputs else loss
125131
126132
@@ -503,7 +509,9 @@ def __call__(self) -> EventsModel:
503509 self .logger = self .__init_logger (log_path )
504510 device = get_device ()
505511
506- self .df = self .__check_data (self .df , self .col_label , self .col_text ) # ty: ignore[invalid-argument-type]
512+ self .df = self .__check_data (
513+ self .df , self .col_label , self .col_text
514+ ) # ty: ignore[invalid-argument-type]
507515 labels , label2id , id2label = self .__retrieve_labels (self .scheme_labels )
508516 self .ds = self .__transform_to_dataset (
509517 self .training_kind , self .df , self .col_label , self .col_text , label2id
@@ -549,7 +557,7 @@ def __call__(self) -> EventsModel:
549557 task_timer .stop ("setup" )
550558
551559 task_timer .start ("train" )
552- trainer .train () # type: ignore[attr-defined] # ty: ignore[unresolved-attribute]
560+ trainer .train () # type: ignore[attr-defined]
553561 self .logger .info (f"Model trained { current_path } " )
554562 task_timer .stop ("train" )
555563
@@ -558,14 +566,21 @@ def __call__(self) -> EventsModel:
558566 predictions_train = trainer .predict (self .ds ["train" ]) # type: ignore[attr-defined] # ty: ignore[invalid-argument-type]
559567
560568 # Compute the metrics
561- df_train_results = self .ds ["train" ].to_pandas ().set_index ("id" ) # ty: ignore[unresolved-attribute]
569+ df_train_results = (
570+ self .ds ["train" ].to_pandas ().set_index ("id" )
571+ ) # ty: ignore[unresolved-attribute]
562572
563- df_train_results ["true_label-matrix" ] = predictions_train .label_ids .tolist () # ty: ignore[unresolved-attribute]
573+ df_train_results ["true_label-matrix" ] = (
574+ predictions_train .label_ids .tolist ()
575+ ) # ty: ignore[unresolved-attribute]
564576 df_train_results ["true_label" ] = [
565- "|" .join (matrix_to_label (row , id2label )) for row in predictions_train .label_ids # ty: ignore[invalid-argument-type, not-iterable]
577+ "|" .join (matrix_to_label (row , id2label ))
578+ for row in predictions_train .label_ids # ty: ignore[invalid-argument-type, not-iterable]
566579 ]
567580
568- y_prob_pred = logits_to_probs (predictions_train .predictions , self .training_kind ) # ty: ignore[invalid-argument-type]
581+ y_prob_pred = logits_to_probs (
582+ predictions_train .predictions , self .training_kind
583+ ) # ty: ignore[invalid-argument-type]
569584
570585 if self .training_kind == "multiclass" :
571586 labels_predicted = activate_probs (
@@ -606,14 +621,21 @@ def __call__(self) -> EventsModel:
606621
607622 if "test" in self .ds :
608623 predictions_test = trainer .predict (self .ds ["test" ]) # type: ignore[attr-defined] # ty: ignore[invalid-argument-type]
609- df_test_results = self .ds ["test" ].to_pandas ().set_index ("id" ) # ty: ignore[unresolved-attribute]
624+ df_test_results = (
625+ self .ds ["test" ].to_pandas ().set_index ("id" )
626+ ) # ty: ignore[unresolved-attribute]
610627
611- df_test_results ["true_label-matrix" ] = predictions_test .label_ids .tolist () # ty: ignore[unresolved-attribute]
628+ df_test_results ["true_label-matrix" ] = (
629+ predictions_test .label_ids .tolist ()
630+ ) # ty: ignore[unresolved-attribute]
612631 df_test_results ["true_label" ] = [
613- "|" .join (matrix_to_label (row , id2label )) for row in predictions_test .label_ids # ty: ignore[invalid-argument-type, not-iterable]
632+ "|" .join (matrix_to_label (row , id2label ))
633+ for row in predictions_test .label_ids # ty: ignore[invalid-argument-type, not-iterable]
614634 ]
615635
616- y_prob_pred = logits_to_probs (predictions_test .predictions , kind = self .training_kind ) # ty: ignore[invalid-argument-type]
636+ y_prob_pred = logits_to_probs (
637+ predictions_test .predictions , kind = self .training_kind
638+ ) # ty: ignore[invalid-argument-type]
617639 if self .training_kind == "multiclass" :
618640 y_label_pred = activate_probs (
619641 y_prob_pred , strategy = "max" , force_max_1_per_row = True
0 commit comments