We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 50dfba9 commit f91bec1Copy full SHA for f91bec1
1 file changed
pathwaysutils/experimental/concatenate_by_mesh_axis.py
@@ -117,7 +117,11 @@ def _get_output_sharding(
117
) -> jax.sharding.NamedSharding:
118
reference_sharding = _get_named_sharding(arrays[0])
119
reference_spec = reference_sharding.spec
120
- return jax.sharding.NamedSharding(concatenated_mesh, reference_spec)
+ return jax.sharding.NamedSharding(
121
+ concatenated_mesh,
122
+ reference_spec,
123
+ memory_kind=reference_sharding.memory_kind,
124
+ )
125
126
def _sharded_dim_idx_for_sharding(
127
sharding: jax.sharding.NamedSharding,
0 commit comments