Skip to content

Commit 3fcedfb

Browse files
committed
feat: add support for mamba cp
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
1 parent 80d32ab commit 3fcedfb

4 files changed

Lines changed: 64 additions & 0 deletions

File tree

tuning/config/acceleration_configs/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@
1818
from .callbacks import get_additional_accel_framework_callbacks
1919
from .fast_moe import FastMoeConfig
2020
from .fused_ops_and_kernels import FusedOpsAndKernelsConfig
21+
from .mcp import MCP, MCPConfig
2122
from .odm import ODM, ODMConfig
2223
from .quantized_lora_config import QuantizedLoraConfig

tuning/config/acceleration_configs/acceleration_framework_config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .attention_and_distributed_packing import MultiPack, PaddingFree
2525
from .fast_moe import FastMoe
2626
from .fused_ops_and_kernels import FastKernelsConfig, FusedLoraConfig
27+
from .mcp import MCP
2728
from .odm import ODM
2829
from .quantized_lora_config import AutoGPTQLoraConfig, BNBQLoraConfig
2930
from tuning.utils.import_utils import is_fms_accelerate_available
@@ -133,6 +134,17 @@ class AccelerationFrameworkConfig:
133134
),
134135
] = None
135136

137+
mcp: Annotated[
138+
MCP,
139+
ConfigAnnotation(
140+
path="training.mamba",
141+
key="cp",
142+
standalone=True,
143+
experimental=True,
144+
required_packages=["mcp"],
145+
),
146+
] = None
147+
136148
multipack: Annotated[
137149
MultiPack,
138150
ConfigAnnotation(
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright The FMS HF Tuning Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Standard
16+
from dataclasses import dataclass
17+
from typing import Union
18+
19+
# Local
20+
from .utils import ensure_nested_dataclasses_initialized, parsable_dataclass
21+
22+
23+
@parsable_dataclass
24+
@dataclass
25+
class MCP:
26+
degree: int = None
27+
mamba_impl: str = None
28+
attn_impl: str = None
29+
mamba_recompute: bool = None
30+
31+
32+
@dataclass
33+
class MCPConfig:
34+
35+
cp: MCP = None
36+
37+
def __post_init__(self):
38+
# ensure nested dataclasses initialized
39+
ensure_nested_dataclasses_initialized(self)

tuning/sft_trainer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
AttentionAndDistributedPackingConfig,
4848
FastMoeConfig,
4949
FusedOpsAndKernelsConfig,
50+
MCPConfig,
5051
ODMConfig,
5152
QuantizedLoraConfig,
5253
get_additional_accel_framework_callbacks,
@@ -87,6 +88,7 @@ def train(
8788
AttentionAndDistributedPackingConfig
8889
] = None,
8990
fast_moe_config: Optional[FastMoeConfig] = None,
91+
mcp_config: Optional[MCPConfig] = None,
9092
additional_data_handlers: Optional[Dict[str, DataHandler]] = None,
9193
) -> tuple[SFTTrainer, dict]:
9294
"""Call the SFTTrainer
@@ -198,6 +200,8 @@ def train(
198200
)
199201
if fast_moe_config is not None and fast_moe_config.fast_moe is None:
200202
fast_moe_config = None
203+
if mcp_config is not None and mcp_config.cp is None:
204+
mcp_config = None
201205
if fast_moe_config is not None:
202206
# If LoRA with ScatterMoE detected, raise warning
203207
accepted_layers = ["all-linear"]
@@ -261,6 +265,7 @@ def train(
261265
quantized_lora_config,
262266
fusedops_kernels_config,
263267
odm_config,
268+
mcp_config,
264269
).get_framework()
265270

266271
# option to set multimodal var here
@@ -567,6 +572,7 @@ def get_parser():
567572
FusedOpsAndKernelsConfig,
568573
AttentionAndDistributedPackingConfig,
569574
FastMoeConfig,
575+
MCPConfig,
570576
TrackerConfigs,
571577
)
572578
)
@@ -648,6 +654,7 @@ def parse_arguments(parser, json_config=None):
648654
fusedops_kernels_config,
649655
attention_and_distributed_packing_config,
650656
fast_moe_config,
657+
mcp_config,
651658
tracker_configs,
652659
) = parser.parse_dict(json_config, allow_extra_keys=True)
653660
peft_method = json_config.get("peft_method")
@@ -667,6 +674,7 @@ def parse_arguments(parser, json_config=None):
667674
fusedops_kernels_config,
668675
attention_and_distributed_packing_config,
669676
fast_moe_config,
677+
mcp_config,
670678
tracker_configs,
671679
additional,
672680
_,
@@ -703,6 +711,7 @@ def parse_arguments(parser, json_config=None):
703711
fusedops_kernels_config,
704712
attention_and_distributed_packing_config,
705713
fast_moe_config,
714+
mcp_config,
706715
tracker_configs,
707716
exp_metadata,
708717
)
@@ -725,6 +734,7 @@ def main():
725734
fusedops_kernels_config,
726735
attention_and_distributed_packing_config,
727736
fast_moe_config,
737+
mcp_config,
728738
tracker_configs,
729739
exp_metadata,
730740
) = parse_arguments(parser, job_config)
@@ -746,6 +756,7 @@ def main():
746756
"AADP (fms-acceleration) Config": attention_and_distributed_packing_config,
747757
"Fused Ops Kernels Config": fusedops_kernels_config,
748758
"Fast MoE Config": fast_moe_config,
759+
"MCP Config": mcp_config,
749760
"Tracker Config": tracker_configs,
750761
"Extra Metadata": exp_metadata,
751762
"Trainer Controller Config": trainer_controller_args,
@@ -789,6 +800,7 @@ def main():
789800
quantized_lora_config=quantized_lora_config,
790801
fusedops_kernels_config=fusedops_kernels_config,
791802
attention_and_distributed_packing_config=attention_and_distributed_packing_config,
803+
mcp_config=mcp_config,
792804
fast_moe_config=fast_moe_config,
793805
)
794806
except (MemoryError, OutOfMemoryError) as e:

0 commit comments

Comments
 (0)