1- # Copyright (c) QuantCo 2023-2025
1+ # Copyright (c) QuantCo 2023-2026
22# SPDX-License-Identifier: BSD-3-Clause
33
4-
54import onnx
65from spox import Var
76from spox import build as spox_build
@@ -32,9 +31,19 @@ def build(
3231
3332 mp = spox_build (ins , outs , drop_unused_inputs = drop_unused )
3433
34+ # Find the names of inputs that have not been dropped, so that only these
35+ # are inserted into the schema.
36+ graph_inputs = [input .name for input in mp .graph .input ]
37+ kept_inputs = []
38+ for name , arr in inputs .items ():
39+ if any (n in graph_inputs for n in _disassemble_named_array (name , arr )):
40+ kept_inputs .append (name )
41+
3542 schema_v1 = {
3643 "ndonnx_schema" : SchemaV1 (
37- input_schema = {k : v .dtype .__ndx_infov1__ for k , v in inputs .items ()},
44+ input_schema = {
45+ k : v .dtype .__ndx_infov1__ for k , v in inputs .items () if k in kept_inputs
46+ },
3847 output_schema = {k : v .dtype .__ndx_infov1__ for k , v in outputs .items ()},
3948 version = 1 ,
4049 ).to_json ()
@@ -45,14 +54,21 @@ def build(
4554
4655
4756def _arrays_to_vars (dct_of_arrs : dict [str , Array ]) -> dict [str , Var ]:
57+ out : dict [str , Var ] = {}
58+ for name , arr in dct_of_arrs .items ():
59+ out |= _disassemble_named_array (name , arr )
60+ return out
61+
62+
63+ def _disassemble_named_array (name : str , arr : Array ) -> dict [str , Var ]:
64+ # Take an Array, and create a map of its component parts prefixed with a name.
4865 # TODO: Use a different separator for the public name and the nested components?
4966 public_separator = "_"
50- out = {}
51- for k , v in dct_of_arrs .items ():
52- components = v ._tyarray .disassemble ()
53- if isinstance (components , Var ):
54- out [k ] = components
55- continue
56- for k_inner , v_inner in components .items ():
57- out [f"{ k } { public_separator } { k_inner } " ] = v_inner
67+ out : dict [str , Var ] = {}
68+ components = arr ._tyarray .disassemble ()
69+ if isinstance (components , Var ):
70+ out [name ] = components
71+ else :
72+ for k , v in components .items ():
73+ out [f"{ name } { public_separator } { k } " ] = v
5874 return out
0 commit comments