@@ -286,22 +286,28 @@ class InputObserverInfo:
286286 and the arguments to send to :func:`torch.export.export`.
287287
288288 Args:
289- signature_names: Names of the arguments of the method
289+ signature_names:
290+ Names of the arguments of the method
290291 the collector tensors come from. They are used if it becomes
291292 necessary to move positional arguments to named ones.
292293 They are used a second time because :func:`torch.export.export`
293294 cares about the order in kwargs and dynamic shapes, it needs
294295 to be the same in the ordered dictionaries `add_inputs` receive.
295- default_values: Default values defined by the signature of the function,
296+ default_values:
297+ Default values defined by the signature of the function,
296298 any value equal to that is ignore to simplify the export.
297- missing: If a named argument (in kwargs) is missing,
299+ missing:
300+ If a named argument (in kwargs) is missing,
298301 a default value will be taken in this dictionary,
299302 this is used when after the prefill step, an argument
300303 disappears (such as `pixel_values`) and another one
301304 is added (such as `past_key_values`).
302305 The values are only to infer dynamic shapes and arguments,
303306 not to run the model.
304- kwargs_name: Name of parameter **kwargs if it exists.
307+ kwargs_name:
308+ Name of parameter `**kwargs` if it exists.
309+
310+ This is used by class :class:`InputObserver`.
305311 """
306312
307313 def __init__ (
@@ -910,6 +916,7 @@ def check_discrepancies(
910916 hist = (0.1 , 0.01 ),
911917 progress_bar : bool = False ,
912918 include_io : bool = True ,
919+ skip_none : bool = True ,
913920 ) -> list [dict [str , str | int | float | bool ]]:
914921 """Computes the discrepancies between the saved inputs and outputs
915922 with the saved onnx model.
@@ -929,6 +936,8 @@ def check_discrepancies(
929936 include_io:
930937 Shows inputs/outputs shapes in the summary
931938 returned by this function.
939+ skip_none:
940+ Dooes not check discrepancies when an output is None.
932941
933942 Returns:
934943 A list of dictionaries, ready to be consumed by a dataframe.
@@ -982,7 +991,7 @@ def check_discrepancies(
982991 if isinstance (outputs , list ) and isinstance (ort_outputs , list ):
983992 while len (ort_outputs ) > len (outputs ) and ort_outputs [- 1 ].numel () == 0 :
984993 ort_outputs .pop ()
985- diff = max_diff (outputs , ort_outputs , hist = lhist ) # type: ignore[assignment]
994+ diff = max_diff (outputs , ort_outputs , hist = lhist , skip_none = skip_none ) # type: ignore[assignment]
986995 if "rep" in diff and isinstance (diff ["rep" ], dict ):
987996 diff .update (diff ["rep" ])
988997 del diff ["rep" ]
0 commit comments