Skip to content

Commit 4979038

Browse files
committed
Fix mtl_backward test of create transform and check keys
1 parent ac2b31b commit 4979038

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

tests/unit/autojac/test_mtl_backward.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ def test_check_create_transform():
2020
y1 = f1 * p1[0] + f2 * p1[1]
2121
y2 = f1 * p2[0] + f2 * p2[1]
2222

23-
_create_transform([y1, y2], [f1, f2], Mean(), [[p1], [p2]], {p0}, False, None)
23+
transform = _create_transform([y1, y2], [f1, f2], Mean(), [[p1], [p2]], {p0}, False, None)
24+
transform.check_keys()
2425

2526

2627
@mark.parametrize("aggregator", [Mean(), UPGrad(), MGDA(), Random()])

0 commit comments

Comments
 (0)