Skip to content

Commit d31d4be

Browse files
Sync Python-side quantized_softmax schema with C++ kernel (add mask_type and pos args) (#18495)
Differential Revision: D98145095 Pull Request resolved: #18495
1 parent 84a4a6c commit d31d4be

4 files changed

Lines changed: 44 additions & 4 deletions

File tree

backends/cadence/aot/ops_registrations.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -472,16 +472,16 @@ def register_fake(
472472
)
473473

474474
lib.define(
475-
"quantized_softmax(Tensor input, Tensor mask, int dim, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point) -> (Tensor out)"
475+
"quantized_softmax(Tensor input, Tensor mask, int dim, int mask_type, Tensor pos, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point) -> (Tensor out)"
476476
)
477477
lib.define(
478-
"quantized_softmax.per_tensor(Tensor input, Tensor mask, int dim, float in_scale, int in_zero_point, float out_scale, int out_zero_point) -> (Tensor out)"
478+
"quantized_softmax.per_tensor(Tensor input, Tensor mask, int dim, int mask_type, Tensor pos, float in_scale, int in_zero_point, float out_scale, int out_zero_point) -> (Tensor out)"
479479
)
480480
lib.define(
481-
"quantized_softmax.out(Tensor input, Tensor mask, int dim, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
481+
"quantized_softmax.out(Tensor input, Tensor mask, int dim, int mask_type, Tensor pos, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
482482
)
483483
lib.define(
484-
"quantized_softmax.per_tensor_out(Tensor input, Tensor mask, int dim, float in_scale, int in_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
484+
"quantized_softmax.per_tensor_out(Tensor input, Tensor mask, int dim, int mask_type, Tensor pos, float in_scale, int in_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
485485
)
486486

487487
# pack float/bool mask tensor into a bitmask of type uint8 (each element holding 8 bool mask elements)
@@ -2957,6 +2957,8 @@ def quantized_softmax_meta(
29572957
input: torch.Tensor,
29582958
mask: torch.Tensor,
29592959
dim: int,
2960+
mask_type: int,
2961+
pos: torch.Tensor,
29602962
in_scale: torch.Tensor,
29612963
in_zero_point: torch.Tensor,
29622964
out_scale: torch.Tensor,
@@ -2970,6 +2972,8 @@ def quantized_softmax_per_tensor_meta(
29702972
input: torch.Tensor,
29712973
mask: torch.Tensor,
29722974
dim: int,
2975+
mask_type: int,
2976+
pos: torch.Tensor,
29732977
in_scale: float,
29742978
in_zero_point: int,
29752979
out_scale: float,

backends/cadence/aot/quantizer/fusion_pass.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,21 @@ def get_args_and_kwargs_softmax(
378378
with fake_mode:
379379
mask_tensor.meta["val"] = torch.full(mask_shape, 0.0, dtype=torch.int32)
380380
copy_node_metadata(mask_tensor, inputs_inputs[0])
381+
382+
# Default mask_type=0 (no masking) and dummy pos tensor
383+
mask_type = 0
384+
pos_tensor = graph_module.graph.call_function(
385+
torch.ops.aten.full.default,
386+
(
387+
[1],
388+
0,
389+
),
390+
{"dtype": torch.int64},
391+
)
392+
with fake_mode:
393+
pos_tensor.meta["val"] = torch.full([1], 0, dtype=torch.int64)
394+
copy_node_metadata(pos_tensor, inputs_inputs[0])
395+
381396
# Make the scale and zero_point tensors
382397
in_scale = dequants_inputs[0].args[1]
383398
in_zero_point = dequants_inputs[0].args[2]
@@ -389,6 +404,8 @@ def get_args_and_kwargs_softmax(
389404
inputs_inputs[0],
390405
mask_tensor,
391406
op_node.args[1],
407+
mask_type,
408+
pos_tensor,
392409
in_scale,
393410
in_zero_point,
394411
out_scale,

backends/cadence/aot/ref_implementations.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2480,6 +2480,8 @@ def quantized_softmax_per_tensor_common(
24802480
input_tensor: torch.Tensor,
24812481
mask: torch.Tensor | None,
24822482
dim: int,
2483+
mask_type: int,
2484+
pos: torch.Tensor,
24832485
in_scale: float,
24842486
in_zero_point: int,
24852487
out_scale: float,
@@ -2492,13 +2494,18 @@ def quantized_softmax_per_tensor_common(
24922494
- input_tensor (Tensor): The quantized input tensor
24932495
- mask (Tensor): Mask tensor
24942496
- dim (int): The dimension along which softmax is computed
2497+
- mask_type (int): Masking strategy (0=none, 1=position-based causal)
2498+
- pos (Tensor): Position tensor for causal masking
24952499
- in_scale (float): The scale of the input quantization
24962500
- in_zero_point (int): The zero point of the input quantization
24972501
- out_scale (float): The scale of the output quantization
24982502
- out_zero_point (int): The zero point of the output quantization
24992503
"""
25002504
# TODO: T228751479 - Add support for mask parameter in softmax
25012505
assert mask is None
2506+
assert (
2507+
mask_type == 0
2508+
), f"Only mask_type=0 (no masking) is supported, got {mask_type}"
25022509
supported_dtypes = [torch.int8, torch.uint8, torch.int16]
25032510
if input_tensor.dtype not in supported_dtypes:
25042511
raise ValueError(
@@ -2531,6 +2538,8 @@ def quantized_softmax_per_tensor(
25312538
input_tensor: torch.Tensor,
25322539
mask: torch.Tensor | None,
25332540
dim: int,
2541+
mask_type: int,
2542+
pos: torch.Tensor,
25342543
in_scale: float,
25352544
in_zero_point: int,
25362545
out_scale: float,
@@ -2540,6 +2549,8 @@ def quantized_softmax_per_tensor(
25402549
input_tensor,
25412550
mask,
25422551
dim,
2552+
mask_type,
2553+
pos,
25432554
in_scale,
25442555
in_zero_point,
25452556
out_scale,
@@ -2552,6 +2563,8 @@ def quantized_softmax(
25522563
input_tensor: torch.Tensor,
25532564
mask: torch.Tensor | None,
25542565
dim: int,
2566+
mask_type: int,
2567+
pos: torch.Tensor,
25552568
in_scale: torch.Tensor,
25562569
in_zero_point: torch.Tensor,
25572570
out_scale: float,
@@ -2561,6 +2574,8 @@ def quantized_softmax(
25612574
input_tensor,
25622575
mask,
25632576
dim,
2577+
mask_type,
2578+
pos,
25642579
float(in_scale.item()),
25652580
int(in_zero_point.item()),
25662581
out_scale,

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3152,6 +3152,8 @@ def test_quantized_softmax_per_tensor(
31523152
input_tensor,
31533153
mask,
31543154
dim,
3155+
0, # mask_type (no masking)
3156+
torch.zeros(1, dtype=torch.int64), # pos
31553157
in_scale,
31563158
in_zero_point,
31573159
out_scale,
@@ -3189,6 +3191,8 @@ def test_quantized_softmax(self) -> None:
31893191
input_tensor,
31903192
None, # mask
31913193
1, # dim
3194+
0, # mask_type (no masking)
3195+
torch.zeros(1, dtype=torch.int64), # pos
31923196
in_scale,
31933197
in_zero_point,
31943198
0.004, # out_scale

0 commit comments

Comments
 (0)