Skip to content

Commit 3c0fa5a

Browse files
Merge branch 'feature/puzzletron' into jrausch/distillation-consolidation
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
2 parents 8c2fa10 + 2669a26 commit 3c0fa5a

File tree

99 files changed

+666
-675
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

99 files changed

+666
-675
lines changed

examples/llm_eval/lm_eval_hf.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,7 @@
5252
from modelopt.torch.sparsity.attention_sparsity.conversion import is_attn_sparsified
5353

5454
try:
55-
import modelopt.torch.puzzletron.anymodel.models # noqa: F401
56-
from modelopt.torch.puzzletron.anymodel.model_descriptor.model_descriptor_factory import (
57-
resolve_descriptor_from_pretrained,
58-
)
59-
from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher
55+
import modelopt.torch.puzzletron as mtpz
6056

6157
_ANYMODEL_AVAILABLE = True
6258
except ImportError:
@@ -68,12 +64,12 @@ def _anymodel_patcher_context(pretrained, trust_remote_code=False):
6864
if not _ANYMODEL_AVAILABLE or not pretrained:
6965
return contextlib.nullcontext()
7066
try:
71-
descriptor = resolve_descriptor_from_pretrained(
67+
descriptor = mtpz.resolve_descriptor_from_pretrained(
7268
pretrained, trust_remote_code=trust_remote_code
7369
)
7470
except (ValueError, AttributeError):
7571
return contextlib.nullcontext()
76-
return deci_x_patcher(model_descriptor=descriptor)
72+
return mtpz.deci_x_patcher(model_descriptor=descriptor)
7773

7874

7975
def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | None = None) -> T:

examples/puzzletron/evaluation/hf_deployable_anymodel.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,7 @@
3131
from peft import PeftModel
3232
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
3333

34-
from modelopt.torch.puzzletron.anymodel.model_descriptor.model_descriptor_factory import (
35-
resolve_descriptor_from_pretrained,
36-
)
37-
from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher
34+
import modelopt.torch.puzzletron as mtpz
3835

