|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | + |
| 18 | +""" |
| 19 | +.. _tutorial-byoc-npu-example: |
| 20 | +
|
| 21 | +Bring Your Own Codegen: NPU Backend Example |
| 22 | +=========================================== |
| 23 | +**Author**: `Sheldon Aristide <https://github.com/Aristide021/>`_ |
| 24 | +
|
| 25 | +This tutorial walks through the example NPU BYOC backend included in TVM. |
| 26 | +It demonstrates the key concepts needed to offload operations to a custom |
| 27 | +accelerator: pattern registration, graph partitioning, codegen, and runtime |
| 28 | +dispatch. |
| 29 | +
|
| 30 | +NPUs are purpose-built accelerators designed around a fixed set of operations |
| 31 | +common in neural network inference, such as matrix multiplication, convolution, |
| 32 | +and activation functions. |
| 33 | +The example backend uses CPU emulation so no real NPU hardware is required. |
| 34 | +
|
| 35 | +**Prerequisites**: Build TVM with ``USE_EXAMPLE_NPU_CODEGEN=ON`` and |
| 36 | +``USE_EXAMPLE_NPU_RUNTIME=ON``. |
| 37 | +""" |
| 38 | + |
| 39 | +###################################################################### |
| 40 | +# Overview of the BYOC Flow |
| 41 | +# ------------------------- |
| 42 | +# |
| 43 | +# The BYOC framework lets you plug a custom backend into TVM's compilation |
| 44 | +# pipeline in four steps: |
| 45 | +# |
| 46 | +# 1. **Register patterns** - describe which sequences of Relax ops the |
| 47 | +# backend can handle. |
| 48 | +# 2. **Partition the graph** - group matched ops into composite functions. |
| 49 | +# 3. **Run codegen** - lower composite functions to backend-specific |
| 50 | +# representation (JSON graph for the example NPU). |
| 51 | +# 4. **Execute** - the runtime dispatches composite functions to the |
| 52 | +# registered backend runtime. |
| 53 | + |
| 54 | +###################################################################### |
| 55 | +# Step 1: Import the backend to register its patterns |
| 56 | +# --------------------------------------------------- |
| 57 | +# |
| 58 | +# Importing the module is enough to register all supported patterns with |
| 59 | +# TVM's pattern registry. |
| 60 | + |
| 61 | +import tvm |
| 62 | +import tvm.relax.backend.contrib.example_npu # registers patterns |
| 63 | +from tvm import relax |
| 64 | +from tvm.relax.backend.pattern_registry import get_patterns_with_prefix |
| 65 | +from tvm.relax.transform import FuseOpsByPattern, MergeCompositeFunctions, RunCodegen |
| 66 | +from tvm.script import relax as R |
| 67 | + |
| 68 | +has_example_npu_codegen = tvm.get_global_func("relax.ext.example_npu", True) |
| 69 | +has_example_npu_runtime = tvm.get_global_func("runtime.ExampleNPUJSONRuntimeCreate", True) |
| 70 | +has_example_npu = has_example_npu_codegen and has_example_npu_runtime |
| 71 | + |
| 72 | +patterns = get_patterns_with_prefix("example_npu") |
| 73 | +print("Registered patterns:", [p.name for p in patterns]) |
| 74 | + |
| 75 | +###################################################################### |
| 76 | +# Step 2: Define a model |
| 77 | +# ---------------------- |
| 78 | +# |
| 79 | +# We use a simple MatMul + ReLU module to illustrate the flow. |
| 80 | + |
| 81 | + |
| 82 | +@tvm.script.ir_module |
| 83 | +class MatmulReLU: |
| 84 | + @R.function |
| 85 | + def main( |
| 86 | + x: R.Tensor((2, 4), "float32"), |
| 87 | + w: R.Tensor((4, 8), "float32"), |
| 88 | + ) -> R.Tensor((2, 8), "float32"): |
| 89 | + with R.dataflow(): |
| 90 | + y = relax.op.matmul(x, w) |
| 91 | + z = relax.op.nn.relu(y) |
| 92 | + R.output(z) |
| 93 | + return z |
| 94 | + |
| 95 | + |
| 96 | +###################################################################### |
| 97 | +# Step 3: Partition the graph |
| 98 | +# --------------------------- |
| 99 | +# |
| 100 | +# ``FuseOpsByPattern`` groups ops that match a registered pattern into |
| 101 | +# composite functions. ``MergeCompositeFunctions`` consolidates them |
| 102 | +# so each group becomes a single external call. |
| 103 | + |
| 104 | +mod = MatmulReLU |
| 105 | +mod = FuseOpsByPattern(patterns, bind_constants=False, annotate_codegen=True)(mod) |
| 106 | +mod = MergeCompositeFunctions()(mod) |
| 107 | +print("After partitioning:") |
| 108 | +print(mod) |
| 109 | + |
| 110 | +###################################################################### |
| 111 | +# Step 4: Run codegen |
| 112 | +# ------------------- |
| 113 | +# |
| 114 | +# ``RunCodegen`` lowers each annotated composite function to the backend's |
| 115 | +# serialization format. For the example NPU this produces a JSON graph |
| 116 | +# that the C++ runtime can execute. |
| 117 | +# |
| 118 | +# Steps 4 and 5 require TVM to be built with ``USE_EXAMPLE_NPU_CODEGEN=ON`` |
| 119 | +# and ``USE_EXAMPLE_NPU_RUNTIME=ON``. |
| 120 | + |
| 121 | +if has_example_npu: |
| 122 | + mod = RunCodegen()(mod) |
| 123 | + print("After codegen:") |
| 124 | + print(mod) |
| 125 | + |
| 126 | + ###################################################################### |
| 127 | + # Step 5: Build and run |
| 128 | + # --------------------- |
| 129 | + # |
| 130 | + # Build the module for the host target, create a virtual machine, and |
| 131 | + # execute the compiled function. |
| 132 | + |
| 133 | + import numpy as np |
| 134 | + |
| 135 | + np.random.seed(0) |
| 136 | + x_np = np.random.randn(2, 4).astype("float32") |
| 137 | + w_np = np.random.randn(4, 8).astype("float32") |
| 138 | + |
| 139 | + target = tvm.target.Target("llvm") |
| 140 | + with tvm.transform.PassContext(opt_level=3): |
| 141 | + built = relax.build(mod, target) |
| 142 | + |
| 143 | + vm = relax.VirtualMachine(built, tvm.cpu()) |
| 144 | + result = vm["main"](tvm.runtime.tensor(x_np, tvm.cpu()), tvm.runtime.tensor(w_np, tvm.cpu())) |
| 145 | + |
| 146 | + expected_shape = (2, 8) |
| 147 | + assert result.numpy().shape == expected_shape |
| 148 | + print("Execution completed. Output shape:", result.numpy().shape) |
| 149 | + |
| 150 | +###################################################################### |
| 151 | +# Step 6: Conv2D + ReLU |
| 152 | +# --------------------- |
| 153 | +# |
| 154 | +# The same flow applies to convolution workloads. |
| 155 | + |
| 156 | + |
| 157 | +@tvm.script.ir_module |
| 158 | +class Conv2dReLU: |
| 159 | + @R.function |
| 160 | + def main( |
| 161 | + x: R.Tensor((1, 3, 32, 32), "float32"), |
| 162 | + w: R.Tensor((16, 3, 3, 3), "float32"), |
| 163 | + ) -> R.Tensor((1, 16, 30, 30), "float32"): |
| 164 | + with R.dataflow(): |
| 165 | + y = relax.op.nn.conv2d(x, w) |
| 166 | + z = relax.op.nn.relu(y) |
| 167 | + R.output(z) |
| 168 | + return z |
| 169 | + |
| 170 | + |
| 171 | +if has_example_npu: |
| 172 | + mod2 = Conv2dReLU |
| 173 | + mod2 = FuseOpsByPattern(patterns, bind_constants=False, annotate_codegen=True)(mod2) |
| 174 | + mod2 = MergeCompositeFunctions()(mod2) |
| 175 | + mod2 = RunCodegen()(mod2) |
| 176 | + |
| 177 | + with tvm.transform.PassContext(opt_level=3): |
| 178 | + built2 = relax.build(mod2, target) |
| 179 | + |
| 180 | + print("Conv2dReLU compiled successfully.") |
| 181 | + |
| 182 | +###################################################################### |
| 183 | +# Next steps |
| 184 | +# ---------- |
| 185 | +# |
| 186 | +# To build a real NPU backend using this example as a starting point: |
| 187 | +# |
| 188 | +# - Replace ``example_npu_runtime.cc`` with your hardware SDK calls. |
| 189 | +# - Extend ``patterns.py`` with the ops your hardware supports. |
| 190 | +# - Add a C++ codegen under ``src/relax/backend/contrib/`` if your |
| 191 | +# hardware requires a non-JSON serialization format. |
| 192 | +# - Add your cmake module under ``cmake/modules/contrib/`` following |
| 193 | +# the pattern in ``cmake/modules/contrib/ExampleNPU.cmake``. |
0 commit comments