Skip to content

Commit 67a947f

Browse files
committed
debug
Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
1 parent 61cc2c1 commit 67a947f

1 file changed

Lines changed: 5 additions & 0 deletions

File tree

  • plugins/mamba-cp/src/fms_acceleration_mcp/utils

plugins/mamba-cp/src/fms_acceleration_mcp/utils/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ def patch_mamba_layers_with_cp_head(
6767
cp_mamba_impl,
6868
cp_mamba_recompute,
6969
):
70+
# to avoid rechunking/sharding of the buffers
71+
# ideally this is not optimal
72+
from torch.distributed.tensor.experimental._attention import _cp_options
73+
_cp_options.enable_load_balance = False
74+
7075
config_ssm = hf_config_ssm_config(model.config)
7176
device = torch.device(f"cuda:{rank}")
7277
if is_fsdp_enabled():

0 commit comments

Comments
 (0)