3936
try:
4037
from pytriton.decorators import batch
@@ -148,11 +145,11 @@ def _load(
148145
# See: modelopt/torch/puzzletron/anymodel/puzzformer/utils.py
149146
# =========================================================================
150147

151-
descriptor = resolve_descriptor_from_pretrained(
148+
descriptor = mtpz.resolve_descriptor_from_pretrained(
152149
self.hf_model_id_path, trust_remote_code=hf_kwargs.get("trust_remote_code", False)
153150
)
154151

155-
with deci_x_patcher(model_descriptor=descriptor):
152+
with mtpz.deci_x_patcher(model_descriptor=descriptor):
156153
self.model = AutoModelForCausalLM.from_pretrained(
157154
self.hf_model_id_path,
158155
torch_dtype=torch_dtype,

examples/puzzletron/main.py

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,8 @@
3737
from pathlib import Path
3838

3939
import modelopt.torch.nas as mtn
40-
import modelopt.torch.puzzletron.mip.mip_and_realize_models as mip_and_realize_models
41-
import modelopt.torch.puzzletron.mip.sweep as sweep
40+
import modelopt.torch.puzzletron as mtpz
4241
import modelopt.torch.utils.distributed as dist
43-
from modelopt.torch.puzzletron.nas.plugins.puzzletron_nas_plugin import PuzzletronModel
44-
from modelopt.torch.puzzletron.tools.hydra_utils import (
45-
initialize_hydra_config_for_dir,
46-
register_hydra_resolvers,
47-
)
48-
from modelopt.torch.puzzletron.tools.logger import mprint
4942

5043

5144
def parse_args():
@@ -74,26 +67,26 @@ def run_full_puzzletron(hydra_config_path: str):
7467
Args:
7568
config_path: Path to the YAML configuration file
7669
"""
77-
mprint("Puzzletron Progress 1/8: starting puzzletron pipeline")
70+
mtpz.tools.mprint("Puzzletron Progress 1/8: starting puzzletron pipeline")
7871
dist.setup(timeout=timedelta(minutes=10))
7972

8073
# Register Hydra custom resolvers (needed for config resolution)
81-
register_hydra_resolvers()
74+
mtpz.tools.register_hydra_resolvers()
8275

8376
hydra_config_path = Path(hydra_config_path).resolve()
8477
hydra_config_dir = str(hydra_config_path.parent)
8578
hydra_config_name = hydra_config_path.stem
8679

8780
# Load hydra config
88-
hydra_cfg = initialize_hydra_config_for_dir(
81+
hydra_cfg = mtpz.tools.initialize_hydra_config_for_dir(
8982
config_dir=hydra_config_dir,
9083
config_name=hydra_config_name,
9184
overrides=[],
9285
)
9386

9487
# Convert model (convert from HF to DeciLM, score pruning activations,
9588
# prune the model and save pruned checkpoints)
96-
input_model = PuzzletronModel()
89+
input_model = mtpz.puzzletron_nas_plugin.PuzzletronModel()
9790
converted_model = mtn.convert(
9891
input_model,
9992
mode=[
@@ -120,7 +113,7 @@ def run_full_puzzletron(hydra_config_path: str):
120113
)
121114

122115
dist.cleanup()
123-
mprint("Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu)")
116+
mtpz.tools.mprint("Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu)")
124117

125118

126119
def run_mip_only(hydra_config_path: str):
@@ -135,33 +128,33 @@ def run_mip_only(hydra_config_path: str):
135128
dist.setup(timeout=timedelta(minutes=10))
136129

137130
# Register Hydra custom resolvers (needed for config resolution)
138-
register_hydra_resolvers()
131+
mtpz.tools.register_hydra_resolvers()
139132

140133
hydra_config_path = Path(hydra_config_path).resolve()
141134
hydra_config_dir = str(hydra_config_path.parent)
142135
hydra_config_name = hydra_config_path.stem
143136

144137
# Load hydra config
145-
hydra_cfg = initialize_hydra_config_for_dir(
138+
hydra_cfg = mtpz.tools.initialize_hydra_config_for_dir(
146139
config_dir=hydra_config_dir,
147140
config_name=hydra_config_name,
148141
overrides=[],
149142
)
150143

151144
# Check if sweep mode is enabled
152145
if hasattr(hydra_cfg.mip, "sweep") and hydra_cfg.mip.sweep.get("enabled", False):
153-
mprint(
146+
mtpz.tools.mprint(
154147
"Puzzletron Progress 7/8: running MIP sweep for multiple compression rates (multi-gpu)"
155148
)
156-
sweep.run_mip_sweep(hydra_cfg)
149+
mtpz.mip.run_mip_sweep(hydra_cfg)
157150
else:
158151
# mip_and_realize_models (distributed processing)
159152
# TODO: How to make it part of mnt.search() api, similarly to run_full_puzzletron() API
160-
mprint("Puzzletron Progress 7/8: running MIP and realizing models (multi-gpu)")
161-
mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg)
153+
mtpz.tools.mprint("Puzzletron Progress 7/8: running MIP and realizing models (multi-gpu)")
154+
mtpz.mip.launch_mip_and_realize_model(hydra_cfg)
162155

163156
dist.cleanup()
164-
mprint("Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu)")
157+
mtpz.tools.mprint("Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu)")
165158

166159

167160
def main():

modelopt/torch/puzzletron/__init__.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,26 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
# NOTE: Some modules also trigger factory registration as side effect
17+
from . import (
18+
activation_scoring,
19+
anymodel,
20+
block_config,
21+
build_library_and_stats,
22+
dataset,
23+
entrypoint,
24+
export,
25+
mip,
26+
pruning,
27+
puzzletron_nas_plugin,
28+
replacement_library,
29+
scoring,
30+
sewing_kit,
31+
subblock_stats,
32+
tools,
33+
utils,
34+
)
35+
36+
# Import functions from important modules to top-level
37+
from .anymodel import *
38+
from .entrypoint import *
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from .activation_hooks import *
17+
from .score_pruning_activations import *

modelopt/torch/puzzletron/activation_scoring/activation_hooks/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@
2222
import torch
2323

2424
from modelopt.torch.prune.importance_hooks.base_hooks import ForwardHook as ActivationsHook
25-
from modelopt.torch.puzzletron.tools.logger import aprint
26-
from modelopt.torch.puzzletron.utils.dummy_modules import DummyBlock, DummyModule
25+
26+
from ...tools.logger import aprint
27+
from ...utils.dummy_modules import DummyBlock, DummyModule
2728

2829

2930
def register_activation_hooks(

modelopt/torch/puzzletron/activation_scoring/score_pruning_activations.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from omegaconf import DictConfig
2020

2121
import modelopt.torch.utils.distributed as dist
22-
from modelopt.torch.puzzletron.tools.logger import mprint
23-
from modelopt.torch.puzzletron.tools.validate_model import validate_model
22+
23+
from ..tools.logger import mprint
2424

2525

2626
def has_checkpoint_support(activation_hooks_kwargs: dict) -> bool:
@@ -127,10 +127,9 @@ def should_skip_scoring_completely(cfg: DictConfig) -> bool:
127127
return is_completed
128128

129129

130-
# Old progress tracking removed - checkpoint manager handles all progress tracking
131-
132-
133130
def launch_score_activations(cfg: DictConfig):
131+
from ..tools.validate_model import validate_model
132+
134133
# Check if we should skip scoring entirely (only if 100% complete)
135134
if should_skip_scoring_completely(cfg):
136135
return

modelopt/torch/puzzletron/anymodel/__init__.py

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -37,28 +37,7 @@
3737
- (more to come: qwen2, mistral_small, etc.)
3838
"""
3939

40-
# Import models to trigger factory registration
41-
from modelopt.torch.puzzletron.anymodel import models # noqa: F401
42-
from modelopt.torch.puzzletron.anymodel.converter import Converter, ConverterFactory, convert_model
43-
from modelopt.torch.puzzletron.anymodel.model_descriptor import (
44-
ModelDescriptor,
45-
ModelDescriptorFactory,
46-
)
47-
from modelopt.torch.puzzletron.anymodel.puzzformer import (
48-
MatchingZeros,
49-
Same,
50-
deci_x_patcher,
51-
return_tuple_of_size,
52-
)
53-
54-
__all__ = [
55-
"Converter",
56-
"ConverterFactory",
57-
"ModelDescriptor",
58-
"ModelDescriptorFactory",
59-
"deci_x_patcher",
60-
"MatchingZeros",
61-
"Same",
62-
"return_tuple_of_size",
63-
"convert_model",
64-
]
40+
from . import models # trigger factory registration
41+
from .converter import *
42+
from .model_descriptor import *
43+
from .puzzformer import *

modelopt/torch/puzzletron/anymodel/converter/convert_any_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818

1919
from pathlib import Path
2020

21-
from modelopt.torch.puzzletron.anymodel.converter.converter import Converter
22-
from modelopt.torch.puzzletron.anymodel.converter.converter_factory import ConverterFactory
23-
from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptorFactory
21+
from ..model_descriptor import ModelDescriptorFactory
22+
from .converter import Converter
23+
from .converter_factory import ConverterFactory
2424

2525
__all__ = ["convert_model"]
2626

modelopt/torch/puzzletron/anymodel/converter/converter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@
2929
from transformers import PretrainedConfig
3030
from transformers.integrations.mxfp4 import convert_moe_packed_tensors
3131

32-
from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor
33-
from modelopt.torch.puzzletron.block_config import BlockConfig
34-
from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import load_model_config, save_model_config
32+
from ...block_config import BlockConfig
33+
from ...tools.checkpoint_utils_hf import load_model_config, save_model_config
34+
from ..model_descriptor import ModelDescriptor
3535

3636
__all__ = ["Converter"]
3737

0 commit comments

Comments
 (0)