Skip to content

Commit 41e5d57

Browse files
check ty types
1 parent 0c15137 commit 41e5d57

1 file changed

Lines changed: 27 additions & 22 deletions

File tree

api/activetigger/tasks/train_bert.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,12 @@ def compute_loss(
122122
# Convert one-hot labels to class indices for CrossEntropyLoss
123123
label_indices = labels.argmax(dim=-1)
124124
loss_fct = nn.CrossEntropyLoss(
125-
weight=self.class_weights.to(logits.device)
126-
) # ty: ignore[unresolved-attribute]
125+
weight=self.class_weights.to(logits.device), # ty: ignore[unresolved-attribute]
126+
)
127127
loss = loss_fct(
128-
logits.view(-1, self.model.config.num_labels), label_indices.view(-1)
129-
) # ty: ignore[unresolved-attribute]
128+
logits.view(-1, self.model.config.num_labels), # ty: ignore[unresolved-attribute]
129+
label_indices.view(-1),
130+
)
130131
return (loss, outputs) if return_outputs else loss
131132

132133

@@ -510,8 +511,10 @@ def __call__(self) -> EventsModel:
510511
device = get_device()
511512

512513
self.df = self.__check_data(
513-
self.df, self.col_label, self.col_text
514-
) # ty: ignore[invalid-argument-type]
514+
self.df, # ty: ignore[invalid-argument-type]
515+
self.col_label,
516+
self.col_text,
517+
)
515518
labels, label2id, id2label = self.__retrieve_labels(self.scheme_labels)
516519
self.ds = self.__transform_to_dataset(
517520
self.training_kind, self.df, self.col_label, self.col_text, label2id
@@ -567,20 +570,21 @@ def __call__(self) -> EventsModel:
567570

568571
# Compute the metrics
569572
df_train_results = (
570-
self.ds["train"].to_pandas().set_index("id")
571-
) # ty: ignore[unresolved-attribute]
573+
self.ds["train"].to_pandas().set_index("id") # ty: ignore[unresolved-attribute]
574+
)
572575

573576
df_train_results["true_label-matrix"] = (
574-
predictions_train.label_ids.tolist()
575-
) # ty: ignore[unresolved-attribute]
577+
predictions_train.label_ids.tolist() # ty: ignore[unresolved-attribute]
578+
)
576579
df_train_results["true_label"] = [
577-
"|".join(matrix_to_label(row, id2label))
578-
for row in predictions_train.label_ids # ty: ignore[invalid-argument-type, not-iterable]
580+
"|".join(matrix_to_label(row, id2label)) # ty: ignore[invalid-argument-type]
581+
for row in predictions_train.label_ids # ty: ignore[not-iterable]
579582
]
580583

581584
y_prob_pred = logits_to_probs(
582-
predictions_train.predictions, self.training_kind
583-
) # ty: ignore[invalid-argument-type]
585+
predictions_train.predictions, # ty: ignore[invalid-argument-type]
586+
self.training_kind,
587+
)
584588

585589
if self.training_kind == "multiclass":
586590
labels_predicted = activate_probs(
@@ -622,20 +626,21 @@ def __call__(self) -> EventsModel:
622626
if "test" in self.ds:
623627
predictions_test = trainer.predict(self.ds["test"]) # type: ignore[attr-defined] # ty: ignore[invalid-argument-type]
624628
df_test_results = (
625-
self.ds["test"].to_pandas().set_index("id")
626-
) # ty: ignore[unresolved-attribute]
629+
self.ds["test"].to_pandas().set_index("id") # ty: ignore[unresolved-attribute]
630+
)
627631

628632
df_test_results["true_label-matrix"] = (
629-
predictions_test.label_ids.tolist()
630-
) # ty: ignore[unresolved-attribute]
633+
predictions_test.label_ids.tolist() # ty: ignore[unresolved-attribute]
634+
)
631635
df_test_results["true_label"] = [
632-
"|".join(matrix_to_label(row, id2label))
633-
for row in predictions_test.label_ids # ty: ignore[invalid-argument-type, not-iterable]
636+
"|".join(matrix_to_label(row, id2label)) # ty: ignore[invalid-argument-type]
637+
for row in predictions_test.label_ids # ty: ignore[not-iterable]
634638
]
635639

636640
y_prob_pred = logits_to_probs(
637-
predictions_test.predictions, kind=self.training_kind
638-
) # ty: ignore[invalid-argument-type]
641+
predictions_test.predictions, # ty: ignore[invalid-argument-type]
642+
kind=self.training_kind,
643+
)
639644
if self.training_kind == "multiclass":
640645
y_label_pred = activate_probs(
641646
y_prob_pred, strategy="max", force_max_1_per_row=True

0 commit comments

Comments
 (0)