Skip to content

Commit a66023c

Browse files
authored
[compilation] allow torch.compile with batch invariant inference (#1570)
Signed-off-by: AlpinDale <alpindale@gmail.com>
1 parent 3611692 commit a66023c

4 files changed

Lines changed: 63 additions & 7 deletions

File tree

aphrodite/config/model.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from aphrodite.config.scheduler import RunnerType
1818
from aphrodite.config.utils import assert_hashable, config, getattr_iter
1919
from aphrodite.logger import init_logger
20-
from aphrodite.modeling.layers.batch_invariant import aphrodite_is_batch_invariant
2120
from aphrodite.platforms import current_platform
2221
from aphrodite.transformers_utils.config import (
2322
ConfigFormat,
@@ -424,10 +423,6 @@ def __post_init__(
424423
skip_mm_profiling: bool | None,
425424
video_pruning_rate: float | None,
426425
) -> None:
427-
# Enable batch invariance settings if requested
428-
if aphrodite_is_batch_invariant():
429-
self.enforce_eager = True
430-
431426
# Set the default seed to 0 in V1.
432427
# NOTE(woosuk): In V1, we use separate processes for workers (unless
433428
# APHRODITE_ENABLE_V1_MULTIPROCESSING=0), so setting a seed here

aphrodite/envs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,11 +254,12 @@ def disable_compile_cache() -> bool:
254254

255255

256256
def use_aot_compile() -> bool:
257+
from aphrodite.modeling.layers.batch_invariant import aphrodite_is_batch_invariant
257258
from aphrodite.utils.torch_utils import is_torch_equal_or_newer
258259

259260
default_value = "1" if is_torch_equal_or_newer("2.10.0.dev") and not disable_compile_cache() else "0"
260261

261-
return os.environ.get("APHRODITE_USE_AOT_COMPILE", default_value) == "1"
262+
return not aphrodite_is_batch_invariant() and os.environ.get("APHRODITE_USE_AOT_COMPILE", default_value) == "1"
262263

263264

264265
def env_with_choices(

aphrodite/modeling/layers/batch_invariant.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import aphrodite.envs as envs
1010
from aphrodite.logger import init_logger
1111
from aphrodite.triton_utils import tl, triton
12+
from aphrodite.utils.torch_utils import is_torch_equal_or_newer
1213

1314
logger = init_logger(__name__)
1415

@@ -678,6 +679,10 @@ def linear_batch_invariant(input, weight, bias=None):
678679
_batch_invariant_MODE = False
679680
_batch_invariant_LIB = None
680681
_original_torch_bmm = None
682+
_original_fp16_reduction_precision = None
683+
_original_bf16_reduction_precision = None
684+
_original_cublas_workspace_cfg = None
685+
_original_cublaslt_workspace_size = None
681686

682687

683688
def is_batch_invariant_mode_enabled():
@@ -686,6 +691,8 @@ def is_batch_invariant_mode_enabled():
686691

687692
def enable_batch_invariant_mode():
688693
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
694+
global _original_fp16_reduction_precision, _original_bf16_reduction_precision
695+
global _original_cublas_workspace_cfg, _original_cublaslt_workspace_size
689696
if _batch_invariant_MODE:
690697
return
691698

@@ -705,14 +712,59 @@ def enable_batch_invariant_mode():
705712
_original_torch_bmm = torch.bmm
706713
torch.bmm = bmm_batch_invariant
707714

715+
_original_bf16_reduction_precision = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
716+
_original_fp16_reduction_precision = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
717+
718+
reduced_precision_val = (False, False) if is_torch_equal_or_newer("2.10.0.dev") else False
719+
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = reduced_precision_val
720+
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = reduced_precision_val
721+
torch.backends.cuda.preferred_blas_library(backend="cublaslt")
722+
723+
if not is_torch_equal_or_newer("2.10.0.dev"):
724+
_original_cublas_workspace_cfg = os.environ.get("CUBLAS_WORKSPACE_CONFIG", None)
725+
_original_cublaslt_workspace_size = os.environ.get("CUBLASLT_WORKSPACE_SIZE", None)
726+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
727+
os.environ["CUBLASLT_WORKSPACE_SIZE"] = "1"
728+
708729

709730
def disable_batch_invariant_mode():
710731
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
732+
global _original_fp16_reduction_precision, _original_bf16_reduction_precision
733+
global _original_cublas_workspace_cfg, _original_cublaslt_workspace_size
734+
if not _batch_invariant_MODE:
735+
return
736+
711737
if _batch_invariant_LIB is not None:
712738
_batch_invariant_LIB._destroy()
713739
if _original_torch_bmm is not None:
714740
torch.bmm = _original_torch_bmm
715741
_original_torch_bmm = None
742+
743+
if _original_bf16_reduction_precision is not None:
744+
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = _original_bf16_reduction_precision
745+
_original_bf16_reduction_precision = None
746+
if _original_fp16_reduction_precision is not None:
747+
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = _original_fp16_reduction_precision
748+
_original_fp16_reduction_precision = None
749+
750+
torch.backends.cuda.preferred_blas_library(backend="default")
751+
752+
if not is_torch_equal_or_newer("2.10.0.dev"):
753+
# Set cublas env vars to previous results. If previous results are None,
754+
# that means the env vars were not set, so we should remove them.
755+
if _original_cublas_workspace_cfg:
756+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = _original_cublas_workspace_cfg
757+
elif "CUBLAS_WORKSPACE_CONFIG" in os.environ:
758+
del os.environ["CUBLAS_WORKSPACE_CONFIG"]
759+
760+
if _original_cublaslt_workspace_size:
761+
os.environ["CUBLASLT_WORKSPACE_SIZE"] = _original_cublaslt_workspace_size
762+
elif "CUBLASLT_WORKSPACE_SIZE" in os.environ:
763+
del os.environ["CUBLASLT_WORKSPACE_SIZE"]
764+
765+
_original_cublas_workspace_cfg = None
766+
_original_cublaslt_workspace_size = None
767+
716768
_batch_invariant_MODE = False
717769
_batch_invariant_LIB = None
718770

@@ -791,6 +843,9 @@ def override_envs_for_invariance():
791843
os.environ["NCCL_NTHREADS"] = "1"
792844
os.environ["NCCL_SOCKET_NTHREADS"] = "1"
793845

846+
# torch.compile settings
847+
os.environ["APHRODITE_USE_AOT_COMPILE"] = "0"
848+
794849

795850
def init_batch_invariance():
796851
# this will hit all the csrc overrides as well

aphrodite/quantization/fp8.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ def __init__(self, quant_config: Fp8Config):
319319
self.use_marlin = False
320320

321321
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
322+
self.use_deep_gemm = is_deep_gemm_supported()
322323

323324
self.weight_block_size = self.quant_config.weight_block_size
324325
self.block_quant = self.weight_block_size is not None
@@ -493,7 +494,11 @@ def apply(
493494
# if batch invariant mode is enabled, prefer DeepGEMM FP8 path
494495
# we will use BF16 dequant when DeepGEMM is not supported.
495496
if aphrodite_is_batch_invariant():
496-
if self.block_quant and should_use_deepgemm_for_fp8_linear(torch.bfloat16, layer.weight, None):
497+
# Call is_deep_gemm_supported() ahead of time for torch.compile
498+
# dynamo has trouble tracing through
499+
if self.block_quant and should_use_deepgemm_for_fp8_linear(
500+
torch.bfloat16, layer.weight, self.use_deep_gemm
501+
):
497502
# use group quant consistent with block size across K
498503
assert self.act_q_group_shape is not None
499504
q_input, input_scale = QuantFP8(

0 commit comments

Comments
 (0)