Skip to content

Commit ccd91f4

Browse files
Merge pull request #3330 from AI-Hypercomputer:qinwen/change_ragged_dot_group_size
PiperOrigin-RevId: 880247095
2 parents ed706be + 9c2bff5 commit ccd91f4

3 files changed

Lines changed: 24 additions & 4 deletions

File tree

src/maxtext/layers/moe.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,10 @@ def sparse_matmul(
896896
def gmm(
897897
inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes, input_buffer_count, combine_scopes
898898
):
899+
tokamax_group_sizes = tokamax.RaggedDotGroupSizes(
900+
group_sizes,
901+
representative_value=max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]),
902+
)
899903
pad_length = self.config.wi_tile_fwd_batch_seq
900904
hs_shape = inputs.shape
901905
# pad length is the 1st dimension of tiling size in gmm call
@@ -926,7 +930,7 @@ def gmm(
926930
output = mblx.gmm(
927931
lhs=inputs,
928932
rhs=kernel,
929-
group_sizes=group_sizes,
933+
group_sizes=tokamax_group_sizes,
930934
preferred_element_type=self.dtype,
931935
tiling=tiling,
932936
lhs_quantize_dtype=lhs_quantize_dtype,
@@ -941,7 +945,7 @@ def gmm(
941945
output = tokamax.ragged_dot(
942946
lhs=inputs,
943947
rhs=kernel,
944-
group_sizes=group_sizes,
948+
group_sizes=tokamax_group_sizes,
945949
precision=jax.lax.Precision.DEFAULT,
946950
preferred_element_type=self.dtype,
947951
implementation="mosaic",

src/maxtext/models/deepseek_batchsplit.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from maxtext.layers import attention_op
2828
from maxtext.layers import moe as moe_lib
2929
from maxtext.layers import quantizations
30+
from maxtext.utils import max_utils
3031
import qwix.pallas as qpl
3132
import tokamax
3233

@@ -803,11 +804,16 @@ def gmm(
803804
input_buffer_count,
804805
combine_scopes,
805806
):
807+
808+
tokamax_group_sizes = tokamax.RaggedDotGroupSizes(
809+
group_sizes,
810+
representative_value=max_utils.generate_representative_group_sizes(inputs.shape[0], kernel.shape[0]),
811+
)
806812
if config.use_qwix_quantization:
807813
output = megablox.gmm(
808814
lhs=inputs,
809815
rhs=kernel,
810-
group_sizes=group_sizes,
816+
group_sizes=tokamax_group_sizes,
811817
preferred_element_type=preferred_element_type,
812818
tiling=tiling,
813819
use_qwix_quantization=config.use_qwix_quantization,
@@ -821,7 +827,7 @@ def gmm(
821827
output = tokamax.ragged_dot(
822828
lhs=inputs,
823829
rhs=kernel,
824-
group_sizes=group_sizes,
830+
group_sizes=tokamax_group_sizes,
825831
precision=jax.lax.Precision.DEFAULT,
826832
preferred_element_type=preferred_element_type,
827833
implementation="mosaic",

src/maxtext/utils/max_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,3 +1078,13 @@ def transformer_engine_context():
10781078
yield
10791079
except (ImportError, AttributeError):
10801080
yield
1081+
1082+
1083+
def generate_representative_group_sizes(target_m: int, g: int) -> tuple[int, ...]:
1084+
"""Generate group sizes for a given target m."""
1085+
np.random.seed(0)
1086+
repr_val = np.random.uniform(size=(g,))
1087+
repr_val = np.random.binomial(1, 0.9, (g,)) * repr_val
1088+
repr_val = np.int32((repr_val / np.sum(repr_val)) * target_m)
1089+
repr_val[0] += target_m - np.sum(repr_val)
1090+
return tuple(map(int, repr_val))

0 commit comments

Comments
 (0)