Skip to content

Commit 8eb70ec

Browse files
committed
add example to export experts part
1 parent 3e7d212 commit 8eb70ec

1 file changed

Lines changed: 133 additions & 0 deletions

File tree

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
"""
2+
.. _l-plot-optimind-export-input-observer:
3+
4+
Export OptiMind-SFT with InputObserver
5+
======================================
6+
7+
This reuses the recipe introduced by example :ref:`l-plot-tiny-llm-export-input-observer`
8+
for model `microsoft/OptiMind-SFT <https://huggingface.co/microsoft/OptiMind-SFT>`_.
9+
We only export class ``GptOssExperts``.
10+
11+
Let's create a random model
12+
+++++++++++++++++++++++++++
13+
"""
14+
15+
import pandas
16+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
17+
from onnx_diagnostic import doc
18+
from onnx_diagnostic.export.api import to_onnx
19+
from onnx_diagnostic.helpers import string_type
20+
from onnx_diagnostic.torch_export_patches import (
21+
register_additional_serialization_functions,
22+
torch_export_patches,
23+
)
24+
from onnx_diagnostic.investigate.input_observer import InputObserver
25+
26+
device = "cuda"
27+
model_id = "microsoft/OptiMind-SFT"
28+
print(f"get tokenizer {model_id!r}")
29+
tokenizer = AutoTokenizer.from_pretrained(model_id)
30+
print(f"get config {model_id!r}")
31+
config = AutoConfig.from_pretrained(model_id)
32+
config.num_hidden_layers = 2
33+
config.layer_types = config.layer_types[:2]
34+
print(f"create model from config for {model_id!r}")
35+
model = AutoModelForCausalLM.from_config(config)
36+
print(f"the model is created with {len(list(model.named_modules()))} subdmodules.")
37+
model = model.to(device)
38+
39+
# %%
40+
# We need to only export class GptOssExperts
41+
# ++++++++++++++++++++++++++++++++++++++++++
42+
43+
44+
def generate_text(
45+
prompt,
46+
model,
47+
tokenizer,
48+
max_length=50,
49+
temperature=0.01,
50+
top_k=50,
51+
top_p=0.95,
52+
do_sample=True,
53+
):
54+
inputs = tokenizer(prompt, return_tensors="pt")
55+
input_ids = inputs["input_ids"].to(device)
56+
attention_mask = inputs["attention_mask"].to(device)
57+
58+
outputs = model.generate(
59+
input_ids=input_ids,
60+
attention_mask=attention_mask,
61+
max_length=max_length,
62+
temperature=temperature,
63+
top_k=top_k,
64+
top_p=top_p,
65+
do_sample=do_sample,
66+
)
67+
68+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
69+
return generated_text
70+
71+
72+
export_module = None
73+
for _name, sub in model.named_modules():
74+
if sub.__class__.__name__ == "GptOssExperts":
75+
export_module = sub
76+
77+
assert export_module is not None, (
78+
f"Unable to find a submodule from class GptOssExperts in "
79+
f"{set(sub.__class__.__name__ for _, sub in model.named_modules())}"
80+
)
81+
82+
# Define your prompt
83+
prompt = "Continue: it rains, what should I do?"
84+
observer = InputObserver()
85+
with (
86+
register_additional_serialization_functions(patch_transformers=True),
87+
observer(export_module),
88+
):
89+
generate_text(prompt, model, tokenizer)
90+
91+
92+
# %%
93+
# Export
94+
# ++++++
95+
#
96+
# First, what was inferred.
97+
98+
args = observer.infer_arguments()
99+
dynamic_shapes = observer.infer_dynamic_shapes()
100+
print(f"kwargs={string_type(args, with_shape=True)}")
101+
print(f"dynamic_shapes={dynamic_shapes}")
102+
103+
# %%
104+
# Next, the export.
105+
106+
107+
filename = "plot_export_optimind_experts_input_observer.onnx"
108+
with torch_export_patches(patch_transformers=True):
109+
to_onnx(
110+
export_module,
111+
args=args,
112+
filename=filename,
113+
dynamic_shapes=dynamic_shapes,
114+
exporter="custom",
115+
verbose=1,
116+
)
117+
118+
# %%
119+
# Let's measure the discrepancies.
120+
data = observer.check_discrepancies(filename, progress_bar=True, atol=1e-2, include_io=True)
121+
df = pandas.DataFrame(data)
122+
df.to_excel("plot_export_optimind_input_observer.xlsx")
123+
print(df)
124+
125+
# %%
126+
# Let's show the errors.
127+
for row in data:
128+
if not row["SUCCESS"] and "error" in row:
129+
print(row["error"])
130+
131+
132+
# %%
133+
doc.save_fig(doc.plot_dot(filename), f"{filename}.png", dpi=400)

0 commit comments

Comments
 (0)