Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions onnxscript/_internal/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
Sequence[float],
Sequence[bool],
Sequence[str],
None,
]

# Mapping from Python scalar types to their default ONNX DataType,
Expand Down Expand Up @@ -258,7 +259,7 @@ def initializer(

def _input_to_ir_value(
self, value: VALUE_LIKE, like_type: ir.Value | None = None
) -> ir.Value:
) -> ir.Value | None:
"""Convert a permissible input (for a call to an op) into an ir.Value.

Permissible values include ir.Value as well as python constants that can be converted
Expand All @@ -267,6 +268,8 @@ def _input_to_ir_value(
"""
if isinstance(value, ir.Value):
return value
if value is None:
return value
Comment thread
justinchuby marked this conversation as resolved.
dtype = (
like_type.type.dtype
if like_type is not None and like_type.type is not None
Expand Down Expand Up @@ -356,7 +359,7 @@ def _get_schema(
def _partition_inputs_attributes(
self,
schema: onnx.defs.OpSchema | None,
inputs: Sequence[ir.Value | ir.TensorProtocol],
inputs: Sequence[ir.Value | ir.TensorProtocol | None],
kwargs: dict[str, Any],
) -> tuple[Sequence[ir.Value | ir.TensorProtocol], dict[str, Any]]:
if schema is None:
Expand Down Expand Up @@ -504,7 +507,7 @@ def subgraph(
def call_op(
self,
op_type: str,
inputs: Sequence[ir.Value | ir.TensorProtocol],
inputs: Sequence[ir.Value | ir.TensorProtocol | None],
Comment thread
justinchuby marked this conversation as resolved.
kwargs: dict[str, Any],
):
"""Create an ONNX node and add it to the graph, returning its output value(s)."""
Expand Down
33 changes: 33 additions & 0 deletions onnxscript/_internal/builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,39 @@ def add_mul(X, Y):

self.assertIn("does not match", str(cm.exception))

def test_none_input_is_passed_through(self):
"""Test that None inputs are preserved as None in the node's inputs."""
op, x, y = _create_builder_with_inputs()

# Gemm's third input (C) is optional; passing None should work
result = op.Gemm(x, y, None, alpha=1.0)

nodes = list(op.builder.graph)
self.assertEqual(len(nodes), 1)
node = nodes[0]
self.assertEqual(node.op_type, "Gemm")
# The third input should be None (optional, omitted)
self.assertEqual(len(list(node.inputs)), 3)
self.assertIs(node.inputs[0], x)
self.assertIs(node.inputs[1], y)
self.assertIsNone(node.inputs[2])
self.assertIsNotNone(result)

def test_none_input_with_custom_domain(self):
"""Test that None inputs work with custom domain ops."""
op, x, y = _create_builder_with_inputs()

result = op.CustomOp(x, None, y, _domain="com.custom")

nodes = list(op.builder.graph)
self.assertEqual(len(nodes), 1)
node = nodes[0]
self.assertEqual(node.op_type, "CustomOp")
self.assertIs(node.inputs[0], x)
self.assertIsNone(node.inputs[1])
self.assertIs(node.inputs[2], y)
self.assertIsNotNone(result)


class BuildSubgraphTest(unittest.TestCase):
"""Tests for GraphBuilder.subgraph()."""
Expand Down
Loading