Skip to content

Commit ada737d

Browse files
committed
feat: MoE Kernels, EP, and Fast Kernels for Granite 4 Preview architecture
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
1 parent 1a804e4 commit ada737d

9 files changed

Lines changed: 217 additions & 10 deletions

File tree

plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class ScatterMoEAccelerationPlugin(AccelerationPlugin):
3636
"GraniteMoeForCausalLM",
3737
"MixtralForCausalLM",
3838
"GraniteMoeSharedForCausalLM",
39+
"GraniteMoeHybridForCausalLM",
3940
]
4041

4142
def __init__(self, configurations: Dict[str, Dict]):

plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_constants.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,13 @@
8383
SCATTERMOE_SPEC_HAS_GATE,
8484
False,
8585
),
86+
"GraniteMoeHybridForCausalLM": (
87+
"GraniteMoeHybridMoE",
88+
"router",
89+
"input_linear|output_linear|input_linear",
90+
SCATTERMOE_SPEC_HAS_GATE,
91+
False,
92+
),
8693
}
8794

8895

plugins/accelerated-peft/requirements.txt

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@
55
accelerate >= 0.29
66

77
# bitsandbytes for the BNB plugin
8-
# - lower bound is because bnb is missing quant_state
9-
# - upper bound is because of segmentation faults
10-
# see https://github.com/foundation-model-stack/fms-acceleration/issues/17
11-
bitsandbytes >=0.41,<=0.43.3
8+
# exact version is needed 0.45.1 for torch upgrade to 2.6
9+
10+
bitsandbytes == 0.45.1
1211

1312
# Used to manage the thread limit in functions for converting old
1413
# GPTQ models to new GPTQ model format that support symmetrical=False

