Skip to content

Commit 796d6ff

Browse files
authored
Merge branch 'main' into uint8-io
2 parents c4dc7e6 + 2c545f8 commit 796d6ff

22 files changed

Lines changed: 871 additions & 50 deletions

backends/arm/ethosu/partitioner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
76
from typing import final, Optional, Sequence
87

8+
import torch
99
from executorch.backends.arm.ethosu import EthosUBackend, EthosUCompileSpec
1010
from executorch.backends.arm.tosa.partitioner import TOSAPartitioner
1111
from executorch.exir.backend.partitioner import DelegationSpec
@@ -33,3 +33,4 @@ def __init__(
3333
)
3434
self.additional_checks = additional_checks
3535
self.tosa_spec = compile_spec.tosa_spec
36+
self._custom_partition_ops: set[torch._ops.OpOverload] = set()

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
op_to_dim_order_copy,
5252
op_tosa_conv2d,
5353
op_tosa_conv3d,
54+
op_tosa_custom,
5455
op_tosa_depthwise_conv2d,
5556
op_tosa_gather,
5657
op_tosa_matmul,
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright 2026 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Any, List
7+
8+
import torch
9+
import tosa_serializer as ts
10+
11+
from executorch.backends.arm.operators.node_visitor import (
12+
NodeVisitor,
13+
register_node_visitor,
14+
)
15+
from executorch.backends.arm.tosa.mapping import TosaArg
16+
17+
18+
@register_node_visitor
19+
class CustomVisitor(NodeVisitor):
20+
"""Lower the TOSA CUSTOM op from the TOSA backend dialect."""
21+
22+
target = "tosa.CUSTOM.default"
23+
24+
def define_node(
25+
self,
26+
node: torch.fx.Node,
27+
tosa_graph: Any,
28+
inputs: List[TosaArg],
29+
output: TosaArg,
30+
) -> None:
31+
allowed_kwargs = {"operator_name", "domain_name", "implementation_attrs"}
32+
unexpected = set(node.kwargs.keys()) - allowed_kwargs
33+
if unexpected:
34+
raise ValueError(
35+
f"tosa.CUSTOM received unexpected kwargs: {sorted(unexpected)}"
36+
)
37+
38+
operator_name = node.kwargs.get("operator_name")
39+
domain_name = node.kwargs.get("domain_name")
40+
implementation_attrs = node.kwargs.get("implementation_attrs")
41+
42+
if operator_name is None or domain_name is None:
43+
raise ValueError(
44+
"tosa.CUSTOM requires operator_name and domain_name in kwargs"
45+
)
46+
47+
if implementation_attrs is None:
48+
impl_list = []
49+
elif isinstance(implementation_attrs, list):
50+
# NOTE: PyTorch schemas do not support a bytes type; we pass
51+
# implementation_attrs as int[] representing raw bytes.
52+
impl_list = [int(x) for x in implementation_attrs]
53+
else:
54+
raise TypeError(
55+
"implementation_attrs must be None or list[int]; "
56+
f"got {type(implementation_attrs)}"
57+
)
58+
59+
attr = ts.TosaSerializerAttribute()
60+
attr.CustomAttribute(
61+
operator_name=operator_name,
62+
domain_name=domain_name,
63+
implementation_attrs=impl_list,
64+
)
65+
66+
expanded = [TosaArg(item, self.tosa_spec) for item in inputs[0].special]
67+
input_names = [arg.name for arg in expanded]
68+
output_names = (
69+
output.multiple_output_names
70+
if getattr(output, "multiple_output_names", None)
71+
else [output.name]
72+
)
73+
if len(output_names) != 1:
74+
# TODO: Support multi-output CUSTOM ops with per-output meta/shape.
75+
raise ValueError(
76+
f"tosa.CUSTOM currently requires a single output, got {len(output_names)}"
77+
)
78+
self._serialize_operator(
79+
node,
80+
tosa_graph,
81+
ts.Op.CUSTOM,
82+
input_names,
83+
output_names,
84+
attr,
85+
)

backends/arm/public_api_manifests/api_manifest_running.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ signature = "EthosUPartitioner.ops_to_not_decompose(self, ep: torch.export.expor
5656
kind = "function"
5757
signature = "EthosUPartitioner.partition(self, exported_program: torch.export.exported_program.ExportedProgram) -> executorch.exir.backend.partitioner.PartitionResult"
5858

59+
[python.EthosUPartitioner.register_custom_partition_op]
60+
kind = "function"
61+
signature = "EthosUPartitioner.register_custom_partition_op(self, op: torch._ops.OpOverload) -> None"
62+
5963
[python.EthosUQuantizer]
6064
kind = "class"
6165
signature = "EthosUQuantizer(compile_spec: 'EthosUCompileSpec', use_composable_quantizer: 'bool' = False) -> 'None'"
@@ -136,6 +140,10 @@ signature = "VgfPartitioner.ops_to_not_decompose(self, ep: torch.export.exported
136140
kind = "function"
137141
signature = "VgfPartitioner.partition(self, exported_program: torch.export.exported_program.ExportedProgram) -> executorch.exir.backend.partitioner.PartitionResult"
138142

143+
[python.VgfPartitioner.register_custom_partition_op]
144+
kind = "function"
145+
signature = "VgfPartitioner.register_custom_partition_op(self, op: torch._ops.OpOverload) -> None"
146+
139147
[python.VgfQuantizer]
140148
kind = "class"
141149
signature = "VgfQuantizer(compile_spec: 'VgfCompileSpec', use_composable_quantizer: 'bool' = False) -> 'None'"

