|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +"""TensorRT Converters for Scaled Dot-Product Attention (SDPA). |
| 8 | +
|
| 9 | +Supported operations: |
| 10 | +- aten.scaled_dot_product_attention: Core attention mechanism for transformers |
| 11 | +
|
| 12 | +SDPA computes: softmax(Q @ K^T / sqrt(d_k)) @ V |
| 13 | +
|
| 14 | +For TensorRT, we implement this using: |
| 15 | +1. Matrix multiply for Q @ K^T |
| 16 | +2. Scale by 1/sqrt(d_k) |
| 17 | +3. Optional attention mask application |
| 18 | +4. Softmax |
| 19 | +5. Matrix multiply with V |
| 20 | +""" |
| 21 | + |
| 22 | +import logging |
| 23 | +import math |
| 24 | +from typing import Any, Dict, Optional |
| 25 | + |
| 26 | +import numpy as np |
| 27 | +import torch |
| 28 | +from executorch.backends.nvidia.tensorrt.converter_registry import converter |
| 29 | +from executorch.backends.nvidia.tensorrt.converter_utils import ( |
| 30 | + create_constant, |
| 31 | + get_trt_tensor, |
| 32 | +) |
| 33 | + |
| 34 | +logger: logging.Logger = logging.getLogger(__name__) |
| 35 | + |
| 36 | + |
| 37 | +def validate_sdpa(node: torch.fx.Node) -> bool: |
| 38 | + """Validate that an SDPA node can be converted to TensorRT.""" |
| 39 | + if node.op != "call_function": |
| 40 | + return False |
| 41 | + |
| 42 | + args = node.args |
| 43 | + if len(args) < 3: |
| 44 | + return False |
| 45 | + |
| 46 | + for i in range(3): |
| 47 | + if not isinstance(args[i], torch.fx.Node): |
| 48 | + return False |
| 49 | + |
| 50 | + return True |
| 51 | + |
| 52 | + |
| 53 | +@converter("aten.scaled_dot_product_attention.default", validator_fn=validate_sdpa) |
| 54 | +def convert_scaled_dot_product_attention( |
| 55 | + node: torch.fx.Node, |
| 56 | + network: Any, # trt.INetworkDefinition |
| 57 | + input_map: Dict[torch.fx.Node, Any], # Dict[Node, trt.ITensor] |
| 58 | + edge_program: Optional[Any] = None, |
| 59 | +) -> Any: # trt.ITensor |
| 60 | + """Convert PyTorch scaled_dot_product_attention to TensorRT. |
| 61 | +
|
| 62 | + SDPA formula: softmax(Q @ K^T / sqrt(d_k) + mask) @ V |
| 63 | + """ |
| 64 | + try: |
| 65 | + import tensorrt as trt |
| 66 | + except ImportError as e: |
| 67 | + raise ImportError("TensorRT is required for convert_sdpa") from e |
| 68 | + |
| 69 | + args = node.args |
| 70 | + kwargs = node.kwargs |
| 71 | + |
| 72 | + query_node = args[0] |
| 73 | + key_node = args[1] |
| 74 | + value_node = args[2] |
| 75 | + attn_mask_node = args[3] if len(args) > 3 else kwargs.get("attn_mask", None) |
| 76 | + is_causal = args[5] if len(args) > 5 else kwargs.get("is_causal", False) |
| 77 | + scale = args[6] if len(args) > 6 else kwargs.get("scale", None) |
| 78 | + |
| 79 | + query_trt = input_map[query_node] |
| 80 | + key_trt = input_map[key_node] |
| 81 | + value_trt = input_map[value_node] |
| 82 | + |
| 83 | + query_shape = query_trt.shape |
| 84 | + d_k = query_shape[-1] |
| 85 | + |
| 86 | + # Calculate scale factor |
| 87 | + if scale is not None: |
| 88 | + scale_factor = float(scale) |
| 89 | + elif d_k > 0: |
| 90 | + scale_factor = 1.0 / math.sqrt(float(d_k)) |
| 91 | + else: |
| 92 | + query_meta_shape = None |
| 93 | + if isinstance(query_node, torch.fx.Node) and "val" in query_node.meta: |
| 94 | + val = query_node.meta["val"] |
| 95 | + if hasattr(val, "shape"): |
| 96 | + query_meta_shape = val.shape |
| 97 | + if query_meta_shape is not None and len(query_meta_shape) > 0: |
| 98 | + d_k_static = query_meta_shape[-1] |
| 99 | + scale_factor = 1.0 / math.sqrt(float(d_k_static)) if d_k_static > 0 else 1.0 |
| 100 | + else: |
| 101 | + raise RuntimeError( |
| 102 | + f"Cannot determine head dimension for SDPA node {node.name}." |
| 103 | + ) |
| 104 | + |
| 105 | + # Step 1: Q @ K^T |
| 106 | + qk_layer = network.add_matrix_multiply( |
| 107 | + query_trt, trt.MatrixOperation.NONE, |
| 108 | + key_trt, trt.MatrixOperation.TRANSPOSE, |
| 109 | + ) |
| 110 | + qk_layer.name = f"sdpa_qk_{node.name}" |
| 111 | + qk = qk_layer.get_output(0) |
| 112 | + |
| 113 | + # Step 2: Scale by 1/sqrt(d_k) |
| 114 | + scale_const = get_trt_tensor( |
| 115 | + network, scale_factor, f"sdpa_scale_{node.name}", dtype=torch.float32 |
| 116 | + ) |
| 117 | + scaled_qk_layer = network.add_elementwise( |
| 118 | + qk, scale_const, trt.ElementWiseOperation.PROD |
| 119 | + ) |
| 120 | + scaled_qk_layer.name = f"sdpa_scale_{node.name}" |
| 121 | + scaled_qk = scaled_qk_layer.get_output(0) |
| 122 | + |
| 123 | + # Step 3: Apply attention mask if provided |
| 124 | + if attn_mask_node is not None and isinstance(attn_mask_node, torch.fx.Node): |
| 125 | + if attn_mask_node in input_map: |
| 126 | + attn_mask_trt = input_map[attn_mask_node] |
| 127 | + mask_layer = network.add_elementwise( |
| 128 | + scaled_qk, attn_mask_trt, trt.ElementWiseOperation.SUM |
| 129 | + ) |
| 130 | + mask_layer.name = f"sdpa_mask_{node.name}" |
| 131 | + scaled_qk = mask_layer.get_output(0) |
| 132 | + |
| 133 | + # Step 4: Handle causal masking |
| 134 | + if is_causal: |
| 135 | + seq_len = query_shape[-2] if len(query_shape) >= 2 else -1 |
| 136 | + if seq_len > 0: |
| 137 | + causal_mask = np.triu( |
| 138 | + np.full((seq_len, seq_len), float("-inf"), dtype=np.float32), k=1 |
| 139 | + ) |
| 140 | + causal_mask_trt = create_constant( |
| 141 | + network, causal_mask, f"sdpa_causal_mask_{node.name}" |
| 142 | + ) |
| 143 | + causal_layer = network.add_elementwise( |
| 144 | + scaled_qk, causal_mask_trt, trt.ElementWiseOperation.SUM |
| 145 | + ) |
| 146 | + causal_layer.name = f"sdpa_causal_{node.name}" |
| 147 | + scaled_qk = causal_layer.get_output(0) |
| 148 | + |
| 149 | + # Step 5: Softmax along the last dimension |
| 150 | + softmax_layer = network.add_softmax(scaled_qk) |
| 151 | + softmax_layer.axes = 1 << (len(query_shape) - 1) |
| 152 | + softmax_layer.name = f"sdpa_softmax_{node.name}" |
| 153 | + attn_weights = softmax_layer.get_output(0) |
| 154 | + |
| 155 | + # Step 6: attn_weights @ V |
| 156 | + output_layer = network.add_matrix_multiply( |
| 157 | + attn_weights, trt.MatrixOperation.NONE, |
| 158 | + value_trt, trt.MatrixOperation.NONE, |
| 159 | + ) |
| 160 | + output_layer.name = f"sdpa_output_{node.name}" |
| 161 | + |
| 162 | + return output_layer.get_output(0) |
| 163 | + |
| 164 | + |
| 165 | +@converter("aten._scaled_dot_product_flash_attention.default", validator_fn=validate_sdpa) |
| 166 | +def convert_flash_attention(node, network, input_map, edge_program=None): |
| 167 | + """Convert flash attention — reuse SDPA implementation.""" |
| 168 | + return convert_scaled_dot_product_attention(node, network, input_map, edge_program) |
| 169 | + |
| 170 | + |
| 171 | +@converter("aten._scaled_dot_product_efficient_attention.default", validator_fn=validate_sdpa) |
| 172 | +def convert_efficient_attention(node, network, input_map, edge_program=None): |
| 173 | + """Convert efficient attention — reuse SDPA implementation.""" |
| 174 | + return convert_scaled_dot_product_attention(node, network, input_map, edge_program) |
0 commit comments