Skip to content

Commit b878e73

Browse files
author
Han Wang
committed
refactor(pt_expt): remove model_params kwarg from deserialize_to_file
Read training config from data["model_def_script"] instead of a separate kwarg, matching the convention of pt/jax/tf/pd backends. This fixes dp convert-backend for pt_expt, which passes the universal dict (already containing model_def_script) with only 2 positional args.
1 parent ff2d1ca commit b878e73

4 files changed

Lines changed: 18 additions & 18 deletions

File tree

deepmd/pt_expt/entrypoints/compress.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ def enable_compression(
118118
deserialize_to_file(
119119
output,
120120
uncompressed_data,
121-
model_params=model_dict.get("model_def_script"),
122121
model_json_override={
123122
"model": compressed_model_dict,
124123
"model_def_script": model_dict.get("model_def_script"),

deepmd/pt_expt/entrypoints/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def freeze(
259259
m.eval()
260260

261261
model_dict = m.serialize()
262-
deserialize_to_file(output, {"model": model_dict}, model_params=model_params)
262+
deserialize_to_file(output, {"model": model_dict, "model_def_script": model_params})
263263
log.info("Saved frozen model to %s", output)
264264

265265

@@ -441,7 +441,7 @@ def change_bias(
441441
)
442442
model_dict = model_to_change.serialize()
443443
deserialize_to_file(
444-
output_path, {"model": model_dict}, model_params=model_params
444+
output_path, {"model": model_dict, "model_def_script": model_params}
445445
)
446446
log.info(f"Saved model to {output_path}")
447447

deepmd/pt_expt/utils/serialization.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,6 @@ def _serialize_from_file_pt2(model_file: str) -> dict:
270270
def deserialize_to_file(
271271
model_file: str,
272272
data: dict,
273-
model_params: dict | None = None,
274273
model_json_override: dict | None = None,
275274
) -> None:
276275
"""Deserialize a dictionary to a .pte or .pt2 model file.
@@ -285,19 +284,18 @@ def deserialize_to_file(
285284
data : dict
286285
The dictionary to be deserialized (same format as dpmodel's
287286
serialize output, with "model" and optionally "model_def_script" keys).
288-
model_params : dict or None
289-
Original model config (the dict passed to ``get_model``).
290-
If provided, embedded in the .pte so that ``--use-pretrain-script``
291-
can extract descriptor/fitting params at finetune time.
287+
If ``data["model_def_script"]`` is present, it is embedded in the
288+
output so that ``--use-pretrain-script`` can extract descriptor/fitting
289+
params at finetune time.
292290
model_json_override : dict or None
293291
If provided, this dict is stored in model.json instead of ``data``.
294292
Used by ``dp compress`` to store the compressed model dict while
295293
tracing the uncompressed model (make_fx cannot trace custom ops).
296294
"""
297295
if model_file.endswith(".pt2"):
298-
_deserialize_to_file_pt2(model_file, data, model_json_override, model_params)
296+
_deserialize_to_file_pt2(model_file, data, model_json_override)
299297
else:
300-
_deserialize_to_file_pte(model_file, data, model_json_override, model_params)
298+
_deserialize_to_file_pte(model_file, data, model_json_override)
301299

302300

303301
def _trace_and_export(
@@ -397,17 +395,17 @@ def _deserialize_to_file_pte(
397395
model_file: str,
398396
data: dict,
399397
model_json_override: dict | None = None,
400-
model_params: dict | None = None,
401398
) -> None:
402399
"""Deserialize a dictionary to a .pte model file."""
403400
exported, metadata, data_for_json, output_keys = _trace_and_export(
404401
data, model_json_override
405402
)
406403

404+
model_def_script = data.get("model_def_script") or {}
407405
metadata["output_keys"] = output_keys
408406
extra_files = {
409407
"metadata.json": json.dumps(metadata),
410-
"model_def_script.json": json.dumps(model_params or {}),
408+
"model_def_script.json": json.dumps(model_def_script),
411409
"model.json": json.dumps(data_for_json, separators=(",", ":")),
412410
}
413411

@@ -418,7 +416,6 @@ def _deserialize_to_file_pt2(
418416
model_file: str,
419417
data: dict,
420418
model_json_override: dict | None = None,
421-
model_params: dict | None = None,
422419
) -> None:
423420
"""Deserialize a dictionary to a .pt2 model file (AOTInductor).
424421
@@ -440,10 +437,11 @@ def _deserialize_to_file_pt2(
440437
aoti_compile_and_package(exported, package_path=model_file)
441438

442439
# Embed metadata into the .pt2 ZIP archive
440+
model_def_script = data.get("model_def_script") or {}
443441
metadata["output_keys"] = output_keys
444442
with zipfile.ZipFile(model_file, "a") as zf:
445443
zf.writestr("extra/metadata.json", json.dumps(metadata))
446-
zf.writestr("extra/model_def_script.json", json.dumps(model_params or {}))
444+
zf.writestr("extra/model_def_script.json", json.dumps(model_def_script))
447445
zf.writestr(
448446
"extra/model.json",
449447
json.dumps(data_for_json, separators=(",", ":")),

source/tests/pt_expt/infer/test_deep_eval.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def test_get_model_def_script_with_params(self) -> None:
115115
with tempfile.NamedTemporaryFile(suffix=".pte", delete=False) as f:
116116
tmpfile2 = f.name
117117
try:
118-
deserialize_to_file(tmpfile2, self.model_data, model_params=training_config)
118+
data_with_config = {**self.model_data, "model_def_script": training_config}
119+
deserialize_to_file(tmpfile2, data_with_config)
119120
dp2 = DeepPot(tmpfile2)
120121
mds = dp2.deep_eval.get_model_def_script()
121122
self.assertEqual(mds, training_config)
@@ -598,9 +599,11 @@ def test_get_model_def_script_with_params(self) -> None:
598599
try:
599600
torch.set_default_device(None)
600601
try:
601-
deserialize_to_file(
602-
tmpfile2, self.model_data, model_params=training_config
603-
)
602+
data_with_config = {
603+
**self.model_data,
604+
"model_def_script": training_config,
605+
}
606+
deserialize_to_file(tmpfile2, data_with_config)
604607
finally:
605608
torch.set_default_device("cuda:9999999")
606609
dp2 = DeepPot(tmpfile2)

0 commit comments

Comments
 (0)