Skip to content

Commit 6835d5c

Browse files
fix type
1 parent a766660 commit 6835d5c

1 file changed

Lines changed: 35 additions & 13 deletions

File tree

api/activetigger/tasks/train_bert.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -112,15 +112,21 @@ def __init__(self, *args, class_weights=None, **kwargs):
112112
self.class_weights = class_weights
113113
print("CustomTrainer initialized with class weights:", self.class_weights)
114114

115-
def compute_loss(self, model, inputs, return_outputs=False, **kwargs): # ty: ignore[invalid-method-override]
115+
def compute_loss(
116+
self, model, inputs, return_outputs=False, **kwargs
117+
): # ty: ignore[invalid-method-override]
116118
labels = inputs.get("labels")
117119
outputs = model(**inputs)
118120
logits = outputs.get("logits")
119121

120122
# Convert one-hot labels to class indices for CrossEntropyLoss
121123
label_indices = labels.argmax(dim=-1)
122-
loss_fct = nn.CrossEntropyLoss(weight=self.class_weights.to(logits.device)) # ty: ignore[unresolved-attribute]
123-
loss = loss_fct(logits.view(-1, self.model.config.num_labels), label_indices.view(-1)) # ty: ignore[unresolved-attribute]
124+
loss_fct = nn.CrossEntropyLoss(
125+
weight=self.class_weights.to(logits.device)
126+
) # ty: ignore[unresolved-attribute]
127+
loss = loss_fct(
128+
logits.view(-1, self.model.config.num_labels), label_indices.view(-1)
129+
) # ty: ignore[unresolved-attribute]
124130
return (loss, outputs) if return_outputs else loss
125131

126132

@@ -503,7 +509,9 @@ def __call__(self) -> EventsModel:
503509
self.logger = self.__init_logger(log_path)
504510
device = get_device()
505511

506-
self.df = self.__check_data(self.df, self.col_label, self.col_text) # ty: ignore[invalid-argument-type]
512+
self.df = self.__check_data(
513+
self.df, self.col_label, self.col_text
514+
) # ty: ignore[invalid-argument-type]
507515
labels, label2id, id2label = self.__retrieve_labels(self.scheme_labels)
508516
self.ds = self.__transform_to_dataset(
509517
self.training_kind, self.df, self.col_label, self.col_text, label2id
@@ -549,7 +557,7 @@ def __call__(self) -> EventsModel:
549557
task_timer.stop("setup")
550558

551559
task_timer.start("train")
552-
trainer.train() # type: ignore[attr-defined] # ty: ignore[unresolved-attribute]
560+
trainer.train() # type: ignore[attr-defined]
553561
self.logger.info(f"Model trained {current_path}")
554562
task_timer.stop("train")
555563

@@ -558,14 +566,21 @@ def __call__(self) -> EventsModel:
558566
predictions_train = trainer.predict(self.ds["train"]) # type: ignore[attr-defined] # ty: ignore[invalid-argument-type]
559567

560568
# Compute the metrics
561-
df_train_results = self.ds["train"].to_pandas().set_index("id") # ty: ignore[unresolved-attribute]
569+
df_train_results = (
570+
self.ds["train"].to_pandas().set_index("id")
571+
) # ty: ignore[unresolved-attribute]
562572

563-
df_train_results["true_label-matrix"] = predictions_train.label_ids.tolist() # ty: ignore[unresolved-attribute]
573+
df_train_results["true_label-matrix"] = (
574+
predictions_train.label_ids.tolist()
575+
) # ty: ignore[unresolved-attribute]
564576
df_train_results["true_label"] = [
565-
"|".join(matrix_to_label(row, id2label)) for row in predictions_train.label_ids # ty: ignore[invalid-argument-type, not-iterable]
577+
"|".join(matrix_to_label(row, id2label))
578+
for row in predictions_train.label_ids # ty: ignore[invalid-argument-type, not-iterable]
566579
]
567580

568-
y_prob_pred = logits_to_probs(predictions_train.predictions, self.training_kind) # ty: ignore[invalid-argument-type]
581+
y_prob_pred = logits_to_probs(
582+
predictions_train.predictions, self.training_kind
583+
) # ty: ignore[invalid-argument-type]
569584

570585
if self.training_kind == "multiclass":
571586
labels_predicted = activate_probs(
@@ -606,14 +621,21 @@ def __call__(self) -> EventsModel:
606621

607622
if "test" in self.ds:
608623
predictions_test = trainer.predict(self.ds["test"]) # type: ignore[attr-defined] # ty: ignore[invalid-argument-type]
609-
df_test_results = self.ds["test"].to_pandas().set_index("id") # ty: ignore[unresolved-attribute]
624+
df_test_results = (
625+
self.ds["test"].to_pandas().set_index("id")
626+
) # ty: ignore[unresolved-attribute]
610627

611-
df_test_results["true_label-matrix"] = predictions_test.label_ids.tolist() # ty: ignore[unresolved-attribute]
628+
df_test_results["true_label-matrix"] = (
629+
predictions_test.label_ids.tolist()
630+
) # ty: ignore[unresolved-attribute]
612631
df_test_results["true_label"] = [
613-
"|".join(matrix_to_label(row, id2label)) for row in predictions_test.label_ids # ty: ignore[invalid-argument-type, not-iterable]
632+
"|".join(matrix_to_label(row, id2label))
633+
for row in predictions_test.label_ids # ty: ignore[invalid-argument-type, not-iterable]
614634
]
615635

616-
y_prob_pred = logits_to_probs(predictions_test.predictions, kind=self.training_kind) # ty: ignore[invalid-argument-type]
636+
y_prob_pred = logits_to_probs(
637+
predictions_test.predictions, kind=self.training_kind
638+
) # ty: ignore[invalid-argument-type]
617639
if self.training_kind == "multiclass":
618640
y_label_pred = activate_probs(
619641
y_prob_pred, strategy="max", force_max_1_per_row=True

0 commit comments

Comments
 (0)