diff --git a/backends/arm/common/arm_compile_spec.py b/backends/arm/common/arm_compile_spec.py index 2d3948beeb1..5a4077f1d47 100644 --- a/backends/arm/common/arm_compile_spec.py +++ b/backends/arm/common/arm_compile_spec.py @@ -36,6 +36,7 @@ class DebugMode(Enum): compiler_flags: list[str] = field(default_factory=list) path_for_intermediates: str | None = None tosa_debug_mode: DebugMode | None = None + tosa_dev_mode: bool | None = None _TOSA_SPEC_KEY = "tosa_spec" _COMPILE_FLAGS_KEY = "compile_flags" @@ -44,6 +45,7 @@ class DebugMode(Enum): _DEBUG_MODE_KEY = "dump_debug_info" _OUTPUT_REORDER_KEY = "ouput_reorder_workaround" _TRANSFORM_PIPELINE_CONFIG_KEY = "transform_pipeline_config" + _TOSA_DEV_MODE = "tosa_sw_dev_mode" def _set_compile_specs( self, @@ -53,6 +55,7 @@ def _set_compile_specs( tosa_debug_mode: DebugMode | None = None, output_order_workaround: bool = False, pipeline_config: ArmPassPipelineConfig | None = None, + tosa_dev_mode: bool | None = None, ): """Set all values of dataclass directly.""" self.tosa_spec = tosa_spec @@ -61,6 +64,7 @@ def _set_compile_specs( self.tosa_debug_mode = tosa_debug_mode self._pipeline_config = pipeline_config self.output_order_workaround = output_order_workaround + self.tosa_dev_mode = tosa_dev_mode if output_order_workaround: warnings.warn( "ArmCompileSpec(output_order_workaround=True) is deprecated and will be " @@ -78,6 +82,7 @@ def _from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901 tosa_debug_mode: ArmCompileSpec.DebugMode | None = None output_order_workaround: bool = False pipeline_config: ArmPassPipelineConfig | None = None + tosa_dev_mode: bool | None = None unknown_specs: dict[str, str] = {} for spec in compile_specs: key = spec.key @@ -128,6 +133,20 @@ def _from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901 "More than one transform pipeline entry in compile spec." ) pipeline_config = ArmPassPipelineConfig.from_dict(json.loads(val)) + elif key == ArmCompileSpec._TOSA_DEV_MODE: + if tosa_dev_mode is not None: + raise ValueError( + "More than one tosa_sw_dev_mode entry in compile spec." + ) + raw = bytes(spec.value) + if raw == b"\x01": + tosa_dev_mode = True + elif raw == b"\x00": + tosa_dev_mode = False + else: + raise ValueError( + f"Invalid tosa_sw_dev_mode byte value: {raw!r}, expected b'\\x00' or b'\\x01'." + ) else: unknown_specs[key] = val @@ -151,6 +170,7 @@ def _from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901 tosa_debug_mode=tosa_debug_mode, output_order_workaround=output_order_workaround, pipeline_config=pipeline_config, + tosa_dev_mode=tosa_dev_mode, ) cls._from_list_hook(compile_spec, unknown_specs) compile_spec._validate() @@ -227,6 +247,15 @@ def _to_list(self): self._pipeline_config.serialize(), ) ) + + if self.tosa_dev_mode is not None: + compile_spec.append( + CompileSpec( + ArmCompileSpec._TOSA_DEV_MODE, + b"\x01" if self.tosa_dev_mode else b"\x00", + ) + ) + return compile_spec def _get_pass_pipeline_config(self) -> ArmPassPipelineConfig: @@ -290,6 +319,16 @@ def dump_debug_info(self, debug_mode: DebugMode | None): self.tosa_debug_mode = debug_mode return self + def _set_tosa_dev_mode(self, tosa_dev_mode: bool): + """Sets whether to enable TOSA software development mode. + + Args: + tosa_dev_mode: Boolean indicating whether to enable TOSA software development mode. + + """ + self.tosa_dev_mode = tosa_dev_mode + return self + @deprecated( "set_output_order_workaround() is deprecated and will be removed in v1.5; please remove this call." ) diff --git a/backends/arm/tosa/backend.py b/backends/arm/tosa/backend.py index 0d1dfb4dfa1..6d44c87ac4e 100644 --- a/backends/arm/tosa/backend.py +++ b/backends/arm/tosa/backend.py @@ -227,6 +227,9 @@ def _preprocess( # noqa: C901 targetDraft=True if version.minor > 0 else False, ) + if compile_spec.tosa_dev_mode: + tosa_graph.setExperimentalDevVersion() + if not ( tosa_spec.version.major == ts.TOSA_VERSION_MAJOR and tosa_spec.version.minor <= ts.TOSA_VERSION_MINOR @@ -440,4 +443,5 @@ def filter_tosa_compile_specs( ) .dump_debug_info(compile_spec.tosa_debug_mode) .set_output_order_workaround(compile_spec.output_order_workaround) + ._set_tosa_dev_mode(compile_spec.tosa_dev_mode) ) diff --git a/backends/arm/tosa/schemas/tosa_1.1.fbs b/backends/arm/tosa/schemas/tosa_1.1.fbs new file mode 100644 index 00000000000..3538a9f99c7 --- /dev/null +++ b/backends/arm/tosa/schemas/tosa_1.1.fbs @@ -0,0 +1,701 @@ + +// Copyright (c) 2020-2026 Arm Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace tosa; + +// This corresponds to the version. +file_identifier "TOSA"; +// File extension of any written files. +file_extension "tosa"; + +// NOTE: New values added to the schema should be placed +// at the end of the list in order to keep schema stable. + +enum DType:uint32 { + UNKNOWN = 0, + BOOL, + INT4, + INT8, + INT16, + INT32, + INT48, + FP32, + FP16, + BF16, + SHAPE, + FP8E4M3, + FP8E5M2, + FP6E2M3, + FP6E3M2, + FP4E2M1, + FP8UE8M0, + INT64, + MXINT8, +} + +enum ResizeMode:uint32 { + UNKNOWN = 0, + NEAREST, + BILINEAR, +} + +enum NanPropagationMode:uint32 { + UNKNOWN = 0, + PROPAGATE, + IGNORE, +} + +enum RoundingMode:uint32 { + UNKNOWN = 0, + SINGLE_ROUND, + INEXACT_ROUND, + DOUBLE_ROUND +} + +enum BlockSize:uint32 { + UNKNOWN = 0, + BLOCK_SIZE_32 = 32, +} + +enum Op:uint32 { + UNKNOWN = 0, + ARGMAX, + AVG_POOL2D, + CONV2D, + CONV3D, + DEPTHWISE_CONV2D, + FFT2D, + MATMUL, + MAX_POOL2D, + RFFT2D, + TRANSPOSE_CONV2D, + CLAMP, + ERF, + SIGMOID, + TANH, + ADD, + ARITHMETIC_RIGHT_SHIFT, + BITWISE_AND, + BITWISE_OR, + BITWISE_XOR, + INTDIV, + LOGICAL_AND, + LOGICAL_LEFT_SHIFT, + LOGICAL_RIGHT_SHIFT, + LOGICAL_OR, + LOGICAL_XOR, + MAXIMUM, + MINIMUM, + MUL, + POW, + SUB, + TABLE, + ABS, + BITWISE_NOT, + CEIL, + CLZ, + COS, + EXP, + FLOOR, + LOG, + LOGICAL_NOT, + NEGATE, + RECIPROCAL, + RSQRT, + SIN, + SELECT, + EQUAL, + GREATER, + GREATER_EQUAL, + REDUCE_ALL, + REDUCE_ANY, + REDUCE_MAX, + REDUCE_MIN, + REDUCE_PRODUCT, + REDUCE_SUM, + CONCAT, + PAD, + RESHAPE, + REVERSE, + SLICE, + TILE, + TRANSPOSE, + GATHER, + SCATTER, + RESIZE, + CAST, + RESCALE, + CONST, + IDENTITY, + CUSTOM, + COND_IF, + WHILE_LOOP, + VARIABLE, + VARIABLE_WRITE, + VARIABLE_READ, + CONST_SHAPE, + MATMUL_T_BLOCK_SCALED, + CAST_FROM_BLOCK_SCALED, + CAST_TO_BLOCK_SCALED, + DIM, + CONCAT_SHAPE, + ADD_SHAPE, + SUB_SHAPE, + MUL_SHAPE, + SLICE_SHAPE, + EXP2_SHAPE, + LOG2_CEIL_SHAPE, + LOG2_FLOOR_SHAPE, + MAX_SHAPE, + MIN_SHAPE, + MOD_SHAPE, + DIV_CEIL_SHAPE, + DIV_FLOOR_SHAPE, + ASSERT_EQUAL_SHAPE, + CONV2D_BLOCK_SCALED, + MAX_POOL2D_ADAPTIVE, + AVG_POOL2D_ADAPTIVE +} + +union Attribute { + ArgMaxAttribute, + AvgPool2dAttribute, + Conv2dAttribute, + Conv3dAttribute, + DepthwiseConv2dAttribute, + FFT2dAttribute, + MatMulAttribute, + MaxPool2dAttribute, + RFFT2dAttribute, + TransposeConv2dAttribute, + ClampAttribute, + ErfAttribute, + SigmoidAttribute, + TanhAttribute, + AddAttribute, + ArithmeticRightShiftAttribute, + BitwiseAndAttribute, + BitwiseOrAttribute, + BitwiseXorAttribute, + IntDivAttribute, + LogicalAndAttribute, + LogicalLeftShiftAttribute, + LogicalRightShiftAttribute, + LogicalOrAttribute, + LogicalXorAttribute, + MaximumAttribute, + MinimumAttribute, + MulAttribute, + PowAttribute, + SubAttribute, + TableAttribute, + AbsAttribute, + BitwiseNotAttribute, + CeilAttribute, + ClzAttribute, + CosAttribute, + ExpAttribute, + FloorAttribute, + LogAttribute, + LogicalNotAttribute, + NegateAttribute, + ReciprocalAttribute, + RsqrtAttribute, + SinAttribute, + SelectAttribute, + EqualAttribute, + GreaterAttribute, + GreaterEqualAttribute, + ReduceAllAttribute, + ReduceAnyAttribute, + ReduceMaxAttribute, + ReduceMinAttribute, + ReduceProductAttribute, + ReduceSumAttribute, + ConcatAttribute, + PadAttribute, + ReshapeAttribute, + ReverseAttribute, + SliceAttribute, + TileAttribute, + TransposeAttribute, + GatherAttribute, + ScatterAttribute, + ResizeAttribute, + CastAttribute, + RescaleAttribute, + ConstAttribute, + IdentityAttribute, + CustomAttribute, + CondIfAttribute, + WhileLoopAttribute, + VariableAttribute, + VariableWriteAttribute, + VariableReadAttribute, + ConstShapeAttribute, + MatMulTBlockScaledAttribute, + CastFromBlockScaledAttribute, + CastToBlockScaledAttribute, + DimAttribute, + ConcatShapeAttribute, + AddShapeAttribute, + SubShapeAttribute, + MulShapeAttribute, + SliceShapeAttribute, + Exp2ShapeAttribute, + Log2CeilShapeAttribute, + Log2FloorShapeAttribute, + MaxShapeAttribute, + MinShapeAttribute, + ModShapeAttribute, + DivCeilShapeAttribute, + DivFloorShapeAttribute, + AssertEqualShapeAttribute, + Conv2dBlockScaledAttribute, + MaxPool2dAdaptiveAttribute, + AvgPool2dAdaptiveAttribute +} + +table ArgMaxAttribute { + axis: int32; + nan_mode: NanPropagationMode; +} + +table AvgPool2dAttribute { + kernel: [int32]; + stride: [int32]; + pad: [int32]; + acc_type: DType; +} + +table AvgPool2dAdaptiveAttribute { + acc_type: DType; +} + +table Conv2dAttribute { + pad: [int32]; + stride: [int32]; + dilation: [int32]; + local_bound: bool; + acc_type: DType; +} + +table Conv3dAttribute { + pad: [int32]; + stride: [int32]; + dilation: [int32]; + local_bound: bool; + acc_type: DType; +} + +table DepthwiseConv2dAttribute { + pad: [int32]; + stride: [int32]; + dilation: [int32]; + local_bound: bool; + acc_type: DType; +} + +table FFT2dAttribute { + inverse: bool; + local_bound: bool; +} + +table MatMulAttribute { +} + +table MaxPool2dAttribute { + kernel: [int32]; + stride: [int32]; + pad: [int32]; + nan_mode: NanPropagationMode; +} + +table MaxPool2dAdaptiveAttribute { + nan_mode: NanPropagationMode; +} + +table RFFT2dAttribute { + local_bound: bool; +} + +table TransposeConv2dAttribute { + out_pad: [int32]; + stride: [int32]; + local_bound: bool; + acc_type: DType; +} + +table ClampAttribute { + min_val: [ubyte] (force_align: 8); + max_val: [ubyte] (force_align: 8); + nan_mode: NanPropagationMode; +} + +table ErfAttribute { +} + +table SigmoidAttribute { +} + +table TanhAttribute { +} + +table AddAttribute { +} + +table ArithmeticRightShiftAttribute { + round: bool; +} + +table BitwiseAndAttribute { +} + +table BitwiseOrAttribute { +} + +table BitwiseXorAttribute { +} + +table IntDivAttribute { +} + +table LogicalAndAttribute { +} + +table LogicalLeftShiftAttribute { +} + +table LogicalRightShiftAttribute { +} + +table LogicalOrAttribute { +} + +table LogicalXorAttribute { +} + +table MaximumAttribute { + nan_mode: NanPropagationMode; +} + +table MinimumAttribute { + nan_mode: NanPropagationMode; +} + +table MulAttribute { +} + +table PowAttribute { +} + +table SubAttribute { +} + +table TableAttribute { +} + +table AbsAttribute { +} + +table BitwiseNotAttribute { +} + +table CeilAttribute { +} + +table ClzAttribute { +} + +table CosAttribute { +} + +table ExpAttribute { +} + +table FloorAttribute { +} + +table LogAttribute { +} + +table LogicalNotAttribute { +} + +table NegateAttribute { +} + +table ReciprocalAttribute { +} + +table RsqrtAttribute { +} + +table SinAttribute { +} + +table SelectAttribute { +} + +table EqualAttribute { +} + +table GreaterAttribute { +} + +table GreaterEqualAttribute { +} + +table ReduceAllAttribute { + axis: int32; +} + +table ReduceAnyAttribute { + axis: int32; +} + +table ReduceMaxAttribute { + axis: int32; + nan_mode: NanPropagationMode; +} + +table ReduceMinAttribute { + axis: int32; + nan_mode: NanPropagationMode; +} + +table ReduceProductAttribute { + axis: int32; +} + +table ReduceSumAttribute { + axis: int32; +} + +table ConcatAttribute { + axis: int32; +} + +table PadAttribute { +} + +table ReshapeAttribute { +} + +table ReverseAttribute { + axis: int32; +} + +table SliceAttribute { +} + +table TileAttribute { +} + +table TransposeAttribute { + perms: [int32]; +} + +table GatherAttribute { +} + +table ScatterAttribute { +} + +table ResizeAttribute { + mode: ResizeMode; +} + +table CastAttribute { +} + +table RescaleAttribute { + scale32: bool; + rounding_mode: RoundingMode; + per_channel: bool; + input_unsigned: bool; + output_unsigned: bool; +} + +table ConstAttribute { + // value is stored in output TosaTensor +} + +table IdentityAttribute { +} + +table CustomAttribute { + operator_name:string; + domain_name:string; + implementation_attrs:[ubyte]; +} + +table CondIfAttribute { + then_graph: string; + else_graph: string; +} + +table WhileLoopAttribute { + cond_graph: string; + body_graph: string; +} + +table VariableAttribute { +} + +table VariableWriteAttribute { +} + +table VariableReadAttribute { +} + +table ConstShapeAttribute { + // value is stored in output TosaTensor +} + +table MatMulTBlockScaledAttribute { + block_size: BlockSize; +} + +table CastFromBlockScaledAttribute { + block_size: BlockSize; +} + +table CastToBlockScaledAttribute { + block_size: BlockSize; +} + +table Conv2dBlockScaledAttribute { + block_size: BlockSize; +} + +table SoftwareVersion { + _major: int32 = -1; + _minor: int32 = -1; + _micro: int32 = -1; + _modifier: string; +} + +table DimAttribute { + axis: int32; +} + +table ConcatShapeAttribute { +} + +table AddShapeAttribute { +} + +table SubShapeAttribute { +} + +table MulShapeAttribute { +} + +table SliceShapeAttribute { +} + +table Exp2ShapeAttribute { +} + +table Log2CeilShapeAttribute { +} + +table Log2FloorShapeAttribute { +} + +table MaxShapeAttribute { +} + +table MinShapeAttribute { +} + +table ModShapeAttribute { +} + +table DivCeilShapeAttribute { +} + +table DivFloorShapeAttribute { +} + +table AssertEqualShapeAttribute { + allow_broadcast: bool; +} + + +table Version { + _major: int32 = -1; + _minor: int32 = -1; + _patch: int32 = -1; + _draft: bool = true; +} + +table TosaTensor { + name:string; // name of the tensor, used for solving dependency + shape:[int32]; // shape of the tensor + type:DType; // data type of the tensor + data: [ubyte] (force_align: 8); // raw data array if it's a constant tensor. + variable: bool; // is this a variable tensor + is_unranked: bool; // whether this is an unranked tensor + variable_name:string; // name for variable attribute + + // In a model that is larger than 2GB, then tensors instead uses the following + // attributes to find stored data, which is outside of flatbuffers + // the offset is calculated relative to the beginning of the file and is only + // valid if > 1. + offset: ulong; + size: ulong; +} + +table TosaShape { + name: string; // name of the shape + rank: uint32; // rank of the shape + data: [ubyte] (force_align: 8); // raw data array if it's a constant shape +} + +table OpLocation { + text: string; // Opaque string, interpretted by user +} + +table TosaOperator { + op:Op; // operator enum + attribute:Attribute; // union structure. operator attribute + inputs:[string]; // list of input tensor or shape names + outputs:[string]; // list of output tensor or shape names + location: OpLocation; // location of this Op in mlir +} + +table TosaBasicBlock { + name:string; // basic block name + operators:[TosaOperator]; // operators array + tensors:[TosaTensor]; // tensors array + inputs:[string]; // name of graph inputs + outputs:[string]; // name of graph outputs + shapes:[TosaShape]; // shapes array +} + +table TosaRegion { + name:string; // name of region + blocks:[TosaBasicBlock]; // basic blocks array +} + +table TosaGraph { + version:Version (required); + regions:[TosaRegion]; // regions array + software_version:SoftwareVersion; // cannot be required for back-compat +} + +root_type TosaGraph; diff --git a/backends/arm/vgf/compile_spec.py b/backends/arm/vgf/compile_spec.py index 034bb7af0db..e4bf80c705a 100644 --- a/backends/arm/vgf/compile_spec.py +++ b/backends/arm/vgf/compile_spec.py @@ -43,6 +43,8 @@ def __init__( if compiler_flags is None: compiler_flags = [] self._set_compile_specs(tosa_spec, compiler_flags) + # intermediate handling + self._set_tosa_dev_mode(True) self._validate() def _validate(self): diff --git a/examples/arm/setup.sh b/examples/arm/setup.sh index 1e579d8bc04..1d8523e7d3f 100755 --- a/examples/arm/setup.sh +++ b/examples/arm/setup.sh @@ -335,35 +335,12 @@ if [[ $is_script_sourced -eq 0 ]]; then CMAKE_POLICY_VERSION_MINIMUM=3.5 \ pip install --no-dependencies -r "$et_dir/backends/arm/requirements-arm-tosa.txt" - pushd "$root_dir" - if [[ ! -d "tosa-tools" ]]; then - git clone https://git.gitlab.arm.com/tosa/tosa-tools.git - fi - - pushd tosa-tools - git checkout v2025.11.2 - - if [[ ! -d "reference_model" ]]; then - log_step "main" "[error] Missing reference_model directory in tosa-tools repo." - exit 1 - fi - if [[ ! -d "serialization" ]]; then - log_step "main" "[error] Missing serialization directory in tosa-tools repo." - exit 1 - fi - export CMAKE_BUILD_PARALLEL_LEVEL="$(get_parallel_jobs)" CMAKE_POLICY_VERSION_MINIMUM=3.5 \ BUILD_PYBIND=1 \ BUILD_TOSA_REFERENCE_MODEL_TESTS=0 \ - pip install --no-dependencies ./reference_model - - CMAKE_POLICY_VERSION_MINIMUM=3.5 \ - BUILD_PYBIND=1 \ - pip install --no-dependencies ./serialization - popd - popd + pip install --no-dependencies git+https://git.gitlab.arm.com/tosa/tosa-tools.git@d6a1a0dea558d00af88241d39499b980bb70dae0 if [[ "${enable_vela}" -eq 1 ]]; then log_step "deps" "Installing Ethos-U Vela compiler"