We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 71c8057 commit e84a713Copy full SHA for e84a713
1 file changed
tests/test_loss_functions.py
@@ -63,7 +63,7 @@ def test_log_cosh_of_log_transformed(dtype=torch.float32):
63
y = 0.5 * x.clone().squeeze() # Shape [N,]
64
65
log_cosh_loss = LogCoshLoss()
66
- log_cosh_of_log_transformed_loss = LogCoshLoss(transform_output=lambda x: torch.log10(x))
+ log_cosh_of_log_transformed_loss = LogCoshLoss(transform_prediction_and_target=lambda x: torch.log10(x))
67
assert torch.allclose(
68
log_cosh_loss(torch.log10(x), torch.log10(y), return_elements=True),
69
log_cosh_of_log_transformed_loss(x, y, return_elements=True),
0 commit comments