diff --git a/onnxscript/ir/passes/common/graph_extration.py b/onnxscript/ir/passes/common/graph_extration.py new file mode 100644 index 0000000000..910dce964f --- /dev/null +++ b/onnxscript/ir/passes/common/graph_extration.py @@ -0,0 +1,145 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Passes for extracting subgraphs from a graph.""" + +from __future__ import annotations + +import itertools + +__all__ = [ + "ExtractGraphPass", +] + +import logging +from collections.abc import Collection + +from onnxscript import ir + +logger = logging.getLogger(__name__) + + +def _find_subgraph_bounded_by_values( + graph: ir.Graph, inputs: Collection[ir.Value], outputs: Collection[ir.Value] +) -> tuple[list[ir.Node], list[ir.Value]]: + """Finds the subgraph bounded by the given inputs and outputs. + + Args: + graph: The graph to search. + inputs: The inputs to the subgraph. + outputs: The outputs of the subgraph. + + Returns: + A list of nodes in the subgraph and the initializers used. + """ + node_index = {node: idx for idx, node in enumerate(graph)} + all_nodes = [] + value_stack: list[ir.Value] = [*outputs] + visited_nodes: set[ir.Node] = set() + visited_values: set[ir.Value] = set(inputs) + initializers = [val for val in inputs if val.name in graph.initializers] + while value_stack: + value = value_stack.pop() + if value in visited_values: + continue + if value.name in graph.initializers: + # Record the initializer + assert value.const_value is not None + initializers.append(value) + visited_values.add(value) + if (node := value.producer()) is not None: + if node not in visited_nodes: + visited_nodes.add(node) + all_nodes.append(node) + for input in node.inputs: + if input not in visited_values and input is not None: + value_stack.append(input) + # Preserve the original order + all_nodes.sort(key=lambda n: node_index[n]) + return all_nodes, initializers + + +class ExtractGraphPass(ir.passes.InPlacePass): + """This pass extracts a subgraph from the given graph.""" + + def __init__(self, input_names: Collection[str], output_names: Collection[str]) -> None: + """Extracts sub-model from an ONNX model. + + The sub-model is defined by the names of the input and output tensors *exactly*. + + Args: + input_names: The names of the inputs to extract. Must be deduplicated. + output_names: The names of the outputs to extract. Must be deduplicated. + """ + super().__init__() + self.input_names = input_names + self.output_names = output_names + + def call(self, model: ir.Model) -> ir.passes.PassResult: + values = ir.convenience.create_value_mapping(model.graph) + inputs = [values[name] for name in self.input_names] + outputs = [values[name] for name in self.output_names] + extracted_nodes, initializers = _find_subgraph_bounded_by_values( + model.graph, inputs, outputs + ) + + model.graph.remove(extracted_nodes) + # Create inputs for the new graph as the old inputs are owned by the old nodes + new_inputs = [] + for input in inputs: + new_inputs.append( + ir.Value( + name=input.name, + shape=input.shape, + type=input.type, + doc_string=input.doc_string, + const_value=input.const_value, + ) + ) + ir.convenience.replace_all_uses_with(inputs, new_inputs) + + # Replace the model graph + model.graph = ir.Graph( + new_inputs, + outputs, + nodes=extracted_nodes, + initializers=initializers, + doc_string=model.graph.doc_string, + opset_imports=model.graph.opset_imports, + name=model.graph.name, + metadata_props=model.graph.metadata_props, + ) + + return ir.passes.PassResult(model, modified=True) + + def requires(self, model: ir.Model) -> None: + # All inputs and outputs can be found in the model + values = ir.convenience.create_value_mapping(model.graph) + input_names_not_found = sorted(set(self.input_names) - set(values.keys())) + if input_names_not_found: + raise ir.passes.PreconditionError( + f"Input names not found in the model: {input_names_not_found}" + ) + output_names_not_found = sorted(set(self.output_names) - set(values.keys())) + if output_names_not_found: + raise ir.passes.PreconditionError( + f"Output names not found in the model: {output_names_not_found}" + ) + + # All inputs and outputs must have type and shape + for name in itertools.chain(self.input_names, self.output_names): + value = values[name] + if value.type is None: + logger.warning( + "Value %%%s does not have a type: '%r'. " + "Consider setting its type or running shape inference first.", + name, + value, + ) + if value.shape is None: + logger.warning( + "Value %%%s does not have a shape: '%r'. " + "Consider setting its shape or running shape inference first.", + name, + value, + ) + # TODO(justinchuby): Make sure the subgraph is completely bounded by inputs and outputs diff --git a/onnxscript/ir/passes/common/graph_extration_test.py b/onnxscript/ir/passes/common/graph_extration_test.py new file mode 100644 index 0000000000..e1ab63c56c --- /dev/null +++ b/onnxscript/ir/passes/common/graph_extration_test.py @@ -0,0 +1,154 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import unittest +import numpy as np + +from onnxscript import ir +from onnxscript.ir.passes.common.graph_extration import ExtractGraphPass + + +class TestExtractGraphPass(unittest.TestCase): + def test_extract_subgraph(self): + inputs = [ + ir.Value(name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))), + ir.Value(name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))), + ] + + add_node = ir.node("Add", inputs=inputs) + mul_node = ir.node("Mul", inputs=[add_node.outputs[0], inputs[1]]) + + model = ir.Model( + graph=ir.Graph( + inputs=inputs, + outputs=mul_node.outputs, + nodes=[add_node, mul_node], + opset_imports={"": 20}, + ), + ir_version=10, + ) + + # Perform extract graph pass + extract_pass = ExtractGraphPass(input_names=["input_a"], output_names=[mul_node.outputs[0].name]) + result = extract_pass(model) + self.assertTrue(result.modified) + self.assertEqual(len(result.model.graph.nodes), 2) + self.assertEqual(result.model.graph.nodes[0].op_type, "Add") + self.assertEqual(result.model.graph.nodes[1].op_type, "Mul") + + def test_extract_subgraph_with_initializers(self): + inputs = [ + ir.Value(name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))), + ir.Value(name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))), + ] + + constant_tensor = ir.tensor(np.random.rand(2, 3).astype(ir.DataType.FLOAT.numpy())) + const_node = ir.node( + "Constant", inputs=[], attributes={"value": constant_tensor}, num_outputs=1 + ) + add_node = ir.node("Add", inputs=[inputs[0], const_node.outputs[0]]) + mul_node = ir.node("Mul", inputs=[add_node.outputs[0], inputs[1]]) + + model = ir.Model( + graph=ir.Graph( + inputs=inputs, + outputs=mul_node.outputs, + nodes=[const_node, add_node, mul_node], + opset_imports={"": 20}, + ), + ir_version=10, + ) + + # Perform extract graph pass + extract_pass = ExtractGraphPass(input_names=["input_a"], output_names=[mul_node.outputs[0].name]) + result = extract_pass(model) + self.assertTrue(result.modified) + self.assertEqual(len(result.model.graph.nodes), 3) + self.assertEqual(result.model.graph.nodes[0].op_type, "Constant") + self.assertEqual(result.model.graph.nodes[1].op_type, "Add") + self.assertEqual(result.model.graph.nodes[2].op_type, "Mul") + + def test_extract_subgraph_with_subgraph(self): + input_value = ir.Value( + name="input", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3)) + ) + + then_constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32)) + then_const_node = ir.node( + "Constant", inputs=[], attributes={"value": then_constant_tensor}, num_outputs=1 + ) + add_node = ir.node("Add", inputs=[input_value, then_const_node.outputs[0]]) + then_graph = ir.Graph( + inputs=[input_value], + outputs=[add_node.outputs[0]], + nodes=[then_const_node, add_node], + opset_imports={"": 20}, + ) + else_constant_tensor = ir.tensor(np.random.rand(2, 3).astype(np.float32)) + else_const_node = ir.node( + "Constant", inputs=[], attributes={"value": else_constant_tensor}, num_outputs=1 + ) + mul_node = ir.node("Mul", inputs=[input_value, else_const_node.outputs[0]]) + else_graph = ir.Graph( + inputs=[input_value], + outputs=[mul_node.outputs[0]], + nodes=[else_const_node, mul_node], + opset_imports={"": 20}, + ) + cond_node = ir.node( + "If", + inputs=[input_value], + attributes={"then_branch": then_graph, "else_branch": else_graph}, + num_outputs=1, + ) + main_graph = ir.Graph( + inputs=[input_value], + outputs=cond_node.outputs, + nodes=[cond_node], + opset_imports={"": 20}, + ) + main_graph.sort() + model = ir.Model( + graph=main_graph, + ir_version=10, + ) + + # Perform extract graph pass + extract_pass = ExtractGraphPass(input_names=["input"], output_names=[cond_node.outputs[0].name]) + result = extract_pass(model) + self.assertTrue(result.modified) + self.assertEqual(len(result.model.graph.nodes), 1) + self.assertEqual(result.model.graph.nodes[0].op_type, "If") + self.assertEqual(len(result.model.graph.nodes[0].attributes["then_branch"].nodes), 2) + self.assertEqual(len(result.model.graph.nodes[0].attributes["else_branch"].nodes), 2) + + def test_extract_partial_subgraph(self): + inputs = [ + ir.Value(name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))), + ir.Value(name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((2, 3))), + ] + + add_node = ir.node("Add", inputs=inputs) + mul_node = ir.node("Mul", inputs=[add_node.outputs[0], inputs[1]]) + sub_node = ir.node("Sub", inputs=[mul_node.outputs[0], inputs[0]]) + + model = ir.Model( + graph=ir.Graph( + inputs=inputs, + outputs=sub_node.outputs, + nodes=[add_node, mul_node, sub_node], + opset_imports={"": 20}, + ), + ir_version=10, + ) + + # Perform extract graph pass + extract_pass = ExtractGraphPass(input_names=["input_a"], output_names=[mul_node.outputs[0].name]) + result = extract_pass(model) + self.assertTrue(result.modified) + self.assertEqual(len(result.model.graph.nodes), 2) + self.assertEqual(result.model.graph.nodes[0].op_type, "Add") + self.assertEqual(result.model.graph.nodes[1].op_type, "Mul") + + +if __name__ == "__main__": + unittest.main()