Skip to content

Commit 9c2bff5

Browse files
committed
update for group_size for tokamax
update utils update format format
1 parent 72e96f5 commit 9c2bff5

3 files changed

Lines changed: 24 additions & 4 deletions

File tree

src/maxtext/layers/moe.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,10 @@ def sparse_matmul(
891891
def gmm(
892892
inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes, input_buffer_count, combine_scopes
893893
):
894+
tokamax_group_sizes = tokamax.RaggedDotGroupSizes(
895+
group_sizes,
896+
representative_value=max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]),
897+
)
894898
pad_length = self.config.wi_tile_fwd_batch_seq
895899
hs_shape = inputs.shape
896900
# pad length is the 1st dimension of tiling size in gmm call
@@ -921,7 +925,7 @@ def gmm(
921925
output = mblx.gmm(
922926
lhs=inputs,
923927
rhs=kernel,
924-
group_sizes=group_sizes,
928+
group_sizes=tokamax_group_sizes,
925929
preferred_element_type=self.dtype,
926930
tiling=tiling,
927931
lhs_quantize_dtype=lhs_quantize_dtype,
@@ -936,7 +940,7 @@ def gmm(
936940
output = tokamax.ragged_dot(
937941
lhs=inputs,
938942
rhs=kernel,
939-
group_sizes=group_sizes,
943+
group_sizes=tokamax_group_sizes,
940944
precision=jax.lax.Precision.DEFAULT,
941945
preferred_element_type=self.dtype,
942946
implementation="mosaic",

src/maxtext/models/deepseek_batchsplit.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from maxtext.layers import attention_op
2828
from maxtext.layers import moe as moe_lib
2929
from maxtext.layers import quantizations
30+
from maxtext.utils import max_utils
3031
import qwix.pallas as qpl
3132
import tokamax
3233

@@ -803,11 +804,16 @@ def gmm(
803804
input_buffer_count,
804805
combine_scopes,
805806
):
807+
808+
tokamax_group_sizes = tokamax.RaggedDotGroupSizes(
809+
group_sizes,
810+
representative_value=max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]),
811+
)
806812
if config.use_qwix_quantization:
807813
output = megablox.gmm(
808814
lhs=inputs,
809815
rhs=kernel,
810-
group_sizes=group_sizes,
816+
group_sizes=tokamax_group_sizes,
811817
preferred_element_type=preferred_element_type,
812818
tiling=tiling,
813819
use_qwix_quantization=config.use_qwix_quantization,
@@ -820,7 +826,7 @@ def gmm(
820826
output = tokamax.ragged_dot(
821827
lhs=inputs,
822828
rhs=kernel,
823-
group_sizes=group_sizes,
829+
group_sizes=tokamax_group_sizes,
824830
precision=jax.lax.Precision.DEFAULT,
825831
preferred_element_type=preferred_element_type,
826832
implementation="mosaic",

src/maxtext/utils/max_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,3 +1078,13 @@ def transformer_engine_context():
10781078
yield
10791079
except (ImportError, AttributeError):
10801080
yield
1081+
1082+
1083+
def generate_representative_group_sizes(target_m: int, g: int) -> tuple[int, ...]:
1084+
"""Generate group sizes for a given target m."""
1085+
np.random.seed(0)
1086+
repr_val = np.random.uniform(size=(g,))
1087+
repr_val = np.random.binomial(1, 0.9, (g,)) * repr_val
1088+
repr_val = np.int32((repr_val / np.sum(repr_val)) * target_m)
1089+
repr_val[0] += target_m - np.sum(repr_val)
1090+
return tuple(map(int, repr_val))

0 commit comments

Comments
 (0)