plugins/fused-ops-and-kernels/src/fms_acceleration_foak/framework_plugin_fast_kernels.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def register_foak_model_patch_rules(
4545
granite,
4646
granitemoe,
4747
granitemoeshared,
48+
granitemoehybrid,
4849
llama,
4950
mistral,
5051
mixtral,
@@ -56,6 +57,7 @@ def register_foak_model_patch_rules(
5657
*granite.get_mp_rules(base_type, config),
5758
*granitemoe.get_mp_rules(base_type),
5859
*granitemoeshared.get_mp_rules(base_type),
60+
*granitemoehybrid.get_mp_rules(base_type),
5961
*llama.get_mp_rules(base_type, config),
6062
*mistral.get_mp_rules(base_type, config),
6163
*mixtral.get_mp_rules(base_type),
@@ -94,6 +96,7 @@ class FastKernelsAccelerationPlugin(AccelerationPlugin):
9496
"LlamaForCausalLM",
9597
"MistralForCausalLM",
9698
"GraniteMoeSharedForCausalLM",
99+
"GraniteMoeHybridForCausalLM",
97100
]
98101

99102
def __init__(self, configurations: Dict[str, Dict]):
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Copyright The FMS HF Tuning Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Standard
16+
from functools import partial
17+
18+
# Third Party
19+
from fms_acceleration.model_patcher import (
20+
ModelPatcherRule,
21+
ModelPatcherTrigger,
22+
combine_functions,
23+
combine_triggers,
24+
)
25+
26+
# Local
27+
from ..kernels.unsloth.cross_entropy_loss import (
28+
FastCrossEntropyLoss,
29+
replace_custom_loss_when_triggered,
30+
)
31+
from ..kernels.unsloth.rms_layernorm import fast_rms_layernorm
32+
from ..kernels.unsloth.rope_embedding import fast_rope_embedding
33+
from .utils import (
34+
KEY_O,
35+
KEY_QKV,
36+
build_lora_fused_ops,
37+
get_transformers_version,
38+
trigger_fused_ops,
39+
)
40+
41+
42+
def get_mp_rules(base_type: str):
43+
"""
44+
Function to access all patch rules in this module.
45+
If it is a forward_builder rule with `base_type` in
46+
its forward builder argument, wrap the forward_builder
47+
function as a partial function with the base_type argument
48+
"""
49+
try:
50+
# Third Party
51+
from transformers.models.granitemoehybrid.modeling_granitemoehybrid import ( # pylint: disable=import-outside-toplevel
52+
GraniteMoeHybridAttention,
53+
GraniteMoeHybridForCausalLM,
54+
GraniteMoeHybridRMSNorm,
55+
)
56+
except ImportError:
57+
return []
58+
59+
return [
60+
# TODO: have a generic version of this rule
61+
# - do regex on RMSNorm class name
62+
# - check on the tensors required for fast_rms_layernorm
63+
ModelPatcherRule(
64+
rule_id="granitemoehybrid-rms",
65+
trigger=ModelPatcherTrigger(check=GraniteMoeHybridRMSNorm),
66+
forward=fast_rms_layernorm,
67+
),
68+
# TODO: have a generic version of this rule
69+
# - do regex on Attention class name
70+
# - have a set of qkv / o module names and check on that
71+
ModelPatcherRule(
72+
rule_id="granitemoehybrid-qkvo",
73+
trigger=combine_triggers(
74+
ModelPatcherTrigger(
75+
check=partial(
76+
trigger_fused_ops,
77+
attn_cls=GraniteMoeHybridAttention,
78+
submodule_names=["q_proj", "k_proj", "v_proj"],
79+
)
80+
),
81+
ModelPatcherTrigger(
82+
check=partial(
83+
trigger_fused_ops,
84+
attn_cls=GraniteMoeHybridAttention,
85+
submodule_names=["o_proj"],
86+
)
87+
),
88+
logic="OR",
89+
),
90+
forward_builder=combine_functions(
91+
partial(
92+
build_lora_fused_ops,
93+
submodule_names=["q_proj", "k_proj", "v_proj"],
94+
fused_op=KEY_QKV,
95+
base_type=base_type,
96+
),
97+
partial(
98+
build_lora_fused_ops,
99+
submodule_names=["o_proj"],
100+
fused_op=KEY_O,
101+
base_type=base_type,
102+
),
103+
logic="APPEND",
104+
),
105+
),
106+
*[
107+
(
108+
ModelPatcherRule(
109+
rule_id="granitemoehybrid-custom-loss",
110+
trigger=ModelPatcherTrigger(
111+
check=replace_custom_loss_when_triggered(
112+
GraniteMoeHybridForCausalLM,
113+
custom_loss_type="granite-custom-loss",
114+
)
115+
),
116+
)
117+
if get_transformers_version() >= "4.46"
118+
else ModelPatcherRule(
119+
rule_id="granitemoehybrid-cross-ent",
120+
import_and_maybe_reload=(
121+
"torch.nn.CrossEntropyLoss",
122+
FastCrossEntropyLoss,
123+
"transformers.models.granitemoehybrid.modeling_granitemoehybrid",
124+
),
125+
)
126+
)
127+
],
128+
# TODO: have a generic version of this rule
129+
# - get the module name
130+
# - check if "apply_rotary_pos_emb" exists
131+
# - patch
132+
ModelPatcherRule(
133+
rule_id="granitemoehybrid-rope",
134+
import_and_maybe_reload=(
135+
"transformers.models.granitemoehybrid.\
136+
modeling_granitemoehybrid.apply_rotary_pos_emb",
137+
fast_rope_embedding,
138+
None,
139+
),
140+
),
141+
]
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# FMS Acceleration Plugin Configuration.
2+
#
3+
# Each stanza incorporates various configurations for
4+
# different fine-tuning / training tasks.
5+
plugins:
6+
# Configurations to accelerate data packing/padding in training
7+
training:
8+
9+
# attention module configurations
10+
# e.g. padding-free modifications to attention layer
11+
attention:
12+
13+
# this controls the confgurations for padding free computation of flash attention
14+
padding_free:
15+
method: huggingface
16+
fused_ops_and_kernels:
17+
18+
# if under training stanza, then putting
19+
# base_layer and fused_lora will be a misnomer
20+
# - this should be in peft.quantized
21+
# However, if it is specified, it will still
22+
# be read. This is useful in use cases where
23+
# the yaml is system generated and not shown
24+
# to a user.
25+
26+
# activate various unsloth optimizations
27+
# there are two versions of the plugin
28+
# - the FastKernel version supports individual kernels
29+
# - the FastQuantized version is all-or-nothing
30+
31+
# fast loss triton kernels
32+
fast_loss: true
33+
34+
# fast rms norm triton kernels
35+
fast_rms_layernorm: true
36+
37+
# fast RoPE embedding triton kernels
38+
fast_rope_embeddings: true
39+
moe:
40+
41+
# expert-parallel for MoE
42+
scattermoe:
43+
44+
# The level of expert parallel sharding.
45+
# - 1 means no sharding
46+
# - if > 1, please ensure that this divides the world_size. This is because
47+
# the devices will be replicated for every ep_degree devices, and
48+
# the experts will be sharded within each group.
49+
# - if > 1, also ensure that it divides the number of experts, as each device
50+
# will then have num_of_experts / ep_degree experts.
51+
ep_degree: 8

scripts/benchmarks/compare_with_reference.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,16 @@
1010

1111
# default columns to compare
1212
DEFAULT_PLOT_COLUMNS = [
13-
"mem_torch_mem_alloc_in_bytes",
14-
"mem_peak_torch_mem_alloc_in_bytes",
13+
# "mem_torch_mem_alloc_in_bytes",
14+
# "mem_peak_torch_mem_alloc_in_bytes",
15+
'mem_nvidia_mem_reserved',
1516
"train_loss",
1617
"train_tokens_per_second",
1718
]
1819
# Used as combined identifier of experiment
1920
DEFAULT_INDICES = [
2021
"framework_config",
21-
"peft_method",
22+
# "peft_method",
2223
"model_name_or_path",
2324
"num_gpus",
2425
"per_device_train_batch_size",
@@ -29,7 +30,7 @@
2930
"train_runtime",
3031
"train_steps_per_second",
3132
"train_samples_per_second",
32-
"mem_nvidia_mem_reserved",
33+
# "mem_nvidia_mem_reserved",
3334
]
3435

3536
DEFAULT_REFERENCE_FILEPATH = "scripts/benchmarks/refs/a100_80gb.csv"

scripts/benchmarks/scenarios-moe.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ scenarios:
5959
model_name_or_path:
6060
- 'ibm-granite/granite-3.0-3b-a800m-instruct'
6161
- 'ibm-research/moe-7b-1b-active-shared-experts'
62+
- 'ibm-granite/granite-4.0-tiny-preview'
6263

6364

6465
- name: accelerated-moe-full-mixtral
@@ -77,4 +78,4 @@ scenarios:
7778
packing: False
7879
adam_epsilon: 1e-8
7980
model_name_or_path:
80-
- 'mistralai/Mixtral-8x7B-Instruct-v0.1'
81+
- 'mistralai/Mixtral-8x7B-Instruct-v0.1'

tox.ini

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ commands =
3434
# some models need this for tokenizers
3535
pip install protobuf
3636

37+
# for mamba based models
38+
pip install --no-build-isolation mamba_ssm[causal-conv1d]>=2.0.0
39+
3740
# install the plugins for test
3841
# NOTE: when there are more plugins install here
3942
python -m fms_acceleration.cli install -e {toxinidir}/plugins/accelerated-peft
@@ -42,7 +45,7 @@ commands =
4245
python -m fms_acceleration.cli install -e {toxinidir}/plugins/accelerated-moe
4346

4447
# install the flash attn at the last
45-
pip install flash-attn
48+
pip install flash-attn --no-build-isolation
4649

4750
# run the benchmark script
4851
bash scripts/run_benchmarks.sh {posargs:"1 2" "4 8" benchmark_outputs}

0 commit comments

Comments
 (0)