diff --git a/src/maxtext/kernels/ragged/ragged_sort.py b/src/maxtext/kernels/ragged/ragged_sort.py index b56e517704..1a0786a932 100644 --- a/src/maxtext/kernels/ragged/ragged_sort.py +++ b/src/maxtext/kernels/ragged/ragged_sort.py @@ -92,9 +92,6 @@ def _ring_ragged_sort_fwd(hidden_states_local, topk_indices_local): shard_output_end, ) - valid_mask = (jnp.arange(x.shape[0]) >= shard_output_start) & (jnp.arange(x.shape[0]) < shard_output_end) - x = jnp.where(valid_mask[:, None], x, 0.0) - out = (x, group_sizes_local, topk_argsort_revert_indices) res = ( @@ -125,8 +122,9 @@ def _ring_ragged_sort_bwd(res, g_out): # only iterates over the populated prefix, so we hand it the mask directly # rather than materializing a (mostly-zero) dense buffer ourselves. n = topk_argsort_revert_indices.shape[0] - pos = jnp.arange(n) - valid_rows_mask = (pos >= shard_output_start) & (pos < shard_output_end) + valid_rows_mask = (topk_argsort_revert_indices >= shard_output_start) & ( + topk_argsort_revert_indices < shard_output_end + ) # The forward scatter-add over `token_indices_sorted` is equivalent to a # gather-reduce: each input token has exactly `topk` contributions located # at sorted positions `topk_argsort_revert_indices[t*topk:(t+1)*topk]`. @@ -209,7 +207,6 @@ def _ring_ragged_unsort_fwd(sorted_tokens_local, group_sizes_local, topk_argsort topk_argsort_revert_indices, shard_output_start, shard_output_end, - sorted_tokens_local.shape, ) return out, res @@ -236,9 +233,8 @@ def _ring_ragged_unsort_bwd(res, g_out): range of ``j``. The simpler equivalent: gather of g_hidden_states_local using the inverse permutation, masked. """ - topk_argsort_revert_indices, shard_output_start, shard_output_end, sorted_tokens_local_shape = res + topk_argsort_revert_indices, shard_output_start, shard_output_end = res g_hidden_states_local = g_out - num_rows = sorted_tokens_local_shape[0] # We want: g_sorted_tokens[j] = g_hidden_states_local[i] where revert[i]=j. # Build the inverse permutation idx_inv such that idx_inv[j] = i. @@ -250,12 +246,6 @@ def _ring_ragged_unsort_bwd(res, g_out): shard_output_start, shard_output_end, ) - # Outside [start, end), positions must be zero — which the ragged_gather - # already guarantees because untouched output rows are uninitialized; we - # explicitly zero them. - pos = jnp.arange(num_rows) - valid = (pos >= shard_output_start) & (pos < shard_output_end) - grad_sorted_tokens = jnp.where(valid[:, None], grad_sorted_tokens, jnp.zeros_like(grad_sorted_tokens)) return grad_sorted_tokens, None, None _ring_ragged_unsort.defvjp(_ring_ragged_unsort_fwd, _ring_ragged_unsort_bwd) diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index baa951b4fe..abcced3a6a 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -1171,9 +1171,10 @@ def get_tokamax_group_sizes(group_sizes, inputs, kernel): elif self.config.attention == "vllm_rpa": return group_sizes else: + num_groups = group_sizes.shape[0] return tokamax.RaggedDotGroupSizes( group_sizes, - (inputs.shape[0] // kernel.shape[0],) * kernel.shape[0], + (inputs.shape[0] // num_groups,) * num_groups, ) def get_quantization_dtypes(): @@ -1184,7 +1185,7 @@ def get_quantization_dtypes(): rhs_quantize_dtype = quant_dg.fwd.dg_quantizer.rhs.numerics.get_dtype() return lhs_quantize_dtype, rhs_quantize_dtype - def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes): + def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes, group_offset): def extract_vma(tensor): # Parses the varying mesh axes from JAX's type string for a tensor inside shard_map. # jax.typeof(t) renders as e.g. 'f32[128,256]{V:(expert, fsdp)}'; this extracts @@ -1219,6 +1220,7 @@ def extract_vma(tensor): group_sizes=group_sizes, preferred_element_type=self.dtype, tiling=tiling, + group_offset=group_offset, lhs_quantize_dtype=lhs_quantize_dtype, rhs_quantize_dtype=rhs_quantize_dtype, use_qwix_quantization=self.config.use_qwix_quantization, @@ -1235,6 +1237,7 @@ def extract_vma(tensor): precision=jax.lax.Precision.DEFAULT, preferred_element_type=self.dtype, implementation="mosaic", + group_offset=group_offset, ) elif self.config.megablox: # Older forked megablox output = mblx.gmm( @@ -1243,6 +1246,7 @@ def extract_vma(tensor): group_sizes=group_sizes, preferred_element_type=self.dtype, tiling=tiling, + group_offset=group_offset, lhs_quantize_dtype=lhs_quantize_dtype, rhs_quantize_dtype=rhs_quantize_dtype, use_qwix_quantization=self.config.use_qwix_quantization, @@ -1382,11 +1386,6 @@ def route(x, logits, pre_bias_logits, rngs): rngs=rngs, ) - # Filter down to the group sizes that apply to only the experts in the - # current shard. - group_sizes = group_sizes[:num_experts_per_shard] - mask = jnp.arange(x.shape[0]) < jnp.sum(group_sizes) - x = jnp.where(mask[:, None], x, 0) else: x, sorted_selected_experts, weights, group_sizes, selected_experts, lb_loss, bias_updates = self.permute( x, logits, pre_bias_logits, self.config.use_custom_sort_vjp, rngs @@ -1559,7 +1558,16 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r if self.config.mlp_bias: w0_bias, w1_bias, wo_bias = self.transform_bias(routing.selected_experts, w0_bias, w1_bias, wo_bias) - gmm_fn = functools.partial(gmm, group_sizes=routing.group_sizes, expert_assignments=routing.selected_experts) + num_ep = self.get_expert_parallelism_size() + num_experts_per_shard = self.config.num_experts // num_ep + if self.config.use_ragged_sort and self.config.use_ring_of_experts: + experts_start = route_metadata.expert_shard_id * num_experts_per_shard + else: + experts_start = 0 + + gmm_fn = functools.partial( + gmm, group_sizes=routing.group_sizes, expert_assignments=routing.selected_experts, group_offset=experts_start + ) intermediate_layer = gmm_up(x, w0, w1, w0_bias, w1_bias, gmm_fn, weight_gather) wo_gather_axes, wo_tile_size = get_wo_gmm_params() @@ -1578,10 +1586,6 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, r intermediate_output = adc.checkpoint_name(adc.checkpoint_name(intermediate_output, "mlpwo"), "moe_mlpwo") if self.config.use_ring_of_experts: - # Set the outputs of tokens which were not processed to 0. - mask = jnp.arange(intermediate_output.shape[0]) < jnp.sum(routing.group_sizes) - intermediate_output = jnp.where(mask[:, None], intermediate_output, 0) - # Unsort and deduplicate the outputs locally. output = self.unpermute( intermediate_output, diff --git a/tests/utils/reference_hlo_deepseek3.txt b/tests/utils/reference_hlo_deepseek3.txt index ffff9103ad..39c44fd147 100644 --- a/tests/utils/reference_hlo_deepseek3.txt +++ b/tests/utils/reference_hlo_deepseek3.txt @@ -10,21 +10,21 @@ StackFrames %region_46.56 (top_k.25: bf16[], top_k.26: bf16[], top_k.27: s32[], top_k.28: s32[]) -> pred[] { - %constant.1408 = s32[]{:T(128)} constant(0) - %constant.1409 = s32[]{:T(128)} constant(2147483647) + %constant.1424 = s32[]{:T(128)} constant(0) + %constant.1425 = s32[]{:T(128)} constant(2147483647) %top_k.25 = bf16[]{:T(256)} parameter(0), metadata={op_name="top_k"} %top_k.26 = bf16[]{:T(256)} parameter(1), metadata={op_name="top_k"} %top_k.27 = s32[]{:T(128)} parameter(2), metadata={op_name="top_k"} %top_k.28 = s32[]{:T(128)} parameter(3), metadata={op_name="top_k"} %convert.393 = f32[]{:T(128)S(6)} convert(%top_k.25), metadata={op_name="convert.18"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} %bitcast-convert.39 = s32[]{:T(128)S(6)} bitcast-convert(%convert.393), metadata={op_name="bitcast-convert.8"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} - %compare.144 = pred[]{:T(512)S(6)} compare(%bitcast-convert.39, %constant.1408), direction=LT, metadata={op_name="compare.38"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} - %xor.40 = s32[]{:T(128)S(6)} xor(%constant.1409, %bitcast-convert.39), metadata={op_name="xor.8"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.144 = pred[]{:T(512)S(6)} compare(%bitcast-convert.39, %constant.1424), direction=LT, metadata={op_name="compare.38"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %xor.40 = s32[]{:T(128)S(6)} xor(%constant.1425, %bitcast-convert.39), metadata={op_name="xor.8"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} %select.127 = s32[]{:T(128)S(6)} select(%compare.144, %xor.40, %bitcast-convert.39), metadata={op_name="select.16"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[{"indices":["1","3"]}]}} %convert.394 = f32[]{:T(128)S(6)} convert(%top_k.26), metadata={op_name="convert.19"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} %bitcast-convert.40 = s32[]{:T(128)S(6)} bitcast-convert(%convert.394), metadata={op_name="bitcast-convert.9"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} - %compare.145 = pred[]{:T(512)S(6)} compare(%bitcast-convert.40, %constant.1408), direction=LT, metadata={op_name="compare.39"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} - %xor.41 = s32[]{:T(128)S(6)} xor(%constant.1409, %bitcast-convert.40), metadata={op_name="xor.9"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.145 = pred[]{:T(512)S(6)} compare(%bitcast-convert.40, %constant.1424), direction=LT, metadata={op_name="compare.39"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %xor.41 = s32[]{:T(128)S(6)} xor(%constant.1425, %bitcast-convert.40), metadata={op_name="xor.9"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} %select.128 = s32[]{:T(128)S(6)} select(%compare.145, %xor.41, %bitcast-convert.40), metadata={op_name="select.17"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[{"indices":["1","3"]}]}} %compare.146 = pred[]{:T(512)S(6)} compare(%select.127, %select.128), direction=GT, metadata={op_name="compare.0"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} %compare.147 = pred[]{:T(512)S(6)} compare(%select.128, %select.127), direction=GT, metadata={op_name="compare.117"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} @@ -78,19 +78,19 @@ StackFrames %region_107.126 (psum.6: bf16[], psum.9: bf16[]) -> bf16[] { %psum.6 = bf16[]{:T(256)} parameter(0), metadata={op_name="psum"} %psum.9 = bf16[]{:T(256)} parameter(1), metadata={op_name="psum"} - ROOT %add.1407 = bf16[]{:T(256)} add(%psum.6, %psum.9), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %add.1417 = bf16[]{:T(256)} add(%psum.6, %psum.9), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } %region_108.127 (psum.10: bf16[], psum.11: bf16[]) -> bf16[] { %psum.10 = bf16[]{:T(256)} parameter(0), metadata={op_name="psum"} %psum.11 = bf16[]{:T(256)} parameter(1), metadata={op_name="psum"} - ROOT %add.1408 = bf16[]{:T(256)} add(%psum.10, %psum.11), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %add.1418 = bf16[]{:T(256)} add(%psum.10, %psum.11), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } %region_109.128 (psum.14: bf16[], psum.15: bf16[]) -> bf16[] { %psum.14 = bf16[]{:T(256)} parameter(0), metadata={op_name="psum"} %psum.15 = bf16[]{:T(256)} parameter(1), metadata={op_name="psum"} - ROOT %add.1409 = bf16[]{:T(256)} add(%psum.14, %psum.15), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %add.1419 = bf16[]{:T(256)} add(%psum.14, %psum.15), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } %region_62.73 (reduce-window.111: s32[], reduce-window.112: s32[]) -> s32[] { @@ -211,167 +211,167 @@ StackFrames %param_0.17 = bf16[129280,512]{1,0:T(8,128)(2,1)} parameter(0) %param_1.108 = s32[1024]{0:T(1024)S(1)} parameter(1) %custom-call.13 = s32[1024]{0:T(1024)} custom-call(%param_1.108), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %slice.920 = s32[512]{0:T(512)} slice(%custom-call.13), slice={[0:512]}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %reshape.3318 = s32[4,128]{1,0:T(4,128)} reshape(%slice.920), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} - %transpose.847 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3318), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} - %gather.187 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} gather(%param_0.17, %transpose.847), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,512}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %transpose.846 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} transpose(%gather.187), dimensions={0,1,2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - ROOT %reshape.3317 = bf16[512,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.846), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %slice.892 = s32[512]{0:T(512)} slice(%custom-call.13), slice={[0:512]}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %reshape.3298 = s32[4,128]{1,0:T(4,128)} reshape(%slice.892), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %transpose.847 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3298), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %gather.183 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} gather(%param_0.17, %transpose.847), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,512}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %transpose.846 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} transpose(%gather.183), dimensions={0,1,2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + ROOT %reshape.3297 = bf16[512,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.846), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} } %fused_computation.6 (param_0.20: f32[163840,32], param_1.110: s32[1024]) -> f32[512,32] { %param_0.20 = f32[163840,32]{1,0:T(8,128)} parameter(0) %param_1.110 = s32[1024]{0:T(1024)S(1)} parameter(1) %custom-call.15 = s32[1024]{0:T(1024)} custom-call(%param_1.110), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} - %slice.922 = s32[512]{0:T(512)} slice(%custom-call.15), slice={[0:512]}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} - %reshape.3326 = s32[4,128]{1,0:T(4,128)} reshape(%slice.922), metadata={op_name="jit(train_step)/dense_layers/broadcast_in_dim" stack_frame_id=0} - %transpose.853 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3326), dimensions={0,1}, metadata={op_name="jit(train_step)/dense_layers/broadcast_in_dim" stack_frame_id=0} - %gather.189 = f32[4,128,32]{2,1,0:T(8,128)} gather(%param_0.20, %transpose.853), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,32}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} - %transpose.852 = f32[4,128,32]{2,1,0:T(8,128)} transpose(%gather.189), dimensions={0,1,2}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} - ROOT %reshape.3325 = f32[512,32]{1,0:T(8,128)} reshape(%transpose.852), metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} + %slice.894 = s32[512]{0:T(512)} slice(%custom-call.15), slice={[0:512]}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} + %reshape.3306 = s32[4,128]{1,0:T(4,128)} reshape(%slice.894), metadata={op_name="jit(train_step)/dense_layers/broadcast_in_dim" stack_frame_id=0} + %transpose.853 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3306), dimensions={0,1}, metadata={op_name="jit(train_step)/dense_layers/broadcast_in_dim" stack_frame_id=0} + %gather.185 = f32[4,128,32]{2,1,0:T(8,128)} gather(%param_0.20, %transpose.853), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,32}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} + %transpose.852 = f32[4,128,32]{2,1,0:T(8,128)} transpose(%gather.185), dimensions={0,1,2}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} + ROOT %reshape.3305 = f32[512,32]{1,0:T(8,128)} reshape(%transpose.852), metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} } %fused_computation.7 (param_0.23: f32[163840,32], param_1.112: s32[1024]) -> f32[512,32] { %param_0.23 = f32[163840,32]{1,0:T(8,128)} parameter(0) %param_1.112 = s32[1024]{0:T(1024)S(1)} parameter(1) %custom-call.17 = s32[1024]{0:T(1024)} custom-call(%param_1.112), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} - %slice.924 = s32[512]{0:T(512)} slice(%custom-call.17), slice={[0:512]}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} - %reshape.3334 = s32[4,128]{1,0:T(4,128)} reshape(%slice.924), metadata={op_name="jit(train_step)/dense_layers/broadcast_in_dim" stack_frame_id=0} - %transpose.859 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3334), dimensions={0,1}, metadata={op_name="jit(train_step)/dense_layers/broadcast_in_dim" stack_frame_id=0} - %gather.191 = f32[4,128,32]{2,1,0:T(8,128)} gather(%param_0.23, %transpose.859), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,32}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} - %transpose.858 = f32[4,128,32]{2,1,0:T(8,128)} transpose(%gather.191), dimensions={0,1,2}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} - ROOT %reshape.3333 = f32[512,32]{1,0:T(8,128)} reshape(%transpose.858), metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} + %slice.896 = s32[512]{0:T(512)} slice(%custom-call.17), slice={[0:512]}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} + %reshape.3314 = s32[4,128]{1,0:T(4,128)} reshape(%slice.896), metadata={op_name="jit(train_step)/dense_layers/broadcast_in_dim" stack_frame_id=0} + %transpose.859 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3314), dimensions={0,1}, metadata={op_name="jit(train_step)/dense_layers/broadcast_in_dim" stack_frame_id=0} + %gather.187 = f32[4,128,32]{2,1,0:T(8,128)} gather(%param_0.23, %transpose.859), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,32}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} + %transpose.858 = f32[4,128,32]{2,1,0:T(8,128)} transpose(%gather.187), dimensions={0,1,2}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} + ROOT %reshape.3313 = f32[512,32]{1,0:T(8,128)} reshape(%transpose.858), metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} } %fused_computation.8 (param_0.26: f32[163840,32], param_1.120: s32[1024]) -> f32[512,32] { %param_0.26 = f32[163840,32]{1,0:T(8,128)} parameter(0) %param_1.120 = s32[1024]{0:T(1024)S(1)} parameter(1) %custom-call.25 = s32[1024]{0:T(1024)} custom-call(%param_1.120), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} - %slice.932 = s32[512]{0:T(512)} slice(%custom-call.25), slice={[0:512]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} - %reshape.3342 = s32[4,128]{1,0:T(4,128)} reshape(%slice.932), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/select_n" stack_frame_id=0} - %transpose.865 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3342), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/select_n" stack_frame_id=0} - %gather.193 = f32[4,128,32]{2,1,0:T(8,128)} gather(%param_0.26, %transpose.865), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,32}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} - %transpose.864 = f32[4,128,32]{2,1,0:T(8,128)} transpose(%gather.193), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} - ROOT %reshape.3341 = f32[512,32]{1,0:T(8,128)S(1)} reshape(%transpose.864), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} + %slice.904 = s32[512]{0:T(512)} slice(%custom-call.25), slice={[0:512]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} + %reshape.3322 = s32[4,128]{1,0:T(4,128)} reshape(%slice.904), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/select_n" stack_frame_id=0} + %transpose.865 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3322), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/select_n" stack_frame_id=0} + %gather.189 = f32[4,128,32]{2,1,0:T(8,128)} gather(%param_0.26, %transpose.865), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,32}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} + %transpose.864 = f32[4,128,32]{2,1,0:T(8,128)} transpose(%gather.189), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} + ROOT %reshape.3321 = f32[512,32]{1,0:T(8,128)S(1)} reshape(%transpose.864), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} } %fused_computation.9 (param_0.29: f32[163840,32], param_1.122: s32[1024]) -> f32[512,32] { %param_0.29 = f32[163840,32]{1,0:T(8,128)} parameter(0) %param_1.122 = s32[1024]{0:T(1024)S(1)} parameter(1) %custom-call.27 = s32[1024]{0:T(1024)} custom-call(%param_1.122), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} - %slice.934 = s32[512]{0:T(512)} slice(%custom-call.27), slice={[0:512]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} - %reshape.3350 = s32[4,128]{1,0:T(4,128)} reshape(%slice.934), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/select_n" stack_frame_id=0} - %transpose.871 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3350), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/select_n" stack_frame_id=0} - %gather.195 = f32[4,128,32]{2,1,0:T(8,128)} gather(%param_0.29, %transpose.871), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,32}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} - %transpose.870 = f32[4,128,32]{2,1,0:T(8,128)} transpose(%gather.195), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} - ROOT %reshape.3349 = f32[512,32]{1,0:T(8,128)S(1)} reshape(%transpose.870), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} + %slice.906 = s32[512]{0:T(512)} slice(%custom-call.27), slice={[0:512]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} + %reshape.3330 = s32[4,128]{1,0:T(4,128)} reshape(%slice.906), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/select_n" stack_frame_id=0} + %transpose.871 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3330), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/select_n" stack_frame_id=0} + %gather.191 = f32[4,128,32]{2,1,0:T(8,128)} gather(%param_0.29, %transpose.871), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,32}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} + %transpose.870 = f32[4,128,32]{2,1,0:T(8,128)} transpose(%gather.191), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} + ROOT %reshape.3329 = f32[512,32]{1,0:T(8,128)S(1)} reshape(%transpose.870), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} } %fused_computation.10 (param_0.32: bf16[4096,512], param_1.126: s32[4096]) -> bf16[4096,512] { %param_0.32 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} parameter(0) %param_1.126 = s32[4096]{0:T(1024)S(1)} parameter(1) %custom-call.31 = s32[4096]{0:T(1024)} custom-call(%param_1.126), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[4096]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %slice.938 = s32[4096]{0:T(1024)} slice(%custom-call.31), slice={[0:4096]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %reshape.3358 = s32[4096]{0:T(1024)} reshape(%slice.938), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %transpose.877 = s32[4096]{0:T(1024)} transpose(%reshape.3358), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %gather.197 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.32, %transpose.877), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %transpose.876 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.197), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - ROOT %reshape.3357 = bf16[4096,512]{1,0:T(8,128)(2,1)} reshape(%transpose.876), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %slice.910 = s32[4096]{0:T(1024)} slice(%custom-call.31), slice={[0:4096]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %reshape.3338 = s32[4096]{0:T(1024)} reshape(%slice.910), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %transpose.877 = s32[4096]{0:T(1024)} transpose(%reshape.3338), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %gather.193 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.32, %transpose.877), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %transpose.876 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.193), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + ROOT %reshape.3337 = bf16[4096,512]{1,0:T(8,128)(2,1)} reshape(%transpose.876), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} } %fused_computation.11 (param_0.35: bf16[4096,512], param_1.128: s32[4096]) -> bf16[4096,512] { %param_0.35 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} parameter(0) %param_1.128 = s32[4096]{0:T(1024)S(1)} parameter(1) %custom-call.33 = s32[4096]{0:T(1024)} custom-call(%param_1.128), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[4096]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %slice.940 = s32[4096]{0:T(1024)} slice(%custom-call.33), slice={[0:4096]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %reshape.3366 = s32[4096]{0:T(1024)} reshape(%slice.940), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %transpose.883 = s32[4096]{0:T(1024)} transpose(%reshape.3366), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %gather.199 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.35, %transpose.883), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %transpose.882 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.199), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - ROOT %reshape.3365 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.882), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %slice.912 = s32[4096]{0:T(1024)} slice(%custom-call.33), slice={[0:4096]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %reshape.3346 = s32[4096]{0:T(1024)} reshape(%slice.912), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %transpose.883 = s32[4096]{0:T(1024)} transpose(%reshape.3346), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %gather.195 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.35, %transpose.883), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %transpose.882 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.195), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + ROOT %reshape.3345 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.882), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} } %fused_computation.12 (param_0.38: bf16[4096,512], param_1.130: s32[4096]) -> bf16[4096,512] { %param_0.38 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} parameter(0) %param_1.130 = s32[4096]{0:T(1024)S(1)} parameter(1) %custom-call.35 = s32[4096]{0:T(1024)} custom-call(%param_1.130), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[4096]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %slice.942 = s32[4096]{0:T(1024)} slice(%custom-call.35), slice={[0:4096]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %reshape.3374 = s32[4096]{0:T(1024)} reshape(%slice.942), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %transpose.889 = s32[4096]{0:T(1024)} transpose(%reshape.3374), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %gather.201 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.38, %transpose.889), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %transpose.888 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.201), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - ROOT %reshape.3373 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.888), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %slice.914 = s32[4096]{0:T(1024)} slice(%custom-call.35), slice={[0:4096]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %reshape.3354 = s32[4096]{0:T(1024)} reshape(%slice.914), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %transpose.889 = s32[4096]{0:T(1024)} transpose(%reshape.3354), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %gather.197 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.38, %transpose.889), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %transpose.888 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.197), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + ROOT %reshape.3353 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.888), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} } %fused_computation.13 (param_0.41: bf16[4096,512], param_1.132: s32[4096]) -> bf16[4096,512] { %param_0.41 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} parameter(0) %param_1.132 = s32[4096]{0:T(1024)S(1)} parameter(1) %custom-call.37 = s32[4096]{0:T(1024)} custom-call(%param_1.132), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[4096]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %slice.944 = s32[4096]{0:T(1024)} slice(%custom-call.37), slice={[0:4096]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %reshape.3382 = s32[4096]{0:T(1024)} reshape(%slice.944), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %transpose.895 = s32[4096]{0:T(1024)} transpose(%reshape.3382), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %gather.203 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.41, %transpose.895), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %transpose.894 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.203), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - ROOT %reshape.3381 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.894), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %slice.916 = s32[4096]{0:T(1024)} slice(%custom-call.37), slice={[0:4096]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %reshape.3362 = s32[4096]{0:T(1024)} reshape(%slice.916), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %transpose.895 = s32[4096]{0:T(1024)} transpose(%reshape.3362), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %gather.199 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.41, %transpose.895), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %transpose.894 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.199), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + ROOT %reshape.3361 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.894), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} } %fused_computation.15 (param_0.47: s32[256], param_1.124: s32[1024]) -> s32[263] { %param_0.47 = s32[256]{0:T(256)S(1)} parameter(0) %param_1.124 = s32[1024]{0:T(1024)S(1)} parameter(1) %custom-call.29 = s32[1024]{0:T(1024)} custom-call(%param_1.124), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} - %slice.936 = s32[263]{0:T(512)} slice(%custom-call.29), slice={[0:263]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} - %reshape.3413 = s32[263]{0:T(512)} reshape(%slice.936), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} - %transpose.911 = s32[263]{0:T(512)} transpose(%reshape.3413), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} - %gather.208 = s32[263]{0:T(512)} gather(%param_0.47, %transpose.911), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} - %transpose.910 = s32[263]{0:T(512)} transpose(%gather.208), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} - ROOT %reshape.3412 = s32[263]{0:T(512)S(1)} reshape(%transpose.910), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} + %slice.908 = s32[263]{0:T(512)} slice(%custom-call.29), slice={[0:263]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} + %reshape.3393 = s32[263]{0:T(512)} reshape(%slice.908), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} + %transpose.911 = s32[263]{0:T(512)} transpose(%reshape.3393), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} + %gather.204 = s32[263]{0:T(512)} gather(%param_0.47, %transpose.911), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} + %transpose.910 = s32[263]{0:T(512)} transpose(%gather.204), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} + ROOT %reshape.3392 = s32[263]{0:T(512)S(1)} reshape(%transpose.910), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} } %fused_computation.16 (param_0.50: s32[256], param_1.134: s32[1024]) -> s32[263] { %param_0.50 = s32[256]{0:T(256)S(1)} parameter(0) %param_1.134 = s32[1024]{0:T(1024)S(1)} parameter(1) %custom-call.39 = s32[1024]{0:T(1024)} custom-call(%param_1.134), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=0} - %slice.946 = s32[263]{0:T(512)} slice(%custom-call.39), slice={[0:263]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=0} - %reshape.3436 = s32[263]{0:T(512)} reshape(%slice.946), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} - %transpose.921 = s32[263]{0:T(512)} transpose(%reshape.3436), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} - %gather.211 = s32[263]{0:T(512)} gather(%param_0.50, %transpose.921), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=0} - %transpose.920 = s32[263]{0:T(512)} transpose(%gather.211), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=0} - ROOT %reshape.3435 = s32[263]{0:T(512)S(1)} reshape(%transpose.920), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=0} + %slice.918 = s32[263]{0:T(512)} slice(%custom-call.39), slice={[0:263]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=0} + %reshape.3416 = s32[263]{0:T(512)} reshape(%slice.918), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} + %transpose.921 = s32[263]{0:T(512)} transpose(%reshape.3416), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} + %gather.207 = s32[263]{0:T(512)} gather(%param_0.50, %transpose.921), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=0} + %transpose.920 = s32[263]{0:T(512)} transpose(%gather.207), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=0} + ROOT %reshape.3415 = s32[263]{0:T(512)S(1)} reshape(%transpose.920), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=0} } %region_173.198.clone (scatter-add.94: bf16[], scatter-add.96: bf16[]) -> bf16[] { %scatter-add.94 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add"} %scatter-add.96 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add"} - ROOT %add.1875 = bf16[]{:T(256)} add(%scatter-add.94, %scatter-add.96), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %add.1885 = bf16[]{:T(256)} add(%scatter-add.94, %scatter-add.96), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } %fused_computation.21 (param_0.55: bf16[129280,512], param_1.65: s32[512], param_2.24: bf16[512,512]) -> bf16[129280,512] { %param_0.55 = bf16[129280,512]{1,0:T(8,128)(2,1)} parameter(0) %param_1.65 = s32[512]{0:T(512)S(1)} parameter(1) - %reshape.3490 = s32[4,128]{1,0:T(4,128)} reshape(%param_1.65), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} - %transpose.954 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3490), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %reshape.3470 = s32[4,128]{1,0:T(4,128)} reshape(%param_1.65), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %transpose.954 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3470), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} %param_2.24 = bf16[512,512]{1,0:T(8,128)(2,1)S(1)} parameter(2) - %reshape.3491 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} reshape(%param_2.24), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/while" stack_frame_id=0} - %transpose.955 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} transpose(%reshape.3491), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/while" stack_frame_id=0} - ROOT %scatter.77 = bf16[129280,512]{1,0:T(8,128)(2,1)} scatter(%param_0.55, %transpose.954, %transpose.955), update_window_dims={2}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=2, to_apply=%region_173.198.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add" stack_frame_id=0} + %reshape.3471 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} reshape(%param_2.24), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/while" stack_frame_id=0} + %transpose.955 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} transpose(%reshape.3471), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/while" stack_frame_id=0} + ROOT %scatter.73 = bf16[129280,512]{1,0:T(8,128)(2,1)} scatter(%param_0.55, %transpose.954, %transpose.955), update_window_dims={2}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=2, to_apply=%region_173.198.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add" stack_frame_id=0} } %region_12.18 (top_k.0: bf16[], top_k.6: bf16[], top_k.7: s32[], top_k.8: s32[]) -> pred[] { - %constant.1369 = s32[]{:T(128)} constant(0) - %constant.1370 = s32[]{:T(128)} constant(2147483647) + %constant.1385 = s32[]{:T(128)} constant(0) + %constant.1386 = s32[]{:T(128)} constant(2147483647) %top_k.0 = bf16[]{:T(256)} parameter(0), metadata={op_name="top_k"} %top_k.6 = bf16[]{:T(256)} parameter(1), metadata={op_name="top_k"} %top_k.7 = s32[]{:T(128)} parameter(2), metadata={op_name="top_k"} %top_k.8 = s32[]{:T(128)} parameter(3), metadata={op_name="top_k"} %convert.385 = f32[]{:T(128)S(6)} convert(%top_k.0), metadata={op_name="convert.16"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} %bitcast-convert.35 = s32[]{:T(128)S(6)} bitcast-convert(%convert.385), metadata={op_name="bitcast-convert.6"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} - %compare.128 = pred[]{:T(512)S(6)} compare(%bitcast-convert.35, %constant.1369), direction=LT, metadata={op_name="compare.35"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} - %xor.36 = s32[]{:T(128)S(6)} xor(%constant.1370, %bitcast-convert.35), metadata={op_name="xor.6"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.128 = pred[]{:T(512)S(6)} compare(%bitcast-convert.35, %constant.1385), direction=LT, metadata={op_name="compare.35"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %xor.36 = s32[]{:T(128)S(6)} xor(%constant.1386, %bitcast-convert.35), metadata={op_name="xor.6"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} %select.118 = s32[]{:T(128)S(6)} select(%compare.128, %xor.36, %bitcast-convert.35), metadata={op_name="select.14"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[{"indices":["1","3"]}]}} %convert.386 = f32[]{:T(128)S(6)} convert(%top_k.6), metadata={op_name="convert.17"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} %bitcast-convert.36 = s32[]{:T(128)S(6)} bitcast-convert(%convert.386), metadata={op_name="bitcast-convert.7"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} - %compare.129 = pred[]{:T(512)S(6)} compare(%bitcast-convert.36, %constant.1369), direction=LT, metadata={op_name="compare.36"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} - %xor.37 = s32[]{:T(128)S(6)} xor(%constant.1370, %bitcast-convert.36), metadata={op_name="xor.7"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.129 = pred[]{:T(512)S(6)} compare(%bitcast-convert.36, %constant.1385), direction=LT, metadata={op_name="compare.36"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %xor.37 = s32[]{:T(128)S(6)} xor(%constant.1386, %bitcast-convert.36), metadata={op_name="xor.7"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} %select.119 = s32[]{:T(128)S(6)} select(%compare.129, %xor.37, %bitcast-convert.36), metadata={op_name="select.15"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[{"indices":["1","3"]}]}} %compare.130 = pred[]{:T(512)S(6)} compare(%select.118, %select.119), direction=GT, metadata={op_name="compare.1"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} %compare.131 = pred[]{:T(512)S(6)} compare(%select.119, %select.118), direction=GT, metadata={op_name="compare.108"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} @@ -420,12 +420,12 @@ StackFrames %param_0.68 = s32[256]{0:T(256)S(1)} parameter(0) %param_1.114 = s32[1024]{0:T(1024)S(1)} parameter(1) %custom-call.19 = s32[1024]{0:T(1024)} custom-call(%param_1.114), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} - %slice.926 = s32[263]{0:T(512)} slice(%custom-call.19), slice={[0:263]}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} - %reshape.3634 = s32[263]{0:T(512)} reshape(%slice.926), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} - %transpose.1037 = s32[263]{0:T(512)} transpose(%reshape.3634), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} - %gather.213 = s32[263]{0:T(512)} gather(%param_0.68, %transpose.1037), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} - %transpose.1036 = s32[263]{0:T(512)} transpose(%gather.213), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} - ROOT %reshape.3633 = s32[263]{0:T(512)S(1)} reshape(%transpose.1036), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} + %slice.898 = s32[263]{0:T(512)} slice(%custom-call.19), slice={[0:263]}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} + %reshape.3614 = s32[263]{0:T(512)} reshape(%slice.898), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} + %transpose.1037 = s32[263]{0:T(512)} transpose(%reshape.3614), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} + %gather.209 = s32[263]{0:T(512)} gather(%param_0.68, %transpose.1037), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} + %transpose.1036 = s32[263]{0:T(512)} transpose(%gather.209), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} + ROOT %reshape.3613 = s32[263]{0:T(512)S(1)} reshape(%transpose.1036), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} } %region_27.34.clone.1 (reduce-window.350: s32[], reduce-window.351: s32[]) -> s32[] { @@ -464,12 +464,12 @@ StackFrames %param_0.71 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} parameter(0) %param_1.116 = s32[4096]{0:T(1024)S(1)} parameter(1) %custom-call.21 = s32[4096]{0:T(1024)} custom-call(%param_1.116), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[4096]{0:T(1024)}}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %slice.928 = s32[4096]{0:T(1024)} slice(%custom-call.21), slice={[0:4096]}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %reshape.3657 = s32[4096]{0:T(1024)} reshape(%slice.928), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %transpose.1043 = s32[4096]{0:T(1024)} transpose(%reshape.3657), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %gather.214 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.71, %transpose.1043), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %transpose.1042 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.214), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - ROOT %reshape.3656 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.1042), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %slice.900 = s32[4096]{0:T(1024)} slice(%custom-call.21), slice={[0:4096]}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %reshape.3637 = s32[4096]{0:T(1024)} reshape(%slice.900), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %transpose.1043 = s32[4096]{0:T(1024)} transpose(%reshape.3637), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %gather.210 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.71, %transpose.1043), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %transpose.1042 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.210), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + ROOT %reshape.3636 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.1042), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} } %region_31.39 (sort.50: s32[], sort.51: s32[], sort.52: s32[], sort.53: s32[], sort.54: s32[], sort.55: s32[]) -> pred[] { @@ -490,12 +490,12 @@ StackFrames %param_0.72 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} parameter(0) %param_1.118 = s32[4096]{0:T(1024)S(1)} parameter(1) %custom-call.23 = s32[4096]{0:T(1024)} custom-call(%param_1.118), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[4096]{0:T(1024)}}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %slice.930 = s32[4096]{0:T(1024)} slice(%custom-call.23), slice={[0:4096]}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %reshape.3659 = s32[4096]{0:T(1024)} reshape(%slice.930), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %transpose.1045 = s32[4096]{0:T(1024)} transpose(%reshape.3659), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %gather.215 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.72, %transpose.1045), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %transpose.1044 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.215), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - ROOT %reshape.3658 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.1044), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %slice.902 = s32[4096]{0:T(1024)} slice(%custom-call.23), slice={[0:4096]}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %reshape.3639 = s32[4096]{0:T(1024)} reshape(%slice.902), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %transpose.1045 = s32[4096]{0:T(1024)} transpose(%reshape.3639), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %gather.211 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.72, %transpose.1045), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %transpose.1044 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.211), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + ROOT %reshape.3638 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.1044), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} } %compare (name: s32[], name.1: s32[], name.2: bf16[], name.3: bf16[]) -> pred[] { @@ -503,7 +503,7 @@ StackFrames %name.3 = bf16[] parameter(3) %name = s32[] parameter(0) %name.1 = s32[] parameter(1) - ROOT %compare.385 = pred[] compare(%name, %name.1), direction=LT, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %compare.377 = pred[] compare(%name, %name.1), direction=LT, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } %compare.1 (name.4: s32[], name.5: s32[], name.6: f32[], name.7: f32[]) -> pred[] { @@ -511,7 +511,7 @@ StackFrames %name.7 = f32[] parameter(3) %name.4 = s32[] parameter(0) %name.5 = s32[] parameter(1) - ROOT %compare.386 = pred[] compare(%name.4, %name.5), direction=LT, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %compare.378 = pred[] compare(%name.4, %name.5), direction=LT, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } %compare.2 (name.8: s32[], name.9: s32[], name.10: f32[], name.11: f32[]) -> pred[] { @@ -519,7 +519,7 @@ StackFrames %name.11 = f32[] parameter(3) %name.8 = s32[] parameter(0) %name.9 = s32[] parameter(1) - ROOT %compare.387 = pred[] compare(%name.8, %name.9), direction=LT, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %compare.379 = pred[] compare(%name.8, %name.9), direction=LT, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } %compare.3 (name.12: s32[], name.13: s32[], name.14: f32[], name.15: f32[]) -> pred[] { @@ -527,7 +527,7 @@ StackFrames %name.15 = f32[] parameter(3) %name.12 = s32[] parameter(0) %name.13 = s32[] parameter(1) - ROOT %compare.388 = pred[] compare(%name.12, %name.13), direction=LT, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %compare.380 = pred[] compare(%name.12, %name.13), direction=LT, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } %compare.4 (name.16: s32[], name.17: s32[], name.18: f32[], name.19: f32[]) -> pred[] { @@ -535,7 +535,7 @@ StackFrames %name.19 = f32[] parameter(3) %name.16 = s32[] parameter(0) %name.17 = s32[] parameter(1) - ROOT %compare.389 = pred[] compare(%name.16, %name.17), direction=LT, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %compare.381 = pred[] compare(%name.16, %name.17), direction=LT, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } %called_computation.13 (param_0.4523: s32[256]) -> s32[256] { @@ -551,18 +551,18 @@ StackFrames %region_49.59 (scatter-add.14: s32[], scatter-add.15: s32[]) -> s32[] { %scatter-add.14 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.15 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.1352 = s32[]{:T(128)S(7)} add(%scatter-add.14, %scatter-add.15), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.1362 = s32[]{:T(128)S(7)} add(%scatter-add.14, %scatter-add.15), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" %fused_computation.22.clone.clone (param_0.4525: s32[256], param_1.5325: s32[4096], param_2.4494: s32[4096]) -> s32[256] { %param_0.4525 = s32[256]{0:T(256)} parameter(0) %param_1.5325 = s32[4096]{0:T(1024)} parameter(1) - %reshape.3923 = s32[4096]{0:T(1024)} reshape(%param_1.5325), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(clip)/max" stack_frame_id=0} - %transpose.1100 = s32[4096]{0:T(1024)} transpose(%reshape.3923), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(clip)/max" stack_frame_id=0} + %reshape.3903 = s32[4096]{0:T(1024)} reshape(%param_1.5325), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(clip)/max" stack_frame_id=0} + %transpose.1100 = s32[4096]{0:T(1024)} transpose(%reshape.3903), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(clip)/max" stack_frame_id=0} %param_2.4494 = s32[4096]{0:T(1024)} parameter(2) - %reshape.3924 = s32[4096]{0:T(1024)} reshape(%param_2.4494), metadata={op_name="jit(train_step)/moe_layers/shard_map/broadcast_in_dim" stack_frame_id=0} - %transpose.1101 = s32[4096]{0:T(1024)} transpose(%reshape.3924), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/broadcast_in_dim" stack_frame_id=0} - ROOT %scatter-add.237 = s32[256]{0:T(256)} scatter(%param_0.4525, %transpose.1100, %transpose.1101), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_49.59, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/scatter-add" stack_frame_id=0} + %reshape.3904 = s32[4096]{0:T(1024)} reshape(%param_2.4494), metadata={op_name="jit(train_step)/moe_layers/shard_map/broadcast_in_dim" stack_frame_id=0} + %transpose.1101 = s32[4096]{0:T(1024)} transpose(%reshape.3904), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/broadcast_in_dim" stack_frame_id=0} + ROOT %scatter-add.231 = s32[256]{0:T(256)} scatter(%param_0.4525, %transpose.1100, %transpose.1101), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_49.59, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" %called_computation.14 (param_0.4526: s32[256], param_1.5326: s32[4096], param_2.4495: s32[4096]) -> s32[256] { @@ -611,18 +611,18 @@ StackFrames %region_61.72 (scatter-add.24: f32[], scatter-add.25: f32[]) -> f32[] { %scatter-add.24 = f32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.25 = f32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.1358 = f32[]{:T(128)S(7)} add(%scatter-add.24, %scatter-add.25), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.1368 = f32[]{:T(128)S(7)} add(%scatter-add.24, %scatter-add.25), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" %fused_computation.24.clone.clone (param_0.4530: f32[9], param_1.5328: s32[256], param_2.4497: f32[256]) -> f32[9] { %param_0.4530 = f32[9]{0:T(128)} parameter(0) %param_1.5328 = s32[256]{0:T(256)} parameter(1) - %reshape.3925 = s32[256]{0:T(256)} reshape(%param_1.5328), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - %transpose.1102 = s32[256]{0:T(256)} transpose(%reshape.3925), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %reshape.3905 = s32[256]{0:T(256)} reshape(%param_1.5328), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.1102 = s32[256]{0:T(256)} transpose(%reshape.3905), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} %param_2.4497 = f32[256]{0:T(256)} parameter(2) - %reshape.3926 = f32[256]{0:T(256)} reshape(%param_2.4497), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - %transpose.1103 = f32[256]{0:T(256)} transpose(%reshape.3926), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - ROOT %scatter-add.238 = f32[9]{0:T(128)} scatter(%param_0.4530, %transpose.1102, %transpose.1103), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, indices_are_sorted=true, to_apply=%region_61.72, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + %reshape.3906 = f32[256]{0:T(256)} reshape(%param_2.4497), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.1103 = f32[256]{0:T(256)} transpose(%reshape.3906), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + ROOT %scatter-add.232 = f32[9]{0:T(128)} scatter(%param_0.4530, %transpose.1102, %transpose.1103), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, indices_are_sorted=true, to_apply=%region_61.72, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" %called_computation.16 (param_0.4531: f32[9], param_1.5329: s32[256], param_2.4498: f32[256]) -> f32[9] { @@ -671,18 +671,18 @@ StackFrames %region_63.74 (scatter-add.28: s32[], scatter-add.29: s32[]) -> s32[] { %scatter-add.28 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.29 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.1359 = s32[]{:T(128)S(7)} add(%scatter-add.28, %scatter-add.29), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.1369 = s32[]{:T(128)S(7)} add(%scatter-add.28, %scatter-add.29), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" %fused_computation.25.clone.clone (param_0.4535: s32[263], param_1.5331: s32[8], param_2.4500: s32[8]) -> s32[263] { %param_0.4535 = s32[263]{0:T(512)} parameter(0) %param_1.5331 = s32[8]{0:T(128)} parameter(1) - %reshape.3927 = s32[8]{0:T(128)} reshape(%param_1.5331), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} - %transpose.1104 = s32[8]{0:T(128)} transpose(%reshape.3927), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %reshape.3907 = s32[8]{0:T(128)} reshape(%param_1.5331), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %transpose.1104 = s32[8]{0:T(128)} transpose(%reshape.3907), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} %param_2.4500 = s32[8]{0:T(128)} parameter(2) - %reshape.3928 = s32[8]{0:T(128)} reshape(%param_2.4500), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} - %transpose.1105 = s32[8]{0:T(128)} transpose(%reshape.3928), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} - ROOT %scatter-add.239 = s32[263]{0:T(512)} scatter(%param_0.4535, %transpose.1104, %transpose.1105), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_63.74, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + %reshape.3908 = s32[8]{0:T(128)} reshape(%param_2.4500), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} + %transpose.1105 = s32[8]{0:T(128)} transpose(%reshape.3908), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} + ROOT %scatter-add.233 = s32[263]{0:T(512)} scatter(%param_0.4535, %transpose.1104, %transpose.1105), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_63.74, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" %called_computation.18 (param_0.4536: s32[263], param_1.5332: s32[8], param_2.4501: s32[8]) -> s32[263] { @@ -731,18 +731,18 @@ StackFrames %region_73.86.clone (scatter-add.163: s32[], scatter-add.164: s32[]) -> s32[] { %scatter-add.163 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.164 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.2474 = s32[]{:T(128)S(7)} add(%scatter-add.163, %scatter-add.164), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.2482 = s32[]{:T(128)S(7)} add(%scatter-add.163, %scatter-add.164), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" %fused_computation.26.clone.clone (param_0.4540: s32[263], param_1.5334: s32[256], param_2.4503: s32[256]) -> s32[263] { %param_0.4540 = s32[263]{0:T(512)} parameter(0) %param_1.5334 = s32[256]{0:T(256)} parameter(1) - %reshape.3929 = s32[256]{0:T(256)} reshape(%param_1.5334), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} - %transpose.1106 = s32[256]{0:T(256)} transpose(%reshape.3929), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %reshape.3909 = s32[256]{0:T(256)} reshape(%param_1.5334), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %transpose.1106 = s32[256]{0:T(256)} transpose(%reshape.3909), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} %param_2.4503 = s32[256]{0:T(256)} parameter(2) - %reshape.3930 = s32[256]{0:T(256)} reshape(%param_2.4503), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - %transpose.1107 = s32[256]{0:T(256)} transpose(%reshape.3930), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - ROOT %scatter-add.240 = s32[263]{0:T(512)} scatter(%param_0.4540, %transpose.1106, %transpose.1107), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_73.86.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + %reshape.3910 = s32[256]{0:T(256)} reshape(%param_2.4503), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.1107 = s32[256]{0:T(256)} transpose(%reshape.3910), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + ROOT %scatter-add.234 = s32[263]{0:T(512)} scatter(%param_0.4540, %transpose.1106, %transpose.1107), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_73.86.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" %called_computation.20 (param_0.4541: s32[263], param_1.5335: s32[256], param_2.4504: s32[256]) -> s32[263] { @@ -791,18 +791,18 @@ StackFrames %region_79.95.clone (scatter-add.167: f32[], scatter-add.168: f32[]) -> f32[] { %scatter-add.167 = f32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.168 = f32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.2476 = f32[]{:T(128)S(7)} add(%scatter-add.167, %scatter-add.168), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.2484 = f32[]{:T(128)S(7)} add(%scatter-add.167, %scatter-add.168), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" %fused_computation.27.clone.clone (param_0.4545: f32[9], param_1.5337: s32[256], param_2.4506: f32[256]) -> f32[9] { %param_0.4545 = f32[9]{0:T(128)} parameter(0) %param_1.5337 = s32[256]{0:T(256)} parameter(1) - %reshape.3931 = s32[256]{0:T(256)} reshape(%param_1.5337), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - %transpose.1108 = s32[256]{0:T(256)} transpose(%reshape.3931), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %reshape.3911 = s32[256]{0:T(256)} reshape(%param_1.5337), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.1108 = s32[256]{0:T(256)} transpose(%reshape.3911), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} %param_2.4506 = f32[256]{0:T(256)} parameter(2) - %reshape.3932 = f32[256]{0:T(256)} reshape(%param_2.4506), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - %transpose.1109 = f32[256]{0:T(256)} transpose(%reshape.3932), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - ROOT %scatter-add.241 = f32[9]{0:T(128)} scatter(%param_0.4545, %transpose.1108, %transpose.1109), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, indices_are_sorted=true, to_apply=%region_79.95.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + %reshape.3912 = f32[256]{0:T(256)} reshape(%param_2.4506), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.1109 = f32[256]{0:T(256)} transpose(%reshape.3912), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + ROOT %scatter-add.235 = f32[9]{0:T(128)} scatter(%param_0.4545, %transpose.1108, %transpose.1109), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, indices_are_sorted=true, to_apply=%region_79.95.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" %called_computation.22 (param_0.4546: f32[9], param_1.5338: s32[256], param_2.4507: f32[256]) -> f32[9] { @@ -851,18 +851,18 @@ StackFrames %region_81.97.clone (scatter-add.171: s32[], scatter-add.172: s32[]) -> s32[] { %scatter-add.171 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.172 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.2478 = s32[]{:T(128)S(7)} add(%scatter-add.171, %scatter-add.172), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.2486 = s32[]{:T(128)S(7)} add(%scatter-add.171, %scatter-add.172), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" %fused_computation.28.clone.clone (param_0.4550: s32[263], param_1.5340: s32[8], param_2.4509: s32[8]) -> s32[263] { %param_0.4550 = s32[263]{0:T(512)} parameter(0) %param_1.5340 = s32[8]{0:T(128)} parameter(1) - %reshape.3933 = s32[8]{0:T(128)} reshape(%param_1.5340), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} - %transpose.1110 = s32[8]{0:T(128)} transpose(%reshape.3933), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %reshape.3913 = s32[8]{0:T(128)} reshape(%param_1.5340), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %transpose.1110 = s32[8]{0:T(128)} transpose(%reshape.3913), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} %param_2.4509 = s32[8]{0:T(128)} parameter(2) - %reshape.3934 = s32[8]{0:T(128)} reshape(%param_2.4509), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} - %transpose.1111 = s32[8]{0:T(128)} transpose(%reshape.3934), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} - ROOT %scatter-add.242 = s32[263]{0:T(512)} scatter(%param_0.4550, %transpose.1110, %transpose.1111), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_81.97.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + %reshape.3914 = s32[8]{0:T(128)} reshape(%param_2.4509), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} + %transpose.1111 = s32[8]{0:T(128)} transpose(%reshape.3914), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} + ROOT %scatter-add.236 = s32[263]{0:T(512)} scatter(%param_0.4550, %transpose.1110, %transpose.1111), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_81.97.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" %called_computation.24 (param_0.4551: s32[263], param_1.5341: s32[8], param_2.4510: s32[8]) -> s32[263] { @@ -911,18 +911,18 @@ StackFrames %region_96.114 (scatter-add.48: s32[], scatter-add.49: s32[]) -> s32[] { %scatter-add.48 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.49 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.1396 = s32[]{:T(128)S(7)} add(%scatter-add.48, %scatter-add.49), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.1406 = s32[]{:T(128)S(7)} add(%scatter-add.48, %scatter-add.49), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" %fused_computation.29.clone.clone (param_0.4555: s32[263], param_1.5343: s32[256], param_2.4512: s32[256]) -> s32[263] { %param_0.4555 = s32[263]{0:T(512)} parameter(0) %param_1.5343 = s32[256]{0:T(256)} parameter(1) - %reshape.3935 = s32[256]{0:T(256)} reshape(%param_1.5343), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/select_n" stack_frame_id=0} - %transpose.1112 = s32[256]{0:T(256)} transpose(%reshape.3935), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/select_n" stack_frame_id=0} + %reshape.3915 = s32[256]{0:T(256)} reshape(%param_1.5343), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/select_n" stack_frame_id=0} + %transpose.1112 = s32[256]{0:T(256)} transpose(%reshape.3915), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/select_n" stack_frame_id=0} %param_2.4512 = s32[256]{0:T(256)} parameter(2) - %reshape.3936 = s32[256]{0:T(256)} reshape(%param_2.4512), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - %transpose.1113 = s32[256]{0:T(256)} transpose(%reshape.3936), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - ROOT %scatter-add.243 = s32[263]{0:T(512)} scatter(%param_0.4555, %transpose.1112, %transpose.1113), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_96.114, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} + %reshape.3916 = s32[256]{0:T(256)} reshape(%param_2.4512), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.1113 = s32[256]{0:T(256)} transpose(%reshape.3916), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + ROOT %scatter-add.237 = s32[263]{0:T(512)} scatter(%param_0.4555, %transpose.1112, %transpose.1113), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_96.114, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" %called_computation.26 (param_0.4556: s32[263], param_1.5344: s32[256], param_2.4513: s32[256]) -> s32[263] { @@ -961,18 +961,18 @@ StackFrames %region_102.120 (scatter-add.52: f32[], scatter-add.53: f32[]) -> f32[] { %scatter-add.52 = f32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.53 = f32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.1399 = f32[]{:T(128)S(7)} add(%scatter-add.52, %scatter-add.53), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.1409 = f32[]{:T(128)S(7)} add(%scatter-add.52, %scatter-add.53), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" %fused_computation.30.clone.clone (param_0.4560: f32[9], param_1.5346: s32[256], param_2.4515: f32[256]) -> f32[9] { %param_0.4560 = f32[9]{0:T(128)} parameter(0) %param_1.5346 = s32[256]{0:T(256)} parameter(1) - %reshape.3937 = s32[256]{0:T(256)} reshape(%param_1.5346), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/broadcast_in_dim" stack_frame_id=0} - %transpose.1114 = s32[256]{0:T(256)} transpose(%reshape.3937), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/broadcast_in_dim" stack_frame_id=0} + %reshape.3917 = s32[256]{0:T(256)} reshape(%param_1.5346), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.1114 = s32[256]{0:T(256)} transpose(%reshape.3917), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/broadcast_in_dim" stack_frame_id=0} %param_2.4515 = f32[256]{0:T(256)} parameter(2) - %reshape.3938 = f32[256]{0:T(256)} reshape(%param_2.4515), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - %transpose.1115 = f32[256]{0:T(256)} transpose(%reshape.3938), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - ROOT %scatter-add.244 = f32[9]{0:T(128)} scatter(%param_0.4560, %transpose.1114, %transpose.1115), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, indices_are_sorted=true, to_apply=%region_102.120, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} + %reshape.3918 = f32[256]{0:T(256)} reshape(%param_2.4515), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.1115 = f32[256]{0:T(256)} transpose(%reshape.3918), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + ROOT %scatter-add.238 = f32[9]{0:T(128)} scatter(%param_0.4560, %transpose.1114, %transpose.1115), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, indices_are_sorted=true, to_apply=%region_102.120, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" %called_computation.28 (param_0.4561: f32[9], param_1.5347: s32[256], param_2.4516: f32[256]) -> f32[9] { @@ -1009,18 +1009,18 @@ StackFrames %region_104.122 (scatter-add.83: s32[], scatter-add.84: s32[]) -> s32[] { %scatter-add.83 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.84 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.1400 = s32[]{:T(128)S(7)} add(%scatter-add.83, %scatter-add.84), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.1410 = s32[]{:T(128)S(7)} add(%scatter-add.83, %scatter-add.84), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" %fused_computation.31.clone.clone (param_0.4565: s32[263], param_1.5349: s32[8], param_2.4518: s32[8]) -> s32[263] { %param_0.4565 = s32[263]{0:T(512)} parameter(0) %param_1.5349 = s32[8]{0:T(128)} parameter(1) - %reshape.3939 = s32[8]{0:T(128)} reshape(%param_1.5349), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/select_n" stack_frame_id=0} - %transpose.1116 = s32[8]{0:T(128)} transpose(%reshape.3939), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/select_n" stack_frame_id=0} + %reshape.3919 = s32[8]{0:T(128)} reshape(%param_1.5349), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/select_n" stack_frame_id=0} + %transpose.1116 = s32[8]{0:T(128)} transpose(%reshape.3919), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/select_n" stack_frame_id=0} %param_2.4518 = s32[8]{0:T(128)} parameter(2) - %reshape.3940 = s32[8]{0:T(128)} reshape(%param_2.4518), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} - %transpose.1117 = s32[8]{0:T(128)} transpose(%reshape.3940), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} - ROOT %scatter-add.245 = s32[263]{0:T(512)} scatter(%param_0.4565, %transpose.1116, %transpose.1117), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_104.122, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} + %reshape.3920 = s32[8]{0:T(128)} reshape(%param_2.4518), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} + %transpose.1117 = s32[8]{0:T(128)} transpose(%reshape.3920), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} + ROOT %scatter-add.239 = s32[263]{0:T(512)} scatter(%param_0.4565, %transpose.1116, %transpose.1117), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_104.122, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" %called_computation.30 (param_0.4566: s32[263], param_1.5350: s32[8], param_2.4519: s32[8]) -> s32[263] { @@ -1057,18 +1057,18 @@ StackFrames %region_14.20 (scatter-add.0: s32[], scatter-add.1: s32[]) -> s32[] { %scatter-add.0 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.1 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.1312 = s32[]{:T(128)S(7)} add(%scatter-add.0, %scatter-add.1), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.1322 = s32[]{:T(128)S(7)} add(%scatter-add.0, %scatter-add.1), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" %fused_computation.17.clone.clone.clone (param_0.4570: s32[256], param_1.5352: s32[4096], param_2.4521: s32[4096]) -> s32[256] { %param_0.4570 = s32[256]{0:T(256)} parameter(0) %param_1.5352 = s32[4096]{0:T(1024)} parameter(1) - %reshape.3941 = s32[4096]{0:T(1024)} reshape(%param_1.5352), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/select_n" stack_frame_id=0} - %transpose.1118 = s32[4096]{0:T(1024)} transpose(%reshape.3941), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/select_n" stack_frame_id=0} + %reshape.3921 = s32[4096]{0:T(1024)} reshape(%param_1.5352), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/select_n" stack_frame_id=0} + %transpose.1118 = s32[4096]{0:T(1024)} transpose(%reshape.3921), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/select_n" stack_frame_id=0} %param_2.4521 = s32[4096]{0:T(1024)} parameter(2) - %reshape.3942 = s32[4096]{0:T(1024)} reshape(%param_2.4521), metadata={op_name="jit(train_step)/moe_layers/shard_map/broadcast_in_dim" stack_frame_id=0} - %transpose.1119 = s32[4096]{0:T(1024)} transpose(%reshape.3942), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/broadcast_in_dim" stack_frame_id=0} - ROOT %scatter-add.246 = s32[256]{0:T(256)} scatter(%param_0.4570, %transpose.1118, %transpose.1119), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_14.20, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/scatter-add" stack_frame_id=0} + %reshape.3922 = s32[4096]{0:T(1024)} reshape(%param_2.4521), metadata={op_name="jit(train_step)/moe_layers/shard_map/broadcast_in_dim" stack_frame_id=0} + %transpose.1119 = s32[4096]{0:T(1024)} transpose(%reshape.3922), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/broadcast_in_dim" stack_frame_id=0} + ROOT %scatter-add.240 = s32[256]{0:T(256)} scatter(%param_0.4570, %transpose.1118, %transpose.1119), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_14.20, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" %called_computation.32 (param_0.4571: s32[256], param_1.5353: s32[4096], param_2.4522: s32[4096]) -> s32[256] { @@ -1115,18 +1115,18 @@ StackFrames %region_20.26.clone.1 (scatter-add.141: s32[], scatter-add.142: s32[]) -> s32[] { %scatter-add.141 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.142 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.2463 = s32[]{:T(128)S(7)} add(%scatter-add.141, %scatter-add.142), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.2471 = s32[]{:T(128)S(7)} add(%scatter-add.141, %scatter-add.142), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" %fused_computation.18.clone.clone.clone (param_0.4575: s32[263], param_1.5355: s32[256], param_2.4524: s32[256]) -> s32[263] { %param_0.4575 = s32[263]{0:T(512)} parameter(0) %param_1.5355 = s32[256]{0:T(256)} parameter(1) - %reshape.3943 = s32[256]{0:T(256)} reshape(%param_1.5355), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} - %transpose.1120 = s32[256]{0:T(256)} transpose(%reshape.3943), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %reshape.3923 = s32[256]{0:T(256)} reshape(%param_1.5355), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %transpose.1120 = s32[256]{0:T(256)} transpose(%reshape.3923), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} %param_2.4524 = s32[256]{0:T(256)} parameter(2) - %reshape.3944 = s32[256]{0:T(256)} reshape(%param_2.4524) - %transpose.1121 = s32[256]{0:T(256)} transpose(%reshape.3944), dimensions={0} - ROOT %scatter-add.247 = s32[263]{0:T(512)} scatter(%param_0.4575, %transpose.1120, %transpose.1121), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_20.26.clone.1, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + %reshape.3924 = s32[256]{0:T(256)} reshape(%param_2.4524) + %transpose.1121 = s32[256]{0:T(256)} transpose(%reshape.3924), dimensions={0} + ROOT %scatter-add.241 = s32[263]{0:T(512)} scatter(%param_0.4575, %transpose.1120, %transpose.1121), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_20.26.clone.1, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" %called_computation.34 (param_0.4576: s32[263], param_1.5356: s32[256], param_2.4525: s32[256]) -> s32[263] { @@ -1175,18 +1175,18 @@ StackFrames %region_26.33.clone.1 (scatter-add.145: f32[], scatter-add.146: f32[]) -> f32[] { %scatter-add.145 = f32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.146 = f32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.2465 = f32[]{:T(128)S(7)} add(%scatter-add.145, %scatter-add.146), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.2473 = f32[]{:T(128)S(7)} add(%scatter-add.145, %scatter-add.146), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" %fused_computation.19.clone.clone.clone (param_0.4580: f32[9], param_1.5358: s32[256], param_2.4527: f32[256]) -> f32[9] { %param_0.4580 = f32[9]{0:T(128)} parameter(0) %param_1.5358 = s32[256]{0:T(256)} parameter(1) - %reshape.3945 = s32[256]{0:T(256)} reshape(%param_1.5358), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - %transpose.1122 = s32[256]{0:T(256)} transpose(%reshape.3945), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %reshape.3925 = s32[256]{0:T(256)} reshape(%param_1.5358), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.1122 = s32[256]{0:T(256)} transpose(%reshape.3925), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} %param_2.4527 = f32[256]{0:T(256)} parameter(2) - %reshape.3946 = f32[256]{0:T(256)} reshape(%param_2.4527) - %transpose.1123 = f32[256]{0:T(256)} transpose(%reshape.3946), dimensions={0} - ROOT %scatter-add.248 = f32[9]{0:T(128)} scatter(%param_0.4580, %transpose.1122, %transpose.1123), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, indices_are_sorted=true, to_apply=%region_26.33.clone.1, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + %reshape.3926 = f32[256]{0:T(256)} reshape(%param_2.4527) + %transpose.1123 = f32[256]{0:T(256)} transpose(%reshape.3926), dimensions={0} + ROOT %scatter-add.242 = f32[9]{0:T(128)} scatter(%param_0.4580, %transpose.1122, %transpose.1123), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, indices_are_sorted=true, to_apply=%region_26.33.clone.1, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" %called_computation.36 (param_0.4581: f32[9], param_1.5359: s32[256], param_2.4528: f32[256]) -> f32[9] { @@ -1235,18 +1235,18 @@ StackFrames %region_28.35.clone.1 (scatter-add.149: s32[], scatter-add.150: s32[]) -> s32[] { %scatter-add.149 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.150 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.2467 = s32[]{:T(128)S(7)} add(%scatter-add.149, %scatter-add.150), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.2475 = s32[]{:T(128)S(7)} add(%scatter-add.149, %scatter-add.150), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" %fused_computation.20.clone.clone.clone (param_0.4585: s32[263], param_1.5361: s32[8], param_2.4530: s32[8]) -> s32[263] { %param_0.4585 = s32[263]{0:T(512)} parameter(0) %param_1.5361 = s32[8]{0:T(128)} parameter(1) - %reshape.3947 = s32[8]{0:T(128)} reshape(%param_1.5361), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} - %transpose.1124 = s32[8]{0:T(128)} transpose(%reshape.3947), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %reshape.3927 = s32[8]{0:T(128)} reshape(%param_1.5361), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %transpose.1124 = s32[8]{0:T(128)} transpose(%reshape.3927), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} %param_2.4530 = s32[8]{0:T(128)} parameter(2) - %reshape.3948 = s32[8]{0:T(128)} reshape(%param_2.4530) - %transpose.1125 = s32[8]{0:T(128)} transpose(%reshape.3948), dimensions={0} - ROOT %scatter-add.249 = s32[263]{0:T(512)} scatter(%param_0.4585, %transpose.1124, %transpose.1125), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_28.35.clone.1, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + %reshape.3928 = s32[8]{0:T(128)} reshape(%param_2.4530) + %transpose.1125 = s32[8]{0:T(128)} transpose(%reshape.3928), dimensions={0} + ROOT %scatter-add.243 = s32[263]{0:T(512)} scatter(%param_0.4585, %transpose.1124, %transpose.1125), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_28.35.clone.1, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" %called_computation.38 (param_0.4586: s32[263], param_1.5362: s32[8], param_2.4531: s32[8]) -> s32[263] { @@ -1292,8 +1292,8 @@ StackFrames %param_0.4170 = f32[3,1536,128,192]{2,3,0,1:T(8,128)} parameter(0) %bitcast.672 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} bitcast(%param_0.4170), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=0} %square.564 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%bitcast.672, %bitcast.672), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5105 = f32[]{:T(128)} constant(0) - ROOT %reduce.669 = f32[]{:T(128)} reduce(%square.564, %constant.5105), dimensions={0,1,2,3}, to_apply=%region_154.179, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %constant.5086 = f32[]{:T(128)} constant(0) + ROOT %reduce.669 = f32[]{:T(128)} reduce(%square.564, %constant.5086), dimensions={0,1,2,3}, to_apply=%region_154.179, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} } %fused_computation.468 (param_0.1421: f32[1536,3,128,192]) -> bf16[3,1536,128,192] { @@ -1317,52 +1317,52 @@ StackFrames %fused_computation.469 (param_0.4140: f32[1536,3,128,192], param_1.5025: f32[], param_2.4298: f32[], param_3.2951: f32[], param_4.2203: f32[1536,3,128,192], param_5.2006: f32[], param_6.1443: f32[3,1536,128,192], param_7.1124: pred[], param_8.889: f32[1536,3,128,192]) -> (f32[], f32[1536,3,128,192], f32[1536,3,128,192], f32[1536,3,128,192], f32[]) { %param_0.4140 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} parameter(0) %param_3.2951 = f32[]{:T(128)S(6)} parameter(3) - %mul.4727.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%param_3.2951), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.4715.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%param_3.2951), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.1124 = pred[]{:T(512)S(6)} parameter(7) - %select_n.2165.clone.1 = pred[1536,3,128,192]{2,3,1,0:T(8,128)(4,1)} broadcast(%param_7.1124), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %select_n.2121.clone.1 = pred[1536,3,128,192]{2,3,1,0:T(8,128)(4,1)} broadcast(%param_7.1124), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} %param_6.1443 = f32[3,1536,128,192]{2,3,0,1:T(8,128)} parameter(6) %bitcast.1374.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} bitcast(%param_6.1443), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=0} %param_5.2006 = f32[]{:T(128)} parameter(5) - %div.2575.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%param_5.2006), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2574.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} divide(%bitcast.1374.clone.1, %div.2575.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.2164.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} select(%select_n.2165.clone.1, %bitcast.1374.clone.1, %div.2574.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.4864.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.4279.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4864.clone.1), dimensions={}, metadata={op_name="broadcast.334"} - %mul.4733.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%select_n.2164.clone.1, %broadcast.4279.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %div.2565.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%param_5.2006), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2564.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} divide(%bitcast.1374.clone.1, %div.2565.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.2120.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} select(%select_n.2121.clone.1, %bitcast.1374.clone.1, %div.2564.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.4845.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.4252.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4845.clone.1), dimensions={}, metadata={op_name="broadcast.334"} + %mul.4721.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%select_n.2120.clone.1, %broadcast.4252.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.889 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} parameter(8) - %constant.4868.clone.1 = f32[]{:T(128)} constant(0.9) - %mul.4734.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4868.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4732.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%param_8.889, %mul.4734.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3443.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%mul.4733.clone.1, %mul.4732.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %constant.4849.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.4722.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4849.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.4720.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%param_8.889, %mul.4722.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3429.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%mul.4721.clone.1, %mul.4720.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} %param_2.4298 = f32[]{:T(128)S(6)} parameter(2) - %div.2571.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%param_2.4298), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %integer_pow.399.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%select_n.2164.clone.1, %select_n.2164.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.4867.clone.1 = f32[]{:T(128)} constant(0.05) - %mul.4731.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4867.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4729.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%integer_pow.399.clone.1, %mul.4731.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %div.2561.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%param_2.4298), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.399.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%select_n.2120.clone.1, %select_n.2120.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.4848.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.4719.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4848.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.4717.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%integer_pow.399.clone.1, %mul.4719.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_4.2203 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} parameter(4) - %constant.4866.clone.1 = f32[]{:T(128)} constant(0.95) - %mul.4730.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4866.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4728.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%param_4.2203, %mul.4730.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3442.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%mul.4729.clone.1, %mul.4728.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %constant.4847.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.4718.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4847.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.4716.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%param_4.2203, %mul.4718.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3428.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%mul.4717.clone.1, %mul.4716.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} %param_1.5025 = f32[]{:T(128)S(6)} parameter(1) - %div.2570.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%param_1.5025), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2569.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} divide(%add.3442.clone.1, %div.2570.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %sqrt.157.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} sqrt(%div.2569.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.4865.clone.1 = f32[]{:T(128)} constant(1e-08) - %add.3441.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4865.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %add.3440.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%sqrt.157.clone.1, %add.3441.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.1293.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%div.2571.clone.1, %add.3440.clone.1), metadata={op_name="multiply.290"} - %div.2568.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} divide(%add.3443.clone.1, %multiply.1293.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.4726.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%param_0.4140, %broadcast.4279.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3439.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%div.2568.clone.1, %mul.4726.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.4725.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%mul.4727.clone.1, %add.3439.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3438.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%param_0.4140, %mul.4725.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.565 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%add.3438.clone.1, %add.3438.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5075 = f32[]{:T(128)} constant(0) - %reduce.670 = f32[]{:T(128)} reduce(%square.565, %constant.5075), dimensions={0,1,2,3}, to_apply=%region_221.246, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.671.clone.1 = f32[]{:T(128)} reduce(%integer_pow.399.clone.1, %constant.5075), dimensions={0,1,2,3}, to_apply=%region_187.212, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.660 = (f32[]{:T(128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.670, %add.3438.clone.1, %add.3442.clone.1, %add.3443.clone.1, %reduce.671.clone.1) + %div.2560.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%param_1.5025), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2559.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} divide(%add.3428.clone.1, %div.2560.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.157.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} sqrt(%div.2559.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.4846.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.3427.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4846.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.3426.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%sqrt.157.clone.1, %add.3427.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.1293.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%div.2561.clone.1, %add.3426.clone.1), metadata={op_name="multiply.290"} + %div.2558.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} divide(%add.3429.clone.1, %multiply.1293.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.4714.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%param_0.4140, %broadcast.4252.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3425.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%div.2558.clone.1, %mul.4714.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.4713.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%mul.4715.clone.1, %add.3425.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3424.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%param_0.4140, %mul.4713.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.565 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%add.3424.clone.1, %add.3424.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5056 = f32[]{:T(128)} constant(0) + %reduce.670 = f32[]{:T(128)} reduce(%square.565, %constant.5056), dimensions={0,1,2,3}, to_apply=%region_221.246, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.671.clone.1 = f32[]{:T(128)} reduce(%integer_pow.399.clone.1, %constant.5056), dimensions={0,1,2,3}, to_apply=%region_187.212, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.656 = (f32[]{:T(128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.670, %add.3424.clone.1, %add.3428.clone.1, %add.3429.clone.1, %reduce.671.clone.1) } %region_160.185 (reduce_sum.473: f32[], reduce_sum.293: f32[]) -> f32[] { @@ -1379,17 +1379,17 @@ StackFrames %fused_computation.495 (param_0.4166: bf16[256,512,512], param_1.5047: bf16[256,512,512]) -> (f32[], f32[]) { %param_0.4166 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(0) - %broadcast_in_dim.1358 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_0.4166), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} - %bitcast.695 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1358), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} + %broadcast_in_dim.1245 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_0.4166), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} + %bitcast.695 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1245), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} %square.570 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.695, %bitcast.695), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5101 = f32[]{:T(128)} constant(0) - %reduce.672 = f32[]{:T(128)} reduce(%square.570, %constant.5101), dimensions={0,1,2,3}, to_apply=%region_160.185, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %constant.5082 = f32[]{:T(128)} constant(0) + %reduce.672 = f32[]{:T(128)} reduce(%square.570, %constant.5082), dimensions={0,1,2,3}, to_apply=%region_160.185, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} %param_1.5047 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(1) - %broadcast_in_dim.1366.clone.1 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_1.5047), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} - %bitcast.703.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1366.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} + %broadcast_in_dim.1253.clone.1 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_1.5047), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} + %bitcast.703.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1253.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} %square.576.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.703.clone.1, %bitcast.703.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %reduce.674.clone.1 = f32[]{:T(128)} reduce(%square.576.clone.1, %constant.5101), dimensions={0,1,2,3}, to_apply=%region_158.183, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.767 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.672, %reduce.674.clone.1) + %reduce.674.clone.1 = f32[]{:T(128)} reduce(%square.576.clone.1, %constant.5082), dimensions={0,1,2,3}, to_apply=%region_158.183, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.763 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.672, %reduce.674.clone.1) } %region_159.184 (reduce_sum.466: f32[], reduce_sum.279: f32[]) -> f32[] { @@ -1400,11 +1400,11 @@ StackFrames %fused_computation.497 (param_0.4165: bf16[256,512,512]) -> f32[] { %param_0.4165 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(0) - %broadcast_in_dim.1362 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_0.4165), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} - %bitcast.699 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1362), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} + %broadcast_in_dim.1249 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_0.4165), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} + %bitcast.699 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1249), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} %square.573 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.699, %bitcast.699), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5100 = f32[]{:T(128)} constant(0) - ROOT %reduce.673 = f32[]{:T(128)} reduce(%square.573, %constant.5100), dimensions={0,1,2,3}, to_apply=%region_159.184, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %constant.5081 = f32[]{:T(128)} constant(0) + ROOT %reduce.673 = f32[]{:T(128)} reduce(%square.573, %constant.5081), dimensions={0,1,2,3}, to_apply=%region_159.184, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} } %region_227.252 (reduce_sum.935: f32[], reduce_sum.631: f32[]) -> f32[] { @@ -1423,57 +1423,57 @@ StackFrames %param_8.883 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(8) %bitcast.1359.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_8.883), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wo\']"} %param_7.1118 = f32[]{:T(128)S(6)} parameter(7) - %mul.4676.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_7.1118), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.4664.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_7.1118), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_6.1437 = pred[]{:T(512)S(6)} parameter(6) - %select_n.2147.clone.1 = pred[256,1,512,512]{3,2,0,1:T(8,128)(4,1)} broadcast(%param_6.1437), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %select_n.2103.clone.1 = pred[256,1,512,512]{3,2,0,1:T(8,128)(4,1)} broadcast(%param_6.1437), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} %param_5.2000 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(5) - %broadcast_in_dim.1572.clone.1 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_5.2000), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} - %bitcast.1361.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1572.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} + %broadcast_in_dim.1459.clone.1 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_5.2000), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} + %bitcast.1361.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1459.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} %param_4.2197 = f32[]{:T(128)} parameter(4) - %div.2533.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_4.2197), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2532.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%bitcast.1361.clone.1, %div.2533.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.2146.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} select(%select_n.2147.clone.1, %bitcast.1361.clone.1, %div.2532.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.4834.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.4259.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4834.clone.1), dimensions={}, metadata={op_name="broadcast.2345"} - %mul.4678.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2146.clone.1, %broadcast.4259.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %div.2523.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_4.2197), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2522.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%bitcast.1361.clone.1, %div.2523.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.2102.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} select(%select_n.2103.clone.1, %bitcast.1361.clone.1, %div.2522.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.4815.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.4232.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4815.clone.1), dimensions={}, metadata={op_name="broadcast.2344"} + %mul.4666.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2102.clone.1, %broadcast.4232.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_3.2945 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(3) %bitcast.1360.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_3.2945), sharding={replicated}, metadata={op_name="state.opt_state[0].mu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wo\']"} - %constant.4833.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.4258.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4833.clone.1), dimensions={}, metadata={op_name="broadcast.329"} - %mul.4677.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1360.clone.1, %broadcast.4258.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3408.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.4678.clone.1, %mul.4677.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %constant.4814.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.4231.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4814.clone.1), dimensions={}, metadata={op_name="broadcast.329"} + %mul.4665.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1360.clone.1, %broadcast.4231.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3394.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.4666.clone.1, %mul.4665.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} %param_2.4292 = f32[]{:T(128)S(6)} parameter(2) - %div.2531.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_2.4292), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %integer_pow.393.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2146.clone.1, %select_n.2146.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.4832.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.4261.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4832.clone.1), dimensions={}, metadata={op_name="broadcast.2348"} - %mul.4680.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%integer_pow.393.clone.1, %broadcast.4261.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %div.2521.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_2.4292), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.393.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2102.clone.1, %select_n.2102.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.4813.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.4234.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4813.clone.1), dimensions={}, metadata={op_name="broadcast.2347"} + %mul.4668.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%integer_pow.393.clone.1, %broadcast.4234.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_1.5019 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(1) %bitcast.1362.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_1.5019), sharding={replicated}, metadata={op_name="state.opt_state[0].nu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wo\']"} - %constant.4831.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.4260.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4831.clone.1), dimensions={}, metadata={op_name="broadcast.312"} - %mul.4679.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1362.clone.1, %broadcast.4260.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3409.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.4680.clone.1, %mul.4679.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %constant.4812.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.4233.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4812.clone.1), dimensions={}, metadata={op_name="broadcast.312"} + %mul.4667.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1362.clone.1, %broadcast.4233.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3395.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.4668.clone.1, %mul.4667.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} %param_0.4134 = f32[]{:T(128)S(6)} parameter(0) - %div.2530.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_0.4134), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2529.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3409.clone.1, %div.2530.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %sqrt.151.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} sqrt(%div.2529.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.4835.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.4257.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4835.clone.1), dimensions={}, metadata={op_name="broadcast.305"} - %add.3407.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%sqrt.151.clone.1, %broadcast.4257.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.1287.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%div.2531.clone.1, %add.3407.clone.1), metadata={op_name="multiply.296"} - %div.2528.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3408.clone.1, %multiply.1287.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.4675.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1359.clone.1, %broadcast.4259.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3406.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%div.2528.clone.1, %mul.4675.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.4674.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%mul.4676.clone.1, %add.3406.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3405.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%bitcast.1359.clone.1, %mul.4674.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.577 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%add.3405.clone.1, %add.3405.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5069 = f32[]{:T(128)} constant(0) - %reduce.675 = f32[]{:T(128)} reduce(%square.577, %constant.5069), dimensions={0,1,2,3}, to_apply=%region_227.252, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %bitcast.849.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3409.clone.1) - %bitcast.822.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3408.clone.1) - %reduce.684.clone.1 = f32[]{:T(128)} reduce(%integer_pow.393.clone.1, %constant.5069), dimensions={0,1,2,3}, to_apply=%region_193.218, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.670 = (f32[]{:T(128)}, f32[256,1,512,512]{3,2,0,1:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.675, %add.3405.clone.1, %bitcast.849.clone.1, %bitcast.822.clone.1, %reduce.684.clone.1) + %div.2520.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_0.4134), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2519.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3395.clone.1, %div.2520.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.151.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} sqrt(%div.2519.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.4816.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.4230.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4816.clone.1), dimensions={}, metadata={op_name="broadcast.305"} + %add.3393.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%sqrt.151.clone.1, %broadcast.4230.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.1287.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%div.2521.clone.1, %add.3393.clone.1), metadata={op_name="multiply.296"} + %div.2518.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3394.clone.1, %multiply.1287.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.4663.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1359.clone.1, %broadcast.4232.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3392.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%div.2518.clone.1, %mul.4663.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.4662.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%mul.4664.clone.1, %add.3392.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3391.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%bitcast.1359.clone.1, %mul.4662.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.577 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%add.3391.clone.1, %add.3391.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5050 = f32[]{:T(128)} constant(0) + %reduce.675 = f32[]{:T(128)} reduce(%square.577, %constant.5050), dimensions={0,1,2,3}, to_apply=%region_227.252, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %bitcast.849.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3395.clone.1) + %bitcast.822.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3394.clone.1) + %reduce.684.clone.1 = f32[]{:T(128)} reduce(%integer_pow.393.clone.1, %constant.5050), dimensions={0,1,2,3}, to_apply=%region_193.218, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.666 = (f32[]{:T(128)}, f32[256,1,512,512]{3,2,0,1:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.675, %add.3391.clone.1, %bitcast.849.clone.1, %bitcast.822.clone.1, %reduce.684.clone.1) } %region_226.251 (reduce_sum.928: f32[], reduce_sum.625: f32[]) -> f32[] { @@ -1492,57 +1492,57 @@ StackFrames %param_8.884 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(8) %bitcast.1363.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_8.884), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_1\']"} %param_7.1119 = f32[]{:T(128)S(6)} parameter(7) - %mul.4683.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_7.1119), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.4671.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_7.1119), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_6.1438 = pred[]{:T(512)S(6)} parameter(6) - %select_n.2149.clone.1 = pred[256,1,512,512]{3,2,0,1:T(8,128)(4,1)} broadcast(%param_6.1438), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %select_n.2105.clone.1 = pred[256,1,512,512]{3,2,0,1:T(8,128)(4,1)} broadcast(%param_6.1438), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} %param_5.2001 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(5) - %broadcast_in_dim.1573.clone.1 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_5.2001), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} - %bitcast.1365.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1573.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} + %broadcast_in_dim.1460.clone.1 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_5.2001), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} + %bitcast.1365.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1460.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} %param_4.2198 = f32[]{:T(128)} parameter(4) - %div.2539.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_4.2198), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2538.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%bitcast.1365.clone.1, %div.2539.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.2148.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} select(%select_n.2149.clone.1, %bitcast.1365.clone.1, %div.2538.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.4839.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.4264.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4839.clone.1), dimensions={}, metadata={op_name="broadcast.2345"} - %mul.4685.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2148.clone.1, %broadcast.4264.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %div.2529.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_4.2198), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2528.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%bitcast.1365.clone.1, %div.2529.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.2104.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} select(%select_n.2105.clone.1, %bitcast.1365.clone.1, %div.2528.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.4820.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.4237.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4820.clone.1), dimensions={}, metadata={op_name="broadcast.2344"} + %mul.4673.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2104.clone.1, %broadcast.4237.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_3.2946 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(3) %bitcast.1364.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_3.2946), sharding={replicated}, metadata={op_name="state.opt_state[0].mu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_1\']"} - %constant.4838.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.4263.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4838.clone.1), dimensions={}, metadata={op_name="broadcast.329"} - %mul.4684.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1364.clone.1, %broadcast.4263.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3413.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.4685.clone.1, %mul.4684.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %constant.4819.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.4236.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4819.clone.1), dimensions={}, metadata={op_name="broadcast.329"} + %mul.4672.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1364.clone.1, %broadcast.4236.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3399.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.4673.clone.1, %mul.4672.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} %param_2.4293 = f32[]{:T(128)S(6)} parameter(2) - %div.2537.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_2.4293), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %integer_pow.394.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2148.clone.1, %select_n.2148.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.4837.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.4266.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4837.clone.1), dimensions={}, metadata={op_name="broadcast.2348"} - %mul.4687.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%integer_pow.394.clone.1, %broadcast.4266.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %div.2527.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_2.4293), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.394.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2104.clone.1, %select_n.2104.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.4818.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.4239.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4818.clone.1), dimensions={}, metadata={op_name="broadcast.2347"} + %mul.4675.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%integer_pow.394.clone.1, %broadcast.4239.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_1.5020 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(1) %bitcast.1366.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_1.5020), sharding={replicated}, metadata={op_name="state.opt_state[0].nu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_1\']"} - %constant.4836.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.4265.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4836.clone.1), dimensions={}, metadata={op_name="broadcast.312"} - %mul.4686.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1366.clone.1, %broadcast.4265.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3414.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.4687.clone.1, %mul.4686.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %constant.4817.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.4238.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4817.clone.1), dimensions={}, metadata={op_name="broadcast.312"} + %mul.4674.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1366.clone.1, %broadcast.4238.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3400.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.4675.clone.1, %mul.4674.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} %param_0.4135 = f32[]{:T(128)S(6)} parameter(0) - %div.2536.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_0.4135), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2535.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3414.clone.1, %div.2536.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %sqrt.152.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} sqrt(%div.2535.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.4840.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.4262.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4840.clone.1), dimensions={}, metadata={op_name="broadcast.305"} - %add.3412.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%sqrt.152.clone.1, %broadcast.4262.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.1288.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%div.2537.clone.1, %add.3412.clone.1), metadata={op_name="multiply.295"} - %div.2534.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3413.clone.1, %multiply.1288.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.4682.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1363.clone.1, %broadcast.4264.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3411.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%div.2534.clone.1, %mul.4682.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.4681.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%mul.4683.clone.1, %add.3411.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3410.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%bitcast.1363.clone.1, %mul.4681.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.578 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%add.3410.clone.1, %add.3410.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5070 = f32[]{:T(128)} constant(0) - %reduce.676 = f32[]{:T(128)} reduce(%square.578, %constant.5070), dimensions={0,1,2,3}, to_apply=%region_226.251, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %bitcast.840.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3414.clone.1) - %bitcast.813.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3413.clone.1) - %reduce.685.clone.1 = f32[]{:T(128)} reduce(%integer_pow.394.clone.1, %constant.5070), dimensions={0,1,2,3}, to_apply=%region_192.217, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.669 = (f32[]{:T(128)}, f32[256,1,512,512]{3,2,0,1:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.676, %add.3410.clone.1, %bitcast.840.clone.1, %bitcast.813.clone.1, %reduce.685.clone.1) + %div.2526.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_0.4135), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2525.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3400.clone.1, %div.2526.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.152.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} sqrt(%div.2525.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.4821.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.4235.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4821.clone.1), dimensions={}, metadata={op_name="broadcast.305"} + %add.3398.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%sqrt.152.clone.1, %broadcast.4235.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.1288.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%div.2527.clone.1, %add.3398.clone.1), metadata={op_name="multiply.295"} + %div.2524.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3399.clone.1, %multiply.1288.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.4670.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1363.clone.1, %broadcast.4237.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3397.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%div.2524.clone.1, %mul.4670.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.4669.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%mul.4671.clone.1, %add.3397.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3396.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%bitcast.1363.clone.1, %mul.4669.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.578 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%add.3396.clone.1, %add.3396.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5051 = f32[]{:T(128)} constant(0) + %reduce.676 = f32[]{:T(128)} reduce(%square.578, %constant.5051), dimensions={0,1,2,3}, to_apply=%region_226.251, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %bitcast.840.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3400.clone.1) + %bitcast.813.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3399.clone.1) + %reduce.685.clone.1 = f32[]{:T(128)} reduce(%integer_pow.394.clone.1, %constant.5051), dimensions={0,1,2,3}, to_apply=%region_192.217, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.665 = (f32[]{:T(128)}, f32[256,1,512,512]{3,2,0,1:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.676, %add.3396.clone.1, %bitcast.840.clone.1, %bitcast.813.clone.1, %reduce.685.clone.1) } %region_225.250 (reduce_sum.921: f32[], reduce_sum.619: f32[]) -> f32[] { @@ -1561,57 +1561,57 @@ StackFrames %param_8.885 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(8) %bitcast.1367.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_8.885), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_0\']"} %param_7.1120 = f32[]{:T(128)S(6)} parameter(7) - %mul.4690.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_7.1120), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.4678.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_7.1120), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_6.1439 = pred[]{:T(512)S(6)} parameter(6) - %select_n.2151.clone.1 = pred[256,1,512,512]{3,2,0,1:T(8,128)(4,1)} broadcast(%param_6.1439), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %select_n.2107.clone.1 = pred[256,1,512,512]{3,2,0,1:T(8,128)(4,1)} broadcast(%param_6.1439), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} %param_5.2002 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(5) - %broadcast_in_dim.1574.clone.1 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_5.2002), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} - %bitcast.1369.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1574.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} + %broadcast_in_dim.1461.clone.1 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_5.2002), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} + %bitcast.1369.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1461.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} %param_4.2199 = f32[]{:T(128)} parameter(4) - %div.2545.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_4.2199), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2544.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%bitcast.1369.clone.1, %div.2545.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.2150.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} select(%select_n.2151.clone.1, %bitcast.1369.clone.1, %div.2544.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.4844.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.4269.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4844.clone.1), dimensions={}, metadata={op_name="broadcast.2345"} - %mul.4692.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2150.clone.1, %broadcast.4269.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %div.2535.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_4.2199), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2534.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%bitcast.1369.clone.1, %div.2535.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.2106.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} select(%select_n.2107.clone.1, %bitcast.1369.clone.1, %div.2534.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.4825.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.4242.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4825.clone.1), dimensions={}, metadata={op_name="broadcast.2344"} + %mul.4680.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2106.clone.1, %broadcast.4242.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_3.2947 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(3) %bitcast.1368.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_3.2947), sharding={replicated}, metadata={op_name="state.opt_state[0].mu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_0\']"} - %constant.4843.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.4268.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4843.clone.1), dimensions={}, metadata={op_name="broadcast.329"} - %mul.4691.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1368.clone.1, %broadcast.4268.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3418.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.4692.clone.1, %mul.4691.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %constant.4824.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.4241.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4824.clone.1), dimensions={}, metadata={op_name="broadcast.329"} + %mul.4679.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1368.clone.1, %broadcast.4241.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3404.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.4680.clone.1, %mul.4679.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} %param_2.4294 = f32[]{:T(128)S(6)} parameter(2) - %div.2543.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_2.4294), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %integer_pow.395.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2150.clone.1, %select_n.2150.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.4842.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.4271.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4842.clone.1), dimensions={}, metadata={op_name="broadcast.2348"} - %mul.4694.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%integer_pow.395.clone.1, %broadcast.4271.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %div.2533.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_2.4294), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.395.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2106.clone.1, %select_n.2106.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.4823.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.4244.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4823.clone.1), dimensions={}, metadata={op_name="broadcast.2347"} + %mul.4682.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%integer_pow.395.clone.1, %broadcast.4244.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_1.5021 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(1) %bitcast.1370.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_1.5021), sharding={replicated}, metadata={op_name="state.opt_state[0].nu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_0\']"} - %constant.4841.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.4270.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4841.clone.1), dimensions={}, metadata={op_name="broadcast.312"} - %mul.4693.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1370.clone.1, %broadcast.4270.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3419.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.4694.clone.1, %mul.4693.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %constant.4822.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.4243.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4822.clone.1), dimensions={}, metadata={op_name="broadcast.312"} + %mul.4681.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1370.clone.1, %broadcast.4243.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3405.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.4682.clone.1, %mul.4681.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} %param_0.4136 = f32[]{:T(128)S(6)} parameter(0) - %div.2542.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_0.4136), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2541.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3419.clone.1, %div.2542.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %sqrt.153.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} sqrt(%div.2541.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.4845.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.4267.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4845.clone.1), dimensions={}, metadata={op_name="broadcast.305"} - %add.3417.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%sqrt.153.clone.1, %broadcast.4267.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.1289.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%div.2543.clone.1, %add.3417.clone.1), metadata={op_name="multiply.294"} - %div.2540.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3418.clone.1, %multiply.1289.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.4689.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1367.clone.1, %broadcast.4269.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3416.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%div.2540.clone.1, %mul.4689.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.4688.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%mul.4690.clone.1, %add.3416.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3415.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%bitcast.1367.clone.1, %mul.4688.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.579 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%add.3415.clone.1, %add.3415.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5071 = f32[]{:T(128)} constant(0) - %reduce.677 = f32[]{:T(128)} reduce(%square.579, %constant.5071), dimensions={0,1,2,3}, to_apply=%region_225.250, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %bitcast.831.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3419.clone.1) - %bitcast.804.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3418.clone.1) - %reduce.686.clone.1 = f32[]{:T(128)} reduce(%integer_pow.395.clone.1, %constant.5071), dimensions={0,1,2,3}, to_apply=%region_191.216, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.668 = (f32[]{:T(128)}, f32[256,1,512,512]{3,2,0,1:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.677, %add.3415.clone.1, %bitcast.831.clone.1, %bitcast.804.clone.1, %reduce.686.clone.1) + %div.2532.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_0.4136), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2531.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3405.clone.1, %div.2532.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.153.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} sqrt(%div.2531.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.4826.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.4240.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4826.clone.1), dimensions={}, metadata={op_name="broadcast.305"} + %add.3403.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%sqrt.153.clone.1, %broadcast.4240.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.1289.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%div.2533.clone.1, %add.3403.clone.1), metadata={op_name="multiply.294"} + %div.2530.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3404.clone.1, %multiply.1289.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.4677.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1367.clone.1, %broadcast.4242.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3402.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%div.2530.clone.1, %mul.4677.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.4676.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%mul.4678.clone.1, %add.3402.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3401.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%bitcast.1367.clone.1, %mul.4676.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.579 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%add.3401.clone.1, %add.3401.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5052 = f32[]{:T(128)} constant(0) + %reduce.677 = f32[]{:T(128)} reduce(%square.579, %constant.5052), dimensions={0,1,2,3}, to_apply=%region_225.250, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %bitcast.831.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3405.clone.1) + %bitcast.804.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3404.clone.1) + %reduce.686.clone.1 = f32[]{:T(128)} reduce(%integer_pow.395.clone.1, %constant.5052), dimensions={0,1,2,3}, to_apply=%region_191.216, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.664 = (f32[]{:T(128)}, f32[256,1,512,512]{3,2,0,1:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.677, %add.3401.clone.1, %bitcast.831.clone.1, %bitcast.804.clone.1, %reduce.686.clone.1) } %region_155.180 (reduce_sum.438: f32[], reduce_sum.259: f32[]) -> f32[] { @@ -1622,39 +1622,39 @@ StackFrames %fused_computation.529.clone.clone.clone (param_0.4079: bf16[4,128,129280], param_1.4953: s32[4,128], param_2.4225: f32[4,128], param_3.2913: f32[4,128], param_4.2170: bf16[4,128], param_5.1978: f32[4,128]) -> bf16[4,128,129280] { %param_5.1978 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) - %mul.4903 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_5.1978), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %mul.4891 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_5.1978), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} %param_3.2913 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) - %mul.4902 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_3.2913), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %mul.4890 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_3.2913), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} %param_0.4079 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} parameter(0) - %convert_element_type.3163 = f32[4,128,129280]{2,1,0:T(8,128)} convert(%param_0.4079), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %convert_element_type.3155 = f32[4,128,129280]{2,1,0:T(8,128)} convert(%param_0.4079), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} %param_4.2170 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) - %sub.804 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_4.2170), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %sub.803 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%convert_element_type.3163, %sub.804), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %exp.534 = f32[4,128,129280]{2,1,0:T(8,128)} exponential(%sub.803), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=0} - %mul.4901 = f32[4,128,129280]{2,1,0:T(8,128)} multiply(%mul.4902, %exp.534), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %sub.791 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_4.2170), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.790 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%convert_element_type.3155, %sub.791), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %exp.534 = f32[4,128,129280]{2,1,0:T(8,128)} exponential(%sub.790), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=0} + %mul.4889 = f32[4,128,129280]{2,1,0:T(8,128)} multiply(%mul.4890, %exp.534), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} %param_2.4225 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %div.2698 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_2.4225), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} - %div.2697 = f32[4,128,129280]{2,1,0:T(8,128)} divide(%mul.4901, %div.2698), metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} + %div.2688 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_2.4225), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} + %div.2687 = f32[4,128,129280]{2,1,0:T(8,128)} divide(%mul.4889, %div.2688), metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} %param_1.4953 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %eq.371 = s32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_1.4953), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} - %eq.370 = s32[4,128,129280]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} - %eq.369 = pred[4,128,129280]{2,1,0:T(8,128)(4,1)} compare(%eq.371, %eq.370), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} - %convert_element_type.3162 = f32[4,128,129280]{2,1,0:T(8,128)} convert(%eq.369), metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/convert_element_type" stack_frame_id=0} - %sub.802 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%div.2697, %convert_element_type.3162), metadata={op_name="jit(train_step)/transpose(jvp())/sub" stack_frame_id=0} - %mul.4900 = f32[4,128,129280]{2,1,0:T(8,128)} multiply(%mul.4903, %sub.802), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - ROOT %convert_element_type.3161 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} convert(%mul.4900), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %eq.363 = s32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_1.4953), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %eq.362 = s32[4,128,129280]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %eq.361 = pred[4,128,129280]{2,1,0:T(8,128)(4,1)} compare(%eq.363, %eq.362), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %convert_element_type.3154 = f32[4,128,129280]{2,1,0:T(8,128)} convert(%eq.361), metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/convert_element_type" stack_frame_id=0} + %sub.789 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%div.2687, %convert_element_type.3154), metadata={op_name="jit(train_step)/transpose(jvp())/sub" stack_frame_id=0} + %mul.4888 = f32[4,128,129280]{2,1,0:T(8,128)} multiply(%mul.4891, %sub.789), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + ROOT %convert_element_type.3153 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} convert(%mul.4888), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} } %fused_computation.939.clone.clone (param_0.4080: f32[4,128], param_1.4954: bf16[4,128,512], param_2.4227: bf16[512]) -> bf16[4,128,512] { %param_2.4227 = bf16[512]{0:T(512)(128)(2,1)S(1)} parameter(2) %dot_general.831 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.4227), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} %param_1.4954 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) - %convert_element_type.3165 = f32[4,128,512]{2,1,0:T(8,128)} convert(%param_1.4954), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %convert_element_type.3157 = f32[4,128,512]{2,1,0:T(8,128)} convert(%param_1.4954), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} %param_0.4080 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.4905 = f32[4,128,512]{2,1,0:T(8,128)} broadcast(%param_0.4080), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %mul.4904 = f32[4,128,512]{2,1,0:T(8,128)} multiply(%convert_element_type.3165, %mul.4905), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %convert_element_type.3164 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} convert(%mul.4904), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - ROOT %dot_general.830 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.831, %convert_element_type.3164), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} + %mul.4893 = f32[4,128,512]{2,1,0:T(8,128)} broadcast(%param_0.4080), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.4892 = f32[4,128,512]{2,1,0:T(8,128)} multiply(%convert_element_type.3157, %mul.4893), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convert_element_type.3156 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} convert(%mul.4892), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + ROOT %dot_general.830 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.831, %convert_element_type.3156), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} } %fused_computation.518 (param_0.4169: bf16[4,128,129280], param_1.5049: s32[4,128], param_2.4319: f32[4,128], param_3.2969: f32[4,128], param_4.2219: bf16[4,128], param_5.2020: f32[4,128], param_6.1457: f32[4,128], param_7.1138: bf16[4,128,512], param_8.902: bf16[512]) -> (f32[], bf16[512,129280,1]) { @@ -1671,11 +1671,11 @@ StackFrames %multiply_convert_fusion.1.clone.1 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} fusion(%param_0.4169, %param_1.5049, %param_2.4319, %param_3.2969, %param_4.2219, /*index=5*/%param_5.2020), kind=kLoop, calls=%fused_computation.529.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} %convolution.141.clone.1 = bf16[512,129280,1]{1,0,2:T(8,128)(2,1)} convolution(%fusion.577.clone.1, %multiply_convert_fusion.1.clone.1), window={size=4}, dim_labels=0fb_0io->bf0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} %bitcast.776 = bf16[512,129280]{1,0:T(8,128)(2,1)} bitcast(%convolution.141.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} - %convert_element_type.2665 = f32[512,129280]{1,0:T(8,128)} convert(%bitcast.776), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} - %square.581 = f32[512,129280]{1,0:T(8,128)} multiply(%convert_element_type.2665, %convert_element_type.2665), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5104 = f32[]{:T(128)} constant(0) - %reduce.678 = f32[]{:T(128)} reduce(%square.581, %constant.5104), dimensions={0,1}, to_apply=%region_155.180, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.757 = (f32[]{:T(128)}, bf16[512,129280,1]{1,0,2:T(8,128)(2,1)}) tuple(%reduce.678, %convolution.141.clone.1) + %convert_element_type.2657 = f32[512,129280]{1,0:T(8,128)} convert(%bitcast.776), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} + %square.581 = f32[512,129280]{1,0:T(8,128)} multiply(%convert_element_type.2657, %convert_element_type.2657), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5085 = f32[]{:T(128)} constant(0) + %reduce.678 = f32[]{:T(128)} reduce(%square.581, %constant.5085), dimensions={0,1}, to_apply=%region_155.180, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.753 = (f32[]{:T(128)}, bf16[512,129280,1]{1,0,2:T(8,128)(2,1)}) tuple(%reduce.678, %convolution.141.clone.1) } %region_174.199 (reduce_sum.564: f32[], reduce_sum.387: f32[]) -> f32[] { @@ -1686,10 +1686,10 @@ StackFrames %fused_computation.519 (param_0.4153: bf16[129280,512]) -> f32[] { %param_0.4153 = bf16[129280,512]{1,0:T(8,128)(2,1)} parameter(0) - %convert_element_type.2667 = f32[129280,512]{1,0:T(8,128)} convert(%param_0.4153), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} - %square.583 = f32[129280,512]{1,0:T(8,128)} multiply(%convert_element_type.2667, %convert_element_type.2667), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5088 = f32[]{:T(128)} constant(0) - ROOT %reduce.679 = f32[]{:T(128)} reduce(%square.583, %constant.5088), dimensions={0,1}, to_apply=%region_174.199, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %convert_element_type.2659 = f32[129280,512]{1,0:T(8,128)} convert(%param_0.4153), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} + %square.583 = f32[129280,512]{1,0:T(8,128)} multiply(%convert_element_type.2659, %convert_element_type.2659), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5069 = f32[]{:T(128)} constant(0) + ROOT %reduce.679 = f32[]{:T(128)} reduce(%square.583, %constant.5069), dimensions={0,1}, to_apply=%region_174.199, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} } %region_240.265 (reduce_sum.1026: f32[], reduce_sum.689: f32[]) -> f32[] { @@ -1707,52 +1707,52 @@ StackFrames %fused_computation.520 (param_0.4121: f32[129280,512], param_1.5006: f32[], param_2.4279: f32[], param_3.2932: f32[], param_4.2184: f32[129280,512], param_5.1987: f32[], param_6.1424: bf16[129280,512], param_7.1105: pred[], param_8.870: f32[129280,512]) -> (f32[], f32[129280,512], f32[129280,512], f32[129280,512], f32[]) { %param_0.4121 = f32[129280,512]{1,0:T(8,128)} parameter(0) %param_3.2932 = f32[]{:T(128)S(6)} parameter(3) - %mul.4564.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%param_3.2932), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.4552.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%param_3.2932), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.1105 = pred[]{:T(512)S(6)} parameter(7) - %select_n.2105.clone.1 = pred[129280,512]{1,0:T(8,128)(4,1)} broadcast(%param_7.1105), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %select_n.2061.clone.1 = pred[129280,512]{1,0:T(8,128)(4,1)} broadcast(%param_7.1105), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} %param_6.1424 = bf16[129280,512]{1,0:T(8,128)(2,1)} parameter(6) - %convert_element_type.3106.clone.1 = f32[129280,512]{1,0:T(8,128)} convert(%param_6.1424), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} + %convert_element_type.3098.clone.1 = f32[129280,512]{1,0:T(8,128)} convert(%param_6.1424), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} %param_5.1987 = f32[]{:T(128)} parameter(5) - %div.2439.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%param_5.1987), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2438.clone.1 = f32[129280,512]{1,0:T(8,128)} divide(%convert_element_type.3106.clone.1, %div.2439.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.2104.clone.1 = f32[129280,512]{1,0:T(8,128)} select(%select_n.2105.clone.1, %convert_element_type.3106.clone.1, %div.2438.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.4754.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.4209.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4754.clone.1), dimensions={}, metadata={op_name="broadcast.318"} - %mul.4570.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%select_n.2104.clone.1, %broadcast.4209.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %div.2429.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%param_5.1987), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2428.clone.1 = f32[129280,512]{1,0:T(8,128)} divide(%convert_element_type.3098.clone.1, %div.2429.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.2060.clone.1 = f32[129280,512]{1,0:T(8,128)} select(%select_n.2061.clone.1, %convert_element_type.3098.clone.1, %div.2428.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.4735.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.4182.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4735.clone.1), dimensions={}, metadata={op_name="broadcast.318"} + %mul.4558.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%select_n.2060.clone.1, %broadcast.4182.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.870 = f32[129280,512]{1,0:T(8,128)} parameter(8) - %constant.4758.clone.1 = f32[]{:T(128)} constant(0.9) - %mul.4571.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4758.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4569.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%param_8.870, %mul.4571.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3338.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%mul.4570.clone.1, %mul.4569.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %constant.4739.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.4559.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4739.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.4557.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%param_8.870, %mul.4559.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3324.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%mul.4558.clone.1, %mul.4557.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} %param_2.4279 = f32[]{:T(128)S(6)} parameter(2) - %div.2435.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%param_2.4279), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %integer_pow.380.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%select_n.2104.clone.1, %select_n.2104.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.4757.clone.1 = f32[]{:T(128)} constant(0.05) - %mul.4568.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4757.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4566.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%integer_pow.380.clone.1, %mul.4568.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %div.2425.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%param_2.4279), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.380.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%select_n.2060.clone.1, %select_n.2060.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.4738.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.4556.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4738.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.4554.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%integer_pow.380.clone.1, %mul.4556.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_4.2184 = f32[129280,512]{1,0:T(8,128)} parameter(4) - %constant.4756.clone.1 = f32[]{:T(128)} constant(0.95) - %mul.4567.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4756.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4565.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%param_4.2184, %mul.4567.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3337.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%mul.4566.clone.1, %mul.4565.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %constant.4737.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.4555.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4737.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.4553.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%param_4.2184, %mul.4555.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3323.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%mul.4554.clone.1, %mul.4553.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} %param_1.5006 = f32[]{:T(128)S(6)} parameter(1) - %div.2434.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%param_1.5006), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2433.clone.1 = f32[129280,512]{1,0:T(8,128)} divide(%add.3337.clone.1, %div.2434.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %sqrt.138.clone.1 = f32[129280,512]{1,0:T(8,128)} sqrt(%div.2433.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.4755.clone.1 = f32[]{:T(128)} constant(1e-08) - %add.3336.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4755.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %add.3335.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%sqrt.138.clone.1, %add.3336.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.1274.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%div.2435.clone.1, %add.3335.clone.1), metadata={op_name="multiply.309"} - %div.2432.clone.1 = f32[129280,512]{1,0:T(8,128)} divide(%add.3338.clone.1, %multiply.1274.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.4563.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%param_0.4121, %broadcast.4209.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3334.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%div.2432.clone.1, %mul.4563.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.4562.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%mul.4564.clone.1, %add.3334.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3333.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%param_0.4121, %mul.4562.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.584 = f32[129280,512]{1,0:T(8,128)} multiply(%add.3333.clone.1, %add.3333.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5056 = f32[]{:T(128)} constant(0) - %reduce.680 = f32[]{:T(128)} reduce(%square.584, %constant.5056), dimensions={0,1}, to_apply=%region_240.265, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.687.clone.1 = f32[]{:T(128)} reduce(%integer_pow.380.clone.1, %constant.5056), dimensions={0,1}, to_apply=%region_206.231, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.671 = (f32[]{:T(128)}, f32[129280,512]{1,0:T(8,128)}, f32[129280,512]{1,0:T(8,128)}, f32[129280,512]{1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.680, %add.3333.clone.1, %add.3337.clone.1, %add.3338.clone.1, %reduce.687.clone.1) + %div.2424.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%param_1.5006), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2423.clone.1 = f32[129280,512]{1,0:T(8,128)} divide(%add.3323.clone.1, %div.2424.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.138.clone.1 = f32[129280,512]{1,0:T(8,128)} sqrt(%div.2423.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.4736.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.3322.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4736.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.3321.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%sqrt.138.clone.1, %add.3322.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.1274.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%div.2425.clone.1, %add.3321.clone.1), metadata={op_name="multiply.309"} + %div.2422.clone.1 = f32[129280,512]{1,0:T(8,128)} divide(%add.3324.clone.1, %multiply.1274.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.4551.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%param_0.4121, %broadcast.4182.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3320.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%div.2422.clone.1, %mul.4551.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.4550.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%mul.4552.clone.1, %add.3320.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3319.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%param_0.4121, %mul.4550.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.584 = f32[129280,512]{1,0:T(8,128)} multiply(%add.3319.clone.1, %add.3319.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5037 = f32[]{:T(128)} constant(0) + %reduce.680 = f32[]{:T(128)} reduce(%square.584, %constant.5037), dimensions={0,1}, to_apply=%region_240.265, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.687.clone.1 = f32[]{:T(128)} reduce(%integer_pow.380.clone.1, %constant.5037), dimensions={0,1}, to_apply=%region_206.231, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.667 = (f32[]{:T(128)}, f32[129280,512]{1,0:T(8,128)}, f32[129280,512]{1,0:T(8,128)}, f32[129280,512]{1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.680, %add.3319.clone.1, %add.3323.clone.1, %add.3324.clone.1, %reduce.687.clone.1) } %region_222.247 (reduce_sum.900: f32[], reduce_sum.605: f32[]) -> f32[] { @@ -1770,53 +1770,53 @@ StackFrames %fused_computation.521 (param_0.4139: f32[512,129280], param_1.5024: f32[], param_2.4297: f32[], param_3.2950: f32[], param_4.2202: f32[512,129280], param_5.2005: f32[], param_6.1442: bf16[512,129280,1], param_7.1123: pred[], param_8.888: f32[512,129280]) -> (f32[], f32[512,129280], f32[512,129280], f32[512,129280], f32[]) { %param_0.4139 = f32[512,129280]{1,0:T(8,128)} parameter(0) %param_3.2950 = f32[]{:T(128)S(6)} parameter(3) - %mul.4717.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%param_3.2950), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.4705.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%param_3.2950), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.1123 = pred[]{:T(512)S(6)} parameter(7) - %select_n.2161.clone.1 = pred[512,129280]{1,0:T(8,128)(4,1)} broadcast(%param_7.1123), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %select_n.2117.clone.1 = pred[512,129280]{1,0:T(8,128)(4,1)} broadcast(%param_7.1123), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} %param_6.1442 = bf16[512,129280,1]{1,0,2:T(8,128)(2,1)} parameter(6) %bitcast.1372.clone.1 = bf16[512,129280]{1,0:T(8,128)(2,1)} bitcast(%param_6.1442), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} - %convert_element_type.3108.clone.1 = f32[512,129280]{1,0:T(8,128)} convert(%bitcast.1372.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} + %convert_element_type.3100.clone.1 = f32[512,129280]{1,0:T(8,128)} convert(%bitcast.1372.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} %param_5.2005 = f32[]{:T(128)} parameter(5) - %div.2567.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%param_5.2005), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2566.clone.1 = f32[512,129280]{1,0:T(8,128)} divide(%convert_element_type.3108.clone.1, %div.2567.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.2160.clone.1 = f32[512,129280]{1,0:T(8,128)} select(%select_n.2161.clone.1, %convert_element_type.3108.clone.1, %div.2566.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.4858.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.4277.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4858.clone.1), dimensions={}, metadata={op_name="broadcast.333"} - %mul.4723.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%select_n.2160.clone.1, %broadcast.4277.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %div.2557.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%param_5.2005), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2556.clone.1 = f32[512,129280]{1,0:T(8,128)} divide(%convert_element_type.3100.clone.1, %div.2557.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.2116.clone.1 = f32[512,129280]{1,0:T(8,128)} select(%select_n.2117.clone.1, %convert_element_type.3100.clone.1, %div.2556.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.4839.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.4250.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4839.clone.1), dimensions={}, metadata={op_name="broadcast.333"} + %mul.4711.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%select_n.2116.clone.1, %broadcast.4250.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.888 = f32[512,129280]{1,0:T(8,128)} parameter(8) - %constant.4862.clone.1 = f32[]{:T(128)} constant(0.9) - %mul.4724.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4862.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4722.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%param_8.888, %mul.4724.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3437.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%mul.4723.clone.1, %mul.4722.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %constant.4843.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.4712.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4843.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.4710.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%param_8.888, %mul.4712.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3423.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%mul.4711.clone.1, %mul.4710.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} %param_2.4297 = f32[]{:T(128)S(6)} parameter(2) - %div.2563.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%param_2.4297), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %integer_pow.398.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%select_n.2160.clone.1, %select_n.2160.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.4861.clone.1 = f32[]{:T(128)} constant(0.05) - %mul.4721.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4861.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4719.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%integer_pow.398.clone.1, %mul.4721.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %div.2553.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%param_2.4297), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.398.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%select_n.2116.clone.1, %select_n.2116.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.4842.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.4709.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4842.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.4707.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%integer_pow.398.clone.1, %mul.4709.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_4.2202 = f32[512,129280]{1,0:T(8,128)} parameter(4) - %constant.4860.clone.1 = f32[]{:T(128)} constant(0.95) - %mul.4720.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4860.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4718.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%param_4.2202, %mul.4720.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3436.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%mul.4719.clone.1, %mul.4718.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %constant.4841.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.4708.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4841.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.4706.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%param_4.2202, %mul.4708.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3422.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%mul.4707.clone.1, %mul.4706.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} %param_1.5024 = f32[]{:T(128)S(6)} parameter(1) - %div.2562.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%param_1.5024), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2561.clone.1 = f32[512,129280]{1,0:T(8,128)} divide(%add.3436.clone.1, %div.2562.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %sqrt.156.clone.1 = f32[512,129280]{1,0:T(8,128)} sqrt(%div.2561.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.4859.clone.1 = f32[]{:T(128)} constant(1e-08) - %add.3435.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4859.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %add.3434.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%sqrt.156.clone.1, %add.3435.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.1292.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%div.2563.clone.1, %add.3434.clone.1), metadata={op_name="multiply.291"} - %div.2560.clone.1 = f32[512,129280]{1,0:T(8,128)} divide(%add.3437.clone.1, %multiply.1292.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.4716.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%param_0.4139, %broadcast.4277.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3433.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%div.2560.clone.1, %mul.4716.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.4715.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%mul.4717.clone.1, %add.3433.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3432.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%param_0.4139, %mul.4715.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.585 = f32[512,129280]{1,0:T(8,128)} multiply(%add.3432.clone.1, %add.3432.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5074 = f32[]{:T(128)} constant(0) - %reduce.681 = f32[]{:T(128)} reduce(%square.585, %constant.5074), dimensions={0,1}, to_apply=%region_222.247, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.688.clone.1 = f32[]{:T(128)} reduce(%integer_pow.398.clone.1, %constant.5074), dimensions={0,1}, to_apply=%region_188.213, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.672 = (f32[]{:T(128)}, f32[512,129280]{1,0:T(8,128)}, f32[512,129280]{1,0:T(8,128)}, f32[512,129280]{1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.681, %add.3432.clone.1, %add.3436.clone.1, %add.3437.clone.1, %reduce.688.clone.1) + %div.2552.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%param_1.5024), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2551.clone.1 = f32[512,129280]{1,0:T(8,128)} divide(%add.3422.clone.1, %div.2552.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.156.clone.1 = f32[512,129280]{1,0:T(8,128)} sqrt(%div.2551.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.4840.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.3421.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4840.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.3420.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%sqrt.156.clone.1, %add.3421.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.1292.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%div.2553.clone.1, %add.3420.clone.1), metadata={op_name="multiply.291"} + %div.2550.clone.1 = f32[512,129280]{1,0:T(8,128)} divide(%add.3423.clone.1, %multiply.1292.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.4704.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%param_0.4139, %broadcast.4250.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3419.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%div.2550.clone.1, %mul.4704.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.4703.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%mul.4705.clone.1, %add.3419.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3418.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%param_0.4139, %mul.4703.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.585 = f32[512,129280]{1,0:T(8,128)} multiply(%add.3418.clone.1, %add.3418.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5055 = f32[]{:T(128)} constant(0) + %reduce.681 = f32[]{:T(128)} reduce(%square.585, %constant.5055), dimensions={0,1}, to_apply=%region_222.247, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.688.clone.1 = f32[]{:T(128)} reduce(%integer_pow.398.clone.1, %constant.5055), dimensions={0,1}, to_apply=%region_188.213, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.668 = (f32[]{:T(128)}, f32[512,129280]{1,0:T(8,128)}, f32[512,129280]{1,0:T(8,128)}, f32[512,129280]{1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.681, %add.3418.clone.1, %add.3422.clone.1, %add.3423.clone.1, %reduce.688.clone.1) } %region_207.232 (reduce_sum.795: f32[], reduce_sum.535: f32[]) -> f32[] { @@ -1827,21 +1827,21 @@ StackFrames %fused_computation.522 (param_0.4190: bf16[4,128,129280], param_1.5063: f32[4,128], param_2.4329: s32[4,128], param_3.2977: bf16[4,128]) -> f32[4,128] { %param_2.4329 = s32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %eq.307 = s32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_2.4329), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} - %eq.302 = s32[4,128,129280]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} - %eq.301 = pred[4,128,129280]{2,1,0:T(8,128)(4,1)} compare(%eq.307, %eq.302), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %eq.299 = s32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_2.4329), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %eq.294 = s32[4,128,129280]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %eq.293 = pred[4,128,129280]{2,1,0:T(8,128)(4,1)} compare(%eq.299, %eq.294), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} %param_0.4190 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} parameter(0) - %convert_element_type.2672 = f32[4,128,129280]{2,1,0:T(8,128)} convert(%param_0.4190), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %convert_element_type.2664 = f32[4,128,129280]{2,1,0:T(8,128)} convert(%param_0.4190), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} %param_3.2977 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(3) - %sub.665 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_3.2977), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %sub.656 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%convert_element_type.2672, %sub.665), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.652 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_3.2977), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.643 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%convert_element_type.2664, %sub.652), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} %param_1.5063 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %sub.663 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_1.5063), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %sub.652 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%sub.656, %sub.663), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %constant.5128 = f32[]{:T(128)} constant(0) - %broadcast.3784 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%constant.5128), dimensions={}, metadata={op_name="broadcast.518"} - %mul.3624 = f32[4,128,129280]{2,1,0:T(8,128)} select(%eq.301, %sub.652, %broadcast.3784), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} - ROOT %reduce.682 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.3624, %constant.5128), dimensions={2}, to_apply=%region_207.232, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} + %sub.650 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_1.5063), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.639 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%sub.643, %sub.650), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %constant.5109 = f32[]{:T(128)} constant(0) + %broadcast.3757 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%constant.5109), dimensions={}, metadata={op_name="broadcast.514"} + %mul.3612 = f32[4,128,129280]{2,1,0:T(8,128)} select(%eq.293, %sub.639, %broadcast.3757), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} + ROOT %reduce.682 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.3612, %constant.5109), dimensions={2}, to_apply=%region_207.232, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} } %region_37.47 (reduce_sum.76: f32[], reduce_sum.80: f32[]) -> f32[] { @@ -1852,13 +1852,13 @@ StackFrames %fused_computation.533 (param_0.4191: bf16[4,128,129280], param_1.5064: bf16[4,128]) -> f32[4,128] { %param_0.4191 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} parameter(0) - %convert_element_type.2678 = f32[4,128,129280]{2,1,0:T(8,128)} convert(%param_0.4191), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %convert_element_type.2670 = f32[4,128,129280]{2,1,0:T(8,128)} convert(%param_0.4191), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} %param_1.5064 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(1) - %sub.666 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_1.5064), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %sub.662 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%convert_element_type.2678, %sub.666), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %exp.448 = f32[4,128,129280]{2,1,0:T(8,128)} exponential(%sub.662), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=0} - %constant.5129 = f32[]{:T(128)} constant(0) - ROOT %reduce.683 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%exp.448, %constant.5129), dimensions={2}, to_apply=%region_37.47, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} + %sub.653 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_1.5064), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.649 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%convert_element_type.2670, %sub.653), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %exp.448 = f32[4,128,129280]{2,1,0:T(8,128)} exponential(%sub.649), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=0} + %constant.5110 = f32[]{:T(128)} constant(0) + ROOT %reduce.683 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%exp.448, %constant.5110), dimensions={2}, to_apply=%region_37.47, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} } %region_152.177 (reduce_sum.417: f32[], reduce_sum.244: f32[]) -> f32[] { @@ -1871,8 +1871,8 @@ StackFrames %param_0.4172 = f32[3,512,128,256]{3,2,0,1:T(8,128)} parameter(0) %bitcast.752 = f32[512,3,128,256]{3,2,1,0:T(8,128)} bitcast(%param_0.4172), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=0} %square.588 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%bitcast.752, %bitcast.752), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5107 = f32[]{:T(128)} constant(0) - ROOT %reduce.689 = f32[]{:T(128)} reduce(%square.588, %constant.5107), dimensions={0,1,2,3}, to_apply=%region_152.177, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %constant.5088 = f32[]{:T(128)} constant(0) + ROOT %reduce.689 = f32[]{:T(128)} reduce(%square.588, %constant.5088), dimensions={0,1,2,3}, to_apply=%region_152.177, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} } %fused_computation.542 (param_0.1602: f32[512,3,128,256]) -> bf16[3,512,128,256] { @@ -1896,52 +1896,52 @@ StackFrames %fused_computation.543 (param_0.4142: f32[512,3,128,256], param_1.5027: f32[], param_2.4300: f32[], param_3.2953: f32[], param_4.2205: f32[512,3,128,256], param_5.2008: f32[], param_6.1445: f32[3,512,128,256], param_7.1126: pred[], param_8.891: f32[512,3,128,256]) -> (f32[], f32[512,3,128,256], f32[512,3,128,256], f32[512,3,128,256], f32[]) { %param_0.4142 = f32[512,3,128,256]{3,2,1,0:T(8,128)} parameter(0) %param_3.2953 = f32[]{:T(128)S(6)} parameter(3) - %mul.4747.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%param_3.2953), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.4735.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%param_3.2953), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.1126 = pred[]{:T(512)S(6)} parameter(7) - %select_n.2173.clone.1 = pred[512,3,128,256]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.1126), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %select_n.2129.clone.1 = pred[512,3,128,256]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.1126), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} %param_6.1445 = f32[3,512,128,256]{3,2,0,1:T(8,128)} parameter(6) %bitcast.1378.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} bitcast(%param_6.1445), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=0} %param_5.2008 = f32[]{:T(128)} parameter(5) - %div.2591.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%param_5.2008), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2590.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} divide(%bitcast.1378.clone.1, %div.2591.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.2172.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} select(%select_n.2173.clone.1, %bitcast.1378.clone.1, %div.2590.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.4876.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.4283.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4876.clone.1), dimensions={}, metadata={op_name="broadcast.336"} - %mul.4753.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%select_n.2172.clone.1, %broadcast.4283.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %div.2581.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%param_5.2008), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2580.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} divide(%bitcast.1378.clone.1, %div.2581.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.2128.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} select(%select_n.2129.clone.1, %bitcast.1378.clone.1, %div.2580.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.4857.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.4256.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4857.clone.1), dimensions={}, metadata={op_name="broadcast.336"} + %mul.4741.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%select_n.2128.clone.1, %broadcast.4256.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.891 = f32[512,3,128,256]{3,2,1,0:T(8,128)} parameter(8) - %constant.4880.clone.1 = f32[]{:T(128)} constant(0.9) - %mul.4754.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4880.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4752.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%param_8.891, %mul.4754.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3455.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%mul.4753.clone.1, %mul.4752.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %constant.4861.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.4742.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4861.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.4740.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%param_8.891, %mul.4742.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3441.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%mul.4741.clone.1, %mul.4740.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} %param_2.4300 = f32[]{:T(128)S(6)} parameter(2) - %div.2587.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%param_2.4300), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %integer_pow.401.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%select_n.2172.clone.1, %select_n.2172.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.4879.clone.1 = f32[]{:T(128)} constant(0.05) - %mul.4751.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4879.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4749.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%integer_pow.401.clone.1, %mul.4751.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %div.2577.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%param_2.4300), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.401.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%select_n.2128.clone.1, %select_n.2128.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.4860.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.4739.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4860.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.4737.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%integer_pow.401.clone.1, %mul.4739.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_4.2205 = f32[512,3,128,256]{3,2,1,0:T(8,128)} parameter(4) - %constant.4878.clone.1 = f32[]{:T(128)} constant(0.95) - %mul.4750.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4878.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4748.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%param_4.2205, %mul.4750.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3454.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%mul.4749.clone.1, %mul.4748.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %constant.4859.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.4738.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4859.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.4736.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%param_4.2205, %mul.4738.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3440.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%mul.4737.clone.1, %mul.4736.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} %param_1.5027 = f32[]{:T(128)S(6)} parameter(1) - %div.2586.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%param_1.5027), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2585.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} divide(%add.3454.clone.1, %div.2586.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %sqrt.159.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} sqrt(%div.2585.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.4877.clone.1 = f32[]{:T(128)} constant(1e-08) - %add.3453.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4877.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %add.3452.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%sqrt.159.clone.1, %add.3453.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.1295.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%div.2587.clone.1, %add.3452.clone.1), metadata={op_name="multiply.288"} - %div.2584.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} divide(%add.3455.clone.1, %multiply.1295.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.4746.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%param_0.4142, %broadcast.4283.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3451.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%div.2584.clone.1, %mul.4746.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.4745.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%mul.4747.clone.1, %add.3451.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3450.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%param_0.4142, %mul.4745.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.589 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%add.3450.clone.1, %add.3450.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5077 = f32[]{:T(128)} constant(0) - %reduce.690 = f32[]{:T(128)} reduce(%square.589, %constant.5077), dimensions={0,1,2,3}, to_apply=%region_219.244, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.691.clone.1 = f32[]{:T(128)} reduce(%integer_pow.401.clone.1, %constant.5077), dimensions={0,1,2,3}, to_apply=%region_185.210, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.667 = (f32[]{:T(128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.690, %add.3450.clone.1, %add.3454.clone.1, %add.3455.clone.1, %reduce.691.clone.1) + %div.2576.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%param_1.5027), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2575.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} divide(%add.3440.clone.1, %div.2576.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.159.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} sqrt(%div.2575.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.4858.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.3439.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4858.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.3438.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%sqrt.159.clone.1, %add.3439.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.1295.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%div.2577.clone.1, %add.3438.clone.1), metadata={op_name="multiply.288"} + %div.2574.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} divide(%add.3441.clone.1, %multiply.1295.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.4734.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%param_0.4142, %broadcast.4256.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3437.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%div.2574.clone.1, %mul.4734.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.4733.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%mul.4735.clone.1, %add.3437.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3436.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%param_0.4142, %mul.4733.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.589 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%add.3436.clone.1, %add.3436.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5058 = f32[]{:T(128)} constant(0) + %reduce.690 = f32[]{:T(128)} reduce(%square.589, %constant.5058), dimensions={0,1,2,3}, to_apply=%region_219.244, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.691.clone.1 = f32[]{:T(128)} reduce(%integer_pow.401.clone.1, %constant.5058), dimensions={0,1,2,3}, to_apply=%region_185.210, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.663 = (f32[]{:T(128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.690, %add.3436.clone.1, %add.3440.clone.1, %add.3441.clone.1, %reduce.691.clone.1) } %region_172.197 (reduce_sum.557: f32[], reduce_sum.381: f32[]) -> f32[] { @@ -1954,12 +1954,12 @@ StackFrames %param_2.4261 = bf16[1536]{0:T(1024)(128)(2,1)S(1)} parameter(2) %dot_general.851 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.4261), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/...k,k->...k/dot_general" stack_frame_id=0} %param_1.4998 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) - %convert_element_type.3187 = f32[4,128,1536]{2,1,0:T(8,128)} convert(%param_1.4998), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/convert_element_type" stack_frame_id=0} + %convert_element_type.3179 = f32[4,128,1536]{2,1,0:T(8,128)} convert(%param_1.4998), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/convert_element_type" stack_frame_id=0} %param_0.4106 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.4951 = f32[4,128,1536]{2,1,0:T(8,128)} broadcast(%param_0.4106), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/mul" stack_frame_id=0} - %mul.4950 = f32[4,128,1536]{2,1,0:T(8,128)} multiply(%convert_element_type.3187, %mul.4951), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/mul" stack_frame_id=0} - %convert_element_type.3186 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)} convert(%mul.4950), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/convert_element_type" stack_frame_id=0} - %dot_general.850 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.851, %convert_element_type.3186), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/...k,k->...k/dot_general" stack_frame_id=0} + %mul.4939 = f32[4,128,1536]{2,1,0:T(8,128)} broadcast(%param_0.4106), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/mul" stack_frame_id=0} + %mul.4938 = f32[4,128,1536]{2,1,0:T(8,128)} multiply(%convert_element_type.3179, %mul.4939), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/mul" stack_frame_id=0} + %convert_element_type.3178 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)} convert(%mul.4938), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/convert_element_type" stack_frame_id=0} + %dot_general.850 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.851, %convert_element_type.3178), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/...k,k->...k/dot_general" stack_frame_id=0} ROOT %bitcast.1466 = bf16[4,128,1536,1]{2,1,0,3:T(8,128)(2,1)} bitcast(%dot_general.850), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/...k,k->...k/dot_general" stack_frame_id=0} } @@ -1977,12 +1977,12 @@ StackFrames %fusion.751 = bf16[4,128,128,192]{2,1,0,3:T(8,128)(2,1)} fusion(%param_0.4154), kind=kLoop, calls=%bitcast_fusion.12 %convolution.146.clone.1 = bf16[1536,128,192,1]{1,0,3,2:T(8,128)(2,1)} convolution(%fusion.460.clone.1, %fusion.751), window={size=192x4 pad=191_191x0_0 rhs_reversal=1x0}, dim_labels=1fb0_1io0->bf01, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/dot_general" stack_frame_id=0} %bitcast.861 = bf16[1536,128,192]{1,0,2:T(8,128)(2,1)} bitcast(%convolution.146.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/dot_general" stack_frame_id=0} - %broadcast_in_dim.1388 = f32[1536,128,192]{1,0,2:T(8,128)} convert(%bitcast.861), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} - %bitcast.763 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} bitcast(%broadcast_in_dim.1388), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} + %broadcast_in_dim.1275 = f32[1536,128,192]{1,0,2:T(8,128)} convert(%bitcast.861), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} + %bitcast.763 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} bitcast(%broadcast_in_dim.1275), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} %square.592 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} multiply(%bitcast.763, %bitcast.763), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5089 = f32[]{:T(128)} constant(0) - %reduce.692 = f32[]{:T(128)} reduce(%square.592, %constant.5089), dimensions={0,1,2,3}, to_apply=%region_172.197, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.766 = (f32[]{:T(128)}, bf16[1536,128,192,1]{1,0,3,2:T(8,128)(2,1)}) tuple(%reduce.692, %convolution.146.clone.1) + %constant.5070 = f32[]{:T(128)} constant(0) + %reduce.692 = f32[]{:T(128)} reduce(%square.592, %constant.5070), dimensions={0,1,2,3}, to_apply=%region_172.197, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.762 = (f32[]{:T(128)}, bf16[1536,128,192,1]{1,0,3,2:T(8,128)(2,1)}) tuple(%reduce.692, %convolution.146.clone.1) } %region_239.264 (reduce_sum.1019: f32[], reduce_sum.687: f32[]) -> f32[] {