Skip to content

Commit 460db5c

Browse files
committed
doc
1 parent 8b2303d commit 460db5c

7 files changed

Lines changed: 31 additions & 8 deletions

File tree

_doc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def linkcode_resolve(domain, info):
141141
("py:class", "torch.fx.proxy.TracerBase"),
142142
("py:class", "torch.FloatTensor"),
143143
("py:class", "torch.LongTensor"),
144+
("py:class", "torch.export._trace.ExportArtifact"),
144145
("py:class", "torch.utils._pytree.Context"),
145146
("py:class", "torch.utils._pytree.KeyEntry"),
146147
("py:class", "torch.utils._pytree.TreeSpec"),

_doc/final/plot_export_tiny_llm_attention_input_observer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,9 @@ def generate_text(
122122

123123
# %%
124124
# Let's measure the discrepancies.
125-
data = observer.check_discrepancies(filename, progress_bar=True, atol=1e-2, include_io=True)
125+
data = observer.check_discrepancies(
126+
filename, progress_bar=True, atol=1e-2, include_io=True, skip_none=True
127+
)
126128
df = pandas.DataFrame(data)
127129
df.to_excel("plot_export_tiny_llm_attention_input_observer.xlsx")
128130
print(df)

_doc/technical/plot_histc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def tune_threshold_histc(
193193
)
194194

195195
# %%
196-
# This does not add up. Let's proove now :func:`torch.histc` is really confusing.
196+
# This does not add up. Let's prove now :func:`torch.histc` is really confusing.
197197
# The following sum should be null but it is not.
198198

199199
diff = torch.histc(tensor, hbins, hmin, hmax) - (

onnx_diagnostic/helpers/dot_helper.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@ def _mkn(obj: object) -> int:
203203
if att.type == onnx.AttributeProto.GRAPH:
204204
unique |= get_hidden_inputs(att.g)
205205
for i in unique:
206+
if i in tiny_inits:
207+
continue
206208
edge = name_to_ids[i], _mkn(node) # type: ignore[assignment]
207209
if edge in done:
208210
continue

onnx_diagnostic/helpers/helper.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,6 +1071,7 @@ def max_diff(
10711071
_index: int = 0,
10721072
allow_unique_tensor_with_list_of_one_element: bool = True,
10731073
hist: Optional[Union[bool, List[float]]] = None,
1074+
skip_none: bool = False,
10741075
) -> Dict[str, Union[float, int, Tuple[Any, ...]]]:
10751076
"""
10761077
Returns the maximum discrepancy.
@@ -1087,6 +1088,7 @@ def max_diff(
10871088
:param allow_unique_tensor_with_list_of_one_element:
10881089
allow a comparison between a single tensor and a list of one tensor
10891090
:param hist: compute an histogram of the discrepancies
1091+
:param skip_none: skips none value
10901092
:return: dictionary with many values
10911093
10921094
* abs: max absolute error
@@ -1112,6 +1114,7 @@ def max_diff(
11121114
end=end,
11131115
_index=_index,
11141116
hist=hist,
1117+
skip_none=skip_none,
11151118
)
11161119
_dkws = {**_dkws_, "flatten": flatten}
11171120
_dkwsf = {**_dkws_, "flatten": False}
@@ -1129,6 +1132,7 @@ def max_diff(
11291132
debug_info=debug_info,
11301133
allow_unique_tensor_with_list_of_one_element=False,
11311134
hist=hist,
1135+
skip_none=skip_none,
11321136
)
11331137
return max_diff(
11341138
expected,
@@ -1142,6 +1146,7 @@ def max_diff(
11421146
_index=_index,
11431147
allow_unique_tensor_with_list_of_one_element=False,
11441148
hist=hist,
1149+
skip_none=skip_none,
11451150
)
11461151

11471152
if expected.__class__.__name__ == "CausalLMOutputWithPast":
@@ -1269,6 +1274,7 @@ def max_diff(
12691274
_index=_index + ip,
12701275
flatten=flatten,
12711276
hist=hist,
1277+
skip_none=skip_none,
12721278
)
12731279
am = max(am, d["abs"])
12741280
dn = max(dn, d["dnan"])
@@ -1793,6 +1799,9 @@ def max_diff(
17931799
**_dkws,
17941800
)
17951801

1802+
if skip_none and (expected is None or got is None):
1803+
return {"abs": 0, "rel": 0, "dnan": 0, "n": 0, "sum": 0}
1804+
17961805
raise AssertionError(
17971806
f"Not implemented with implemented with expected="
17981807
f"{string_type(expected)} ({type(expected)}), got={string_type(got)},\n"

onnx_diagnostic/investigate/input_observer.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -286,22 +286,28 @@ class InputObserverInfo:
286286
and the arguments to send to :func:`torch.export.export`.
287287
288288
Args:
289-
signature_names: Names of the arguments of the method
289+
signature_names:
290+
Names of the arguments of the method
290291
the collector tensors come from. They are used if it becomes
291292
necessary to move positional arguments to named ones.
292293
They are used a second time because :func:`torch.export.export`
293294
cares about the order in kwargs and dynamic shapes, it needs
294295
to be the same in the ordered dictionaries `add_inputs` receive.
295-
default_values: Default values defined by the signature of the function,
296+
default_values:
297+
Default values defined by the signature of the function,
296298
any value equal to that is ignore to simplify the export.
297-
missing: If a named argument (in kwargs) is missing,
299+
missing:
300+
If a named argument (in kwargs) is missing,
298301
a default value will be taken in this dictionary,
299302
this is used when after the prefill step, an argument
300303
disappears (such as `pixel_values`) and another one
301304
is added (such as `past_key_values`).
302305
The values are only to infer dynamic shapes and arguments,
303306
not to run the model.
304-
kwargs_name: Name of parameter **kwargs if it exists.
307+
kwargs_name:
308+
Name of parameter `**kwargs` if it exists.
309+
310+
This is used by class :class:`InputObserver`.
305311
"""
306312

307313
def __init__(
@@ -910,6 +916,7 @@ def check_discrepancies(
910916
hist=(0.1, 0.01),
911917
progress_bar: bool = False,
912918
include_io: bool = True,
919+
skip_none: bool = True,
913920
) -> list[dict[str, str | int | float | bool]]:
914921
"""Computes the discrepancies between the saved inputs and outputs
915922
with the saved onnx model.
@@ -929,6 +936,8 @@ def check_discrepancies(
929936
include_io:
930937
Shows inputs/outputs shapes in the summary
931938
returned by this function.
939+
skip_none:
940+
Dooes not check discrepancies when an output is None.
932941
933942
Returns:
934943
A list of dictionaries, ready to be consumed by a dataframe.
@@ -982,7 +991,7 @@ def check_discrepancies(
982991
if isinstance(outputs, list) and isinstance(ort_outputs, list):
983992
while len(ort_outputs) > len(outputs) and ort_outputs[-1].numel() == 0:
984993
ort_outputs.pop()
985-
diff = max_diff(outputs, ort_outputs, hist=lhist) # type: ignore[assignment]
994+
diff = max_diff(outputs, ort_outputs, hist=lhist, skip_none=skip_none) # type: ignore[assignment]
986995
if "rep" in diff and isinstance(diff["rep"], dict):
987996
diff.update(diff["rep"])
988997
del diff["rep"]

onnx_diagnostic/torch_export_patches/patches/patch_torch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def patched__get_range_constraints(
143143
):
144144
"""
145145
Patches ``torch.export._trace._get_range_constraints``.
146-
See PR `#174593 whttps://github.com/pytorch/pytorch/pull/174593>`_.
146+
See PR `#174593 <https://github.com/pytorch/pytorch/pull/174593>`_.
147147
"""
148148
gm: torch.fx.GraphModule = export_artifact.aten.gm
149149
export_graph_signature: torch.export.graph_signature.ExportGraphSignature = (

0 commit comments

Comments
 (0)