Skip to content

Commit e84a713

Browse files
committed
Fix argument name in unit test
1 parent 71c8057 commit e84a713

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

tests/test_loss_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_log_cosh_of_log_transformed(dtype=torch.float32):
6363
y = 0.5 * x.clone().squeeze() # Shape [N,]
6464

6565
log_cosh_loss = LogCoshLoss()
66-
log_cosh_of_log_transformed_loss = LogCoshLoss(transform_output=lambda x: torch.log10(x))
66+
log_cosh_of_log_transformed_loss = LogCoshLoss(transform_prediction_and_target=lambda x: torch.log10(x))
6767
assert torch.allclose(
6868
log_cosh_loss(torch.log10(x), torch.log10(y), return_elements=True),
6969
log_cosh_of_log_transformed_loss(x, y, return_elements=True),

0 commit comments

Comments
 (0)