Skip to content

Commit f91bec1

Browse files
lukebaumanncopybara-github
authored andcommitted
Update concatenate_by_mesh_axis to preserve memory kind
PiperOrigin-RevId: 904207252
1 parent 50dfba9 commit f91bec1

1 file changed

Lines changed: 5 additions & 1 deletion

File tree

pathwaysutils/experimental/concatenate_by_mesh_axis.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,11 @@ def _get_output_sharding(
117117
) -> jax.sharding.NamedSharding:
118118
reference_sharding = _get_named_sharding(arrays[0])
119119
reference_spec = reference_sharding.spec
120-
return jax.sharding.NamedSharding(concatenated_mesh, reference_spec)
120+
return jax.sharding.NamedSharding(
121+
concatenated_mesh,
122+
reference_spec,
123+
memory_kind=reference_sharding.memory_kind,
124+
)
121125

122126
def _sharded_dim_idx_for_sharding(
123127
sharding: jax.sharding.NamedSharding,

0 commit comments

Comments
 (0)