Skip to content

Commit 9ea2e5f

Browse files
committed
a few fixes
1 parent f3be6c8 commit 9ea2e5f

3 files changed

Lines changed: 50 additions & 39 deletions

File tree

_doc/examples/plot_export_with_modelbuilder.py

Lines changed: 49 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
77
"""
88

9+
import sys
910
import os
1011
import pandas
11-
from transformers import AutoModelForCausalLM, AutoTokenizer
12+
import torch
13+
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
1214
from onnx_diagnostic import doc
1315
from onnx_diagnostic.investigate.input_observer import InputObserver
1416
from onnx_diagnostic.helpers.rt_helper import onnx_generate
@@ -28,10 +30,11 @@ def generate_text(
2830
top_k=50,
2931
top_p=0.95,
3032
do_sample=True,
33+
device="cpu",
3134
):
3235
inputs = tokenizer(prompt, return_tensors="pt")
33-
input_ids = inputs["input_ids"]
34-
attention_mask = inputs["attention_mask"]
36+
input_ids = inputs["input_ids"].to(device)
37+
attention_mask = inputs["attention_mask"].to(device)
3538

3639
outputs = model.generate(
3740
input_ids=input_ids,
@@ -47,58 +50,71 @@ def generate_text(
4750
return generated_text
4851

4952

53+
# %%
54+
# filename for the model
55+
MODEL_NAME = sys.argv[1] if sys.argv and len(sys.argv) > 1 else "arnir0/Tiny-LLM"
56+
cache_dir = "dump_modelbuilder"
57+
os.makedirs(cache_dir, exist_ok=True)
58+
name = MODEL_NAME.replace("/", "_")
59+
filename = os.path.join(cache_dir, f"plot_export_with_modelbuilder_{name}.onnx")
60+
61+
5062
# %%
5163
# Creating the model
52-
print("-- creating...")
53-
MODEL_NAME = "arnir0/Tiny-LLM"
64+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5465
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
55-
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
66+
if not os.path.exists(filename):
67+
print(f"-- creating... on {device} into {filename!r}")
68+
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16)
69+
model = model.to(device)
70+
config = model.config
71+
else:
72+
config = AutoConfig.from_pretrained(MODEL_NAME)
5673

5774

5875
# %%
5976
# Capturing inputs/outputs to infer dynamic shapes and arguments
6077
print("-- capturing...")
6178
prompt = "Continue: it rains, what should I do?"
62-
observer = InputObserver()
63-
with register_additional_serialization_functions(patch_transformers=True), observer(model):
64-
generate_text(prompt, model, tokenizer)
79+
if not os.path.exists(filename):
80+
observer = InputObserver()
81+
with register_additional_serialization_functions(patch_transformers=True), observer(model):
82+
generate_text(prompt, model, tokenizer, device=device)
6583

6684

6785
# %%
6886
# Exporting.
69-
print("-- exporting...")
70-
observer.remove_inputs(["cache_position", "logits_to_keep", "position_ids"])
71-
ds = observer.infer_dynamic_shapes(set_batch_dimension_for=True)
72-
kwargs = observer.infer_arguments()
73-
74-
cache_dir = "dump_modelbuilder"
75-
os.makedirs(cache_dir, exist_ok=True)
76-
filename = os.path.join(cache_dir, "plot_export_with_modelbuilder.onnx")
77-
with torch_export_patches(patch_transformers=True):
78-
to_onnx(
79-
model,
80-
filename=filename,
81-
kwargs=kwargs,
82-
dynamic_shapes=ds,
83-
exporter="modelbuilder",
84-
)
85-
86-
data = observer.check_discrepancies(filename, progress_bar=True)
87-
print(pandas.DataFrame(data))
87+
if not os.path.exists(filename):
88+
print("-- exporting...")
89+
observer.remove_inputs(["cache_position", "logits_to_keep", "position_ids"])
90+
ds = observer.infer_dynamic_shapes(set_batch_dimension_for=True)
91+
kwargs = observer.infer_arguments()
92+
93+
with torch_export_patches(patch_transformers=True):
94+
to_onnx(
95+
model,
96+
filename=filename,
97+
kwargs=kwargs,
98+
dynamic_shapes=ds,
99+
exporter="modelbuilder",
100+
)
101+
102+
data = observer.check_discrepancies(filename, progress_bar=True)
103+
print(pandas.DataFrame(data))
88104

89105
# %%
90106
# ONNX Prompt
91107
# +++++++++++
92108
print("-- ONNX prompts...")
93109
inputs = tokenizer(prompt, return_tensors="pt")
94-
input_ids = inputs["input_ids"]
95-
attention_mask = inputs["attention_mask"]
110+
input_ids = inputs["input_ids"].to(device)
111+
attention_mask = inputs["attention_mask"].to(device)
96112

97113
onnx_tokens = onnx_generate(
98114
filename,
99115
input_ids=input_ids,
100116
attention_mask=attention_mask,
101-
eos_token_id=model.config.eos_token_id,
117+
eos_token_id=config.eos_token_id,
102118
max_new_tokens=50,
103119
)
104120
onnx_generated_text = tokenizer.decode(onnx_tokens, skip_special_tokens=True)
@@ -108,4 +124,5 @@ def generate_text(
108124
print("-----------------")
109125

110126
# %%
111-
doc.save_fig(doc.plot_dot(filename), f"{filename}.png", dpi=400)
127+
if os.stat(filename).st_size < 2**14:
128+
doc.save_fig(doc.plot_dot(filename), f"{filename}.png", dpi=400)

_unittests/ut_investigate/test_input_observer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1196,7 +1196,6 @@ def forward(self, a, *args, **kwargs):
11961196
)
11971197
torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=ds)
11981198

1199-
12001199
def test_remove_inputs_kwargs(self):
12011200
"""Test that remove_inputs removes a kwarg from the observer info."""
12021201

onnx_diagnostic/export/api.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -320,12 +320,7 @@ def to_onnx(
320320
temp_filename = os.path.join(os.path.dirname(filename), "model.onnx")
321321
# renaming
322322
onx = onnx.load(temp_filename, load_external_data=True)
323-
onnx.save(
324-
onx,
325-
filename,
326-
save_as_external_data=True,
327-
location=f"{os.path.splitext(filename[0])}.data",
328-
)
323+
onnx.save(onx, filename, save_as_external_data=True)
329324
return onx
330325

331326
raise ValueError(f"Unknown exporter={exporter!r}")

0 commit comments

Comments
 (0)