Skip to content

Commit b26324a

Browse files
committed
fix
1 parent a7f0bb6 commit b26324a

2 files changed

Lines changed: 12 additions & 4 deletions

File tree

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,7 @@ def test_cache_dependant_input_preparation_exporting(self):
948948
torch.testing.assert_close(eager1, export1)
949949
torch.testing.assert_close(eager2, export2)
950950

951+
@requires_transformers("4.57")
951952
def test_prepare_inputs_for_generation_decoder_llm(self):
952953
data = get_untrained_model_with_inputs(
953954
"hf-internal-testing/tiny-random-LlamaForCausalLM"

onnx_diagnostic/helpers/onnx_helper.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1356,7 +1356,7 @@ def _mkv_(name, itype, irank):
13561356

13571357

13581358
def unknown_names_within_nodes(nodes: List[NodeProto]) -> Set[str]:
1359-
"""Returns the list of unkonwn results from a list of nodes."""
1359+
"""Returns the list of unknown results from a list of nodes."""
13601360
not_known: Set[str] = set()
13611361
for node in nodes[::-1]:
13621362
not_known -= {o for o in node.output if o}
@@ -1875,15 +1875,22 @@ def make_model_with_local_functions(
18751875
f"nodes in partition {function_name!r}"
18761876
)
18771877
outputs = _find_used_names(new_nodes, node_indices)
1878-
function_nodes = [new_nodes[i] for i in node_indices]
1878+
function_nodes = [new_nodes[i] for i in node_indices if new_nodes[i]]
1879+
1880+
check_for_non_recursivity(
1881+
function_nodes, unknown_names_within_nodes(function_nodes), outputs
1882+
)
1883+
18791884
lf = make_subfunction(
18801885
function_name,
1881-
[n for n in function_nodes if n],
1886+
function_nodes,
18821887
model.opset_import,
18831888
outputs,
18841889
domain=domain,
18851890
)
1886-
check_for_non_recursivity(new_nodes, lf.input, lf.output)
1891+
1892+
check_for_non_recursivity(function_nodes, lf.input, lf.output)
1893+
18871894
functions.append(lf)
18881895
maxi = max(node_indices)
18891896
for i in node_indices:

0 commit comments

Comments
 (0)