Skip to content

Commit de7f386

Browse files
committed
pipeline fix
1 parent 24ee27a commit de7f386

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1359,7 +1359,7 @@ def __call__(
13591359

13601360
# If weights are on CPU (from a previous offload), restore them to TPU with correct sharding
13611361
sample_leaf = jax.tree_util.tree_leaves(state)[0]
1362-
if sample_leaf.device().platform == "cpu":
1362+
if list(sample_leaf.devices())[0].platform == "cpu":
13631363
logical_state_spec = nnx.get_partition_spec(state)
13641364
logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, self.mesh, self.config.logical_axis_rules)
13651365
state = jax.tree_util.tree_map(lambda x, sharding: jax.device_put(x, sharding), state, logical_state_sharding)

0 commit comments

Comments
 (0)