|
43 | 43 | from qwix.contrib.sparsity import sparsity_module |
44 | 44 | import qwix.pallas as qpl |
45 | 45 | import tokamax |
| 46 | +from tokamax import config as tokamax_config |
| 47 | +from tokamax._src.ops.ragged_dot.pallas_mosaic_tpu import PallasMosaicTpuRaggedDot, Config |
46 | 48 |
|
47 | 49 | set_xla_metadata = xla_metadata.set_xla_metadata |
48 | 50 |
|
@@ -1121,14 +1123,25 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a |
1121 | 1123 | weight_gather_axes=weight_gather_axes, |
1122 | 1124 | ) |
1123 | 1125 | else: # tokamax (unquantized) |
1124 | | - output = tokamax.ragged_dot( |
1125 | | - lhs=inputs, |
1126 | | - rhs=kernel, |
1127 | | - group_sizes=tokamax_group_sizes, |
1128 | | - precision=jax.lax.Precision.DEFAULT, |
1129 | | - preferred_element_type=self.dtype, |
1130 | | - implementation="mosaic", |
1131 | | - ) |
| 1126 | + if self.config.tokamax_gmm_autotune: |
| 1127 | + with tokamax_config.autotuning_cache_miss_fallback("autotune"): |
| 1128 | + output = tokamax.ragged_dot( |
| 1129 | + lhs=inputs, |
| 1130 | + rhs=kernel, |
| 1131 | + group_sizes=tokamax_group_sizes, |
| 1132 | + precision=jax.lax.Precision.DEFAULT, |
| 1133 | + preferred_element_type=self.dtype, |
| 1134 | + implementation="mosaic", |
| 1135 | + ) |
| 1136 | + else: |
| 1137 | + output = tokamax.ragged_dot( |
| 1138 | + lhs=inputs, |
| 1139 | + rhs=kernel, |
| 1140 | + group_sizes=tokamax_group_sizes, |
| 1141 | + precision=jax.lax.Precision.DEFAULT, |
| 1142 | + preferred_element_type=self.dtype, |
| 1143 | + implementation="mosaic", |
| 1144 | + ) |
1132 | 1145 | elif self.config.megablox: # Older forked megablox |
1133 | 1146 | output = mblx.gmm( |
1134 | 1147 | lhs=inputs, |
|
0 commit comments