Skip to content

Commit 76e7b79

Browse files
fix rope_parameters is not inited (#2882)
1 parent cde1604 commit 76e7b79

1 file changed

Lines changed: 10 additions & 10 deletions

File tree

gptqmodel/utils/hf.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,26 @@
33
# SPDX-License-Identifier: Apache-2.0
44
# Contact: qubitium@modelcloud.ai, x.com/qubitium
55

6-
import inspect
76
import json
7+
import numpy as np
88
import os
99
import sys
10+
import transformers
1011
import warnings
12+
from accelerate import init_empty_weights
1113
from contextlib import contextmanager
1214
from functools import lru_cache
13-
from typing import Any, Optional
14-
15-
import numpy as np
16-
import torch
17-
import transformers
18-
from accelerate import init_empty_weights
1915
from transformers import (
2016
AutoConfig,
2117
AutoModelForCausalLM,
2218
AutoTokenizer,
2319
GenerationConfig,
2420
PreTrainedModel,
2521
)
22+
from typing import Any, Optional
2623

24+
import inspect
25+
import torch
2726
from ..nn_modules.qlinear.gguf import (
2827
PRISM_Q1_0_G128_BLOCK_SIZE,
2928
PRISM_Q1_0_G128_NAME,
@@ -991,9 +990,6 @@ def _normalize_rope_parameters_config_compat(config: Any) -> None:
991990
legacy_rope_scaling = getattr(config, "rope_scaling", None)
992991
rope_parameters = dict(legacy_rope_scaling) if isinstance(legacy_rope_scaling, dict) else dict(rope_parameters or {})
993992

994-
if not rope_parameters and getattr(config, "rope_theta", None) is None and getattr(config, "default_theta", None) is None:
995-
return
996-
997993
rope_parameters.setdefault("rope_type", rope_parameters.get("type", "default"))
998994
if rope_parameters.get("rope_theta") is None:
999995
rope_theta = getattr(config, "rope_theta", None)
@@ -1138,6 +1134,10 @@ def normalize_hf_config_compat(config: Any, *, trust_remote_code: bool = False)
11381134
# some config classes synchronize `rope_scaling` from `rope_parameters` and
11391135
# can drop legacy keys like `rope_scaling["type"]`.
11401136
_normalize_remote_code_config_compat(config)
1137+
# Some config classes can still nullify rope_parameters during the second
1138+
# remote-code field normalization pass. Ensure final config always carries
1139+
# valid rope_parameters for model classes that directly subscript it.
1140+
_normalize_rope_parameters_config_compat(config)
11411141

11421142

11431143
def prepare_remote_code_compat(config: Any) -> None:

0 commit comments

Comments
 (0)