backends/arm/requirements-arm-tosa.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,5 @@ ml_dtypes == 0.5.1
77
flatbuffers == 24.3.25
88
tosa-adapter-model-explorer == 0.1.0
99
ai-edge-model-explorer >= 0.1.16
10-
# NOTE: Will be removed when tosa-tools is installed via pypi
11-
pybind11 == 2.10.4
12-
pytest-timeout == 2.4.0
10+
pytest-timeout == 2.4.0
11+
tosa-tools == 2026.2.1

backends/arm/requirements-arm-vgf.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
ai_ml_emulation_layer_for_vulkan == 0.8.0
7-
ai_ml_sdk_model_converter == 0.8.0
8-
ai_ml_sdk_vgf_library == 0.8.0
6+
ai_ml_emulation_layer_for_vulkan == 0.9.0
7+
ai_ml_sdk_model_converter == 0.9.0
8+
ai_ml_sdk_vgf_library == 0.9.0

backends/arm/runtime/VGFBackend.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,8 @@ class VGFBackend final : public ::executorch::runtime::BackendInterface {
157157
new (repr) VgfRepr(
158158
vk_instance, vk_physical_device, vk_device, vk_queue, vk_command_pool);
159159

160-
auto valid_vgf = repr->process_vgf(vgf_data, compile_specs);
160+
auto valid_vgf =
161+
repr->process_vgf(vgf_data, processed->size(), compile_specs);
161162
if (!valid_vgf) {
162163
ET_LOG(Error, "Failed to process VGF blob.");
163164
return Error::Internal;

backends/arm/runtime/VGFSetup.cpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2025 Arm Limited and/or its affiliates.
2+
* Copyright 2025-2026 Arm Limited and/or its affiliates.
33
*
44
* This source code is licensed under the BSD-style license found in the
55
* LICENSE file in the root directory of this source tree.
@@ -324,26 +324,38 @@ static void debug_print_modules(
324324
}
325325
}
326326

327-
bool VgfRepr::process_vgf(const char* vgf_data, ArrayRef<CompileSpec> specs) {
327+
bool VgfRepr::process_vgf(
328+
const char* vgf_data,
329+
size_t vgf_size,
330+
ArrayRef<CompileSpec> specs) {
328331
ET_LOG(Info, "Preparing VGF as Vulkan objects");
329332

330333
VkResult result;
331334

332335
// Prepare temporary decoders
333336
unique_ptr<vgflib::HeaderDecoder> header_decoder =
334-
vgflib::CreateHeaderDecoder(vgf_data);
337+
vgflib::CreateHeaderDecoder(vgf_data, vgflib::HeaderSize(), vgf_size);
338+
if (!header_decoder) {
339+
ET_LOG(Error, "Failed to create VGF header decoder");
340+
return false;
341+
}
342+
335343
unique_ptr<vgflib::ModelSequenceTableDecoder> sequence_decoder =
336344
vgflib::CreateModelSequenceTableDecoder(
337-
vgf_data + header_decoder->GetModelSequenceTableOffset());
345+
vgf_data + header_decoder->GetModelSequenceTableOffset(),
346+
header_decoder->GetModelSequenceTableSize());
338347
unique_ptr<vgflib::ModuleTableDecoder> module_decoder =
339348
vgflib::CreateModuleTableDecoder(
340-
vgf_data + header_decoder->GetModuleTableOffset());
349+
vgf_data + header_decoder->GetModuleTableOffset(),
350+
header_decoder->GetModuleTableSize());
341351
unique_ptr<vgflib::ModelResourceTableDecoder> resource_decoder =
342352
vgflib::CreateModelResourceTableDecoder(
343-
vgf_data + header_decoder->GetModelResourceTableOffset());
353+
vgf_data + header_decoder->GetModelResourceTableOffset(),
354+
header_decoder->GetModelResourceTableSize());
344355
unique_ptr<vgflib::ConstantDecoder> constant_decoder =
345356
vgflib::CreateConstantDecoder(
346-
vgf_data + header_decoder->GetConstantsOffset());
357+
vgf_data + header_decoder->GetConstantsOffset(),
358+
header_decoder->GetConstantsSize());
347359
// Check the VGF decoders
348360
if (not(header_decoder && module_decoder && sequence_decoder &&
349361
resource_decoder && constant_decoder && header_decoder->IsValid() &&

backends/arm/runtime/VGFSetup.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2025 Arm Limited and/or its affiliates.
2+
* Copyright 2025-2026 Arm Limited and/or its affiliates.
33
*
44
* This source code is licensed under the BSD-style license found in the
55
* LICENSE file in the root directory of this source tree.
@@ -58,7 +58,10 @@ class VgfRepr {
5858
/*
5959
* Process a VGF ready for execution, allocate necessary Vulkan objects.
6060
*/
61-
bool process_vgf(const char* vgf_data, ArrayRef<CompileSpec> specs);
61+
bool process_vgf(
62+
const char* vgf_data,
63+
size_t vgf_size,
64+
ArrayRef<CompileSpec> specs);
6265

6366
/*
6467
* Execute the VGF we've previously processed.

backends/arm/scripts/aot_arm_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -691,7 +691,7 @@ def _get_args():
691691
if args.evaluate is not None or args.evaluate_config is not None:
692692
logging.error(
693693
"Model evaluation is no longer supported in this script."
694-
" Ignore and continue."
694+
" Use evaluate_model.py instead. Ignore and continue."
695695
)
696696

697697
return args

0 commit comments

Comments
 (0)