Skip to content

Commit 8ea0131

Browse files
authored
Merge branch 'main' into add-mlx-op-handler-aten-isnan
2 parents b55b95c + 7fdd306 commit 8ea0131

35 files changed

Lines changed: 1500 additions & 116 deletions

.ci/scripts/export_model_artifact.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ case "$HF_MODEL" in
184184
PREPROCESSOR_FEATURE_SIZE=""
185185
PREPROCESSOR_OUTPUT=""
186186
;;
187-
SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4)
187+
SocialLocalMobile/Qwen3.6-35B-A3B-HQQ-INT4)
188188
MODEL_NAME="qwen3_5_moe"
189189
TASK=""
190190
MAX_SEQ_LEN=""
@@ -194,7 +194,7 @@ case "$HF_MODEL" in
194194
;;
195195
*)
196196
echo "Error: Unsupported model '$HF_MODEL'"
197-
echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, openai/whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}, google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/diar_streaming_sortformer_4spk-v2, nvidia/parakeet-tdt, facebook/dinov2-small-imagenet1k-1-layer, SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4"
197+
echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, openai/whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}, google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/diar_streaming_sortformer_4spk-v2, nvidia/parakeet-tdt, facebook/dinov2-small-imagenet1k-1-layer, SocialLocalMobile/Qwen3.6-35B-A3B-HQQ-INT4"
198198
exit 1
199199
;;
200200
esac

.ci/scripts/test_model_e2e.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ case "$HF_MODEL" in
216216
AUDIO_FILE="test_audio.wav"
217217
IMAGE_PATH=""
218218
;;
219-
SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4)
219+
SocialLocalMobile/Qwen3.6-35B-A3B-HQQ-INT4)
220220
MODEL_NAME="qwen3_5_moe"
221221
RUNNER_TARGET="qwen3_5_moe_runner"
222222
RUNNER_PATH="qwen3_5_moe"
@@ -230,7 +230,7 @@ case "$HF_MODEL" in
230230
;;
231231
*)
232232
echo "Error: Unsupported model '$HF_MODEL'"
233-
echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, nvidia/diar_streaming_sortformer_4spk-v2, openai/whisper series (whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}), google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/parakeet-tdt, facebook/dinov2-small-imagenet1k-1-layer, SocialLocalMobile/Qwen3.5-35B-A3B-HQQ-INT4"
233+
echo "Supported models: mistralai/Voxtral-Mini-3B-2507, mistralai/Voxtral-Mini-4B-Realtime-2602, nvidia/diar_streaming_sortformer_4spk-v2, openai/whisper series (whisper-{small, medium, large, large-v2, large-v3, large-v3-turbo}), google/gemma-3-4b-it, Qwen/Qwen3-0.6B, nvidia/parakeet-tdt, facebook/dinov2-small-imagenet1k-1-layer, SocialLocalMobile/Qwen3.6-35B-A3B-HQQ-INT4"
234234
exit 1
235235
;;
236236
esac

