Skip to content

Commit 40d1d82

Browse files
authored
feat: add missing deps for accelerate patches (#159)
* fix: add dependencies Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * fix: add dependencies Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * fix: add dependencies Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * fix: pin accelerate to latest until they make a release this is needed for fsdp2 Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * fix: pin accelerate to latest until they make a release this is needed for fsdp2 Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> * fix: pin accelerate to latest until they make a release this is needed for fsdp2 Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com> --------- Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
1 parent 4e74cd8 commit 40d1d82

3 files changed

Lines changed: 22 additions & 1 deletion

File tree

.github/workflows/format.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ jobs:
3434
- "online-data-mixing"
3535

3636
steps:
37+
- name: Delete huge unnecessary tools folder
38+
run: rm -rf /opt/hostedtoolcache
3739
- uses: actions/checkout@v4
3840
- name: Set up Python 3.11
3941
uses: actions/setup-python@v4

plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,9 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic
735735
full_sd (`dict`): The full state dict to load, can only be on rank 0
736736
"""
737737
# Third Party
738+
# pylint: disable=import-outside-toplevel
739+
from accelerate.utils.fsdp_utils import get_parameters_from_modules
740+
738741
# pylint: disable=import-outside-toplevel
739742
from torch.distributed.tensor import distribute_tensor
740743
import torch.distributed as dist
@@ -847,7 +850,20 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
847850
Returns:
848851
`torch.nn.Module`: Prepared model
849852
"""
853+
# Standard
854+
# pylint: disable=import-outside-toplevel
855+
import copy
856+
import warnings
857+
850858
# Third Party
859+
# pylint: disable=import-outside-toplevel
860+
from accelerate.utils.fsdp_utils import (
861+
fsdp2_prepare_auto_wrap_policy,
862+
get_parameters_from_modules,
863+
)
864+
from accelerate.utils.modeling import get_non_persistent_buffers
865+
from accelerate.utils.other import get_module_children_bottom_up, is_compiled_module
866+
851867
# pylint: disable=import-outside-toplevel
852868
from torch.distributed.fsdp import FSDPModule, MixedPrecisionPolicy, fully_shard
853869

plugins/framework/pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,15 @@ dependencies = [
2525
"numpy<2.0", # numpy needs to be bounded due to incompatiblity with current torch<2.3
2626
"torch>2.2",
2727
"peft>=0.15.0",
28-
"accelerate",
28+
"accelerate @ git+https://github.com/huggingface/accelerate.git@5998f8625b8dfde9253c241233ff13bc2c18635d",
2929
"pandas",
3030
]
3131

3232
[tool.hatch.build.targets.wheel]
3333
only-include = ["src/fms_acceleration"]
3434

35+
[tool.hatch.metadata]
36+
allow-direct-references = true
37+
3538
[tool.hatch.build.targets.wheel.sources]
3639
"src" = ""

0 commit comments

Comments
 (0)