3838from shardy .integrations .python .jax .mpmd import types as mpmd_types
3939from shardy .integrations .python .jax .mpmd import utils
4040
41+
4142PyTree = jaxtyping .PyTree
4243FunctionNamedShardings = utils .FunctionNamedShardings
4344
@@ -135,6 +136,16 @@ def __init__(
135136 self ._kept_var_idx = kept_inputs_indices
136137 self ._donated_inputs_indices = donated_inputs_indices
137138
139+ logging .info ('XXX openxla MpmdExecutable.init func_name: %s' , func_name )
140+ logging .info ('XXX MpmdExecutable.init nr_const_args: %s' , nr_const_args )
141+ logging .info ('XXX MpmdExecutable.init kept_inputs_indices[%d]: %s' ,
142+ len (kept_inputs_indices ), kept_inputs_indices )
143+ logging .info ('donated_inputs_indices[%d]: %s' ,
144+ len (donated_inputs_indices ), donated_inputs_indices )
145+ logging .info ('flat_in_avals[%d]: %s' , len (flat_in_avals ), flat_in_avals )
146+ logging .info ('in_shardings[%d]: %s' , len (in_shardings ), in_shardings )
147+ logging .info ('module_ir: %s' , module_ir )
148+
138149 # Hacks for the export flow.
139150 unloaded_executable = NamedTuple (
140151 'MpmdUnloadedExecutableInfo' ,
@@ -167,6 +178,24 @@ def call(self, *args) -> Sequence[Any]:
167178 'MPMD Executable called with inputs that are not a `jax.Array`. '
168179 f'Errors: { str_errors } '
169180 )
181+ logging .info ('XXX openxla MpmdExecutable.call: args[%d]' , len (args ))
182+ logging .info ('MpmdExecutable.call: const_args[%d]' ,
183+ self .nr_const_args )
184+ logging .info ('MpmdExecutable.call: kept_var_idx[%d]: %s' ,
185+ len (self ._kept_var_idx ), self ._kept_var_idx )
186+ logging .info ('MpmdExecutable.call: kept_in_avals[%d]: %s' ,
187+ len (self ._kept_in_avals ), self ._kept_in_avals )
188+ logging .info ('MpmdExecutable.call: kept_in_shardings[%d]: %s' ,
189+ len (self ._kept_in_shardings ), self ._kept_in_shardings )
190+ logging .info ('MpmdExecutable.call: kept_in_shardings_paths[%d]: %s' ,
191+ len (self ._kept_in_shardings_paths ),
192+ self ._kept_in_shardings_paths )
193+ logging .info ('MpmdExecutable.call: donated_inputs_indices[%d]: %s' ,
194+ len (self ._donated_inputs_indices ),
195+ self ._donated_inputs_indices )
196+ logging .info ('MpmdExecutable.call: topology: %s' , self ._topology )
197+ logging .info ('MpmdExecutable.call: avals[%d]: %s' , len (args ),
198+ [a .aval for a in args ])
170199
171200 if self .nr_const_args > 0 :
172201 const_args = args [:self .nr_const_args ]
@@ -177,12 +206,17 @@ def call(self, *args) -> Sequence[Any]:
177206 const_shardings = [
178207 getattr (c , 'sharding' , replicated_sharding )
179208 for c in const_args ]
209+ logging .info ('MpmdExecutable.call: const_shardings[%d]: %s' ,
210+ len (const_shardings ), const_shardings )
180211 const_layouts = pjit .const_args_layouts (
181212 const_args , const_args_avals , const_shardings )
182213 const_args_sharded = pxla .shard_args (
183214 const_shardings , const_layouts ,
184215 [xla_client .ArrayCopySemantics .REUSE_INPUT ] * self .nr_const_args ,
185216 const_args )
217+ logging .info ('MpmdExecutable.call: const_args_sharded[%d]: %s' ,
218+ len (const_args_sharded ),
219+ [type (a ) for a in const_args_sharded ])
186220 else :
187221 const_args_sharded = []
188222
@@ -201,6 +235,10 @@ def call(self, *args) -> Sequence[Any]:
201235 # The argument is donated and not used by the program. It is safe
202236 # to delete the argument.
203237 arg .delete ()
238+ logging .info ('MpmdExecutable.call: kept_args[%d]: %s' ,
239+ len (kept_args ), {i : v .aval for i , v in kept_args .items ()})
240+ logging .info ('MpmdExecutable.call: arg_avals[%d]: %s' ,
241+ len (arg_avals ), arg_avals )
204242 pxla .check_arg_avals_for_call (
205243 self ._kept_in_avals , arg_avals [self .nr_const_args :], self ._debug_info
206244 )
@@ -214,6 +252,8 @@ def call(self, *args) -> Sequence[Any]:
214252 tuple ((k , v ) for k , v in self ._topology .items ()),
215253 )
216254 kept_args_tuple = tuple (v for _ , v in sorted (kept_args .items ()))
255+ logging .info ('MpmdExecutable.call: execute args[%d]: %s' ,
256+ len (kept_args_tuple ), [type (a ) for a in kept_args_tuple ])
217257 return self ._executable .execute (kept_args_tuple )
218258
219259 def create_cpp_call (
@@ -480,6 +520,9 @@ def __init__(
480520
481521 meshes_and_specs = partitioning_result .module_io_sharding_specs_and_meshes
482522 nr_const_args = len (self .const_args )
523+ logging .info ('XXX openxla MpmdLowered.init nr_const_args: %s' , nr_const_args )
524+ logging .info ('MpmdLowered.const_args[%d]: %s' ,
525+ len (self .const_args ), self .const_args )
483526 if len (jax_fn_info .kept_inputs_indices ) != nr_const_args + len (
484527 jax_fn_info .global_flat_input_abstract_values
485528 ):
@@ -582,6 +625,7 @@ def compile(
582625 Raises:
583626 ValueError: if the function has already been compiled.
584627 """
628+ logging .info ('XXX openxla MpmdLowered.compile' )
585629 if device_assignment is None :
586630 in_shardings = self .function_named_shardings .input_specs
587631 flat_out_shardings = jax .tree .leaves (
0 commit comments