@@ -51,7 +51,7 @@ def __init__(
5151 If ``None``, the :class:`torch.optim.Adam` optimizer is used.
5252 Default is ``None``.
5353 :param Optimizer optimizer_discriminator: The optimizer for the
54- discriminator. If ``None``, the :class:`torch.optim.Adam`
54+ discriminator. If ``None``, the :class:`torch.optim.Adam`
5555 optimizer is used. Default is ``None``.
5656 :param Scheduler scheduler_generator: The learning rate scheduler for
5757 the generator.
@@ -88,7 +88,7 @@ def __init__(
8888 check_consistency (
8989 loss , (LossInterface , _Loss , torch .nn .Module ), subclass = False
9090 )
91- self ._loss = loss
91+ self ._loss_fn = loss
9292
9393 # set automatic optimization for GANs
9494 self .automatic_optimization = False
@@ -157,10 +157,11 @@ def _train_generator(self, parameters, snapshots):
157157 generated_snapshots = self .sample (parameters )
158158
159159 # generator loss
160- r_loss = self ._loss (snapshots , generated_snapshots )
160+ r_loss = self ._loss_fn (snapshots , generated_snapshots )
161161 d_fake = self .discriminator ([generated_snapshots , parameters ])
162162 g_loss = (
163- self ._loss (d_fake , generated_snapshots ) + self .regularizer * r_loss
163+ self ._loss_fn (d_fake , generated_snapshots )
164+ + self .regularizer * r_loss
164165 )
165166
166167 # backward step
@@ -189,8 +190,8 @@ def _train_discriminator(self, parameters, snapshots):
189190 d_fake = self .discriminator ([generated_snapshots , parameters ])
190191
191192 # evaluate loss
192- d_loss_real = self ._loss (d_real , snapshots )
193- d_loss_fake = self ._loss (d_fake , generated_snapshots .detach ())
193+ d_loss_real = self ._loss_fn (d_real , snapshots )
194+ d_loss_fake = self ._loss_fn (d_fake , generated_snapshots .detach ())
194195 d_loss = d_loss_real - self .k * d_loss_fake
195196
196197 # backward step
@@ -270,7 +271,7 @@ def validation_step(self, batch):
270271 points ["target" ],
271272 )
272273 snapshots_gen = self .generator (parameters )
273- condition_loss [condition_name ] = self ._loss (
274+ condition_loss [condition_name ] = self ._loss_fn (
274275 snapshots , snapshots_gen
275276 )
276277 loss = self .weighting .aggregate (condition_loss )
@@ -293,7 +294,7 @@ def test_step(self, batch):
293294 points ["target" ],
294295 )
295296 snapshots_gen = self .generator (parameters )
296- condition_loss [condition_name ] = self ._loss (
297+ condition_loss [condition_name ] = self ._loss_fn (
297298 snapshots , snapshots_gen
298299 )
299300 loss = self .weighting .aggregate (condition_loss )
0 commit comments