Skip to content

Commit 031a7c4

Browse files
gneculacopybara-github
authored andcommitted
[consts] Adapt MPMD code to the new JAX handling of closed-over constants
PiperOrigin-RevId: 897465166
1 parent 4d4e625 commit 031a7c4

1 file changed

Lines changed: 45 additions & 0 deletions

File tree

shardy/integrations/python/jax/mpmd/stages.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from shardy.integrations.python.jax.mpmd import types as mpmd_types
3939
from shardy.integrations.python.jax.mpmd import utils
4040

41+
4142
PyTree = jaxtyping.PyTree
4243
FunctionNamedShardings = utils.FunctionNamedShardings
4344

@@ -135,6 +136,16 @@ def __init__(
135136
self._kept_var_idx = kept_inputs_indices
136137
self._donated_inputs_indices = donated_inputs_indices
137138

139+
logging.info('XXX openxla MpmdExecutable.init func_name: %s', func_name)
140+
logging.info('XXX MpmdExecutable.init nr_const_args: %s', nr_const_args)
141+
logging.info('XXX MpmdExecutable.init kept_inputs_indices[%d]: %s',
142+
len(kept_inputs_indices), kept_inputs_indices)
143+
logging.info('donated_inputs_indices[%d]: %s',
144+
len(donated_inputs_indices), donated_inputs_indices)
145+
logging.info('flat_in_avals[%d]: %s', len(flat_in_avals), flat_in_avals)
146+
logging.info('in_shardings[%d]: %s', len(in_shardings), in_shardings)
147+
logging.info('module_ir: %s', module_ir)
148+
138149
# Hacks for the export flow.
139150
unloaded_executable = NamedTuple(
140151
'MpmdUnloadedExecutableInfo',
@@ -167,6 +178,24 @@ def call(self, *args) -> Sequence[Any]:
167178
'MPMD Executable called with inputs that are not a `jax.Array`. '
168179
f'Errors: {str_errors}'
169180
)
181+
logging.info('XXX openxla MpmdExecutable.call: args[%d]', len(args))
182+
logging.info('MpmdExecutable.call: const_args[%d]',
183+
self.nr_const_args)
184+
logging.info('MpmdExecutable.call: kept_var_idx[%d]: %s',
185+
len(self._kept_var_idx), self._kept_var_idx)
186+
logging.info('MpmdExecutable.call: kept_in_avals[%d]: %s',
187+
len(self._kept_in_avals), self._kept_in_avals)
188+
logging.info('MpmdExecutable.call: kept_in_shardings[%d]: %s',
189+
len(self._kept_in_shardings), self._kept_in_shardings)
190+
logging.info('MpmdExecutable.call: kept_in_shardings_paths[%d]: %s',
191+
len(self._kept_in_shardings_paths),
192+
self._kept_in_shardings_paths)
193+
logging.info('MpmdExecutable.call: donated_inputs_indices[%d]: %s',
194+
len(self._donated_inputs_indices),
195+
self._donated_inputs_indices)
196+
logging.info('MpmdExecutable.call: topology: %s', self._topology)
197+
logging.info('MpmdExecutable.call: avals[%d]: %s', len(args),
198+
[a.aval for a in args])
170199

171200
if self.nr_const_args > 0:
172201
const_args = args[:self.nr_const_args]
@@ -177,12 +206,17 @@ def call(self, *args) -> Sequence[Any]:
177206
const_shardings = [
178207
getattr(c, 'sharding', replicated_sharding)
179208
for c in const_args]
209+
logging.info('MpmdExecutable.call: const_shardings[%d]: %s',
210+
len(const_shardings), const_shardings)
180211
const_layouts = pjit.const_args_layouts(
181212
const_args, const_args_avals, const_shardings)
182213
const_args_sharded = pxla.shard_args(
183214
const_shardings, const_layouts,
184215
[xla_client.ArrayCopySemantics.REUSE_INPUT] * self.nr_const_args,
185216
const_args)
217+
logging.info('MpmdExecutable.call: const_args_sharded[%d]: %s',
218+
len(const_args_sharded),
219+
[type(a) for a in const_args_sharded])
186220
else:
187221
const_args_sharded = []
188222

@@ -201,6 +235,10 @@ def call(self, *args) -> Sequence[Any]:
201235
# The argument is donated and not used by the program. It is safe
202236
# to delete the argument.
203237
arg.delete()
238+
logging.info('MpmdExecutable.call: kept_args[%d]: %s',
239+
len(kept_args), {i: v.aval for i, v in kept_args.items()})
240+
logging.info('MpmdExecutable.call: arg_avals[%d]: %s',
241+
len(arg_avals), arg_avals)
204242
pxla.check_arg_avals_for_call(
205243
self._kept_in_avals, arg_avals[self.nr_const_args:], self._debug_info
206244
)
@@ -214,6 +252,8 @@ def call(self, *args) -> Sequence[Any]:
214252
tuple((k, v) for k, v in self._topology.items()),
215253
)
216254
kept_args_tuple = tuple(v for _, v in sorted(kept_args.items()))
255+
logging.info('MpmdExecutable.call: execute args[%d]: %s',
256+
len(kept_args_tuple), [type(a) for a in kept_args_tuple])
217257
return self._executable.execute(kept_args_tuple)
218258

219259
def create_cpp_call(
@@ -269,6 +309,7 @@ def _xla_in_layouts(self) -> Sequence[layout.Layout | None]:
269309
"""Returns the input layouts for used inputs only."""
270310
input_xla_layouts = self._executable.input_layouts()
271311
input_shardings = self._in_shardings
312+
assert not self.nr_const_args
272313
# Remove shardings for unused inputs.
273314
if len(input_xla_layouts) < self.nr_const_args + len(input_shardings):
274315
iter_layouts = iter(input_xla_layouts)
@@ -480,6 +521,9 @@ def __init__(
480521

481522
meshes_and_specs = partitioning_result.module_io_sharding_specs_and_meshes
482523
nr_const_args = len(self.const_args)
524+
logging.info('XXX openxla MpmdLowered.init nr_const_args: %s', nr_const_args)
525+
logging.info('MpmdLowered.const_args[%d]: %s',
526+
len(self.const_args), self.const_args)
483527
if len(jax_fn_info.kept_inputs_indices) != nr_const_args + len(
484528
jax_fn_info.global_flat_input_abstract_values
485529
):
@@ -582,6 +626,7 @@ def compile(
582626
Raises:
583627
ValueError: if the function has already been compiled.
584628
"""
629+
logging.info('XXX openxla MpmdLowered.compile')
585630
if device_assignment is None:
586631
in_shardings = self.function_named_shardings.input_specs
587632
flat_out_shardings = jax.tree.leaves(

0 commit comments

Comments
 (0)