@@ -47,7 +47,7 @@ def _get_max_input_seq_len(program) -> int:
4747 return sizes [1 ] if len (sizes ) >= 2 else 1
4848
4949
50- def _load_text_processor (model_id : str ):
50+ def _load_text_processor (model_id : str , revision : str | None ):
5151 """
5252 Load a text processor for the model.
5353
@@ -58,13 +58,13 @@ def _load_text_processor(model_id: str):
5858 """
5959 logger .info (f"Loading tokenizer from HuggingFace: { model_id } ..." )
6060 try :
61- tokenizer = AutoTokenizer .from_pretrained (model_id )
61+ tokenizer = AutoTokenizer .from_pretrained (model_id , revision = revision )
6262 return tokenizer , False
6363 except Exception as exc :
6464 logger .info (f"AutoTokenizer unavailable for { model_id } : { exc } " )
6565
6666 try :
67- processor = AutoProcessor .from_pretrained (model_id )
67+ processor = AutoProcessor .from_pretrained (model_id , revision = revision )
6868 if hasattr (processor , "apply_chat_template" ) and hasattr (processor , "decode" ):
6969 logger .info (f"Loaded processor from HuggingFace: { model_id } " )
7070 return processor , True
@@ -101,11 +101,12 @@ def _get_eos_token_id(text_processor):
101101def run_inference (
102102 pte_path : str ,
103103 model_id : str ,
104+ revision : str | None ,
104105 prompt : str ,
105106 max_new_tokens : int = 50 ,
106107) -> str :
107108 """Run inference on the exported HuggingFace model."""
108- text_processor , uses_processor = _load_text_processor (model_id )
109+ text_processor , uses_processor = _load_text_processor (model_id , revision )
109110
110111 logger .info (f"Loading model from { pte_path } ..." )
111112 et_runtime = Runtime .get ()
@@ -208,6 +209,12 @@ def main():
208209 default = "unsloth/Llama-3.2-1B-Instruct" ,
209210 help = "HuggingFace model ID (used to load tokenizer or processor)" ,
210211 )
212+ parser .add_argument (
213+ "--revision" ,
214+ type = str ,
215+ default = None ,
216+ help = "Optional HuggingFace model revision/commit to pin" ,
217+ )
211218 parser .add_argument (
212219 "--prompt" ,
213220 type = str ,
@@ -226,6 +233,7 @@ def main():
226233 generated_text = run_inference (
227234 pte_path = args .pte ,
228235 model_id = args .model_id ,
236+ revision = args .revision ,
229237 prompt = args .prompt ,
230238 max_new_tokens = args .max_new_tokens ,
231239 )
0 commit comments