Skip to content

Commit 97a2958

Browse files
[MODEL] support hrm_text (#2905)
* add missing test_interns1.py Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai> * Support multi-tree module layouts and direct meta tensor materialization Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai> * support hrm_text Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai> * format Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai> * cleanup Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai> --------- Signed-off-by: ZX-ModelCloud <zx@modelcloud.ai>
1 parent 6edc24d commit 97a2958

43 files changed

Lines changed: 937 additions & 153 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

gptqmodel/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import os
77
import sys
88

9+
910
# isort: off
1011
from ._banner import get_startup_banner # noqa: E402
1112
from .utils import _MONKEY_PATCH_LOCK # noqa: E402

gptqmodel/looper/module_looper.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,14 @@
5151
rehome_module_to_device,
5252
select_forward_devices,
5353
)
54-
from ..utils.model import find_modules, get_module, get_module_by_name_prefix, move_to, MoETopKState, set_moe_topk, restore_moe_topk
54+
from ..utils.model import (
55+
MoETopKState,
56+
get_layers_with_prefixes,
57+
get_module,
58+
move_to,
59+
restore_moe_topk,
60+
set_moe_topk, get_module_by_name_prefix,
61+
)
5562
from ..utils.offload import offload_to_disk
5663
from ..utils.python import has_gil_control, has_gil_disabled
5764
from ..utils.torch import (CPU, META, timed_gc_collect, torch_sync, tf32_high_precision_guard)
@@ -1428,7 +1435,10 @@ def _loop_impl(self, fallback=None, **kwargs):
14281435

14291436
forward_pass_use_cache = self.gptq_model.model.config.use_cache if hasattr(self.gptq_model.model.config, "use_cache") else False
14301437
self.gptq_model.model.config.use_cache = False
1431-
layers, layers_prefix = get_module_by_name_prefix(self.gptq_model.model, self.gptq_model.extract_layers_node())
1438+
layers, layer_names = get_layers_with_prefixes(
1439+
self.gptq_model.model,
1440+
self.gptq_model.extract_layers_node(),
1441+
)
14321442
region_timer = getattr(self.gptq_model, "quant_region_timer", None)
14331443

