66
77"""
88
9+ import sys
910import os
1011import pandas
11- from transformers import AutoModelForCausalLM , AutoTokenizer
12+ import torch
13+ from transformers import AutoModelForCausalLM , AutoTokenizer , AutoConfig
1214from onnx_diagnostic import doc
1315from onnx_diagnostic .investigate .input_observer import InputObserver
1416from onnx_diagnostic .helpers .rt_helper import onnx_generate
@@ -28,10 +30,11 @@ def generate_text(
2830 top_k = 50 ,
2931 top_p = 0.95 ,
3032 do_sample = True ,
33+ device = "cpu" ,
3134):
3235 inputs = tokenizer (prompt , return_tensors = "pt" )
33- input_ids = inputs ["input_ids" ]
34- attention_mask = inputs ["attention_mask" ]
36+ input_ids = inputs ["input_ids" ]. to ( device )
37+ attention_mask = inputs ["attention_mask" ]. to ( device )
3538
3639 outputs = model .generate (
3740 input_ids = input_ids ,
@@ -47,58 +50,71 @@ def generate_text(
4750 return generated_text
4851
4952
53+ # %%
54+ # filename for the model
55+ MODEL_NAME = sys .argv [1 ] if sys .argv and len (sys .argv ) > 1 else "arnir0/Tiny-LLM"
56+ cache_dir = "dump_modelbuilder"
57+ os .makedirs (cache_dir , exist_ok = True )
58+ name = MODEL_NAME .replace ("/" , "_" )
59+ filename = os .path .join (cache_dir , f"plot_export_with_modelbuilder_{ name } .onnx" )
60+
61+
5062# %%
5163# Creating the model
52- print ("-- creating..." )
53- MODEL_NAME = "arnir0/Tiny-LLM"
64+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
5465tokenizer = AutoTokenizer .from_pretrained (MODEL_NAME )
55- model = AutoModelForCausalLM .from_pretrained (MODEL_NAME )
66+ if not os .path .exists (filename ):
67+ print (f"-- creating... on { device } into { filename !r} " )
68+ model = AutoModelForCausalLM .from_pretrained (MODEL_NAME , torch_dtype = torch .bfloat16 )
69+ model = model .to (device )
70+ config = model .config
71+ else :
72+ config = AutoConfig .from_pretrained (MODEL_NAME )
5673
5774
5875# %%
5976# Capturing inputs/outputs to infer dynamic shapes and arguments
6077print ("-- capturing..." )
6178prompt = "Continue: it rains, what should I do?"
62- observer = InputObserver ()
63- with register_additional_serialization_functions (patch_transformers = True ), observer (model ):
64- generate_text (prompt , model , tokenizer )
79+ if not os .path .exists (filename ):
80+ observer = InputObserver ()
81+ with register_additional_serialization_functions (patch_transformers = True ), observer (model ):
82+ generate_text (prompt , model , tokenizer , device = device )
6583
6684
6785# %%
6886# Exporting.
69- print ("-- exporting..." )
70- observer .remove_inputs (["cache_position" , "logits_to_keep" , "position_ids" ])
71- ds = observer .infer_dynamic_shapes (set_batch_dimension_for = True )
72- kwargs = observer .infer_arguments ()
73-
74- cache_dir = "dump_modelbuilder"
75- os .makedirs (cache_dir , exist_ok = True )
76- filename = os .path .join (cache_dir , "plot_export_with_modelbuilder.onnx" )
77- with torch_export_patches (patch_transformers = True ):
78- to_onnx (
79- model ,
80- filename = filename ,
81- kwargs = kwargs ,
82- dynamic_shapes = ds ,
83- exporter = "modelbuilder" ,
84- )
85-
86- data = observer .check_discrepancies (filename , progress_bar = True )
87- print (pandas .DataFrame (data ))
87+ if not os .path .exists (filename ):
88+ print ("-- exporting..." )
89+ observer .remove_inputs (["cache_position" , "logits_to_keep" , "position_ids" ])
90+ ds = observer .infer_dynamic_shapes (set_batch_dimension_for = True )
91+ kwargs = observer .infer_arguments ()
92+
93+ with torch_export_patches (patch_transformers = True ):
94+ to_onnx (
95+ model ,
96+ filename = filename ,
97+ kwargs = kwargs ,
98+ dynamic_shapes = ds ,
99+ exporter = "modelbuilder" ,
100+ )
101+
102+ data = observer .check_discrepancies (filename , progress_bar = True )
103+ print (pandas .DataFrame (data ))
88104
89105# %%
90106# ONNX Prompt
91107# +++++++++++
92108print ("-- ONNX prompts..." )
93109inputs = tokenizer (prompt , return_tensors = "pt" )
94- input_ids = inputs ["input_ids" ]
95- attention_mask = inputs ["attention_mask" ]
110+ input_ids = inputs ["input_ids" ]. to ( device )
111+ attention_mask = inputs ["attention_mask" ]. to ( device )
96112
97113onnx_tokens = onnx_generate (
98114 filename ,
99115 input_ids = input_ids ,
100116 attention_mask = attention_mask ,
101- eos_token_id = model . config .eos_token_id ,
117+ eos_token_id = config .eos_token_id ,
102118 max_new_tokens = 50 ,
103119)
104120onnx_generated_text = tokenizer .decode (onnx_tokens , skip_special_tokens = True )
@@ -108,4 +124,5 @@ def generate_text(
108124print ("-----------------" )
109125
110126# %%
111- doc .save_fig (doc .plot_dot (filename ), f"{ filename } .png" , dpi = 400 )
127+ if os .stat (filename ).st_size < 2 ** 14 :
128+ doc .save_fig (doc .plot_dot (filename ), f"{ filename } .png" , dpi = 400 )
0 commit comments