Skip to content

Commit feb84f8

Browse files
authored
Arm backend: Make quantization of infs user configurable (pytorch#19915)
Add `QuantizeInfConfig` to the Arm pass pipeline config so compile specs can set the finite values used to quantize infinities. Signed-off-by: Martin Lindström <Martin.Lindstroem@arm.com>
1 parent c5da8fb commit feb84f8

5 files changed

Lines changed: 158 additions & 31 deletions

File tree

backends/arm/_passes/arm_pass_manager.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,7 @@
150150
)
151151
from executorch.backends.arm._passes.arm_pass import ArmPass
152152
from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec
153-
from executorch.backends.arm.common.pipeline_config import (
154-
ArmPassPipelineConfig,
155-
SoftmaxDecompositionConfig,
156-
)
153+
from executorch.backends.arm.common.pipeline_config import SoftmaxDecompositionConfig
157154
from executorch.backends.arm.tosa.specification import (
158155
tosa_spec_in_set,
159156
TosaLoweringContext,
@@ -221,16 +218,13 @@ def __init__(self, compile_spec: ArmCompileSpec) -> None:
221218
super().__init__()
222219
self.configure_skip_passes()
223220

224-
def configure_skip_passes(
225-
self,
226-
override_config: ArmPassPipelineConfig | None = None,
227-
) -> tuple[type, ...]:
221+
def configure_skip_passes(self) -> tuple[type, ...]:
228222
"""Configures the pass manager to skip certain passes based on the
229223
ArmPassPipelineConfig class found in the compile spec.
230224
"""
231225
skip_set: set[type] = set()
232226

233-
config = override_config or self.compile_spec._get_pass_pipeline_config()
227+
config = self.compile_spec._get_pass_pipeline_config()
234228
logger.debug(f"Skip Config: {config}")
235229

236230
match config.softmax:
@@ -649,9 +643,14 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
649643
)
650644

