Skip to content

Commit 65a6d8e

Browse files
committed
Update get_tokamax_group_sizes in moe.py
1 parent 682db7f commit 65a6d8e

1 file changed

Lines changed: 14 additions & 3 deletions

File tree

src/maxtext/layers/moe.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import functools
2020
import math
2121
import random
22+
import os
2223
from typing import Iterable, Optional, Tuple, Union
2324

2425
from aqt.jax.v2 import aqt_tensor as aqt
@@ -1082,9 +1083,10 @@ def get_tokamax_group_sizes(group_sizes, inputs, kernel):
10821083
elif self.config.attention == "vllm_rpa":
10831084
return group_sizes
10841085
else:
1086+
ep = self.get_expert_parallelism_size()
10851087
return tokamax.RaggedDotGroupSizes(
10861088
group_sizes,
1087-
(inputs.shape[0] // kernel.shape[0],) * kernel.shape[0],
1089+
(inputs.shape[0] // kernel.shape[0] // ep,) * kernel.shape[0],
10881090
)
10891091

10901092
def get_quantization_dtypes():
@@ -1124,14 +1126,23 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a
11241126
)
11251127
else: # tokamax (unquantized)
11261128
if self.config.tokamax_gmm_autotune:
1127-
with tokamax_config.autotuning_cache_miss_fallback("autotune"):
1129+
cache_file = "tokamax_autotune_cache.json"
1130+
if os.path.exists(cache_file):
1131+
with open(cache_file, "r") as f:
1132+
autotune_result_json = f.read()
1133+
autotune_result = tokamax.AutotuningResult.loads(autotune_result_json)
1134+
autotune_context = autotune_result
1135+
else:
1136+
autotune_context = tokamax_config.autotuning_cache_miss_fallback("heuristics")
1137+
1138+
with autotune_context:
11281139
output = tokamax.ragged_dot(
11291140
lhs=inputs,
11301141
rhs=kernel,
11311142
group_sizes=tokamax_group_sizes,
11321143
precision=jax.lax.Precision.DEFAULT,
11331144
preferred_element_type=self.dtype,
1134-
implementation="mosaic",
1145+
implementation=None,
11351146
)
11361147
else:
11371148
output = tokamax.ragged_dot(

0 commit comments

Comments
 (0)