Skip to content

Commit eeff734

Browse files
WOODchen7woodchenwu
andauthored
self.model_path change to self.absolute_model_path, Use GlobalConfig to access absolute_model_path (#105)
Co-authored-by: woodchenwu <woodchenwu@tencent.com>
1 parent a937651 commit eeff734

4 files changed

Lines changed: 18 additions & 6 deletions

File tree

angelslim/compressor/quant/ptq.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(self, model, slim_config=None):
3939
self.quant_model = model
4040
# init ptq config of model
4141
self.quant_model.init_ptq(slim_config)
42-
self.model_path = slim_config.get("model_path")
42+
self.absolute_model_path = slim_config["global_config"].absolute_model_path
4343
self.quant_algo = self.quant_model.quant_config.quant_algo
4444
self.quant_helpers = self.quant_model.quant_config.quant_helpers
4545
if (
@@ -213,12 +213,15 @@ def _convert(self):
213213
):
214214
if sub_layer.weight.device.type == "meta":
215215
with open(
216-
os.path.join(self.model_path, "model.safetensors.index.json"),
216+
os.path.join(
217+
self.absolute_model_path, "model.safetensors.index.json"
218+
),
217219
"r",
218220
) as f:
219221
model_index = json.load(f)
220222
orign_w_file = os.path.join(
221-
self.model_path, model_index["weight_map"][name + ".weight"]
223+
self.absolute_model_path,
224+
model_index["weight_map"][name + ".weight"],
222225
)
223226
orign_w = load_file(orign_w_file, device="cpu")
224227
print_info(f"Load meta weight {name} from file {orign_w_file}")
@@ -228,7 +231,7 @@ def _convert(self):
228231
if hasattr(sub_layer, "bias"):
229232
if (name + ".bias") in model_index["weight_map"]:
230233
orign_b_file = os.path.join(
231-
self.model_path,
234+
self.absolute_model_path,
232235
model_index["weight_map"][name + ".bias"],
233236
)
234237
orign_b = load_file(orign_b_file, device="cpu")

angelslim/engine.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,6 @@ def prepare_compressor(
204204
slim_config = {
205205
"global_config": global_config,
206206
"compress_config": compress_config,
207-
"model_path": self.model_path,
208207
}
209208
self.compress_type = compress_names
210209
self.only_inference = (

angelslim/utils/config_parser.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import yaml
2121

22-
from .utils import get_hf_config
22+
from .utils import get_hf_config, get_hf_model_path
2323

2424

2525
class CompressionMethod(str, Enum):
@@ -62,6 +62,7 @@ class GlobalConfig:
6262
max_seq_length: int = field(default=2048)
6363
hidden_size: int = field(default=2048)
6464
model_arch_type: str = field(default=None)
65+
absolute_model_path: str = field(default=None)
6566
deploy_backend: str = field(default="vllm")
6667

6768
def update(self, model_path: str = None, max_seq_length: int = None):
@@ -78,6 +79,7 @@ def update(self, model_path: str = None, max_seq_length: int = None):
7879
if model_path:
7980
self.set_model_hidden_size(model_path)
8081
self.set_model_arch_type(model_path)
82+
self.absolute_model_path = get_hf_model_path(model_path)
8183
if max_seq_length:
8284
self.set_max_seq_length(max_seq_length)
8385

angelslim/utils/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,14 @@ def get_hf_config(model_path) -> dict:
148148
return json_data
149149

150150

151+
def get_hf_model_path(model_path) -> str:
152+
"When model_path does not exist, fetch the model.config from cached_file."
153+
if os.path.isfile(model_path):
154+
return model_path
155+
else:
156+
return os.path.dirname(cached_file(model_path, "config.json"))
157+
158+
151159
def common_prefix(str1, str2):
152160
return "".join(
153161
x[0] for x in takewhile(lambda x: x[0] == x[1], zip(str1, str2))

0 commit comments

Comments
 (0)