Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions docs/tutorial/builder/graph_builder.md
Original file line number Diff line number Diff line change
Expand Up @@ -406,3 +406,110 @@ def build_linear(op, x, weight, bias_value):

This pattern keeps function signatures simple while preserving access to the
full builder API when needed.
## Calling Script Functions from OpBuilder

The `OpBuilder` provides a `call()` method to inline `@script`-decorated ONNX functions directly into the builder's graph. This enables composition of both imperative (builder) and declarative (`@script`) code within a single graph.

### Basic function inlining

Define an ONNX script function and then call it through `op.call()`:

```python
from onnxscript import script, opset23 as op23

# Define a reusable script function
@script(default_opset=op23)
def mul_add_relu(X, Y):
tmp = X * Y
tmp = tmp + X
return op23.Relu(tmp)

# Now build a graph using OpBuilder
graph = ir.Graph(
name="my_graph",
inputs=[],
outputs=[],
nodes=[],
opset_imports={"": 23},
)
x = ir.Value(name="x", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape([3, 4]))
y = ir.Value(name="y", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape([3, 4]))
graph.inputs.extend([x, y])

builder = onnxscript.GraphBuilder(graph)
op = builder.op

# Call the script function — it gets inlined into the graph
result = op.call(mul_add_relu, x, y)
graph.outputs.append(result)
```

The function body (three nodes: Mul, Add, Relu) is inlined directly into the graph.

### Renaming outputs with `_outputs`

By default, inlined function outputs keep their original names, qualified by the
current naming context. You can rename them explicitly with `_outputs`:

```python
@script(default_opset=op23)
def add_mul(X, Y):
a = X + Y
b = X * Y
return a, b

# Inline with custom output names
result_sum, result_prod = op.call(
add_mul, x, y,
_outputs=["custom_sum", "custom_product"]
)
```

### Adding hierarchical context with `_prefix`

Use `_prefix` to add a naming context to all nodes and intermediate values created
by the inlined function:

```python
result = op.call(
mul_add_relu, x, y,
_prefix="layer1"
)
# Node names will be "layer1.Mul_n...", "layer1.Add_n...", "layer1.Relu_n..."
# Intermediate value names will also start with "layer1."
```

You can combine both options:

```python
result_a, result_b = op.call(
add_mul, x, y,
_outputs=["sum_out", "prod_out"],
_prefix="math_ops"
)
# Final outputs: "sum_out", "prod_out" (renamed before prefix context)
# Intermediate values: "math_ops.Add_n...", "math_ops.Mul_n..." (with prefix)
```

### Using OpBuilder as the default_opset

`OpBuilder` can be passed directly as the `default_opset` when decorating a script
function. This enables scripted functions to use the same opset version as the
builder they will be inlined into:

```python
builder = onnxscript.GraphBuilder(graph)
op = builder.op

# Define the function *after* creating the builder, using op as default_opset
@script(default_opset=op)
def my_func(X, Y):
t = X + Y
return op.Relu(t) # Uses the op directly

# Inline it
result = op.call(my_func, x, y)
```

This pattern ensures consistency: the script function operates in the same domain
and opset version as the builder.
56 changes: 56 additions & 0 deletions onnxscript/_internal/_inliner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) Microsoft Corporation.
Comment thread Fixed
Comment thread Fixed
# Licensed under the MIT License.

from __future__ import annotations

from typing import Mapping, Sequence

import onnx_ir as ir
from onnx_ir._cloner import Cloner
Comment thread
gramalingam marked this conversation as resolved.


def instantiate(
function: ir.Function,
inputs: Sequence[ir.Value | None],
attributes: Mapping[str, ir.Attr],
*,
prefix: str = "",
) -> tuple[list[ir.Node], list[ir.Value | None]]:
"""Instantiate (inline) a function, substituting inputs and attributes.

Args:
function: The function to instantiate.
inputs: Actual input values to bind to the function's formal parameters.
attributes: Attribute values to substitute for reference attributes.
prefix: Optional prefix to prepend to node and output names.

Returns:
A tuple of (nodes, outputs) where nodes are the cloned function body
and outputs are the values corresponding to the function's outputs.
"""
formal_inputs = function.inputs
if len(inputs) > len(formal_inputs):
raise ValueError(
f"Too many inputs: got {len(inputs)}, "
f"but function has {len(formal_inputs)} parameters."
)
value_map: dict[ir.Value, ir.Value | None] = dict(zip(formal_inputs, inputs))