651645
# Postprocessing passes
646+
quant_inf_cfg = self.compile_spec._get_pass_pipeline_config().quantize_inf
652647
self.add_passes(
653648
[
654-
ReplaceInfAndLimitValuesPass(tfa_pass=True),
649+
ReplaceInfAndLimitValuesPass(
650+
quant_inf_cfg.neg_inf,
651+
quant_inf_cfg.pos_inf,
652+
tfa_pass=True,
653+
),
655654
DecomposeMaskedFillPass(tfa_pass=True),
656655
DeduplicateGetAttrPass(tfa_pass=True),
657656
]

backends/arm/_passes/replace_inf_and_limit_values_pass.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,22 @@
1616

1717
class ReplaceInfAndLimitValuesPass(ArmPass):
1818
"""Rewrites +inf/-inf and floating-point limit values (e.g.,
19-
torch.finfo(...).min/max) to quantization-friendly values (±255 by default),
19+
torch.finfo(...).min/max) to configured quantization-friendly values,
2020
improving quantizer stability (notably for attention mask paths).
2121
"""
2222

2323
_passes_required_after: Set[Type[ExportPass]] = set()
2424

25+
def __init__(
26+
self,
27+
neg_inf: float,
28+
pos_inf: float,
29+
tfa_pass: bool = False,
30+
):
31+
super().__init__(tfa_pass=tfa_pass)
32+
self.neg_inf = neg_inf
33+
self.pos_inf = pos_inf
34+
2535
def _allowed_to_transform_named_buffer(self, buf_name, graph_module) -> bool:
2636
attr_nodes = [
2737
node
@@ -51,19 +61,19 @@ def call(self, graph_module: torch.fx.GraphModule):
5161
continue
5262

5363
modified = True
54-
# 255 here is mainly for attention_mask in Llama for reasonable quant scale
55-
t = torch.nan_to_num(tensor, posinf=255, neginf=-255)
64+
65+
t = torch.nan_to_num(tensor, posinf=self.pos_inf, neginf=self.neg_inf)
5666
setattr(graph_module, buf_name, t)
5767

5868
for node in graph_module.graph.nodes:
5969
arg_list = list(node.args)
6070
for index, arg in enumerate(arg_list):
6171
if arg == float("-inf") or arg == torch.finfo(torch.float32).min:
6272
modified = True
63-
arg_list[index] = -255.0
73+
arg_list[index] = self.neg_inf
6474
elif arg == float("inf") or arg == torch.finfo(torch.float32).max:
6575
modified = True
66-
arg_list[index] = +255.0
76+
arg_list[index] = self.pos_inf
6777
node.args = tuple(arg_list)
6878

6979
if modified:

backends/arm/common/pipeline_config.py

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,75 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import json
7-
from dataclasses import dataclass, fields
7+
from dataclasses import asdict, dataclass, field, fields, is_dataclass
88
from enum import auto, Enum
9-
from typing import Any
9+
from typing import Any, cast
1010

1111

1212
class SoftmaxDecompositionConfig(Enum):
1313
MASKED = auto() # Stable softmax + masked fill decomposition
1414
STABLE = auto() # Stable softmax, no masked fill decomposition
1515

1616

17+
@dataclass
18+
class QuantizeInfConfig:
19+
"""Replacement values for infinities before quantization.
20+
21+
Infinities cannot be quantized directly, so the Arm pipeline replaces them
22+
with finite values before running the quantization passes.
23+
24+
Args:
25+
neg_inf (float): Value used for ``-inf``.
26+
pos_inf (float): Value used for ``inf``.
27+
28+
"""
29+
30+
neg_inf: float = -256.0
31+
pos_inf: float = 255.0
32+
33+
1734
@dataclass
1835
class ArmPassPipelineConfig:
36+
"""Options for tuning the Arm pass pipeline.
37+
38+
Args:
39+
softmax (SoftmaxDecompositionConfig): Softmax decomposition mode.
40+
quantize_inf (QuantizeInfConfig): Values used when replacing
41+
infinities before quantization.
42+
43+
Example:
44+
compile_spec.set_pass_pipeline_config(
45+
ArmPassPipelineConfig(
46+
softmax=SoftmaxDecompositionConfig.STABLE,
47+
quantize_inf=QuantizeInfConfig(
48+
neg_inf=-100.0,
49+
pos_inf=100.0,
50+
),
51+
)
52+
)
53+
54+
"""
55+
1956
softmax: SoftmaxDecompositionConfig = SoftmaxDecompositionConfig.MASKED
57+
quantize_inf: QuantizeInfConfig = field(default_factory=QuantizeInfConfig)
2058

2159
def is_default(self) -> bool:
22-
return self.softmax is SoftmaxDecompositionConfig.MASKED
60+
return (
61+
self.softmax is SoftmaxDecompositionConfig.MASKED
62+
and self.quantize_inf == QuantizeInfConfig()
63+
)
2364

24-
def to_dict(self) -> dict[str, str]:
25-
return {f.name: getattr(self, f.name).name for f in fields(self)}
65+
def to_dict(self) -> dict[str, Any]:
66+
data: dict[str, Any] = {}
67+
for f in fields(self):
68+
value = getattr(self, f.name)
69+
if is_dataclass(value):
70+
data[f.name] = asdict(cast(Any, value))
71+
elif isinstance(value, Enum):
72+
data[f.name] = value.name
73+
else:
74+
raise AssertionError(f"Cannot serialize {f.name}")
75+
return data
2676

2777
@classmethod
2878
def from_dict(cls, data: dict[str, Any]) -> "ArmPassPipelineConfig":
@@ -31,8 +81,13 @@ def from_dict(cls, data: dict[str, Any]) -> "ArmPassPipelineConfig":
3181
raw_value = data.get(f.name)
3282
if raw_value is None:
3383
continue
34-
enum_type = f.type
35-
setattr(config, f.name, enum_type[raw_value])
84+
85+
if f.name == "quantize_inf":
86+
config.quantize_inf = QuantizeInfConfig(**raw_value)
87+
else:
88+
# The field is an enum
89+
enum_type = f.type
90+
setattr(config, f.name, enum_type[raw_value])
3691
return config
3792

3893
def serialize(self) -> bytes:

backends/arm/test/misc/test_pass_pipeline_config.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import torch
7+
68
from executorch.backends.arm._passes import (
79
DecomposeMaskedFillPass,
810
DecomposeSoftmaxPass,
@@ -11,10 +13,26 @@
1113
from executorch.backends.arm._passes.arm_pass_manager import ArmPassManager
1214
from executorch.backends.arm.common.pipeline_config import (
1315
ArmPassPipelineConfig,
16+
QuantizeInfConfig,
1417
SoftmaxDecompositionConfig,
1518
)
1619
from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec
1720
from executorch.backends.arm.tosa.specification import TosaSpecification
21+
from torch.export import export
22+
23+
24+
class ModuleWithInf(torch.nn.Module):
25+
def __init__(self) -> None:
26+
super().__init__()
27+
self.register_buffer(
28+
"mask", torch.tensor([float("inf"), float("-inf")], dtype=torch.float32)
29+
)
30+
31+
def forward(self, x: torch.Tensor) -> torch.Tensor:
32+
x = x + self.mask # type: ignore[operator]
33+
x = torch.ops.aten.add.Tensor(x, float("-inf"))
34+
x = torch.ops.aten.add.Tensor(x, float("inf"))
35+
return x
1836

1937

2038
def test_pipeline_config_override_outside_compile_spec():
@@ -68,3 +86,27 @@ def test_softmax_config_stable_no_target():
6886
assert DecomposeSoftmaxPass not in skip_passes
6987
# STABLE: masked fill decomposition is disabled (skipped)
7088
assert DecomposeMaskedFillPass in skip_passes
89+
90+
91+
def test_quant_inf_config_reaches_annotation_pipeline():
92+
QUANT_NEG_INF = -321.0
93+
QUANT_POS_INF = 123.0
94+
95+
config = ArmPassPipelineConfig(
96+
quantize_inf=QuantizeInfConfig(neg_inf=QUANT_NEG_INF, pos_inf=QUANT_POS_INF),
97+
)
98+
compile_spec = TosaCompileSpec(
99+
TosaSpecification.create_from_string("TOSA-1.00+INT")
100+
)
101+
compile_spec.set_pass_pipeline_config(config)
102+
manager = ArmPassManager(compile_spec)
103+
exported = export(ModuleWithInf(), (torch.zeros(2),), strict=True)
104+
105+
transformed = manager.transform_for_annotation_pipeline(exported.graph_module)
106+
tensor_constant_values = sorted(
107+
constant.item()
108+
for name, constant in transformed.named_buffers()
109+
if name.startswith("_tensor_constant")
110+
)
111+
112+
assert tensor_constant_values == [QUANT_NEG_INF, QUANT_POS_INF]

backends/arm/test/passes/test_replace_inf_values_pass.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,26 +49,41 @@ def _get_mask_buffer(graph_module: fx.GraphModule) -> torch.Tensor:
4949

5050
def test_replace_inf_and_limit_values_clamps_inf_constants():
5151
"""Trace a module with infinities, run ReplaceInfAndLimitValuesPass, and
52-
expect the buffer and scalar literals to be clamped to ±255 with no
53-
infinities left.
52+
expect the buffer and scalar literals to be clamped to the configured finite
53+
values.
5454
"""
55+
QUANTIZED_NEG_INF = -42.0
56+
QUANTIZED_POS_INF = 13.0
57+
5558
gm = fx.symbolic_trace(ModuleWithInf())
5659

57-
result = ReplaceInfAndLimitValuesPass().call(gm)
60+
result = ReplaceInfAndLimitValuesPass(
61+
neg_inf=QUANTIZED_NEG_INF,
62+
pos_inf=QUANTIZED_POS_INF,
63+
).call(gm)
5864
mask_after_pass = _get_mask_buffer(result.graph_module)
5965

6066
assert result.modified
61-
expected = torch.tensor([255.0, -255.0], dtype=mask_after_pass.dtype)
67+
expected = torch.tensor(
68+
[QUANTIZED_POS_INF, QUANTIZED_NEG_INF],
69+
dtype=mask_after_pass.dtype,
70+
)
6271
assert torch.equal(mask_after_pass, expected)
6372
assert not torch.isinf(mask_after_pass).any()
64-
assert sorted(_get_add_constants(result.graph_module)) == [-255, 255]
73+
assert sorted(_get_add_constants(result.graph_module)) == [
74+
QUANTIZED_NEG_INF,
75+
QUANTIZED_POS_INF,
76+
]
6577

6678

6779
def test_replace_inf_and_limit_values_respects_disallowed_nodes():
6880
"""When nodes opt out of transforms, running the pass in TFA mode should
69-
leave the mask buffer untouched while still clamping scalar literals to
70-
±255.
81+
leave the mask buffer untouched while still clamping scalar literals to the
82+
configured finite values.
7183
"""
84+
QUANTIZED_NEG_INF = -1_000_000.0
85+
QUANTIZED_POS_INF = 10_000.0
86+
7287
gm = fx.symbolic_trace(ModuleWithInf())
7388
mask_before = _get_mask_buffer(gm).clone()
7489

@@ -82,7 +97,10 @@ def test_replace_inf_and_limit_values_respects_disallowed_nodes():
8297
):
8398
node.meta[DISALLOW_TFA_META_KEY] = True
8499

85-
replace_inf = ReplaceInfAndLimitValuesPass()
100+
replace_inf = ReplaceInfAndLimitValuesPass(
101+
neg_inf=QUANTIZED_NEG_INF,
102+
pos_inf=QUANTIZED_POS_INF,
103+
)
86104
replace_inf.is_tfa_pass = True
87105

88106
result = replace_inf.call(gm)
@@ -91,4 +109,7 @@ def test_replace_inf_and_limit_values_respects_disallowed_nodes():
91109
mask_after = _get_mask_buffer(result.graph_module)
92110
assert torch.equal(mask_after, mask_before)
93111
assert torch.isinf(mask_after).tolist() == [True, True]
94-
assert sorted(_get_add_constants(result.graph_module)) == [-255, 255]
112+
assert sorted(_get_add_constants(result.graph_module)) == [
113+
QUANTIZED_NEG_INF,
114+
QUANTIZED_POS_INF,
115+
]

0 commit comments

Comments
 (0)