Skip to content

Commit 61cc2c1

Browse files
committed
feat: add support for mamba cp
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
1 parent 6e78c0a commit 61cc2c1

3 files changed

Lines changed: 26 additions & 13 deletions

File tree

plugins/mamba-cp/src/fms_acceleration_mcp/framework_plugin_mcp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from transformers import TrainingArguments
2222
import torch
2323

24+
# Local
2425
from .utils import patch_mamba_layers_with_cp_head
2526

2627

@@ -41,6 +42,7 @@ def __init__(self, configurations: Dict[str, Dict]):
4142
key="training.mamba.cp.mamba_recompute",
4243
default=False,
4344
)
45+
4446
# data_config file should be there
4547
@property
4648
def requires_augmentation(self):
@@ -52,7 +54,7 @@ def augmentation(
5254
train_args: TrainingArguments,
5355
modifiable_args: Tuple[LoraConfig],
5456
):
55-
if self._mamba_cp_degree != None:
57+
if self._mamba_cp_degree is not None:
5658
rank = 0
5759
if torch.distributed.is_initialized():
5860
rank = torch.distributed.get_node_local_rank()

plugins/mamba-cp/src/fms_acceleration_mcp/utils/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,3 @@
1414

1515
# Local
1616
from .utils import patch_mamba_layers_with_cp_head
17-

plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,17 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
# Standard
1415
from typing import Dict
1516

17+
# Third Party
18+
from mamba_ssm.modules.mamba2_cp import Mamba2CP
19+
1620
# pylint: disable=import-error
17-
from torch.distributed._tensor.device_mesh import DeviceMesh, init_device_mesh
18-
import torch
19-
from transformers.modeling_utils import is_fsdp_enabled, is_local_dist_rank_0
21+
from torch.distributed._tensor.device_mesh import init_device_mesh
2022
from tqdm import tqdm
21-
22-
try:
23-
from mamba_ssm.modules.mamba2_cp import Mamba2CP
24-
except ImportError:
25-
ValueError("Mamba2CP is required to enable context parallelism for mamba layers")
23+
from transformers.modeling_utils import is_fsdp_enabled
24+
import torch
2625

2726
key_ep = "cp"
2827
key_rep = "dp_shard"
@@ -42,9 +41,22 @@ def hf_config_ssm_config(hf_config) -> Dict:
4241

4342
class Mamba2CPHF(Mamba2CP):
4443
def forward(
45-
self, hidden_states, cache_params=None, cache_position=None, attention_mask=None, seq_idx=None, **kwargs
44+
self,
45+
hidden_states,
46+
cache_params=None,
47+
cache_position=None,
48+
attention_mask=None,
49+
seq_idx=None,
50+
**kwargs,
4651
):
47-
return super().forward(u=hidden_states, seqlen=None, seq_idx=None, cu_seqlens=None, inference_params=None)
52+
return super().forward(
53+
u=hidden_states,
54+
seqlen=None,
55+
seq_idx=None,
56+
cu_seqlens=None,
57+
inference_params=None,
58+
)
59+
4860

4961
def patch_mamba_layers_with_cp_head(
5062
model,
@@ -63,7 +75,7 @@ def patch_mamba_layers_with_cp_head(
6375

6476
if cp_degree == 1:
6577
raise ValueError("CP degree can't be one")
66-
elif rep_size == 1:
78+
if rep_size == 1:
6779
device_mesh = init_device_mesh(
6880
"cuda",
6981
(cp_degree,),

0 commit comments

Comments
 (0)