|
43 | 43 | from qwix.contrib.sparsity import sparsity_module |
44 | 44 | import qwix.pallas as qpl |
45 | 45 | import tokamax |
| 46 | +from tokamax import config as tokamax_config |
| 47 | +from tokamax._src.ops.ragged_dot.pallas_mosaic_tpu import PallasMosaicTpuRaggedDot, Config |
46 | 48 |
|
47 | 49 | set_xla_metadata = xla_metadata.set_xla_metadata |
48 | 50 |
|
@@ -1104,14 +1106,35 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a |
1104 | 1106 | weight_gather_axes=weight_gather_axes, |
1105 | 1107 | ) |
1106 | 1108 | else: # tokamax (unquantized) |
1107 | | - output = tokamax.ragged_dot( |
1108 | | - lhs=inputs, |
1109 | | - rhs=kernel, |
1110 | | - group_sizes=tokamax_group_sizes, |
1111 | | - precision=jax.lax.Precision.DEFAULT, |
1112 | | - preferred_element_type=self.dtype, |
1113 | | - implementation="mosaic", |
1114 | | - ) |
| 1109 | + if self.config.tokamax_gmm_autotune: |
| 1110 | + with tokamax_config.autotuning_cache_miss_fallback("heuristics"): |
| 1111 | + output = tokamax.ragged_dot( |
| 1112 | + lhs=inputs, |
| 1113 | + rhs=kernel, |
| 1114 | + group_sizes=tokamax_group_sizes, |
| 1115 | + precision=jax.lax.Precision.DEFAULT, |
| 1116 | + preferred_element_type=self.dtype, |
| 1117 | + implementation="mosaic", |
| 1118 | + ) |
| 1119 | + else: |
| 1120 | + custom_impl = PallasMosaicTpuRaggedDot( |
| 1121 | + config=Config( |
| 1122 | + # Forward Pass |
| 1123 | + gmm_tiling=(tiling[0], tiling[1], tiling[2]), |
| 1124 | + # Backward DLHS Pass |
| 1125 | + gmm_rhs_transpose_tiling=(tiling[3], tiling[4], tiling[5]), |
| 1126 | + # Backward DRHS Pass |
| 1127 | + tgmm_tiling=(tiling[6], tiling[7], tiling[8]), |
| 1128 | + ) |
| 1129 | + ) |
| 1130 | + output = tokamax.ragged_dot( |
| 1131 | + lhs=inputs, |
| 1132 | + rhs=kernel, |
| 1133 | + group_sizes=tokamax_group_sizes, |
| 1134 | + precision=jax.lax.Precision.DEFAULT, |
| 1135 | + preferred_element_type=self.dtype, |
| 1136 | + implementation=custom_impl, |
| 1137 | + ) |
1115 | 1138 | elif self.config.megablox: # Older forked megablox |
1116 | 1139 | output = mblx.gmm( |
1117 | 1140 | lhs=inputs, |
|
0 commit comments