99import aphrodite .envs as envs
1010from aphrodite .logger import init_logger
1111from aphrodite .triton_utils import tl , triton
12+ from aphrodite .utils .torch_utils import is_torch_equal_or_newer
1213
1314logger = init_logger (__name__ )
1415
@@ -678,6 +679,10 @@ def linear_batch_invariant(input, weight, bias=None):
678679_batch_invariant_MODE = False
679680_batch_invariant_LIB = None
680681_original_torch_bmm = None
682+ _original_fp16_reduction_precision = None
683+ _original_bf16_reduction_precision = None
684+ _original_cublas_workspace_cfg = None
685+ _original_cublaslt_workspace_size = None
681686
682687
683688def is_batch_invariant_mode_enabled ():
@@ -686,6 +691,8 @@ def is_batch_invariant_mode_enabled():
686691
687692def enable_batch_invariant_mode ():
688693 global _batch_invariant_MODE , _batch_invariant_LIB , _original_torch_bmm
694+ global _original_fp16_reduction_precision , _original_bf16_reduction_precision
695+ global _original_cublas_workspace_cfg , _original_cublaslt_workspace_size
689696 if _batch_invariant_MODE :
690697 return
691698
@@ -705,14 +712,59 @@ def enable_batch_invariant_mode():
705712 _original_torch_bmm = torch .bmm
706713 torch .bmm = bmm_batch_invariant
707714
715+ _original_bf16_reduction_precision = torch .backends .cuda .matmul .allow_bf16_reduced_precision_reduction
716+ _original_fp16_reduction_precision = torch .backends .cuda .matmul .allow_fp16_reduced_precision_reduction
717+
718+ reduced_precision_val = (False , False ) if is_torch_equal_or_newer ("2.10.0.dev" ) else False
719+ torch .backends .cuda .matmul .allow_fp16_reduced_precision_reduction = reduced_precision_val
720+ torch .backends .cuda .matmul .allow_bf16_reduced_precision_reduction = reduced_precision_val
721+ torch .backends .cuda .preferred_blas_library (backend = "cublaslt" )
722+
723+ if not is_torch_equal_or_newer ("2.10.0.dev" ):
724+ _original_cublas_workspace_cfg = os .environ .get ("CUBLAS_WORKSPACE_CONFIG" , None )
725+ _original_cublaslt_workspace_size = os .environ .get ("CUBLASLT_WORKSPACE_SIZE" , None )
726+ os .environ ["CUBLAS_WORKSPACE_CONFIG" ] = ":16:8"
727+ os .environ ["CUBLASLT_WORKSPACE_SIZE" ] = "1"
728+
708729
709730def disable_batch_invariant_mode ():
710731 global _batch_invariant_MODE , _batch_invariant_LIB , _original_torch_bmm
732+ global _original_fp16_reduction_precision , _original_bf16_reduction_precision
733+ global _original_cublas_workspace_cfg , _original_cublaslt_workspace_size
734+ if not _batch_invariant_MODE :
735+ return
736+
711737 if _batch_invariant_LIB is not None :
712738 _batch_invariant_LIB ._destroy ()
713739 if _original_torch_bmm is not None :
714740 torch .bmm = _original_torch_bmm
715741 _original_torch_bmm = None
742+
743+ if _original_bf16_reduction_precision is not None :
744+ torch .backends .cuda .matmul .allow_bf16_reduced_precision_reduction = _original_bf16_reduction_precision
745+ _original_bf16_reduction_precision = None
746+ if _original_fp16_reduction_precision is not None :
747+ torch .backends .cuda .matmul .allow_fp16_reduced_precision_reduction = _original_fp16_reduction_precision
748+ _original_fp16_reduction_precision = None
749+
750+ torch .backends .cuda .preferred_blas_library (backend = "default" )
751+
752+ if not is_torch_equal_or_newer ("2.10.0.dev" ):
753+ # Set cublas env vars to previous results. If previous results are None,
754+ # that means the env vars were not set, so we should remove them.
755+ if _original_cublas_workspace_cfg :
756+ os .environ ["CUBLAS_WORKSPACE_CONFIG" ] = _original_cublas_workspace_cfg
757+ elif "CUBLAS_WORKSPACE_CONFIG" in os .environ :
758+ del os .environ ["CUBLAS_WORKSPACE_CONFIG" ]
759+
760+ if _original_cublaslt_workspace_size :
761+ os .environ ["CUBLASLT_WORKSPACE_SIZE" ] = _original_cublaslt_workspace_size
762+ elif "CUBLASLT_WORKSPACE_SIZE" in os .environ :
763+ del os .environ ["CUBLASLT_WORKSPACE_SIZE" ]
764+
765+ _original_cublas_workspace_cfg = None
766+ _original_cublaslt_workspace_size = None
767+
716768 _batch_invariant_MODE = False
717769 _batch_invariant_LIB = None
718770
@@ -791,6 +843,9 @@ def override_envs_for_invariance():
791843 os .environ ["NCCL_NTHREADS" ] = "1"
792844 os .environ ["NCCL_SOCKET_NTHREADS" ] = "1"
793845
846+ # torch.compile settings
847+ os .environ ["APHRODITE_USE_AOT_COMPILE" ] = "0"
848+
794849
795850def init_batch_invariance ():
796851 # this will hit all the csrc overrides as well
0 commit comments