Skip to content

Commit 0c2a2ee

Browse files
Merge branch 'feature/puzzletron' into jrausch/distillation-consolidation
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
2 parents 8c2fa10 + 977d60a commit 0c2a2ee

File tree

118 files changed

+899
-770
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

118 files changed

+899
-770
lines changed

.github/CODEOWNERS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ modelopt_recipes @NVIDIA/modelopt-recipes-codeowners
5050
/examples/model_hub @NVIDIA/modelopt-examples-model_hub-codeowners
5151
/examples/onnx_ptq @NVIDIA/modelopt-onnx-codeowners
5252
/examples/pruning @NVIDIA/modelopt-torch-nas-prune-codeowners
53+
/examples/puzzletron @NVIDIA/modelopt-torch-puzzletron-codeowners
5354
/examples/specdec_bench @NVIDIA/modelopt-torch-speculative-codeowners
5455
/examples/speculative_decoding @NVIDIA/modelopt-torch-speculative-codeowners
5556
/examples/torch_onnx @NVIDIA/modelopt-onnx-codeowners

docs/source/conf.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
# import sys
3232
# sys.path.insert(0, os.path.abspath('.'))
3333

34+
import contextlib
3435
import os
3536
import sys
3637

@@ -44,6 +45,14 @@
4445
sys.path.insert(0, os.path.abspath("../../"))
4546
sys.path.append(os.path.abspath("./_ext"))
4647

48+
# Pre-import modelopt.torch so it is cached in sys.modules before Sphinx applies
49+
# autodoc_mock_imports. Mocking triton/tensorrt_llm at the Sphinx level can break
50+
# transitive imports (transformers, transformer_engine, …) and cause modelopt.torch
51+
# to fail inside autosummary. Importing here — while the real packages are still on
52+
# sys.path — avoids that problem entirely.
53+
with contextlib.suppress(Exception):
54+
import modelopt.torch # noqa: F401
55+
4756
# -- Project information -----------------------------------------------------
4857

4958
project = "Model Optimizer" # pylint: disable=C0103

examples/llm_eval/lm_eval_hf.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,7 @@
5252
from modelopt.torch.sparsity.attention_sparsity.conversion import is_attn_sparsified
5353

5454
try:
55-
import modelopt.torch.puzzletron.anymodel.models # noqa: F401
56-
from modelopt.torch.puzzletron.anymodel.model_descriptor.model_descriptor_factory import (
57-
resolve_descriptor_from_pretrained,
58-
)
59-
from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher
55+
import modelopt.torch.puzzletron as mtpz
6056