.github/workflows/cuda.yml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ jobs:
180180
- repo: "facebook"
181181
name: "dinov2-small-imagenet1k-1-layer"
182182
- repo: "SocialLocalMobile"
183-
name: "Qwen3.5-35B-A3B-HQQ-INT4"
183+
name: "Qwen3.6-35B-A3B-HQQ-INT4"
184184
quant:
185185
- "non-quantized"
186186
- "quantized-int4-tile-packed"
@@ -194,11 +194,11 @@ jobs:
194194
# Qwen3.5 MoE uses a prequantized checkpoint, only tile-packed
195195
- model:
196196
repo: "SocialLocalMobile"
197-
name: "Qwen3.5-35B-A3B-HQQ-INT4"
197+
name: "Qwen3.6-35B-A3B-HQQ-INT4"
198198
quant: "non-quantized"
199199
- model:
200200
repo: "SocialLocalMobile"
201-
name: "Qwen3.5-35B-A3B-HQQ-INT4"
201+
name: "Qwen3.6-35B-A3B-HQQ-INT4"
202202
quant: "quantized-int4-weight-only"
203203
# Voxtral Realtime only supports int4-tile-packed on CUDA
204204
- model:
@@ -254,7 +254,7 @@ jobs:
254254
with:
255255
timeout: 90
256256
secrets-env: EXECUTORCH_HF_TOKEN
257-
runner: ${{ matrix.model.name == 'Qwen3.5-35B-A3B-HQQ-INT4' && 'linux.aws.a100' || 'linux.g5.4xlarge.nvidia.gpu' }}
257+
runner: ${{ matrix.model.name == 'Qwen3.6-35B-A3B-HQQ-INT4' && 'linux.aws.a100' || 'linux.g5.4xlarge.nvidia.gpu' }}
258258
gpu-arch-type: cuda
259259
gpu-arch-version: 12.6
260260
use-custom-docker-registry: false
@@ -310,7 +310,7 @@ jobs:
310310
- repo: "facebook"
311311
name: "dinov2-small-imagenet1k-1-layer"
312312
- repo: "SocialLocalMobile"
313-
name: "Qwen3.5-35B-A3B-HQQ-INT4"
313+
name: "Qwen3.6-35B-A3B-HQQ-INT4"
314314
quant:
315315
- "non-quantized"
316316
- "quantized-int4-tile-packed"
@@ -324,11 +324,11 @@ jobs:
324324
# Qwen3.5 MoE uses a prequantized checkpoint, only tile-packed
325325
- model:
326326
repo: "SocialLocalMobile"
327-
name: "Qwen3.5-35B-A3B-HQQ-INT4"
327+
name: "Qwen3.6-35B-A3B-HQQ-INT4"
328328
quant: "non-quantized"
329329
- model:
330330
repo: "SocialLocalMobile"
331-
name: "Qwen3.5-35B-A3B-HQQ-INT4"
331+
name: "Qwen3.6-35B-A3B-HQQ-INT4"
332332
quant: "quantized-int4-weight-only"
333333
# Voxtral Realtime only supports int4-tile-packed on CUDA
334334
- model:
@@ -378,7 +378,7 @@ jobs:
378378
quant: "non-quantized"
379379
with:
380380
timeout: 90
381-
runner: ${{ matrix.model.name == 'Qwen3.5-35B-A3B-HQQ-INT4' && 'linux.aws.a100' || 'linux.g5.4xlarge.nvidia.gpu' }}
381+
runner: ${{ matrix.model.name == 'Qwen3.6-35B-A3B-HQQ-INT4' && 'linux.aws.a100' || 'linux.g5.4xlarge.nvidia.gpu' }}
382382
gpu-arch-type: cuda
383383
gpu-arch-version: 12.6
384384
use-custom-docker-registry: false

