44from pathlib import Path
55from typing import Dict , List , Optional
66
7- import yaml
87from model_lab import RuntimeEnum
98from sanitize .constants import ArchitectureEnum , EPNames , IconEnum , ModelStatusEnum
10- from sanitize .copy_config import CopyConfig
9+ from sanitize .copy_config import Copy , CopyConfig
1110from sanitize .generator_amd import generator_amd
1211from sanitize .generator_dml import generator_dml
1312from sanitize .generator_intel import generator_intel
1413from sanitize .generator_qnn import generator_qnn
1514from sanitize .generator_trtrtx import generator_trtrtx
1615from sanitize .model_info import ModelInfo , ModelList
1716from sanitize .project_config import ModelInfoProject , ModelProjectConfig , WorkflowItem
18- from sanitize .utils import GlobalVars , isLLM_by_id , open_ex
17+ from sanitize .utils import (
18+ GlobalVars ,
19+ WINML_COPY_EXEMPT_IDS ,
20+ isLLM_by_id ,
21+ iter_aitk_info_yml ,
22+ open_ex ,
23+ winml_copy_src_for ,
24+ )
1925
2026def fetch_pipeline_tags (model_link : str ) -> Optional [List [str ]]:
2127 """Fetch pipeline_tag from HuggingFace API for a given model link.
@@ -209,22 +215,9 @@ def project_processor():
209215
210216 all_ids = set ()
211217 all_summary = AllModelSummary ()
212- for yml_file in root_dir . rglob ( "info.yml" ):
218+ for yml_file , yaml_object in iter_aitk_info_yml ( root_dir ):
213219 # if "DEBUG_ID" in str(yml_file):
214220 # pass
215- # read yml file as yaml object
216- with yml_file .open ("r" , encoding = "utf-8" ) as file :
217- try :
218- yaml_content = file .read ()
219- yaml_object = yaml .safe_load (yaml_content )
220- except yaml .YAMLError as e :
221- print (f"Error reading { yml_file } : { e } " )
222- continue
223- aitk = yaml_object .get ("aitk" , [])
224- if not aitk :
225- if yml_file .parent .name == "aitk" :
226- raise KeyError (f"aitk not found in { yml_file } " )
227- continue
228221 print (f"Process aitk for { yml_file } " )
229222 # model info
230223 modelInfo = convert_yaml_to_model_info (root_dir , yml_file , yaml_object )
@@ -236,10 +229,24 @@ def project_processor():
236229 raise KeyError (f"same id found in { yml_file } " )
237230 all_ids .add (modelInfo .id .lower ())
238231 modelList .models .append (modelInfo )
239- # copy pre
232+ # copy pre — auto-ensure winml.py copy entry (unless exempt), then run pre-phase copies
240233 copyConfigFile = yml_file .parent / "_copy.json.config"
241- if copyConfigFile .exists ():
242- copyConfig = CopyConfig .Read (copyConfigFile .as_posix ())
234+ copyConfig : CopyConfig | None = (
235+ CopyConfig .Read (copyConfigFile .as_posix ()) if copyConfigFile .exists () else None
236+ )
237+ if modelInfo .id not in WINML_COPY_EXEMPT_IDS :
238+ desired_src = winml_copy_src_for (modelInfo .id )
239+ if copyConfig is None :
240+ copyConfig = CopyConfig ()
241+ copyConfig ._file = str (copyConfigFile )
242+ copyConfig ._fileContent = None
243+ existing = next ((c for c in copyConfig .copies if c .dst == "winml.py" ), None )
244+ if existing is None :
245+ copyConfig .copies .append (Copy (src = desired_src , dst = "winml.py" ))
246+ elif existing .src != desired_src :
247+ existing .src = desired_src
248+ GlobalVars .winmlCopyCheck += 1
249+ if copyConfig is not None :
243250 copyConfig .process (yml_file .parent .as_posix (), pre = True )
244251 copyConfig .writeIfChanged ()
245252 # model summary
@@ -258,4 +265,4 @@ def project_processor():
258265
259266
260267if __name__ == "__main__" :
261- project_processor ()
268+ project_processor ()
0 commit comments