44# --------------------------------------------------------------------------
55import functools
66import json
7+ import re
78from copy import deepcopy
89from os import PathLike
910from pathlib import Path , PurePosixPath , PureWindowsPath
1415from olive .package_config import OlivePackageConfig
1516from olive .systems .common import SystemType
1617from olive .telemetry .telemetry import is_ci_environment
17- from olive .workflows .run .config import RunConfig
1818
1919if TYPE_CHECKING :
20- from olive .engine . config import RunPassConfig
20+ from olive .workflows . run . config import RunConfig
2121
2222RECIPE_HASH_REDACTED_VALUE = "<resource>"
2323CONFIG_REFERENCE_REDACTED_VALUE = "<reference>"
4545 "adapter_path" ,
4646 "user_script" ,
4747}
48+ HF_MODEL_IDENTIFIER_KEYS = {"model_path" , "_name_or_path" }
4849CONFIG_REFERENCE_KEYS = {"host" , "target" , "evaluator" }
50+ LOCAL_MODEL_FILE_SUFFIXES = {".bin" , ".model" , ".onnx" , ".pb" , ".pt" , ".pth" , ".safetensors" , ".tflite" }
51+ HF_CACHE_MODEL_PATTERN = re .compile (r"(?:^|[\\/])models--([^\\/]+)--([^\\/]+)(?:[\\/]|$)" )
52+ HF_REPO_ID_PATTERN = re .compile (r"^[A-Za-z0-9][A-Za-z0-9._-]*(/[A-Za-z0-9][A-Za-z0-9._-]*)?$" )
4953_NO_OVERRIDE = object ()
5054
5155
5256def _build_recipe_result_metadata (
5357 run_config_input : Union [str , Path , dict ],
5458 run_config_telemetry_input : Optional [Any ],
55- run_config : Optional [RunConfig ],
59+ run_config : Optional [" RunConfig" ],
5660 recipe_telemetry_metadata : Optional [dict [str , Any ]],
5761 * ,
5862 list_required_packages : bool ,
@@ -65,9 +69,17 @@ def _build_recipe_result_metadata(
6569 metadata .setdefault ("recipe_format" , default_format )
6670 metadata .setdefault ("execution_mode" , "list_required_packages" if list_required_packages else "run" )
6771 metadata .setdefault ("package_config_provided" , package_config_provided )
68- metadata .setdefault ("config_overrides" , _build_config_overrides (run_config_telemetry_input ))
72+ config_overrides = metadata .pop ("config_overrides" , _NO_OVERRIDE )
73+ if config_overrides is _NO_OVERRIDE :
74+ config_overrides = _build_config_overrides (run_config_telemetry_input )
75+ elif not isinstance (config_overrides , str ):
76+ config_overrides = _build_config_overrides (config_overrides )
77+ if config_overrides is not None :
78+ metadata ["config_overrides" ] = config_overrides
6979 if package_config_provided :
70- metadata .setdefault ("package_config_overrides" , _build_package_config_overrides (package_config_input ))
80+ package_config_overrides = _build_package_config_overrides (package_config_input )
81+ if package_config_overrides is not None :
82+ metadata .setdefault ("package_config_overrides" , package_config_overrides )
7183 metadata ["is_ci" ] = is_ci_environment ()
7284
7385 if run_config is None :
@@ -78,7 +90,7 @@ def _build_recipe_result_metadata(
7890 model_metadata = _extract_input_model_metadata (run_config_json ["input_model" ])
7991 target_metadata = _extract_target_metadata (run_config )
8092 host_metadata = _extract_host_metadata (run_config )
81- pass_types = [ pass_config . type for pass_config in _get_used_passes_configs (run_config )]
93+ pass_types = _get_used_pass_types (run_config )
8294
8395 metadata .setdefault ("recipe_name" , metadata .get ("recipe_command" ) or run_config .workflow_id )
8496 metadata .setdefault ("workflow_id" , run_config .workflow_id )
@@ -208,34 +220,44 @@ def _load_config_input_for_telemetry(config_input: Any) -> Optional[Any]:
208220 return None
209221
210222
211- def _sanitize_config_snapshot (value : Any , key : Optional [str ] = None ) -> Any :
223+ def _sanitize_config_snapshot (value : Any , key : Optional [str ] = None , model_type : Optional [str ] = None ) -> Any :
224+ if key in HF_MODEL_IDENTIFIER_KEYS :
225+ if str (model_type ).lower () == "hfmodel" :
226+ hf_model_id = _extract_huggingface_model_id (value )
227+ if hf_model_id :
228+ return hf_model_id
229+ return RECIPE_HASH_REDACTED_VALUE
212230 if key in CONFIG_SNAPSHOT_REDACTED_KEYS or _is_path_like_key (key ):
213231 return RECIPE_HASH_REDACTED_VALUE
214232 if key in CONFIG_REFERENCE_KEYS and isinstance (value , str ):
215233 return CONFIG_REFERENCE_REDACTED_VALUE
216234
217235 if isinstance (value , dict ):
236+ child_model_type = _get_model_type (value ) or model_type
218237 if key == "systems" :
219- return [_sanitize_config_snapshot (system , "system" ) for system in value .values ()]
238+ return [_sanitize_config_snapshot (system , "system" , child_model_type ) for system in value .values ()]
220239 if key == "passes" :
221240 passes = []
222241 for pass_configs in value .values ():
223242 if isinstance (pass_configs , list ):
224243 passes .extend (pass_configs )
225244 else :
226245 passes .append (pass_configs )
227- return [_sanitize_config_snapshot (pass_config , "pass" ) for pass_config in passes ]
246+ return [_sanitize_config_snapshot (pass_config , "pass" , child_model_type ) for pass_config in passes ]
228247 if key == "evaluators" :
229- return [_sanitize_config_snapshot (evaluator , "evaluator_config" ) for evaluator in value .values ()]
248+ return [
249+ _sanitize_config_snapshot (evaluator , "evaluator_config" , child_model_type )
250+ for evaluator in value .values ()
251+ ]
230252 return {
231- child_key : _sanitize_config_snapshot (child_value , child_key )
253+ child_key : _sanitize_config_snapshot (child_value , child_key , child_model_type )
232254 for child_key , child_value in value .items ()
233255 if child_value is not None
234256 }
235257 if isinstance (value , list ):
236- return [_sanitize_config_snapshot (item , key ) for item in value ]
258+ return [_sanitize_config_snapshot (item , key , model_type ) for item in value ]
237259 if isinstance (value , tuple ):
238- return [_sanitize_config_snapshot (item , key ) for item in value ]
260+ return [_sanitize_config_snapshot (item , key , model_type ) for item in value ]
239261 if isinstance (value , Path ):
240262 return RECIPE_HASH_REDACTED_VALUE
241263 if callable (value ):
@@ -255,6 +277,35 @@ def _is_path_like_key(key: Optional[str]) -> bool:
255277 )
256278
257279
280+ def _get_model_type (config : dict [str , Any ]) -> Optional [str ]:
281+ model_type = config .get ("type" )
282+ return str (model_type ).lower () if model_type is not None else None
283+
284+
285+ def _extract_huggingface_model_id (model_identifier : Any ) -> Optional [str ]:
286+ if not isinstance (model_identifier , str ):
287+ return None
288+
289+ identifier = model_identifier .strip ()
290+ if not identifier :
291+ return None
292+
293+ if identifier .startswith ("https://huggingface.co/" ):
294+ parts = identifier .removeprefix ("https://huggingface.co/" ).strip ("/" ).split ("/" )
295+ if len (parts ) >= 2 :
296+ return f"{ parts [0 ]} /{ parts [1 ]} "
297+ if parts and parts [0 ]:
298+ return parts [0 ]
299+
300+ if match := HF_CACHE_MODEL_PATTERN .search (identifier ):
301+ return f"{ match .group (1 )} /{ match .group (2 )} "
302+
303+ if HF_REPO_ID_PATTERN .match (identifier ) and not _has_local_model_file_suffix (identifier ):
304+ return identifier
305+
306+ return None
307+
308+
258309def _extract_input_model_metadata (input_model_config : dict [str , Any ]) -> dict [str , Optional [str ]]:
259310 model_config = input_model_config .get ("config" , {})
260311 model_attributes = model_config .get ("model_attributes" , {})
@@ -290,19 +341,26 @@ def _classify_input_model_source(model_identifier: Any) -> str:
290341
291342
292343def _is_explicit_local_model_path (identifier : str ) -> bool :
344+ if _has_local_model_file_suffix (identifier ):
345+ return True
293346 return (
294347 identifier .startswith (("./" , "../" , ".\\ " , "..\\ " , "~/" , "~\\ " , "/" , "\\ \\ " ))
295348 or PureWindowsPath (identifier ).is_absolute ()
296349 or PurePosixPath (identifier ).is_absolute ()
297350 )
298351
299352
300- def _extract_target_metadata (run_config : RunConfig ) -> dict [str , Optional [str ]]:
353+ def _has_local_model_file_suffix (identifier : str ) -> bool :
354+ suffix = PureWindowsPath (identifier ).suffix or PurePosixPath (identifier ).suffix
355+ return suffix .lower () in LOCAL_MODEL_FILE_SUFFIXES
356+
357+
358+ def _extract_target_metadata (run_config : "RunConfig" ) -> dict [str , Optional [str ]]:
301359 target_system = run_config .engine .target
302360 return _extract_system_metadata (target_system , "target" )
303361
304362
305- def _extract_host_metadata (run_config : RunConfig ) -> dict [str , Optional [str ]]:
363+ def _extract_host_metadata (run_config : " RunConfig" ) -> dict [str , Optional [str ]]:
306364 host_system = run_config .engine .host
307365 if host_system is None :
308366 return {
@@ -340,9 +398,9 @@ def _set_metadata_if_present(metadata: dict[str, Any], values: dict[str, Optiona
340398 metadata .setdefault (key , value )
341399
342400
343- def _get_used_passes_configs (run_config : RunConfig ) -> list ["RunPassConfig" ]:
401+ def _get_used_pass_types (run_config : " RunConfig" ) -> list [str ]:
344402 return (
345- [pass_config for _ , pass_configs in run_config .passes .items () for pass_config in pass_configs ]
403+ [pass_config . type for _ , pass_configs in run_config .passes .items () for pass_config in pass_configs ]
346404 if run_config .passes
347405 else []
348406 )
@@ -363,4 +421,12 @@ def _redact_recipe_hash_keys(value: Any, key: Optional[str] = None) -> Any:
363421 elif isinstance (value , list ):
364422 for index , item in enumerate (value ):
365423 value [index ] = _redact_recipe_hash_keys (item , key )
424+ elif isinstance (value , tuple ):
425+ return [_redact_recipe_hash_keys (item , key ) for item in value ]
426+ elif isinstance (value , Path ):
427+ return RECIPE_HASH_REDACTED_VALUE
428+ elif callable (value ):
429+ return CONFIG_CALLABLE_REDACTED_VALUE
430+ elif hasattr (value , "value" ) and isinstance (value .value , (str , int , float , bool )):
431+ return value .value
366432 return value
0 commit comments