Skip to content

Commit 47b3874

Browse files
taufeeque9ernestum
authored andcommitted
Fix test errors
1 parent 53c1212 commit 47b3874

2 files changed

Lines changed: 17 additions & 10 deletions

File tree

src/imitation/algorithms/adversarial/common.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ def __init__(self, adversarial_trainer, *args, **kwargs):
102102
"""Builds TrainDiscriminatorCallback.
103103
104104
Args:
105+
adversarial_trainer: The AdversarialTrainer instance in which
106+
this callback will be called.
105107
*args: Passed through to `callbacks.BaseCallback`.
106108
**kwargs: Passed through to `callbacks.BaseCallback`.
107109
"""
@@ -277,7 +279,7 @@ def __init__(
277279
# Would use an identity reward fn here, but RewardFns can't see rewards.
278280
self.venv_wrapped = self.venv_buffering
279281
self.gen_callback: List[callbacks.BaseCallback] = [
280-
self.disc_trainer_callback
282+
self.disc_trainer_callback,
281283
]
282284
else:
283285
self.venv_wrapped = reward_wrapper.RewardVecEnvWrapper(
@@ -370,7 +372,7 @@ def update_rewards_of_rollouts(self) -> None:
370372
buffer = self.gen_algo.rollout_buffer
371373
assert buffer is not None
372374
reward_fn_inputs = replay_buffer_wrapper._rollout_buffer_to_reward_fn_input(
373-
self.gen_algo.rollout_buffer
375+
self.gen_algo.rollout_buffer,
374376
)
375377
rewards = self._reward_net.predict(**reward_fn_inputs)
376378
rewards = rewards.reshape(buffer.rewards.shape)
@@ -381,13 +383,14 @@ def update_rewards_of_rollouts(self) -> None:
381383
last_dones = last_values == 0.0
382384
self.gen_algo.rollout_buffer.rewards[:] = rewards
383385
self.gen_algo.rollout_buffer.compute_returns_and_advantage(
384-
th.tensor(last_values), last_dones
386+
th.tensor(last_values),
387+
last_dones,
385388
)
386389
elif isinstance(self.gen_algo, off_policy_algorithm.OffPolicyAlgorithm):
387390
buffer = self.gen_algo.replay_buffer
388391
assert buffer is not None
389392
reward_fn_inputs = replay_buffer_wrapper._replay_buffer_to_reward_fn_input(
390-
buffer
393+
buffer,
391394
)
392395
rewards = self._reward_net.predict(**reward_fn_inputs)
393396
buffer.rewards[:] = rewards.reshape(buffer.rewards.shape)
@@ -466,13 +469,15 @@ def train_disc(
466469

467470
return train_stats
468471

469-
def train_gen(
472+
def train_gen_with_disc(
470473
self,
471474
total_timesteps: Optional[int] = None,
472475
learn_kwargs: Optional[Mapping] = None,
473476
) -> None:
474477
"""Trains the generator to maximize the discriminator loss.
475478
479+
The discriminator is also trained after the rollouts are collected and before
480+
the generator is trained.
476481
After the end of training populates the generator replay buffer (used in
477482
discriminator training) with `self.disc_batch_size` transitions.
478483
@@ -509,7 +514,7 @@ def train(
509514
) -> None:
510515
"""Alternates between training the generator and discriminator.
511516
512-
Every "round" consists of a call to `train_gen(self.gen_train_timesteps)`,
517+
Every "round" consists of a call to `train_gen_with_disc(self.gen_train_timesteps)`,
513518
a call to `train_disc`, and finally a call to `callback(round)`.
514519
515520
Training ends once an additional "round" would cause the number of transitions
@@ -529,7 +534,7 @@ def train(
529534
f"total_timesteps={total_timesteps})!"
530535
)
531536
for r in tqdm.tqdm(range(0, n_rounds), desc="round"):
532-
self.train_gen(self.gen_train_timesteps)
537+
self.train_gen_with_disc(self.gen_train_timesteps)
533538
if callback:
534539
callback(r)
535540
self.logger.dump(self._global_step)
@@ -621,7 +626,8 @@ def _make_disc_train_batches(
621626
if gen_samples is None:
622627
if self._gen_replay_buffer.size() == 0:
623628
raise RuntimeError(
624-
"No generator samples for training. " "Call `train_gen()` first.",
629+
"No generator samples for training. "
630+
"Call `train_gen_with_disc()` first.",
625631
)
626632
gen_samples_dataclass = self._gen_replay_buffer.sample(batch_size)
627633
gen_samples = types.dataclass_quick_asdict(gen_samples_dataclass)

tests/algorithms/test_adversarial.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,9 @@ def test_train_gen_train_disc_no_crash(
231231
trainer_parametrized: common.AdversarialTrainer,
232232
n_updates: int = 2,
233233
) -> None:
234-
trainer_parametrized.train_gen(n_updates * trainer_parametrized.gen_train_timesteps)
235-
trainer_parametrized.train_disc()
234+
trainer_parametrized.train_gen_with_disc(
235+
n_updates * trainer_parametrized.gen_train_timesteps
236+
)
236237

237238

238239
@pytest.fixture

0 commit comments

Comments
 (0)