Skip to content

Commit 6f0613a

Browse files
authored
Don't include unused inputs in ndonnx_schema when they are dropped by ndonnx.build (#211)
1 parent d27e3fa commit 6f0613a

3 files changed

Lines changed: 73 additions & 12 deletions

File tree

CHANGELOG.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,14 @@
77
Changelog
88
=========
99

10+
0.18.2 (unreleased)
11+
-------------------
12+
13+
**Bug fix**
14+
15+
- Fixed an issue whereby any unused inputs dropped via the `drop_unused` argument to :func:`ndonnx.build` were still included in `ndonnx_schema`.
16+
17+
1018
0.18.1 (2026-04-22)
1119
-------------------
1220

ndonnx/_build.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
# Copyright (c) QuantCo 2023-2025
1+
# Copyright (c) QuantCo 2023-2026
22
# SPDX-License-Identifier: BSD-3-Clause
33

4-
54
import onnx
65
from spox import Var
76
from spox import build as spox_build
@@ -32,9 +31,19 @@ def build(
3231

3332
mp = spox_build(ins, outs, drop_unused_inputs=drop_unused)
3433

34+
# Find the names of inputs that have not been dropped, so that only these
35+
# are inserted into the schema.
36+
graph_inputs = [input.name for input in mp.graph.input]
37+
kept_inputs = []
38+
for name, arr in inputs.items():
39+
if any(n in graph_inputs for n in _disassemble_named_array(name, arr)):
40+
kept_inputs.append(name)
41+
3542
schema_v1 = {
3643
"ndonnx_schema": SchemaV1(
37-
input_schema={k: v.dtype.__ndx_infov1__ for k, v in inputs.items()},
44+
input_schema={
45+
k: v.dtype.__ndx_infov1__ for k, v in inputs.items() if k in kept_inputs
46+
},
3847
output_schema={k: v.dtype.__ndx_infov1__ for k, v in outputs.items()},
3948
version=1,
4049
).to_json()
@@ -45,14 +54,21 @@ def build(
4554

4655

4756
def _arrays_to_vars(dct_of_arrs: dict[str, Array]) -> dict[str, Var]:
57+
out: dict[str, Var] = {}
58+
for name, arr in dct_of_arrs.items():
59+
out |= _disassemble_named_array(name, arr)
60+
return out
61+
62+
63+
def _disassemble_named_array(name: str, arr: Array) -> dict[str, Var]:
64+
# Take an Array, and create a map of its component parts prefixed with a name.
4865
# TODO: Use a different separator for the public name and the nested components?
4966
public_separator = "_"
50-
out = {}
51-
for k, v in dct_of_arrs.items():
52-
components = v._tyarray.disassemble()
53-
if isinstance(components, Var):
54-
out[k] = components
55-
continue
56-
for k_inner, v_inner in components.items():
57-
out[f"{k}{public_separator}{k_inner}"] = v_inner
67+
out: dict[str, Var] = {}
68+
components = arr._tyarray.disassemble()
69+
if isinstance(components, Var):
70+
out[name] = components
71+
else:
72+
for k, v in components.items():
73+
out[f"{name}{public_separator}{k}"] = v
5874
return out

tests/test_build_utils.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
# Copyright (c) QuantCo 2023-2025
1+
# Copyright (c) QuantCo 2023-2026
22
# SPDX-License-Identifier: BSD-3-Clause
33
import json
44
from pathlib import Path
55

66
import pytest
7+
from onnx import ModelProto
78

89
import ndonnx as ndx
910
import ndonnx._typed_array as tydx
@@ -106,3 +107,39 @@ def test_schema_against_snapshots(dtype, update_schema_snapshots):
106107
# test json round trip of schema data
107108
assert candidate_schemas["input_schema"]["a"] == a.dtype.__ndx_infov1__.__dict__
108109
assert candidate_schemas["output_schema"]["b"] == b.dtype.__ndx_infov1__.__dict__
110+
111+
112+
def _get_input_schema(model_proto: ModelProto):
113+
for entry in model_proto.metadata_props:
114+
if entry.key == "ndonnx_schema":
115+
return json.loads(entry.value)["input_schema"]
116+
raise KeyError("'ndonnx_schema' not present")
117+
118+
119+
def test_unused_inputs_not_in_schema():
120+
# Case 1: simple graph with an unused input.
121+
a = ndx.argument(shape=("N",), dtype=ndx.float64)
122+
b = ndx.argument(shape=("N",), dtype=ndx.float64)
123+
unused = ndx.argument(shape=("N",), dtype=ndx.float64)
124+
125+
c = a + b
126+
127+
no_drop = ndx.build(
128+
inputs={"a": a, "b": b, "unused": unused}, outputs={"c": c}, drop_unused=False
129+
)
130+
131+
assert "unused" in _get_input_schema(no_drop)
132+
133+
drop = ndx.build(
134+
inputs={"a": a, "b": b, "unused": unused}, outputs={"c": c}, drop_unused=True
135+
)
136+
assert "unused" not in _get_input_schema(drop)
137+
138+
# Case 2: part of a nullable input is not used.
139+
a = ndx.argument(shape=("N",), dtype=ndx.nint64)
140+
141+
b = ndx.isnan(a)
142+
143+
drop = ndx.build(inputs={"a": a}, outputs={"b": b}, drop_unused=True)
144+
145+
assert "a" in _get_input_schema(no_drop)

0 commit comments

Comments
 (0)