6157
_ANYMODEL_AVAILABLE = True
6258
except ImportError:
@@ -68,12 +64,12 @@ def _anymodel_patcher_context(pretrained, trust_remote_code=False):
6864
if not _ANYMODEL_AVAILABLE or not pretrained:
6965
return contextlib.nullcontext()
7066
try:
71-
descriptor = resolve_descriptor_from_pretrained(
67+
descriptor = mtpz.anymodel.resolve_descriptor_from_pretrained(
7268
pretrained, trust_remote_code=trust_remote_code
7369
)
7470
except (ValueError, AttributeError):
7571
return contextlib.nullcontext()
76-
return deci_x_patcher(model_descriptor=descriptor)
72+
return mtpz.anymodel.deci_x_patcher(model_descriptor=descriptor)
7773

7874

7975
def create_from_arg_obj(cls: type[T], arg_dict: dict, additional_config: dict | None = None) -> T:

examples/puzzletron/evaluation/hf_deployable_anymodel.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,7 @@
3131
from peft import PeftModel
3232
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
3333

34-
from modelopt.torch.puzzletron.anymodel.model_descriptor.model_descriptor_factory import (
35-
resolve_descriptor_from_pretrained,
36-
)
37-
from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher
34+
import modelopt.torch.puzzletron as mtpz
3835

3936
try:
4037
from pytriton.decorators import batch
@@ -145,14 +142,14 @@ def _load(
145142
# =========================================================================
146143
# BEGIN ANYMODEL PATCH
147144
# Wraps model loading with deci_x_patcher for heterogeneous layer configs.
148-
# See: modelopt/torch/puzzletron/anymodel/puzzformer/utils.py
145+
# See: modelopt/torch/puzzletron/anymodel/puzzformer/patcher.py
149146
# =========================================================================
150147

151-
descriptor = resolve_descriptor_from_pretrained(
148+
descriptor = mtpz.anymodel.resolve_descriptor_from_pretrained(
152149
self.hf_model_id_path, trust_remote_code=hf_kwargs.get("trust_remote_code", False)
153150
)
154151

155-
with deci_x_patcher(model_descriptor=descriptor):
152+
with mtpz.anymodel.deci_x_patcher(model_descriptor=descriptor):
156153
self.model = AutoModelForCausalLM.from_pretrained(
157154
self.hf_model_id_path,
158155
torch_dtype=torch_dtype,

examples/puzzletron/main.py

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,8 @@
3737
from pathlib import Path
3838

3939
import modelopt.torch.nas as mtn
40-
import modelopt.torch.puzzletron.mip.mip_and_realize_models as mip_and_realize_models
41-
import modelopt.torch.puzzletron.mip.sweep as sweep
40+
import modelopt.torch.puzzletron as mtpz
4241
import modelopt.torch.utils.distributed as dist
43-
from modelopt.torch.puzzletron.nas.plugins.puzzletron_nas_plugin import PuzzletronModel
44-
from modelopt.torch.puzzletron.tools.hydra_utils import (
45-
initialize_hydra_config_for_dir,
46-
register_hydra_resolvers,
47-
)
48-
from modelopt.torch.puzzletron.tools.logger import mprint
4942

5043

5144
def parse_args():
@@ -74,26 +67,26 @@ def run_full_puzzletron(hydra_config_path: str):
7467
Args:
7568
config_path: Path to the YAML configuration file
7669
"""
77-
mprint("Puzzletron Progress 1/8: starting puzzletron pipeline")
70+
mtpz.tools.mprint("Puzzletron Progress 1/8: starting puzzletron pipeline")
7871
dist.setup(timeout=timedelta(minutes=10))
7972

8073
# Register Hydra custom resolvers (needed for config resolution)
81-
register_hydra_resolvers()
74+
mtpz.tools.register_hydra_resolvers()
8275

8376
hydra_config_path = Path(hydra_config_path).resolve()
8477
hydra_config_dir = str(hydra_config_path.parent)
8578
hydra_config_name = hydra_config_path.stem
8679

8780
# Load hydra config
88-
hydra_cfg = initialize_hydra_config_for_dir(
81+
hydra_cfg = mtpz.tools.initialize_hydra_config_for_dir(
8982
config_dir=hydra_config_dir,
9083
config_name=hydra_config_name,
9184
overrides=[],
9285
)
9386

9487
# Convert model (convert from HF to DeciLM, score pruning activations,
9588
# prune the model and save pruned checkpoints)
96-
input_model = PuzzletronModel()
89+
input_model = mtpz.puzzletron_nas_plugin.PuzzletronModel()
9790
converted_model = mtn.convert(
9891
input_model,
9992
mode=[
@@ -120,7 +113,7 @@ def run_full_puzzletron(hydra_config_path: str):
120113
)
121114

122115
dist.cleanup()
123-
mprint("Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu)")
116+
mtpz.tools.mprint("Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu)")
124117

125118

126119
def run_mip_only(hydra_config_path: str):
@@ -135,33 +128,33 @@ def run_mip_only(hydra_config_path: str):
135128
dist.setup(timeout=timedelta(minutes=10))
136129

137130
# Register Hydra custom resolvers (needed for config resolution)
138-
register_hydra_resolvers()
131+
mtpz.tools.register_hydra_resolvers()
139132

140133
hydra_config_path = Path(hydra_config_path).resolve()
141134
hydra_config_dir = str(hydra_config_path.parent)
142135
hydra_config_name = hydra_config_path.stem
143136

144137
# Load hydra config
145-
hydra_cfg = initialize_hydra_config_for_dir(
138+
hydra_cfg = mtpz.tools.initialize_hydra_config_for_dir(
146139
config_dir=hydra_config_dir,
147140
config_name=hydra_config_name,
148141
overrides=[],
149142
)
150143

151144
# Check if sweep mode is enabled
152145
if hasattr(hydra_cfg.mip, "sweep") and hydra_cfg.mip.sweep.get("enabled", False):
153-
mprint(
146+
mtpz.tools.mprint(
154147
"Puzzletron Progress 7/8: running MIP sweep for multiple compression rates (multi-gpu)"
155148
)
156-
sweep.run_mip_sweep(hydra_cfg)
149+
mtpz.mip.run_mip_sweep(hydra_cfg)
157150
else:
158151
# mip_and_realize_models (distributed processing)
159152
# TODO: How to make it part of mnt.search() api, similarly to run_full_puzzletron() API
160-
mprint("Puzzletron Progress 7/8: running MIP and realizing models (multi-gpu)")
161-
mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg)
153+
mtpz.tools.mprint("Puzzletron Progress 7/8: running MIP and realizing models (multi-gpu)")
154+
mtpz.mip.launch_mip_and_realize_model(hydra_cfg)
162155

163156
dist.cleanup()
164-
mprint("Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu)")
157+
mtpz.tools.mprint("Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu)")
165158

166159

167160
def main():

examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,8 @@ def keep_conversation(entry):
142142
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)
143143
if tokenizer.pad_token is None:
144144
tokenizer.pad_token = tokenizer.eos_token
145-
tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "")
145+
if tokenizer.chat_template is not None:
146+
tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "")
146147

147148
output_dir = args.output_dir
148149
output_dir.mkdir(parents=True, exist_ok=True)

modelopt/torch/prune/importance_hooks/expert_removal_hooks.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from .base_hooks import ForwardHook
2828

2929
if TYPE_CHECKING:
30+
# Okay since this is only used for type hints else we should not import puzzletron here
31+
# as its dependencies may not be installed
3032
from modelopt.torch.puzzletron.block_config import BlockConfig
3133

3234
__all__ = [

modelopt/torch/puzzletron/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,22 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
# NOTE: Some modules also trigger factory registration as side effect
17+
from . import (
18+
activation_scoring,
19+
anymodel,
20+
block_config,
21+
build_library_and_stats,
22+
dataset,
23+
entrypoint,
24+
mip,
25+
plugins,
26+
pruning,
27+
puzzletron_nas_plugin,
28+
replacement_library,
29+
scoring,
30+
sewing_kit,
31+
subblock_stats,
32+
tools,
33+
utils,
34+
)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from .activation_hooks import *
17+
from .score_pruning_activations import *

modelopt/torch/puzzletron/activation_scoring/activation_hooks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
from .utils import *

0 commit comments

Comments
 (0)