@@ -102,6 +102,8 @@ def __init__(self, adversarial_trainer, *args, **kwargs):
102102 """Builds TrainDiscriminatorCallback.
103103
104104 Args:
105+ adversarial_trainer: The AdversarialTrainer instance in which
106+ this callback will be called.
105107 *args: Passed through to `callbacks.BaseCallback`.
106108 **kwargs: Passed through to `callbacks.BaseCallback`.
107109 """
@@ -277,7 +279,7 @@ def __init__(
277279 # Would use an identity reward fn here, but RewardFns can't see rewards.
278280 self .venv_wrapped = self .venv_buffering
279281 self .gen_callback : List [callbacks .BaseCallback ] = [
280- self .disc_trainer_callback
282+ self .disc_trainer_callback ,
281283 ]
282284 else :
283285 self .venv_wrapped = reward_wrapper .RewardVecEnvWrapper (
@@ -370,7 +372,7 @@ def update_rewards_of_rollouts(self) -> None:
370372 buffer = self .gen_algo .rollout_buffer
371373 assert buffer is not None
372374 reward_fn_inputs = replay_buffer_wrapper ._rollout_buffer_to_reward_fn_input (
373- self .gen_algo .rollout_buffer
375+ self .gen_algo .rollout_buffer ,
374376 )
375377 rewards = self ._reward_net .predict (** reward_fn_inputs )
376378 rewards = rewards .reshape (buffer .rewards .shape )
@@ -381,13 +383,14 @@ def update_rewards_of_rollouts(self) -> None:
381383 last_dones = last_values == 0.0
382384 self .gen_algo .rollout_buffer .rewards [:] = rewards
383385 self .gen_algo .rollout_buffer .compute_returns_and_advantage (
384- th .tensor (last_values ), last_dones
386+ th .tensor (last_values ),
387+ last_dones ,
385388 )
386389 elif isinstance (self .gen_algo , off_policy_algorithm .OffPolicyAlgorithm ):
387390 buffer = self .gen_algo .replay_buffer
388391 assert buffer is not None
389392 reward_fn_inputs = replay_buffer_wrapper ._replay_buffer_to_reward_fn_input (
390- buffer
393+ buffer ,
391394 )
392395 rewards = self ._reward_net .predict (** reward_fn_inputs )
393396 buffer .rewards [:] = rewards .reshape (buffer .rewards .shape )
@@ -466,13 +469,15 @@ def train_disc(
466469
467470 return train_stats
468471
469- def train_gen (
472+ def train_gen_with_disc (
470473 self ,
471474 total_timesteps : Optional [int ] = None ,
472475 learn_kwargs : Optional [Mapping ] = None ,
473476 ) -> None :
474477 """Trains the generator to maximize the discriminator loss.
475478
479+ The discriminator is also trained after the rollouts are collected and before
480+ the generator is trained.
476481 After the end of training populates the generator replay buffer (used in
477482 discriminator training) with `self.disc_batch_size` transitions.
478483
@@ -509,7 +514,7 @@ def train(
509514 ) -> None :
510515 """Alternates between training the generator and discriminator.
511516
512- Every "round" consists of a call to `train_gen (self.gen_train_timesteps)`,
517+ Every "round" consists of a call to `train_gen_with_disc (self.gen_train_timesteps)`,
513518 a call to `train_disc`, and finally a call to `callback(round)`.
514519
515520 Training ends once an additional "round" would cause the number of transitions
@@ -529,7 +534,7 @@ def train(
529534 f"total_timesteps={ total_timesteps } )!"
530535 )
531536 for r in tqdm .tqdm (range (0 , n_rounds ), desc = "round" ):
532- self .train_gen (self .gen_train_timesteps )
537+ self .train_gen_with_disc (self .gen_train_timesteps )
533538 if callback :
534539 callback (r )
535540 self .logger .dump (self ._global_step )
@@ -621,7 +626,8 @@ def _make_disc_train_batches(
621626 if gen_samples is None :
622627 if self ._gen_replay_buffer .size () == 0 :
623628 raise RuntimeError (
624- "No generator samples for training. " "Call `train_gen()` first." ,
629+ "No generator samples for training. "
630+ "Call `train_gen_with_disc()` first." ,
625631 )
626632 gen_samples_dataclass = self ._gen_replay_buffer .sample (batch_size )
627633 gen_samples = types .dataclass_quick_asdict (gen_samples_dataclass )
0 commit comments