diff --git a/docs/tutorial/builder/graph_builder.md b/docs/tutorial/builder/graph_builder.md index 81b1e7df87..55bd83a90e 100644 --- a/docs/tutorial/builder/graph_builder.md +++ b/docs/tutorial/builder/graph_builder.md @@ -406,3 +406,110 @@ def build_linear(op, x, weight, bias_value): This pattern keeps function signatures simple while preserving access to the full builder API when needed. +## Calling Script Functions from OpBuilder + +The `OpBuilder` provides a `call()` method to inline `@script`-decorated ONNX functions directly into the builder's graph. This enables composition of both imperative (builder) and declarative (`@script`) code within a single graph. + +### Basic function inlining + +Define an ONNX script function and then call it through `op.call()`: + +```python +from onnxscript import script, opset23 as op23 + +# Define a reusable script function +@script(default_opset=op23) +def mul_add_relu(X, Y): + tmp = X * Y + tmp = tmp + X + return op23.Relu(tmp) + +# Now build a graph using OpBuilder +graph = ir.Graph( + name="my_graph", + inputs=[], + outputs=[], + nodes=[], + opset_imports={"": 23}, +) +x = ir.Value(name="x", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape([3, 4])) +y = ir.Value(name="y", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape([3, 4])) +graph.inputs.extend([x, y]) + +builder = onnxscript.GraphBuilder(graph) +op = builder.op + +# Call the script function — it gets inlined into the graph +result = op.call(mul_add_relu, x, y) +graph.outputs.append(result) +``` + +The function body (three nodes: Mul, Add, Relu) is inlined directly into the graph. + +### Renaming outputs with `_outputs` + +By default, inlined function outputs keep their original names, qualified by the +current naming context. You can rename them explicitly with `_outputs`: + +```python +@script(default_opset=op23) +def add_mul(X, Y): + a = X + Y + b = X * Y + return a, b + +# Inline with custom output names +result_sum, result_prod = op.call( + add_mul, x, y, + _outputs=["custom_sum", "custom_product"] +) +``` + +### Adding hierarchical context with `_prefix` + +Use `_prefix` to add a naming context to all nodes and intermediate values created +by the inlined function: + +```python +result = op.call( + mul_add_relu, x, y, + _prefix="layer1" +) +# Node names will be "layer1.Mul_n...", "layer1.Add_n...", "layer1.Relu_n..." +# Intermediate value names will also start with "layer1." +``` + +You can combine both options: + +```python +result_a, result_b = op.call( + add_mul, x, y, + _outputs=["sum_out", "prod_out"], + _prefix="math_ops" +) +# Final outputs: "sum_out", "prod_out" (renamed before prefix context) +# Intermediate values: "math_ops.Add_n...", "math_ops.Mul_n..." (with prefix) +``` + +### Using OpBuilder as the default_opset + +`OpBuilder` can be passed directly as the `default_opset` when decorating a script +function. This enables scripted functions to use the same opset version as the +builder they will be inlined into: + +```python +builder = onnxscript.GraphBuilder(graph) +op = builder.op + +# Define the function *after* creating the builder, using op as default_opset +@script(default_opset=op) +def my_func(X, Y): + t = X + Y + return op.Relu(t) # Uses the op directly + +# Inline it +result = op.call(my_func, x, y) +``` + +This pattern ensures consistency: the script function operates in the same domain +and opset version as the builder. diff --git a/onnxscript/_internal/_inliner.py b/onnxscript/_internal/_inliner.py new file mode 100644 index 0000000000..6a4d6d6742 --- /dev/null +++ b/onnxscript/_internal/_inliner.py @@ -0,0 +1,56 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from typing import Mapping, Sequence + +import onnx_ir as ir +from onnx_ir._cloner import Cloner + + +def instantiate( + function: ir.Function, + inputs: Sequence[ir.Value | None], + attributes: Mapping[str, ir.Attr], + *, + prefix: str = "", +) -> tuple[list[ir.Node], list[ir.Value | None]]: + """Instantiate (inline) a function, substituting inputs and attributes. + + Args: + function: The function to instantiate. + inputs: Actual input values to bind to the function's formal parameters. + attributes: Attribute values to substitute for reference attributes. + prefix: Optional prefix to prepend to node and output names. + + Returns: + A tuple of (nodes, outputs) where nodes are the cloned function body + and outputs are the values corresponding to the function's outputs. + """ + formal_inputs = function.inputs + if len(inputs) > len(formal_inputs): + raise ValueError( + f"Too many inputs: got {len(inputs)}, " + f"but function has {len(formal_inputs)} parameters." + ) + value_map: dict[ir.Value, ir.Value | None] = dict(zip(formal_inputs, inputs)) + + def rename(node: ir.Node) -> None: + if prefix: + if node.name: + node.name = prefix + node.name + for output in node.outputs: + if output is not None and output.name: + output.name = prefix + output.name + + cloner = Cloner( + attr_map=attributes, + value_map=value_map, + metadata_props={}, + post_process=rename, + resolve_ref_attrs=True, + ) + nodes = [cloner.clone_node(n) for n in function] + outputs = [value_map.get(v) for v in function.outputs] + return nodes, outputs diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index 4589c99e59..d60db2c7da 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -1,5 +1,12 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +"""Graph builder for constructing ONNX IR graphs imperatively. + +This module provides imperative builders for constructing ONNX IR graphs with automatic +constant promotion, type casting, and shape inference. The GraphBuilder class enables +programmatic construction of graphs with proper scoping, constant management, and node +creation. The OpBuilder class provides dynamic op dispatching via attribute access. +""" from __future__ import annotations @@ -10,6 +17,7 @@ import onnxscript._internal._inference as inference import onnxscript.optimizer +from onnxscript._internal import _inliner # A permissible value for an op input, which can be converted to an ir.Value. VALUE_LIKE = Union[ @@ -334,6 +342,49 @@ def call_op( return node.outputs if len(node.outputs) > 1 else node.outputs[0] + def call( + self, + function, + *args, + _outputs: Sequence[str] | None = None, + _prefix: str = "", + **kwargs, + ): + if isinstance(function, ir.Function): + function_ir = function + elif isinstance(function, onnxscript.OnnxFunction): + function_proto = function.to_function_proto() + function_ir = ir.serde.deserialize_function(function_proto) + else: + raise TypeError("Function must be an ir.Function or onnxscript.OnnxFunction") + output_renaming: dict[str, str] = {} + if _outputs is not None: + if len(_outputs) != len(function_ir.outputs): + raise ValueError( + f"Number of provided output names {_outputs} does not match " + f"number of function outputs {len(function_ir.outputs)}." + ) + for output, name in zip(function_ir.outputs, _outputs): + output_renaming[output.name] = self._qualify_value_name(name) + else: + for output in function_ir.outputs: + output_renaming[output.name] = self._qualify_value_name(output.name) + nodes, outputs = _inliner.instantiate(function_ir, args, kwargs) + if _prefix: + self.push_module(_prefix) + for node in nodes: + node.name = self._qualify_node_name(node.name) + for output in node.outputs: + if output.name: + if output.name in output_renaming: + output.name = output_renaming[output.name] + else: + output.name = self._qualify_value_name(output.name) + self.add_node(node) + if _prefix: + self.pop_module() + return outputs if len(outputs) > 1 else outputs[0] + def push_module(self, module: str, class_name: str = "") -> None: """Push a new module scope onto the stack. @@ -414,6 +465,14 @@ def __init__( def builder(self) -> GraphBuilder: return self._builder + @property + def domain(self) -> str: + return self._domain + + @property + def version(self) -> int | None: + return self._version + def _call_op(self, op_type: str, inputs: Sequence[Any], kwargs: dict[str, Any]): if "_domain" not in kwargs: kwargs["_domain"] = self._domain @@ -426,3 +485,28 @@ def __getattr__(self, op_type: str) -> Callable: def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir.Value: return self._builder.initializer(tensor, name) + + def call( + self, + function, + *args, + _outputs: Sequence[str] | None = None, + _prefix: str = "", + **kwargs, + ): + """Call a function and inline it into the graph. + + Args: + function: The function to call (ir.Function or onnxscript.OnnxFunction). + *args: Positional arguments to pass to the function. + _outputs: Optional sequence of output names. If provided, must match the + number of function outputs. + _prefix: Optional prefix for module scoping (e.g., "layers.0"). + **kwargs: Keyword arguments to pass to the function. + + Returns: + The output value(s) from the function call. + """ + return self._builder.call( + function, *args, _outputs=_outputs, _prefix=_prefix, **kwargs + ) diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py index 209ccb6165..8dbb81525a 100644 --- a/onnxscript/_internal/builder_test.py +++ b/onnxscript/_internal/builder_test.py @@ -10,6 +10,7 @@ import onnx_ir as ir import onnxscript._internal.builder as builder +from onnxscript import script _default_opset_version = 23 @@ -192,12 +193,12 @@ def test_value_naming_with_ir_value_objects(self): self.assertIs(t3, out3) def test_default_output_naming_strategy(self): - """Test the default naming strategy for generated output values using op_type_output format.""" + """Test the default naming strategy for generated output values using op_type_nX_output format.""" def _ops_with_default_names( op: builder.OpBuilder, x: ir.Value, y: ir.Value ) -> ir.Value: - # Single output operations should be named {op_type}_output + # Single output operations should be named {op_type}_nX_output where X is node count t1 = op.Add(x, y) t2 = op.Mul(x, y) z = op.Add(t1, t2) @@ -276,6 +277,8 @@ def test_shape_inference_add(self): self.assertIsNotNone(result.shape) self.assertEqual(list(result.shape), [2, 3, 4]) + self.assertEqual(result.name, "v_Add_0") + def test_custom_domain_explicit(self): """Test using operations from custom domains with explicit _domain parameter.""" op, x, y = _create_builder_with_inputs() @@ -644,6 +647,177 @@ def test_attributes_are_created_properly(self): self.assertEqual(strs_attr.type, ir.AttributeType.STRINGS) self.assertEqual(list(strs_attr.value), ["a", "b", "c"]) + def test_call_inlines_onnxscript_function(self): + """Test that GraphBuilder.call inlines an @onnxscript.script function.""" + # Create a GraphBuilder first + op, x, y = _create_builder_with_inputs() + + # Define the script function after creating op, using op as default_opset + @script(default_opset=op) + def mul_add_relu(X, Y): + tmp = X * Y + tmp = tmp + X + return op.Relu(tmp) + + result = op.call(mul_add_relu, x, y) + + # The inlined function should produce 3 nodes: Mul, Add, Relu + nodes = list(op.builder.graph) + op_types = [n.op_type for n in nodes] + self.assertEqual(op_types, ["Mul", "Add", "Relu"]) + + # The result should be a single ir.Value (the Relu output) + self.assertIsInstance(result, ir.Value) + + # Verify connectivity: Relu takes the Add output + relu_node = nodes[2] + add_node = nodes[1] + self.assertIs(relu_node.inputs[0], add_node.outputs[0]) + + # Verify the Add takes the Mul output and original input x + mul_node = nodes[0] + self.assertIs(add_node.inputs[0], mul_node.outputs[0]) + self.assertIs(add_node.inputs[1], x) + + # Verify the Mul takes the original inputs x and y + self.assertIs(mul_node.inputs[0], x) + self.assertIs(mul_node.inputs[1], y) + + def test_call_with_outputs_option(self): + """Test that GraphBuilder.call respects the _outputs option for renaming.""" + # Create a GraphBuilder first + op, x, y = _create_builder_with_inputs() + + # Define the script function after creating op, using op as default_opset + @script(default_opset=op) + def add_mul(X, Y): + a = X + Y + b = X * Y + return a, b + + result = op.call(add_mul, x, y, _outputs=["sum_result", "product_result"]) + + # The result should be a list of 2 ir.Values (when function returns multiple outputs) + self.assertIsInstance(result, list) + self.assertEqual(len(result), 2) + sum_result, product_result = result + + # Verify output names are correctly set (with v_ prefix from value naming convention) + self.assertEqual(sum_result.name, "v_sum_result") + self.assertEqual(product_result.name, "v_product_result") + + # Verify the nodes were created correctly + nodes = list(op.builder.graph) + self.assertEqual(len(nodes), 2) + self.assertEqual(nodes[0].op_type, "Add") + self.assertEqual(nodes[1].op_type, "Mul") + + def test_call_with_prefix_option(self): + """Test that GraphBuilder.call respects the _prefix option for hierarchical naming.""" + # Create a GraphBuilder first + op, x, y = _create_builder_with_inputs() + + # Define the script function after creating op, using op as default_opset + @script(default_opset=op) + def mul_add_relu(X, Y): + tmp = X * Y + tmp = tmp + X + return op.Relu(tmp) + + result = op.call(mul_add_relu, x, y, _prefix="layer1") + + # The nodes should have the prefix in their names + nodes = list(op.builder.graph) + self.assertEqual(len(nodes), 3) + + # Check that all node names start with the prefix (node names use / separator) + for node in nodes: + self.assertTrue( + node.name.startswith("layer1/"), + f"Node name {node.name} should start with layer1/", + ) + + # Verify the result is a single ir.Value + self.assertIsInstance(result, ir.Value) + + def test_call_with_outputs_and_prefix_options(self): + """Test that GraphBuilder.call respects both _outputs and _prefix options together. + + Note: _outputs names are set before the prefix context is applied, so they don't get + the prefix in their names. However, the inlined nodes do get the prefix applied, and + intermediate values (not renamed by _outputs) do get the prefix applied. + """ + # Create a GraphBuilder first + op, x, y = _create_builder_with_inputs() + + # Define the script function after creating op, using op as default_opset + @script(default_opset=op) + def add_mul(X, Y): + # Intermediate values that are not explicitly renamed by _outputs + XSquare = X * X + YSquare = Y * Y + # Final outputs that will be renamed by _outputs + a = XSquare + Y + b = XSquare * YSquare + return a, b + + result = op.call( + add_mul, x, y, _outputs=["custom_sum", "custom_product"], _prefix="math_ops" + ) + + # The result should be a list of 2 ir.Values + self.assertIsInstance(result, list) + self.assertEqual(len(result), 2) + sum_result, product_result = result + + # Verify output names are set (with v_ prefix from value naming convention) + self.assertEqual(sum_result.name, "v_custom_sum") + self.assertEqual(product_result.name, "v_custom_product") + + # Verify all nodes have the prefix applied to their names + nodes = list(op.builder.graph) + self.assertEqual(len(nodes), 4) # Mul (XSquare), Mul (YSquare), Add, Mul (final) + + # All node names should start with prefix (node names use / separator) + for node in nodes: + self.assertTrue( + node.name.startswith("math_ops/"), + f"Node name {node.name} should start with math_ops/", + ) + + # Verify intermediate value names also get the prefix (value names use v_ prefix and . separator) + # The first Mul produces XSquare + x_square = nodes[0].outputs[0] + self.assertTrue( + x_square.name.startswith("v_math_ops."), + f"Intermediate value {x_square.name} should have prefix", + ) + + # The second Mul produces YSquare + y_square = nodes[1].outputs[0] + self.assertTrue( + y_square.name.startswith("v_math_ops."), + f"Intermediate value {y_square.name} should have prefix", + ) + + def test_call_outputs_mismatch_error(self): + """Test that GraphBuilder.call raises an error if _outputs has wrong count.""" + # Create a GraphBuilder first + op, x, y = _create_builder_with_inputs() + + # Define the script function after creating op, using op as default_opset + @script(default_opset=op) + def add_mul(X, Y): + a = X + Y + b = X * Y + return a, b + + # The function returns 2 outputs, but we provide only 1 name + with self.assertRaises(ValueError) as cm: + op.call(add_mul, x, y, _outputs=["only_one_name"]) + + self.assertIn("does not match", str(cm.exception)) + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/_internal/converter.py b/onnxscript/_internal/converter.py index 6ebe5bda4a..dd215a7c06 100644 --- a/onnxscript/_internal/converter.py +++ b/onnxscript/_internal/converter.py @@ -20,6 +20,7 @@ analysis, ast_utils, autocast, + builder, irbuilder, param_manipulation, sourceinfo, @@ -171,14 +172,18 @@ def __init__( opset: values.Opset | None = None, global_names: dict[str, Any] | None = None, source: str | None = None, - default_opset: values.Opset | None = None, + default_opset: Union[values.Opset, builder.OpBuilder, None] = None, ): self.source = source if global_names is not None: # We make a copy in case function eval modifies it. self.globals = global_names.copy() self.this_module = opset - self.default_opset_ = default_opset + # Convert OpBuilder to Opset if necessary and store the converted value + if isinstance(default_opset, builder.OpBuilder): + self.default_opset_ = values.Opset(default_opset.domain, default_opset.version) + else: + self.default_opset_ = default_opset # States initialized by `_init_function_translation` self._outer: list[irbuilder.IRFunction] = [] @@ -225,14 +230,17 @@ def _set_default_opset(self, opset: values.Opset, node: ast.AST) -> None: def _find_onnx_opset(self, node: ast.AST) -> values.Opset | None: """Find the (first) ONNX opset used in the function, if any.""" # Search for a Call expression of form "op.OpName(...)" - if isinstance(node, ast.Call): - if isinstance(node.func, ast.Attribute): - opset_expr = node.func.value - if isinstance(opset_expr, ast.Name): - if opset_expr.id in self.globals: - opset = self.globals[opset_expr.id] - if isinstance(opset, values.Opset) and opset.domain == "": - return opset + if isinstance(node, ast.Call) and isinstance(node.func, ast.Attribute): + opset_expr = node.func.value + if isinstance(opset_expr, ast.Name) and opset_expr.id in self.globals: + opset = self.globals[opset_expr.id] + # Accept both values.Opset and builder.OpBuilder + if isinstance(opset, values.Opset): + if opset.domain == "": + return opset + elif isinstance(opset, builder.OpBuilder): + if opset.domain == "": + return values.Opset(opset.domain, opset.version) for child in ast.iter_child_nodes(node): res = self._find_onnx_opset(child) if res is not None: @@ -954,6 +962,9 @@ def _translate_opset_expr(self, node: ast.Attribute) -> values.Opset: val = self._lookup(node.id, self._source_of(node), raise_exception=False) if isinstance(val, values.Opset): return val + elif isinstance(val, builder.OpBuilder): + # Convert OpBuilder to Opset for compatibility + return values.Opset(val.domain, val.version) self.fail(node, f"'{node.id}' is not an instance of type Opset but {type(val)}.") elif isinstance(node, ast.Attribute): self.fail(node, "Nested module unimplemented.") # TODO