def rename(node: ir.Node) -> None:
if prefix:
if node.name:
node.name = prefix + node.name
for output in node.outputs:
if output is not None and output.name:
output.name = prefix + output.name

cloner = Cloner(
attr_map=attributes,
value_map=value_map,
metadata_props={},
post_process=rename,
resolve_ref_attrs=True,
)
nodes = [cloner.clone_node(n) for n in function]
outputs = [value_map.get(v) for v in function.outputs]
return nodes, outputs
84 changes: 84 additions & 0 deletions onnxscript/_internal/builder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Graph builder for constructing ONNX IR graphs imperatively.

This module provides imperative builders for constructing ONNX IR graphs with automatic
constant promotion, type casting, and shape inference. The GraphBuilder class enables
programmatic construction of graphs with proper scoping, constant management, and node
creation. The OpBuilder class provides dynamic op dispatching via attribute access.
"""

from __future__ import annotations

Expand All @@ -10,6 +17,7 @@

import onnxscript._internal._inference as inference
import onnxscript.optimizer
from onnxscript._internal import _inliner

# A permissible value for an op input, which can be converted to an ir.Value.
VALUE_LIKE = Union[
Expand Down Expand Up @@ -334,6 +342,49 @@ def call_op(

return node.outputs if len(node.outputs) > 1 else node.outputs[0]

def call(
self,
function,
*args,
_outputs: Sequence[str] | None = None,
_prefix: str = "",
**kwargs,
):
if isinstance(function, ir.Function):
function_ir = function
elif isinstance(function, onnxscript.OnnxFunction):
function_proto = function.to_function_proto()
function_ir = ir.serde.deserialize_function(function_proto)
Comment thread
gramalingam marked this conversation as resolved.
else:
raise TypeError("Function must be an ir.Function or onnxscript.OnnxFunction")
output_renaming: dict[str, str] = {}
if _outputs is not None:
if len(_outputs) != len(function_ir.outputs):
raise ValueError(
f"Number of provided output names {_outputs} does not match "
f"number of function outputs {len(function_ir.outputs)}."
)
for output, name in zip(function_ir.outputs, _outputs):
output_renaming[output.name] = self._qualify_value_name(name)
else:
for output in function_ir.outputs:
output_renaming[output.name] = self._qualify_value_name(output.name)
nodes, outputs = _inliner.instantiate(function_ir, args, kwargs)
if _prefix:
self.push_module(_prefix)
Comment thread
justinchuby marked this conversation as resolved.
for node in nodes:
node.name = self._qualify_node_name(node.name)
for output in node.outputs:
if output.name:
if output.name in output_renaming:
output.name = output_renaming[output.name]
else:
output.name = self._qualify_value_name(output.name)
self.add_node(node)
if _prefix:
self.pop_module()
return outputs if len(outputs) > 1 else outputs[0]

def push_module(self, module: str, class_name: str = "") -> None:
"""Push a new module scope onto the stack.

Expand Down Expand Up @@ -414,6 +465,14 @@ def __init__(
def builder(self) -> GraphBuilder:
return self._builder

@property
def domain(self) -> str:
return self._domain

@property
def version(self) -> int | None:
return self._version

def _call_op(self, op_type: str, inputs: Sequence[Any], kwargs: dict[str, Any]):
if "_domain" not in kwargs:
kwargs["_domain"] = self._domain
Expand All @@ -426,3 +485,28 @@ def __getattr__(self, op_type: str) -> Callable:

def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir.Value:
return self._builder.initializer(tensor, name)

def call(
self,
function,
*args,
_outputs: Sequence[str] | None = None,
_prefix: str = "",
**kwargs,
):
Comment thread
gramalingam marked this conversation as resolved.
"""Call a function and inline it into the graph.

Args:
function: The function to call (ir.Function or onnxscript.OnnxFunction).
*args: Positional arguments to pass to the function.
Comment thread
gramalingam marked this conversation as resolved.
_outputs: Optional sequence of output names. If provided, must match the
number of function outputs.
_prefix: Optional prefix for module scoping (e.g., "layers.0").
**kwargs: Keyword arguments to pass to the function.

Returns:
The output value(s) from the function call.
"""
return self._builder.call(
function, *args, _outputs=_outputs, _prefix=_prefix, **kwargs
)
Loading
Loading