|
6 | 6 |
|
7 | 7 | # pyre-strict |
8 | 8 |
|
9 | | -from typing import Any, cast, Dict, List, Tuple |
| 9 | +from typing import Any, Dict, List, Tuple |
10 | 10 |
|
11 | 11 | import torch |
12 | | -from executorch.backends.cadence.aot.compiler_utils import get_shape |
13 | 12 | from executorch.backends.cadence.aot.quantizer.patterns import ( |
14 | 13 | AddmmPattern, |
15 | 14 | AddPattern, |
|
26 | 25 | MatmulPattern, |
27 | 26 | ReluPattern0, |
28 | 27 | ReluPattern1, |
29 | | - SoftmaxPattern, |
30 | 28 | ) |
31 | 29 | from executorch.backends.cadence.aot.quantizer.utils import ( |
32 | 30 | check_out_zero_point_is_min_range, |
@@ -390,73 +388,6 @@ def get_args_and_kwargs_relu( |
390 | 388 | return args, kwargs |
391 | 389 |
|
392 | 390 |
|
393 | | -def get_args_and_kwargs_softmax( |
394 | | - graph_module: GraphModule, |
395 | | - inputs_inputs: List[fx.Node], |
396 | | - dequants_inputs: List[fx.Node], |
397 | | - quant_node: fx.Node, |
398 | | - op_node: fx.Node, |
399 | | -) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]: |
400 | | - # Make a dummy mask tensor |
401 | | - mask_shape = get_shape(graph_module, cast(fx.Node, quant_node.args[0])) |
402 | | - mask_shape = list(mask_shape) if mask_shape else [] |
403 | | - mask_shape[-1] = mask_shape[-1] // 16 |
404 | | - mask_tensor = graph_module.graph.call_function( |
405 | | - torch.ops.aten.full.default, |
406 | | - ( |
407 | | - mask_shape, |
408 | | - 0.0, |
409 | | - ), |
410 | | - {"dtype": torch.int32}, |
411 | | - ) |
412 | | - # Make the scale and zero_point tensors |
413 | | - in_scale_tensor = graph_module.graph.call_function( |
414 | | - torch.ops.aten.full.default, |
415 | | - ( |
416 | | - [1], |
417 | | - dequants_inputs[0].args[1], |
418 | | - ), |
419 | | - {"dtype": torch.float32}, |
420 | | - ) |
421 | | - in_zero_point_tensor = graph_module.graph.call_function( |
422 | | - torch.ops.aten.full.default, |
423 | | - ( |
424 | | - [1], |
425 | | - dequants_inputs[0].args[2], |
426 | | - ), |
427 | | - {"dtype": torch.int32}, |
428 | | - ) |
429 | | - out_scale_tensor = graph_module.graph.call_function( |
430 | | - torch.ops.aten.full.default, |
431 | | - ( |
432 | | - [1], |
433 | | - quant_node.args[1], |
434 | | - ), |
435 | | - {"dtype": torch.float32}, |
436 | | - ) |
437 | | - out_zero_point_tensor = graph_module.graph.call_function( |
438 | | - torch.ops.aten.full.default, |
439 | | - ( |
440 | | - [1], |
441 | | - quant_node.args[2], |
442 | | - ), |
443 | | - {"dtype": torch.int32}, |
444 | | - ) |
445 | | - |
446 | | - # Make the args and kwargs for the replacement op |
447 | | - args = ( |
448 | | - inputs_inputs[0], |
449 | | - mask_tensor, |
450 | | - op_node.args[1], |
451 | | - in_scale_tensor, |
452 | | - in_zero_point_tensor, |
453 | | - out_scale_tensor, |
454 | | - out_zero_point_tensor, |
455 | | - ) |
456 | | - kwargs = {} |
457 | | - return args, kwargs |
458 | | - |
459 | | - |
460 | 391 | class QuantFusion(ExportPass): |
461 | 392 | # pyre-ignore[2]: Parameter `patterns` has no type specified |
462 | 393 | def __init__(self, patterns) -> None: |
@@ -612,14 +543,6 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 |
612 | 543 | dequants_inputs, |
613 | 544 | quant_node, |
614 | 545 | ) |
615 | | - elif isinstance(pattern, SoftmaxPattern): |
616 | | - args, kwargs = get_args_and_kwargs_softmax( |
617 | | - graph_module, |
618 | | - inputs_inputs, |
619 | | - dequants_inputs, |
620 | | - quant_node, |
621 | | - anchor_output_node, |
622 | | - ) |
623 | 546 | fused = graph_module.graph.call_function( |
624 | 547 | pattern.replacement_op(), |
625 | 548 | args, |
|
0 commit comments