|
19 | 19 | import functools |
20 | 20 | import math |
21 | 21 | import random |
| 22 | +import os |
22 | 23 | from typing import Iterable, Optional, Tuple, Union |
23 | 24 |
|
24 | 25 | from aqt.jax.v2 import aqt_tensor as aqt |
@@ -1082,9 +1083,10 @@ def get_tokamax_group_sizes(group_sizes, inputs, kernel): |
1082 | 1083 | elif self.config.attention == "vllm_rpa": |
1083 | 1084 | return group_sizes |
1084 | 1085 | else: |
| 1086 | + ep = self.get_expert_parallelism_size() |
1085 | 1087 | return tokamax.RaggedDotGroupSizes( |
1086 | 1088 | group_sizes, |
1087 | | - (inputs.shape[0] // kernel.shape[0],) * kernel.shape[0], |
| 1089 | + (inputs.shape[0] // kernel.shape[0] // ep,) * kernel.shape[0], |
1088 | 1090 | ) |
1089 | 1091 |
|
1090 | 1092 | def get_quantization_dtypes(): |
@@ -1124,14 +1126,23 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a |
1124 | 1126 | ) |
1125 | 1127 | else: # tokamax (unquantized) |
1126 | 1128 | if self.config.tokamax_gmm_autotune: |
1127 | | - with tokamax_config.autotuning_cache_miss_fallback("autotune"): |
| 1129 | + cache_file = "tokamax_autotune_cache.json" |
| 1130 | + if os.path.exists(cache_file): |
| 1131 | + with open(cache_file, "r") as f: |
| 1132 | + autotune_result_json = f.read() |
| 1133 | + autotune_result = tokamax.AutotuningResult.loads(autotune_result_json) |
| 1134 | + autotune_context = autotune_result |
| 1135 | + else: |
| 1136 | + autotune_context = tokamax_config.autotuning_cache_miss_fallback("heuristics") |
| 1137 | + |
| 1138 | + with autotune_context: |
1128 | 1139 | output = tokamax.ragged_dot( |
1129 | 1140 | lhs=inputs, |
1130 | 1141 | rhs=kernel, |
1131 | 1142 | group_sizes=tokamax_group_sizes, |
1132 | 1143 | precision=jax.lax.Precision.DEFAULT, |
1133 | 1144 | preferred_element_type=self.dtype, |
1134 | | - implementation="mosaic", |
| 1145 | + implementation=None, |
1135 | 1146 | ) |
1136 | 1147 | else: |
1137 | 1148 | output = tokamax.ragged_dot( |
|
0 commit comments