We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 61cc2c1 commit 67a947fCopy full SHA for 67a947f
1 file changed
plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py
@@ -67,6 +67,11 @@ def patch_mamba_layers_with_cp_head(
67
cp_mamba_impl,
68
cp_mamba_recompute,
69
):
70
+ # to avoid rechunking/sharding of the buffers
71
+ # ideally this is not optimal
72
+ from torch.distributed.tensor.experimental._attention import _cp_options
73
+ _cp_options.enable_load_balance = False
74
+
75
config_ssm = hf_config_ssm_config(model.config)
76
device = torch.device(f"cuda:{rank}")
77
if is_fsdp_enabled():
0 commit comments