14341444
for p_index, processor in enumerate(self.processors):
@@ -1523,7 +1533,7 @@ def _loop_impl(self, fallback=None, **kwargs):
15231533
layers=layers,
15241534
layer_modules=layer_modules,
15251535
planning_layer_modules=planning_layer_modules,
1526-
layers_prefix=layers_prefix,
1536+
layer_names=layer_names,
15271537
fallback=fallback,
15281538
shared_kv_cache_dict=shared_kv_cache_dict,
15291539
pb=pb,
@@ -1631,7 +1641,7 @@ def create_named_modules(self, module, full, is_lm_head_module, layer_index, lay
16311641
capture_only_flags[n] = True # forward-only modules should not be finalized
16321642
skipped_modules = []
16331643
for name in subset:
1634-
layer_name = self.gptq_model.lm_head if is_lm_head_module else f"{layers_prefix}.{layer_index}.{name}"
1644+
layer_name = self.gptq_model.lm_head if is_lm_head_module else f"{layers_prefix}.{name}"
16351645

16361646
# gptq task is created and stored inside processor
16371647
if not isinstance(subset[name], NamedModule):

gptqmodel/looper/stage_inputs_capture.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from ..utils.device import get_device
2020
from ..utils.looper_helpers import device_ctx, select_forward_devices
2121
from ..utils.logger import setup_logger
22-
from ..utils.model import get_module_by_name_prefix, move_to, nested_move_to
22+
from ..utils.model import get_module, get_module_by_name_prefix, move_to, nested_move_to
2323
from ..utils.torch import CPU, META
2424

2525
if TYPE_CHECKING: # pragma: no cover - import for typing only
@@ -36,6 +36,15 @@ def __init__(self, looper: ModuleLooper, logger=None) -> None:
3636
self.gptq_model = looper.gptq_model
3737
self.logger = logger or setup_logger()
3838

39+
def _materialize_modules_with_direct_meta_tensors(self, device: torch.device) -> None:
40+
for module_name in self.gptq_model.get_modules_with_direct_meta_tensors(self.gptq_model.model):
41+
module = get_module(self.gptq_model.model, module_name)
42+
if isinstance(module, torch.nn.Module):
43+
self.gptq_model.shell_direct_meta_materialize(
44+
target_submodule=module,
45+
device=device,
46+
)
47+
3948
def cache_inputs(
4049
self,
4150
layers: Sequence[torch.nn.Module],
@@ -174,6 +183,9 @@ def store_input_hook(module, args, kwargs):
174183
# and wait for the first instance this callback is called
175184
raise STOP_FORWARD_EXCEPTION
176185

186+
# Parameters attached to the shell root must be ready before embedding forward.
187+
self._materialize_modules_with_direct_meta_tensors(cur_layer_device)
188+
177189
ori_outside_layer_module_devices: Dict[str, torch.device] = {}
178190
for module_name in self.gptq_model.get_base_modules(self.gptq_model.model):
179191
module, _ = get_module_by_name_prefix(self.gptq_model.model, [module_name])

gptqmodel/looper/stage_layer.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from ..utils.device import get_device, get_device_new
3737
from ..utils.looper_helpers import normalize_device_like
3838
from ..utils.logger import live_renderables_suppressed, log_time_block, setup_logger
39-
from ..utils.model import find_modules, get_module
39+
from ..utils.model import find_modules, get_layer_name, get_module
4040
from ..utils.offload import offload_to_disk
4141
from ..utils.torch import CPU, torch_empty_cache, torch_sync
4242
from .stage_subset import SubsetPlan, build_layer_subset_plans, run_subset_stage
@@ -49,11 +49,11 @@ def _find_last_quantized_layer_index(
4949
looper: "ModuleLooper",
5050
*,
5151
layer_modules: List[List[str]],
52-
layers_prefix: Optional[str],
52+
layer_names: Optional[List[str]],
5353
layer_count: int,
5454
) -> Optional[int]:
5555
"""Return the highest layer index whose tracked modules are not all dynamically skipped."""
56-
if looper.gptq_model.quantize_config.lm_head or not layers_prefix:
56+
if looper.gptq_model.quantize_config.lm_head or not layer_names:
5757
return None
5858

5959
layer_module_names = {
@@ -67,8 +67,9 @@ def _find_last_quantized_layer_index(
6767

6868
last_quantized_layer_index = -1
6969
for candidate_layer_index in range(layer_count):
70+
layer_name = get_layer_name(layer_names, candidate_layer_index)
7071
for module_name in layer_module_names:
71-
module_full_name = f"{layers_prefix}.{candidate_layer_index}.{module_name}"
72+
module_full_name = f"{layer_name}.{module_name}"
7273
# If at least one module in this layer is not dynamically excluded,
7374
# the layer still needs forward/quantization work.
7475
if looper.gptq_model.quantize_config.dynamic_get(layer_name=module_full_name) != False:
@@ -387,7 +388,7 @@ def run_layer_stage(
387388
layers: List[torch.nn.Module],
388389
layer_modules: List[List[str]],
389390
planning_layer_modules: List[List[str]],
390-
layers_prefix: Optional[str],
391+
layer_names: Optional[List[str]],
391392
fallback,
392393
shared_kv_cache_dict: Dict[int, torch.Tensor],
393394
pb,
@@ -403,7 +404,7 @@ def run_layer_stage(
403404
last_quantized_layer_index = _find_last_quantized_layer_index(
404405
looper,
405406
layer_modules=layer_modules,
406-
layers_prefix=layers_prefix,
407+
layer_names=layer_names,
407408
layer_count=layer_count,
408409
)
409410

@@ -436,10 +437,12 @@ def run_layer_stage(
436437
layer_title = "Quantizing lm_head"
437438
module = get_module(looper.gptq_model.model, key=looper.gptq_model.lm_head)
438439
pristine_group_module = None
440+
layer_name = ""
439441
else:
440442
layer_title = f"Quantizing layer {layer_index} of {layer_count - 1}"
441443
module = layers[layer_index]
442444
pristine_group_module = None
445+
layer_name = get_layer_name(layer_names, layer_index)
443446

444447
pb.title(layer_title).subtitle("").draw()
445448
if durable_progress_logs:
@@ -483,8 +486,8 @@ def run_layer_stage(
483486

484487
layers[layer_index] = module
485488

486-
if layers_prefix:
487-
layer_descriptor = f"{layers_prefix}.{layer_index}"
489+
if layer_name:
490+
layer_descriptor = layer_name
488491
else:
489492
layer_descriptor = str(layer_index)
490493

@@ -530,7 +533,7 @@ def run_layer_stage(
530533
full=full,
531534
is_lm_head_module=is_lm_head_module,
532535
layer_index=layer_index,
533-
layers_prefix=layers_prefix,
536+
layers_prefix=layer_name,
534537
fallback=fallback,
535538
)
536539
if durable_progress_logs:

gptqmodel/looper/weight_only_looper.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,14 @@
2727
from ..nn_modules.converter import MODULE_CONVERTER_MAP
2828
from ..quantization.config import BitsAndBytesConfig, FP8Config, GGUFConfig, RTNConfig
2929
from ..utils.logger import setup_logger
30-
from ..utils.model import find_modules, get_module, get_module_by_name_prefix, move_to
30+
from ..utils.model import (
31+
find_modules,
32+
get_layer_name,
33+
get_layers_with_prefixes,
34+
get_module,
35+
get_module_by_name_prefix,
36+
move_to,
37+
)
3138
from ..utils.offload import offload_to_disk
3239

3340

@@ -49,7 +56,7 @@ def _resolve_named_module(
4956
layer_module: torch.nn.Module,
5057
full: Dict[str, torch.nn.Module],
5158
layer_index: int,
52-
layers_prefix: Optional[str],
59+
layer_path: Optional[str],
5360
module_name: str,
5461
is_lm_head_module: bool,
5562
) -> Optional[NamedModule]:
@@ -65,7 +72,7 @@ def _resolve_named_module(
6572
if isinstance(resolved, NamedModule):
6673
return resolved
6774

68-
layer_name = self.gptq_model.lm_head if is_lm_head_module else f"{layers_prefix}.{layer_index}.{module_name}"
75+
layer_name = self.gptq_model.lm_head if is_lm_head_module else f"{layer_path}.{module_name}"
6976
named = NamedModule(
7077
resolved,
7178
name=module_name,
@@ -132,11 +139,19 @@ def loop(self, **kwargs):
132139
# decoder-cache state while layers are being replaced.
133140
self.gptq_model.model.config.use_cache = False
134141

135-
layers, layers_prefix = get_module_by_name_prefix(
142+
layers, layer_names = get_layers_with_prefixes(
136143
self.gptq_model.model,
137144
self.gptq_model.extract_layers_node(),
138145
)
139146

147+
for module_name in self.gptq_model.get_modules_with_direct_meta_tensors(self.gptq_model.model):
148+
module = get_module(self.gptq_model.model, module_name)
149+
if module is not None:
150+
self.gptq_model.shell_direct_meta_materialize(
151+
target_submodule=module,
152+
device=CPU,
153+
)
154+
140155
if quant_config.offload_to_disk:
141156
log.info("Offloading base modules to disk...")
142157
offload_to_disk(
@@ -181,6 +196,10 @@ def loop(self, **kwargs):
181196
else:
182197
module = layers[layer_index]
183198
subsets = layer_modules
199+
# Flattened layer names preserve the source stack for split decoders.
200+
layer_name = get_layer_name(layer_names, layer_index)
201+
if is_lm_head_module:
202+
layer_name = None
184203

185204
module = self.gptq_model.pre_quantize(module)
186205
if not is_lm_head_module:
@@ -204,7 +223,7 @@ def loop(self, **kwargs):
204223
layer_module=module,
205224
full=full,
206225
layer_index=layer_index,
207-
layers_prefix=layers_prefix,
226+
layer_path=layer_name,
208227
module_name=module_name,
209228
is_lm_head_module=is_lm_head_module,
210229
)

gptqmodel/models/auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@
113113
from .definitions.gptj import GptJQModel # noqa: E402
114114
from .definitions.granitemoehybrid import GraniteMoeHybridQModel
115115
from .definitions.grinmoe import GrinMoeQModel # noqa: E402
116+
from .definitions.hrm_text import HrmTextQModel # noqa: E402
116117
from .definitions.hymba import HymbaQModel # noqa: E402
117118
from .definitions.instella import InstellaQModel # noqa: E402
118119
from .definitions.internlm import InternLMQModel # noqa: E402
@@ -227,6 +228,7 @@
227228
"internlm2": InternLM2QModel,
228229
"interns1": InternS1QModel,
229230
"internvl_chat": InternVLChatQModel,
231+
"hrm_text": HrmTextQModel,
230232
"qwen": QwenQModel,
231233
"mistral": LlamaQModel, # 100% llama clone
232234
"yi": LlamaQModel, # 100% llama clone

0 commit comments

Comments
 (0)