@@ -270,7 +270,6 @@ def _serialize_from_file_pt2(model_file: str) -> dict:
270270def deserialize_to_file (
271271 model_file : str ,
272272 data : dict ,
273- model_params : dict | None = None ,
274273 model_json_override : dict | None = None ,
275274) -> None :
276275 """Deserialize a dictionary to a .pte or .pt2 model file.
@@ -285,19 +284,18 @@ def deserialize_to_file(
285284 data : dict
286285 The dictionary to be deserialized (same format as dpmodel's
287286 serialize output, with "model" and optionally "model_def_script" keys).
288- model_params : dict or None
289- Original model config (the dict passed to ``get_model``).
290- If provided, embedded in the .pte so that ``--use-pretrain-script``
291- can extract descriptor/fitting params at finetune time.
287+ If ``data["model_def_script"]`` is present, it is embedded in the
288+ output so that ``--use-pretrain-script`` can extract descriptor/fitting
289+ params at finetune time.
292290 model_json_override : dict or None
293291 If provided, this dict is stored in model.json instead of ``data``.
294292 Used by ``dp compress`` to store the compressed model dict while
295293 tracing the uncompressed model (make_fx cannot trace custom ops).
296294 """
297295 if model_file .endswith (".pt2" ):
298- _deserialize_to_file_pt2 (model_file , data , model_json_override , model_params )
296+ _deserialize_to_file_pt2 (model_file , data , model_json_override )
299297 else :
300- _deserialize_to_file_pte (model_file , data , model_json_override , model_params )
298+ _deserialize_to_file_pte (model_file , data , model_json_override )
301299
302300
303301def _trace_and_export (
@@ -397,17 +395,17 @@ def _deserialize_to_file_pte(
397395 model_file : str ,
398396 data : dict ,
399397 model_json_override : dict | None = None ,
400- model_params : dict | None = None ,
401398) -> None :
402399 """Deserialize a dictionary to a .pte model file."""
403400 exported , metadata , data_for_json , output_keys = _trace_and_export (
404401 data , model_json_override
405402 )
406403
404+ model_def_script = data .get ("model_def_script" ) or {}
407405 metadata ["output_keys" ] = output_keys
408406 extra_files = {
409407 "metadata.json" : json .dumps (metadata ),
410- "model_def_script.json" : json .dumps (model_params or {} ),
408+ "model_def_script.json" : json .dumps (model_def_script ),
411409 "model.json" : json .dumps (data_for_json , separators = ("," , ":" )),
412410 }
413411
@@ -418,7 +416,6 @@ def _deserialize_to_file_pt2(
418416 model_file : str ,
419417 data : dict ,
420418 model_json_override : dict | None = None ,
421- model_params : dict | None = None ,
422419) -> None :
423420 """Deserialize a dictionary to a .pt2 model file (AOTInductor).
424421
@@ -440,10 +437,11 @@ def _deserialize_to_file_pt2(
440437 aoti_compile_and_package (exported , package_path = model_file )
441438
442439 # Embed metadata into the .pt2 ZIP archive
440+ model_def_script = data .get ("model_def_script" ) or {}
443441 metadata ["output_keys" ] = output_keys
444442 with zipfile .ZipFile (model_file , "a" ) as zf :
445443 zf .writestr ("extra/metadata.json" , json .dumps (metadata ))
446- zf .writestr ("extra/model_def_script.json" , json .dumps (model_params or {} ))
444+ zf .writestr ("extra/model_def_script.json" , json .dumps (model_def_script ))
447445 zf .writestr (
448446 "extra/model.json" ,
449447 json .dumps (data_for_json , separators = ("," , ":" )),
0 commit comments