1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ # Standard
1415from typing import Dict
1516
17+ # Third Party
18+ from mamba_ssm .modules .mamba2_cp import Mamba2CP
19+
1620# pylint: disable=import-error
17- from torch .distributed ._tensor .device_mesh import DeviceMesh , init_device_mesh
18- import torch
19- from transformers .modeling_utils import is_fsdp_enabled , is_local_dist_rank_0
21+ from torch .distributed ._tensor .device_mesh import init_device_mesh
2022from tqdm import tqdm
21-
22- try :
23- from mamba_ssm .modules .mamba2_cp import Mamba2CP
24- except ImportError :
25- ValueError ("Mamba2CP is required to enable context parallelism for mamba layers" )
23+ from transformers .modeling_utils import is_fsdp_enabled
24+ import torch
2625
2726key_ep = "cp"
2827key_rep = "dp_shard"
@@ -42,9 +41,22 @@ def hf_config_ssm_config(hf_config) -> Dict:
4241
4342class Mamba2CPHF (Mamba2CP ):
4443 def forward (
45- self , hidden_states , cache_params = None , cache_position = None , attention_mask = None , seq_idx = None , ** kwargs
44+ self ,
45+ hidden_states ,
46+ cache_params = None ,
47+ cache_position = None ,
48+ attention_mask = None ,
49+ seq_idx = None ,
50+ ** kwargs ,
4651 ):
47- return super ().forward (u = hidden_states , seqlen = None , seq_idx = None , cu_seqlens = None , inference_params = None )
52+ return super ().forward (
53+ u = hidden_states ,
54+ seqlen = None ,
55+ seq_idx = None ,
56+ cu_seqlens = None ,
57+ inference_params = None ,
58+ )
59+
4860
4961def patch_mamba_layers_with_cp_head (
5062 model ,
@@ -63,7 +75,7 @@ def patch_mamba_layers_with_cp_head(
6375
6476 if cp_degree == 1 :
6577 raise ValueError ("CP degree can't be one" )
66- elif rep_size == 1 :
78+ if rep_size == 1 :
6779 device_mesh = init_device_mesh (
6880 "cuda" ,
6981 (cp_degree ,),
0 commit comments