Skip to content

Commit 6e622c5

Browse files
committed
fix lint errors
1 parent dd73a38 commit 6e622c5

2 files changed

Lines changed: 9 additions & 7 deletions

File tree

src/imitation/scripts/train_adversarial.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,15 @@
2424

2525

2626
class CheckpointCallback(BaseCallback):
27+
"""A callback for calling `save` at regular intervals."""
28+
2729
def __init__(
2830
self,
2931
trainer: common.AdversarialTrainer,
3032
log_dir: pathlib.Path,
31-
interval: int
33+
interval: int,
3234
):
35+
"""Creates new Checkpoint callback."""
3336
super().__init__(self)
3437
self.trainer = trainer
3538
self.log_dir = log_dir

tests/algorithms/test_adversarial.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -468,11 +468,10 @@ def test_regression_gail_with_sac(
468468

469469

470470
def test_gen_callback(trainer: common.AdversarialTrainer):
471-
learner = stable_baselines3.PPO("MlpPolicy", env=trainer.venv)
472-
473471
def make_fn_callback(calls, key):
474472
def cb(_a, _b):
475473
calls[key] += 1
474+
476475
return cb
477476

478477
class SB3Callback(BaseCallback):
@@ -490,10 +489,10 @@ def _on_step(self):
490489

491490
trainer.train(n_steps, callback=make_fn_callback(calls, "fn"))
492491
trainer.train(n_steps, callback=SB3Callback(calls, "sb3"))
493-
trainer.train(n_steps, callback=[
494-
SB3Callback(calls, "list.0"),
495-
SB3Callback(calls, "list.1")
496-
])
492+
trainer.train(
493+
n_steps,
494+
callback=[SB3Callback(calls, "list.0"), SB3Callback(calls, "list.1")],
495+
)
497496

498497
# Env steps for off-plicy algos (DQN) may exceed `total_timesteps`,
499498
# so we check if the callback was called *at least* that many times.

0 commit comments

Comments
 (0)