backends/apple/mps/runtime/MPSDevice.mm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
138138
ET_CHECK_OR_RETURN_ERROR(
139139
err == Error::Ok,
140140
Internal,
141-
"An error occured occured while compiling library %d", libraryType
141+
"An error occurred while compiling library %d", libraryType
142142
);
143143
}
144144
if (_m_pso_cache.find(kernelName) == _m_pso_cache.end()) {

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import math
7+
import operator
78
from copy import copy
89
from typing import cast, Dict, Optional, Set, Tuple, Type
910

@@ -34,22 +35,67 @@ class InsertRescalePass(ArmPass):
3435

3536
_passes_required_after: Set[Type[ExportPass]] = set()
3637

38+
def _ensure_uint8_io_only(self, graph_module: GraphModule) -> None:
39+
"""Ensure uint8 tensors only appear at IO boundaries.
40+
41+
TOSA has no true uint8 tensor type; unsigned semantics are carried via
42+
RESCALE input/output flags. If uint8 appears for other nodes, it means
43+
unsigned data leaked past IO.
44+
45+
"""
46+
for node in graph_module.graph.nodes:
47+
meta_val = node.meta.get("val")
48+
if not isinstance(meta_val, torch.Tensor):
49+
continue
50+
if meta_val.dtype != torch.uint8:
51+
continue
52+
if node.op in ("placeholder", "output"):
53+
continue
54+
if node.op == "call_function" and node.target == operator.getitem:
55+
if all(user.op == "output" for user in node.users):
56+
continue
57+
if (
58+
node.op == "call_function"
59+
and node.target
60+
== exir_ops.edge.dim_order_ops._to_dim_order_copy.default
61+
):
62+
# dim_order is a view-like transform; allow it to preserve uint8 at IO.
63+
continue
64+
if (
65+
node.op == "call_function"
66+
and node.target == exir_ops.backend.tosa.RESCALE.default
67+
):
68+
continue
69+
raise ValueError(
70+
f"Found internal uint8 tensor at node {node.name} "
71+
f"({node.target}). Uint8 is only allowed at IO boundaries."
72+
)
73+
3774
def fold_dq_q_to_rescale(self, node: Node, user: Node, graph_module: GraphModule):
3875
dq_args = QuantArgs.from_operator(node.target, node.args)
3976
q_args = QuantArgs.from_operator(user.target, user.args)
4077
new_scale = dq_args.scale / q_args.scale
78+
input_unsigned = dq_args.dtype == torch.uint8
79+
output_unsigned = q_args.dtype == torch.uint8
80+
# TOSA has no true uint8 tensors; unsigned semantics are handled via
81+
# the RESCALE flags, so uint8 does not propagate as a tensor dtype.
82+
output_dtype = torch.int8 if output_unsigned else q_args.dtype
4183

4284
with graph_module.graph.inserting_before(node):
4385
rescale_node = create_node(
4486
graph_module.graph,
4587
exir_ops.backend.tosa.RESCALE.default,
4688
(
4789
node.all_input_nodes[0],
48-
q_args.dtype,
90+
output_dtype,
4991
[new_scale],
5092
dq_args.zp,
5193
q_args.zp,
5294
),
95+
kwargs={
96+
"input_unsigned": input_unsigned,
97+
"output_unsigned": output_unsigned,
98+
},
5399
)
54100
rescale_node.meta = copy(user.meta)
55101
user.replace_all_uses_with(rescale_node)
@@ -74,6 +120,9 @@ def call(self, graph_module: GraphModule) -> PassResult:
74120
graph_module.recompile()
75121
return PassResult(graph_module, modified)
76122

123+
def ensures(self, graph_module: GraphModule) -> None:
124+
self._ensure_uint8_io_only(graph_module)
125+
77126

78127
class InsertRescaleInt32Pass(ArmPass):
79128
"""Numerous TOSA ops require inputs and outputs to be 32-bit integers in

backends/arm/common/arm_compile_spec.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class DebugMode(Enum):
3636
compiler_flags: list[str] = field(default_factory=list)
3737
path_for_intermediates: str | None = None
3838
tosa_debug_mode: DebugMode | None = None
39+
preserve_io_quantization: bool = False
3940

4041
_TOSA_SPEC_KEY = "tosa_spec"
4142
_COMPILE_FLAGS_KEY = "compile_flags"
@@ -44,6 +45,7 @@ class DebugMode(Enum):
4445
_DEBUG_MODE_KEY = "dump_debug_info"
4546
_OUTPUT_REORDER_KEY = "ouput_reorder_workaround"
4647
_TRANSFORM_PIPELINE_CONFIG_KEY = "transform_pipeline_config"
48+
_PRESERVE_IO_QUANT_KEY = "preserve_io_quantization"
4749

4850
def _set_compile_specs(
4951
self,
@@ -53,6 +55,7 @@ def _set_compile_specs(
5355
tosa_debug_mode: DebugMode | None = None,
5456
output_order_workaround: bool = False,
5557
pipeline_config: ArmPassPipelineConfig | None = None,
58+
preserve_io_quantization: bool = False,
5659
):
5760
"""Set all values of dataclass directly."""
5861
self.tosa_spec = tosa_spec
@@ -61,6 +64,8 @@ def _set_compile_specs(
6164
self.tosa_debug_mode = tosa_debug_mode
6265
self._pipeline_config = pipeline_config
6366
self.output_order_workaround = output_order_workaround
67+
self.preserve_io_quantization = preserve_io_quantization
68+
self._warn_if_redundant_preserve_io_quantization()
6469
if output_order_workaround:
6570
warnings.warn(
6671
"ArmCompileSpec(output_order_workaround=True) is deprecated and will be "
@@ -78,6 +83,7 @@ def _from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
7883
tosa_debug_mode: ArmCompileSpec.DebugMode | None = None
7984
output_order_workaround: bool = False
8085
pipeline_config: ArmPassPipelineConfig | None = None
86+
preserve_io_quantization: bool = False
8187
unknown_specs: dict[str, str] = {}
8288
for spec in compile_specs:
8389
key = spec.key
@@ -128,6 +134,8 @@ def _from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
128134
"More than one transform pipeline entry in compile spec."
129135
)
130136
pipeline_config = ArmPassPipelineConfig.from_dict(json.loads(val))
137+
elif key == ArmCompileSpec._PRESERVE_IO_QUANT_KEY:
138+
preserve_io_quantization = str(val).lower() in ("1", "true", "yes")
131139
else:
132140
unknown_specs[key] = val
133141

@@ -151,6 +159,7 @@ def _from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
151159
tosa_debug_mode=tosa_debug_mode,
152160
output_order_workaround=output_order_workaround,
153161
pipeline_config=pipeline_config,
162+
preserve_io_quantization=preserve_io_quantization,
154163
)
155164
cls._from_list_hook(compile_spec, unknown_specs)
156165
compile_spec._validate()
@@ -227,8 +236,35 @@ def _to_list(self):
227236
self._pipeline_config.serialize(),
228237
)
229238
)
239+
compile_spec.append(
240+
CompileSpec(
241+
ArmCompileSpec._PRESERVE_IO_QUANT_KEY,
242+
str(bool(self.preserve_io_quantization)).encode(),
243+
)
244+
)
230245
return compile_spec
231246

247+
def _set_preserve_io_quantization(self, enabled: bool) -> "ArmCompileSpec":
248+
"""Preserve Q/DQ nodes at IO boundaries when lowering."""
249+
self.preserve_io_quantization = enabled
250+
self._warn_if_redundant_preserve_io_quantization()
251+
return self
252+
253+
def _warn_if_redundant_preserve_io_quantization(self) -> None:
254+
"""Warn when preserve_io_quantization has no effect for INT-only
255+
specs.
256+
"""
257+
if (
258+
self.preserve_io_quantization
259+
and self.tosa_spec.support_integer()
260+
and not self.tosa_spec.support_float()
261+
):
262+
warnings.warn(
263+
"preserve_io_quantization=True is redundant for INT-only TOSA "
264+
"specifications because boundary Q/DQ are already de-tagged.",
265+
stacklevel=3,
266+
)
267+
232268
def _get_pass_pipeline_config(self) -> ArmPassPipelineConfig:
233269
"""Returns configuration that controls how the Arm pass pipeline should
234270
behave.

backends/arm/operators/op_tosa_rescale.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ def _build_rescale(
161161
rounding_mode: ts.RoundingMode,
162162
per_channel: bool = False,
163163
is_scale32: bool = True,
164+
input_unsigned: bool = False,
165+
output_unsigned: bool = False,
164166
):
165167
"""Insert a TOSA RESCALE operator configured for the quantized path.
166168
@@ -198,8 +200,8 @@ def _build_rescale(
198200
scale32=is_scale32,
199201
rounding_mode=rounding_mode,
200202
per_channel=per_channel,
201-
input_unsigned=False,
202-
output_unsigned=False,
203+
input_unsigned=input_unsigned,
204+
output_unsigned=output_unsigned,
203205
)
204206

205207
tosa_fb.addOperator(
@@ -228,6 +230,14 @@ def define_node(
228230
scales = cast(list[float], node.args[2])
229231
input_zp = cast(int, node.args[3])
230232
output_zp = cast(int, node.args[4])
233+
if "input_unsigned" in node.kwargs:
234+
input_unsigned = cast(bool, node.kwargs.get("input_unsigned", False))
235+
else:
236+
input_unsigned = cast(bool, node.args[5]) if len(node.args) > 5 else False
237+
if "output_unsigned" in node.kwargs:
238+
output_unsigned = cast(bool, node.kwargs.get("output_unsigned", False))
239+
else:
240+
output_unsigned = cast(bool, node.args[6]) if len(node.args) > 6 else False
231241

232242
if (
233243
input_dtype
@@ -244,7 +254,6 @@ def define_node(
244254
raise ValueError(
245255
f"If output dtype is not int8 or int16, output_zp must be 0. Got {ts.DTypeNames[output_dtype]}, {output_zp=}"
246256
)
247-
248257
_build_rescale(
249258
tosa_graph,
250259
scale=scales,
@@ -255,4 +264,6 @@ def define_node(
255264
output_zp=[output_zp],
256265
rounding_mode=ts.RoundingMode.SINGLE_ROUND,
257266
per_channel=len(scales) > 1,
267+
input_unsigned=input_unsigned,
268+
output_unsigned=output_unsigned,
258269
)

backends/arm/quantizer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
EthosUQuantizer,
1616
get_symmetric_a16w8_quantization_config,
1717
get_symmetric_quantization_config,
18+
get_uint8_io_quantization_config,
1819
TOSAQuantizer,
1920
VgfQuantizer,
2021
)

0 commit comments

Comments
 (0)