You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/reference/core_concepts/moe_configuration.md
+5Lines changed: 5 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -88,6 +88,11 @@ Dropping:
88
88
- Value > 0: Enforces a strict capacity limit; tokens exceeding this limit are dropped.
89
89
- Value = -1: Dropless with dense matrix multiplication, which is computationally expensive and typically used only as a baseline.
90
90
91
+
`ragged_buffer_factor`: A scalar multiplier for the size of the ragged buffer (effectively expert capacity). Effective only when `sparse_matmul` is True.
92
+
93
+
- Value > 0: Uses an explicit buffer size which may drop tokens when this size is exceeded
94
+
- Value = -1: Uses a worst case calculated buffer size which is guaranteed to not drop any tokens.
95
+
91
96
`use_custom_sort_vjp`: If enabled, use a custom Vector-Jacobian Product (VJP) sort for efficient backward pass processing in sparse matmul. Recommended to replace the inefficient scatter-add generated by the `jax.numpy.take` in the backward pass.
92
97
93
98
`mlp_bias`: If enabled, add learnable bias terms for MLP matmul. Originally implemented to support the GPT-OSS model architecture.
0 commit comments