Skip to content

Commit 6a6fbc4

Browse files
hsuan-lun-chiangecnal-cienet
authored andcommitted
tests/unit/max_utils_test.py::UnscanTest::test_unscan_train_state_params
1 parent 2c8efe0 commit 6a6fbc4

1 file changed

Lines changed: 4 additions & 2 deletions

File tree

tests/unit/max_utils_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,13 +205,15 @@ def test_unscan_train_state_params(self):
205205
self.assertEqual(unstacked_shape, expected_shape)
206206

207207
# Check that the original state is unchanged.
208-
if hasattr(state, "model"):
208+
209+
if hasattr(state, 'model'):
209210
_, params_state, _ = nnx.split(state.model, nnx.Param, ...)
210211
state_decoder_params = params_state.to_pure_dict()["decoder"]
211212
self.assertIn("layers", state_decoder_params)
212213
else:
213214
self.assertIn("layers", state.params["params"]["decoder"])
214-
if hasattr(state, "model"):
215+
216+
if hasattr(state, 'model'):
215217
self.assertNotIn("layers_0", state_decoder_params)
216218
else:
217219
self.assertNotIn("layers_0", state.params["params"]["decoder"])

0 commit comments

Comments
 (0)