Skip to content

Commit 7359cf2

Browse files
committed
Update on "Use caching allocator for runner (#15730)"
Summary: We observed that on iOS it improves perf by 6% because SDPA op does temp allocations. No significant difference on android though. ghstack-source-id: 328001114 exported-using-ghexport Reviewed By: navsud, derekdixu Differential Revision: D86120038 [ghstack-poisoned]
2 parents 704fb2e + 467774d commit 7359cf2

40 files changed

Lines changed: 1997 additions & 115 deletions

backends/apple/mps/runtime/MPSDevice.mm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
138138
ET_CHECK_OR_RETURN_ERROR(
139139
err == Error::Ok,
140140
Internal,
141-
"An error occured occured while compiling library %d", libraryType
141+
"An error occurred while compiling library %d", libraryType
142142
);
143143
}
144144
if (_m_pso_cache.find(kernelName) == _m_pso_cache.end()) {

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import math
7+
import operator
78
from copy import copy
89
from typing import cast, Dict, Optional, Set, Tuple, Type
910

@@ -34,22 +35,67 @@ class InsertRescalePass(ArmPass):
3435

3536
_passes_required_after: Set[Type[ExportPass]] = set()
3637

38+
def _ensure_uint8_io_only(self, graph_module: GraphModule) -> None:
39+
"""Ensure uint8 tensors only appear at IO boundaries.
40+
41+
TOSA has no true uint8 tensor type; unsigned semantics are carried via
42+
RESCALE input/output flags. If uint8 appears for other nodes, it means
43+
unsigned data leaked past IO.
44+
45+
"""
46+
for node in graph_module.graph.nodes:
47+
meta_val = node.meta.get("val")
48+
if not isinstance(meta_val, torch.Tensor):
49+
continue
50+
if meta_val.dtype != torch.uint8:
51+
continue
52+
if node.op in ("placeholder", "output"):
53+
continue
54+
if node.op == "call_function" and node.target == operator.getitem:
55+
if all(user.op == "output" for user in node.users):
56+
continue
57+
if (
58+
node.op == "call_function"
59+
and node.target
60+
== exir_ops.edge.dim_order_ops._to_dim_order_copy.default
61+
):
62+
# dim_order is a view-like transform; allow it to preserve uint8 at IO.
63+
continue
64+
if (
65+
node.op == "call_function"
66+
and node.target == exir_ops.backend.tosa.RESCALE.default
67+
):
68+
continue
69+
raise ValueError(
70+
f"Found internal uint8 tensor at node {node.name} "
71+
f"({node.target}). Uint8 is only allowed at IO boundaries."
72+
)
73+
3774
def fold_dq_q_to_rescale(self, node: Node, user: Node, graph_module: GraphModule):
3875
dq_args = QuantArgs.from_operator(node.target, node.args)
3976
q_args = QuantArgs.from_operator(user.target, user.args)
4077
new_scale = dq_args.scale / q_args.scale
78+
input_unsigned = dq_args.dtype == torch.uint8
79+
output_unsigned = q_args.dtype == torch.uint8
80+
# TOSA has no true uint8 tensors; unsigned semantics are handled via
81+
# the RESCALE flags, so uint8 does not propagate as a tensor dtype.
82+
output_dtype = torch.int8 if output_unsigned else q_args.dtype
4183

4284
with graph_module.graph.inserting_before(node):
4385
rescale_node = create_node(
4486
graph_module.graph,
4587
exir_ops.backend.tosa.RESCALE.default,
4688
(
4789
node.all_input_nodes[0],
48-
q_args.dtype,
90+
output_dtype,
4991
[new_scale],
5092
dq_args.zp,
5193
q_args.zp,
5294
),
95+
kwargs={
96+
"input_unsigned": input_unsigned,
97+
"output_unsigned": output_unsigned,
98+
},
5399
)
54100
rescale_node.meta = copy(user.meta)
55101
user.replace_all_uses_with(rescale_node)
@@ -74,6 +120,9 @@ def call(self, graph_module: GraphModule) -> PassResult:
74120
graph_module.recompile()
75121
return PassResult(graph_module, modified)
76122

123+
def ensures(self, graph_module: GraphModule) -> None:
124+
self._ensure_uint8_io_only(graph_module)
125+
77126

78127
class InsertRescaleInt32Pass(ArmPass):
79128
"""Numerous TOSA ops require inputs and outputs to be 32-bit integers in

backends/arm/common/arm_compile_spec.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class DebugMode(Enum):
3636
compiler_flags: list[str] = field(default_factory=list)
3737
path_for_intermediates: str | None = None
3838
tosa_debug_mode: DebugMode | None = None
39+
preserve_io_quantization: bool = False
3940

4041
_TOSA_SPEC_KEY = "tosa_spec"
4142
_COMPILE_FLAGS_KEY = "compile_flags"
@@ -44,6 +45,7 @@ class DebugMode(Enum):
4445
_DEBUG_MODE_KEY = "dump_debug_info"
4546
_OUTPUT_REORDER_KEY = "ouput_reorder_workaround"
4647
_TRANSFORM_PIPELINE_CONFIG_KEY = "transform_pipeline_config"
48+
_PRESERVE_IO_QUANT_KEY = "preserve_io_quantization"
4749

4850
def _set_compile_specs(
4951
self,
@@ -53,6 +55,7 @@ def _set_compile_specs(
5355
tosa_debug_mode: DebugMode | None = None,
5456
output_order_workaround: bool = False,
5557
pipeline_config: ArmPassPipelineConfig | None = None,
58+
preserve_io_quantization: bool = False,
5659
):
5760
"""Set all values of dataclass directly."""
5861
self.tosa_spec = tosa_spec
@@ -61,6 +64,8 @@ def _set_compile_specs(
6164
self.tosa_debug_mode = tosa_debug_mode
6265
self._pipeline_config = pipeline_config
6366
self.output_order_workaround = output_order_workaround
67+
self.preserve_io_quantization = preserve_io_quantization
68+
self._warn_if_redundant_preserve_io_quantization()
6469
if output_order_workaround:
6570
warnings.warn(
6671
"ArmCompileSpec(output_order_workaround=True) is deprecated and will be "
@@ -78,6 +83,7 @@ def _from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
7883
tosa_debug_mode: ArmCompileSpec.DebugMode | None = None
7984
output_order_workaround: bool = False
8085
pipeline_config: ArmPassPipelineConfig | None = None
86+
preserve_io_quantization: bool = False
8187
unknown_specs: dict[str, str] = {}
8288
for spec in compile_specs:
8389
key = spec.key
@@ -128,6 +134,8 @@ def _from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
128134
"More than one transform pipeline entry in compile spec."
129135
)
130136
pipeline_config = ArmPassPipelineConfig.from_dict(json.loads(val))
137+
elif key == ArmCompileSpec._PRESERVE_IO_QUANT_KEY:
138+
preserve_io_quantization = str(val).lower() in ("1", "true", "yes")
131139
else:
132140
unknown_specs[key] = val
133141

@@ -151,6 +159,7 @@ def _from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
151159
tosa_debug_mode=tosa_debug_mode,
152160
output_order_workaround=output_order_workaround,
153161
pipeline_config=pipeline_config,
162+
preserve_io_quantization=preserve_io_quantization,
154163
)
155164
cls._from_list_hook(compile_spec, unknown_specs)
156165
compile_spec._validate()
@@ -227,8 +236,35 @@ def _to_list(self):
227236
self._pipeline_config.serialize(),
228237
)
229238
)
239+
compile_spec.append(
240+
CompileSpec(
241+
ArmCompileSpec._PRESERVE_IO_QUANT_KEY,
242+
str(bool(self.preserve_io_quantization)).encode(),
243+
)
244+
)
230245
return compile_spec
231246

247+
def _set_preserve_io_quantization(self, enabled: bool) -> "ArmCompileSpec":
248+
"""Preserve Q/DQ nodes at IO boundaries when lowering."""
249+
self.preserve_io_quantization = enabled
250+
self._warn_if_redundant_preserve_io_quantization()
251+
return self
252+
253+
def _warn_if_redundant_preserve_io_quantization(self) -> None:
254+
"""Warn when preserve_io_quantization has no effect for INT-only
255+
specs.
256+
"""
257+
if (
258+
self.preserve_io_quantization
259+
and self.tosa_spec.support_integer()
260+
and not self.tosa_spec.support_float()
261+
):
262+
warnings.warn(
263+
"preserve_io_quantization=True is redundant for INT-only TOSA "
264+
"specifications because boundary Q/DQ are already de-tagged.",
265+
stacklevel=3,
266+
)
267+
232268
def _get_pass_pipeline_config(self) -> ArmPassPipelineConfig:
233269
"""Returns configuration that controls how the Arm pass pipeline should
234270
behave.

backends/arm/ethosu/compile_spec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def _default_system_config_and_memory_mode(
4848
return resolved_system_config, resolved_memory_mode
4949
if "ethos-u65" in target_lower:
5050
resolved_system_config = (
51-
"Ethos_U65_SYS_DRAM_Mid" if system_config is None else system_config
51+
"Ethos_U65_High_End" if system_config is None else system_config
5252
)
5353
resolved_memory_mode = "Sram_Only" if memory_mode is None else memory_mode
5454
return resolved_system_config, resolved_memory_mode

backends/arm/operators/op_tosa_rescale.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ def _build_rescale(
161161
rounding_mode: ts.RoundingMode,
162162
per_channel: bool = False,
163163
is_scale32: bool = True,
164+
input_unsigned: bool = False,
165+
output_unsigned: bool = False,
164166
):
165167
"""Insert a TOSA RESCALE operator configured for the quantized path.
166168
@@ -198,8 +200,8 @@ def _build_rescale(
198200
scale32=is_scale32,
199201
rounding_mode=rounding_mode,
200202
per_channel=per_channel,
201-
input_unsigned=False,
202-
output_unsigned=False,
203+
input_unsigned=input_unsigned,
204+
output_unsigned=output_unsigned,
203205
)
204206

205207
tosa_fb.addOperator(
@@ -228,6 +230,14 @@ def define_node(
228230
scales = cast(list[float], node.args[2])
229231
input_zp = cast(int, node.args[3])
230232
output_zp = cast(int, node.args[4])
233+
if "input_unsigned" in node.kwargs:
234+
input_unsigned = cast(bool, node.kwargs.get("input_unsigned", False))
235+
else:
236+
input_unsigned = cast(bool, node.args[5]) if len(node.args) > 5 else False
237+
if "output_unsigned" in node.kwargs:
238+
output_unsigned = cast(bool, node.kwargs.get("output_unsigned", False))
239+
else:
240+
output_unsigned = cast(bool, node.args[6]) if len(node.args) > 6 else False
231241

232242
if (
233243
input_dtype
@@ -244,7 +254,6 @@ def define_node(
244254
raise ValueError(
245255
f"If output dtype is not int8 or int16, output_zp must be 0. Got {ts.DTypeNames[output_dtype]}, {output_zp=}"
246256
)
247-
248257
_build_rescale(
249258
tosa_graph,
250259
scale=scales,
@@ -255,4 +264,6 @@ def define_node(
255264
output_zp=[output_zp],
256265
rounding_mode=ts.RoundingMode.SINGLE_ROUND,
257266
per_channel=len(scales) > 1,
267+
input_unsigned=input_unsigned,
268+
output_unsigned=output_unsigned,
258269
)

backends/arm/quantizer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
EthosUQuantizer,
1616
get_symmetric_a16w8_quantization_config,
1717
get_symmetric_quantization_config,
18+
get_uint8_io_quantization_config,
1819
TOSAQuantizer,
1920
VgfQuantizer,
2021
)

backends/arm/quantizer/arm_quantizer.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@
105105
"VgfQuantizer",
106106
"get_symmetric_a16w8_quantization_config",
107107
"get_symmetric_quantization_config",
108+
"get_uint8_io_quantization_config",
108109
]
109110

110111
logger = logging.getLogger(__name__)
@@ -234,6 +235,53 @@ def get_symmetric_quantization_config(
234235
return quantization_config
235236

236237

238+
@functools.lru_cache
239+
def get_uint8_io_quantization_config(
240+
is_qat: bool = False,
241+
is_dynamic: bool = False,
242+
eps: float = 2**-16,
243+
) -> QuantizationConfig:
244+
"""Create a uint8 IO quantization config for TOSA backends.
245+
246+
This config is intended for model inputs/outputs only. Internal tensors
247+
should remain int8 for TOSA INT lowering.
248+
249+
"""
250+
extra_args: Dict[str, Any] = {"eps": eps}
251+
if is_qat:
252+
if is_dynamic:
253+
act_observer_or_fake_quant_ctr = FakeQuantize
254+
dynamic_quant_observer = MovingAverageMinMaxObserver.with_args(
255+
averaging_constant=1
256+
)
257+
extra_args["observer"] = dynamic_quant_observer
258+
else:
259+
act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment]
260+
else:
261+
if is_dynamic:
262+
act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment]
263+
else:
264+
act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment]
265+
266+
act_quantization_spec = QuantizationSpec(
267+
dtype=torch.uint8,
268+
quant_min=torch.iinfo(torch.uint8).min,
269+
quant_max=torch.iinfo(torch.uint8).max,
270+
qscheme=torch.per_tensor_affine,
271+
is_dynamic=is_dynamic,
272+
observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args(
273+
**extra_args,
274+
),
275+
)
276+
277+
return TOSAQuantizationConfig(
278+
act_quantization_spec,
279+
act_quantization_spec,
280+
None,
281+
None,
282+
)
283+
284+
237285
def get_symmetric_a8w4_quantization_config(
238286
is_per_channel: bool = True, is_qat: bool = True, is_dynamic: bool = False
239287
):

0 commit comments

Comments
 (0)