|
| 1 | +# Copyright 2023–2026 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""SparseCore gather-reduce kernel implementation using Pallas. |
| 16 | +
|
| 17 | +This module contains a Pallas kernel implementation for performing a |
| 18 | +gather-reduce operation on TPU SparseCore. It groups rows of an operand |
| 19 | +based on provided indices, sums them up, and scatters the results. |
| 20 | +""" |
| 21 | + |
| 22 | +import functools |
| 23 | + |
| 24 | +import jax |
| 25 | +from jax import lax |
| 26 | +from jax.experimental import pallas as pl |
| 27 | +from jax.experimental.pallas import tpu as pltpu |
| 28 | +from jax.experimental.pallas import tpu_sc as plsc |
| 29 | +import jax.numpy as jnp |
| 30 | + |
| 31 | + |
| 32 | +def sc_gather_reduce( |
| 33 | + op: jax.Array, |
| 34 | + idx: jax.Array, |
| 35 | + topk_weights: jax.Array | None = None, |
| 36 | + *, |
| 37 | + reduce_group_size: int, |
| 38 | + single_sc: bool = False, |
| 39 | + col_chunk_size: int = int(3.5 * 1024), |
| 40 | + row_chunk_size: int = 512, |
| 41 | + topk_wgt_zero_nan: bool = False, |
| 42 | +) -> jax.Array: |
| 43 | + """Performs a gather-reduce operation on SparseCore. |
| 44 | +
|
| 45 | + This kernel groups rows of the operand ``op`` based on ``idx``, sums them |
| 46 | + up, and scatters the results. The gather and add operations are performed |
| 47 | + in fp32, and the results are written back in bf16. |
| 48 | +
|
| 49 | + Equivalent JAX code:: |
| 50 | +
|
| 51 | + gathered = op[idx, :] |
| 52 | + if topk_weights is not None: |
| 53 | + flat_weights = topk_weights.flatten() |
| 54 | + gathered = gathered * flat_weights[:, None].astype(jnp.float32) |
| 55 | + gathered = jnp.reshape(gathered, (-1, reduce_group_size, op.shape[1])) |
| 56 | + output = jnp.sum(gathered.astype(jnp.float32), axis=1).astype(jnp.bfloat16) |
| 57 | +
|
| 58 | + Args: |
| 59 | + op: The operand matrix [B, K] in f32 or bf16 to gather from and reduce. |
| 60 | + idx: The indices [M,] in int32 guiding the gather. |
| 61 | + topk_weights: Optional weights [M // 128, 128] in bf16 to apply to the |
| 62 | + gathered rows before reduction. |
| 63 | + reduce_group_size: The number of gathered rows to sum per output row. |
| 64 | + single_sc: Whether to use a single SparseCore. |
| 65 | + col_chunk_size: The size of column chunks to process. |
| 66 | + row_chunk_size: The size of row chunks for internal processing. Must be ``2 |
| 67 | + * reduce_group_size``. |
| 68 | + topk_wgt_zero_nan: If True, treat zero ``topk_weights`` as indicators of NaN |
| 69 | + during multiplication, resulting in zero output. |
| 70 | +
|
| 71 | + Returns: |
| 72 | + The reduced result as a bf16 matrix [M / reduce_group_size, K]. |
| 73 | + """ |
| 74 | + if op.dtype != jnp.bfloat16: |
| 75 | + raise ValueError(f"op.dtype must be f32 or bf16, but got {op.dtype}") |
| 76 | + if op.shape[0] % reduce_group_size != 0: |
| 77 | + raise ValueError(f"{op.shape[0]=} must be divisible by {reduce_group_size=}") |
| 78 | + |
| 79 | + sc_info = pltpu.get_tpu_info().sparse_core |
| 80 | + if sc_info is None: |
| 81 | + raise RuntimeError("SparseCore is not available on this TPU version.") |
| 82 | + |
| 83 | + [M] = idx.shape |
| 84 | + _, K = op.shape |
| 85 | + M_out = M // reduce_group_size |
| 86 | + |
| 87 | + if topk_weights is not None: |
| 88 | + topk_weights = topk_weights.flatten() |
| 89 | + |
| 90 | + @jax.jit |
| 91 | + @pl.kernel( |
| 92 | + out_type=jax.ShapeDtypeStruct((M_out, K), op.dtype), |
| 93 | + mesh=plsc.VectorSubcoreMesh( |
| 94 | + core_axis_name="core", |
| 95 | + subcore_axis_name="subcore", |
| 96 | + num_cores=1 if single_sc else 2, |
| 97 | + ), |
| 98 | + compiler_params=pltpu.CompilerParams(needs_layout_passes=True), |
| 99 | + ) |
| 100 | + def kernel(in_hbm_ref, idx_hbm_ref, weights_hbm_ref, out_hbm_ref): |
| 101 | + row_wave_size = row_chunk_size * lax.axis_size(("core", "subcore")) |
| 102 | + if M % row_wave_size: |
| 103 | + raise NotImplementedError( |
| 104 | + f"{M=} must be divisible by {row_chunk_size=} *" |
| 105 | + f" num_cores={lax.axis_size('core')} *" |
| 106 | + f" num_vector_subcores={lax.axis_size('subcore')} = {row_wave_size}" |
| 107 | + ) |
| 108 | + num_row_chunks = M // row_wave_size |
| 109 | + num_col_chunks = K // col_chunk_size |
| 110 | + packing = 32 // jax.dtypes.itemsize_bits(op.dtype) |
| 111 | + |
| 112 | + subcore_first_row_chunk = lax.axis_index(("core", "subcore")) * num_row_chunks |
| 113 | + |
| 114 | + in_spec = pl.BlockSpec((row_chunk_size,), lambda i: (subcore_first_row_chunk + i,)) |
| 115 | + in_specs = (in_spec,) * (1 + (weights_hbm_ref is not None)) |
| 116 | + |
| 117 | + @functools.partial(pltpu.emit_pipeline, grid=(num_row_chunks,), in_specs=in_specs) |
| 118 | + def idx_pipeline(idx_ref, weights_ref=None): |
| 119 | + row_chunk_idx = subcore_first_row_chunk + pl.program_id(0) |
| 120 | + |
| 121 | + row_subchunk_size = 16 |
| 122 | + out_rows_per_step = row_subchunk_size // reduce_group_size |
| 123 | + assert reduce_group_size * out_rows_per_step == sc_info.num_lanes |
| 124 | + num_row_subchunks = row_chunk_size // row_subchunk_size |
| 125 | + if row_chunk_size % row_subchunk_size: |
| 126 | + raise ValueError(f"row_chunk_size needs to be a multiple of {row_subchunk_size}, but" f" got {row_chunk_size}") |
| 127 | + |
| 128 | + @functools.partial( |
| 129 | + pltpu.emit_pipeline, |
| 130 | + grid=(num_row_subchunks, num_col_chunks), |
| 131 | + in_specs=pl.BlockSpec( |
| 132 | + (pl.Indirect(row_subchunk_size), col_chunk_size), |
| 133 | + lambda r, c: ( |
| 134 | + lax.div( |
| 135 | + idx_ref[pl.ds(r * row_subchunk_size, row_subchunk_size)], |
| 136 | + packing, |
| 137 | + ), |
| 138 | + c, |
| 139 | + ), |
| 140 | + ), |
| 141 | + out_specs=pl.BlockSpec( |
| 142 | + (out_rows_per_step // packing, col_chunk_size), |
| 143 | + lambda r, c: (row_chunk_idx * num_row_subchunks + r, c), |
| 144 | + ), |
| 145 | + ) |
| 146 | + def data_pipeline(gather_ref, out_ref): |
| 147 | + gather_ref = gather_ref.bitcast(op.dtype) |
| 148 | + out_ref = out_ref.bitcast(op.dtype) |
| 149 | + |
| 150 | + row_slice = pl.ds(pl.program_id(0) * row_subchunk_size, row_subchunk_size) |
| 151 | + subchunk_idxs = idx_ref[row_slice] |
| 152 | + weights = None if weights_ref is None else weights_ref[row_slice].astype(jnp.float32) |
| 153 | + |
| 154 | + unpack_col_chunk = 32 # 32 seems to works best when tuning. |
| 155 | + |
| 156 | + @plsc.parallel_loop(0, col_chunk_size, step=unpack_col_chunk) |
| 157 | + def _(col_base): |
| 158 | + accs = [] |
| 159 | + for reduce_group in range(out_rows_per_step): |
| 160 | + acc = jnp.zeros((unpack_col_chunk,), dtype=jnp.float32) |
| 161 | + for row_in_group in range(reduce_group_size): |
| 162 | + row = reduce_group * reduce_group_size + row_in_group |
| 163 | + row_data = gather_ref[pl.ds(row * packing, packing), pl.ds(col_base, unpack_col_chunk)].astype(jnp.float32) |
| 164 | + if packing == 1: |
| 165 | + row_data = row_data[0] |
| 166 | + else: |
| 167 | + assert packing == 2 |
| 168 | + # For dtypes narrower than 32-bit, we end up gathering multiple |
| 169 | + # rows (since we had to bitcast to int32 before the gather). |
| 170 | + # This uses the remainder of the packing to choose the only row |
| 171 | + # we actually care about. |
| 172 | + row_data = jnp.where( |
| 173 | + lax.rem(subchunk_idxs[row], 2) == 0, |
| 174 | + row_data[0], |
| 175 | + row_data[1], |
| 176 | + ) |
| 177 | + if weights is not None: |
| 178 | + row_data *= weights[row] |
| 179 | + if topk_wgt_zero_nan: |
| 180 | + row_data = jnp.where(weights[row] == 0.0, jnp.zeros_like(row_data), row_data) |
| 181 | + acc += row_data |
| 182 | + accs.append(acc) |
| 183 | + out = jnp.stack(accs, axis=0).astype(op.dtype) |
| 184 | + out_ref[:, pl.ds(col_base, unpack_col_chunk)] = out |
| 185 | + |
| 186 | + data_pipeline(in_hbm_ref.bitcast(jnp.int32), out_hbm_ref.bitcast(jnp.int32)) |
| 187 | + |
| 188 | + idx_pipeline(idx_hbm_ref, *([weights_hbm_ref] if weights_hbm_ref is not None else [])) |
| 189 | + |
| 190 | + return kernel(op, idx, topk_weights) # pylint: disable=no-value-for-parameter |
0 commit comments