11# SPDX-License-Identifier: LGPL-3.0-or-later
22import functools
3+ import json
34import logging
45import time
56from collections .abc import (
@@ -288,6 +289,7 @@ def single_model_stat(
288289 _training_data : DpLoaderSet ,
289290 _stat_file_path : str | None ,
290291 finetune_has_new_type : bool = False ,
292+ preset_observed_type : list [str ] | None = None ,
291293 ) -> Callable [[], Any ]:
292294 @functools .lru_cache
293295 def get_sample () -> Any :
@@ -302,6 +304,7 @@ def get_sample() -> Any:
302304 _model .compute_or_load_stat (
303305 sampled_func = get_sample ,
304306 stat_file_path = _stat_file_path ,
307+ preset_observed_type = preset_observed_type ,
305308 )
306309 if isinstance (_stat_file_path , DPH5Path ):
307310 _stat_file_path .root .close ()
@@ -394,7 +397,16 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
394397 finetune_has_new_type = self .finetune_links ["Default" ].get_has_new_type ()
395398 if self .finetune_links is not None
396399 else False ,
400+ preset_observed_type = model_params .get ("info" , {}).get ("observed_type" ),
397401 )
402+ # Persist observed_type from stat into model_params and model_def_script
403+ if not resuming and self .rank == 0 :
404+ observed = getattr (
405+ self .model .atomic_model , "_observed_type" , None
406+ )
407+ if observed is not None :
408+ model_params .setdefault ("info" , {})["observed_type" ] = observed
409+ self .model .model_def_script = json .dumps (model_params )
398410 (
399411 self .training_dataloader ,
400412 self .training_data ,
@@ -432,6 +444,11 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
432444 training_data [model_key ].preload_and_modify_all_data_torch ()
433445 if validation_data [model_key ] is not None :
434446 validation_data [model_key ].preload_and_modify_all_data_torch ()
447+ _mt_user_observed = (
448+ model_params ["model_dict" ][model_key ]
449+ .get ("info" , {})
450+ .get ("observed_type" )
451+ )
435452 self .get_sample_func [model_key ] = single_model_stat (
436453 self .model [model_key ],
437454 model_params ["model_dict" ][model_key ].get ("data_stat_nbatch" , 10 ),
@@ -442,7 +459,22 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
442459 ].get_has_new_type ()
443460 if self .finetune_links is not None
444461 else False ,
462+ preset_observed_type = _mt_user_observed ,
445463 )
464+ # Persist observed_type into model_params and model_def_script
465+ if not resuming and self .rank == 0 :
466+ observed = getattr (
467+ self .model [model_key ].atomic_model ,
468+ "_observed_type" ,
469+ None ,
470+ )
471+ if observed is not None :
472+ model_params ["model_dict" ][model_key ].setdefault (
473+ "info" , {}
474+ )["observed_type" ] = observed
475+ self .model [model_key ].model_def_script = json .dumps (
476+ model_params ["model_dict" ][model_key ]
477+ )
446478
447479 (
448480 self .training_dataloader [model_key ],
0 commit comments