Skip to content

Commit 9d8605b

Browse files
committed
Fixing typing for transformers integration
1 parent 73e33ee commit 9d8605b

1 file changed

Lines changed: 4 additions & 8 deletions

File tree

dreadnode/integrations/transformers.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,8 @@
55

66
import typing as t
77

8-
from transformers.trainer_callback import ( # type: ignore [import-untyped]
9-
TrainerCallback,
10-
TrainerControl,
11-
TrainerState,
12-
TrainingArguments,
13-
)
8+
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
9+
from transformers.training_args import TrainingArguments
1410

1511
import dreadnode as dn
1612

@@ -28,7 +24,7 @@ def _clean_keys(data: dict[str, t.Any]) -> dict[str, t.Any]:
2824
return cleaned
2925

3026

31-
class DreadnodeCallback(TrainerCallback): # type: ignore [misc]
27+
class DreadnodeCallback(TrainerCallback):
3228
"""
3329
An implementation of the `TrainerCallback` interface for Dreadnode.
3430
@@ -124,7 +120,7 @@ def on_epoch_begin(
124120
control: TrainerControl,
125121
**kwargs: t.Any,
126122
) -> None:
127-
if self._run is None:
123+
if self._run is None or state.epoch is None:
128124
return
129125

130126
dn.log_metric("epoch", state.epoch)

0 commit comments

Comments
 (0)