Skip to content

Commit 1f4860e

Browse files
committed
util: agnostic.py
This adds a small wrapper around common GPU calls that aims to break out of redundant vendor specific calls.
1 parent 4b8fdf4 commit 1f4860e

5 files changed

Lines changed: 95 additions & 8 deletions

File tree

src/transformers/models/olmo_hybrid/modular_olmo_hybrid.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from ...modeling_rope_utils import dynamic_rope_update
3333
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
3434
from ...processing_utils import Unpack
35-
from ...utils import TransformersKwargs, auto_docstring, logging
35+
from ...utils import TransformersKwargs, agnostic, auto_docstring, logging
3636
from ...utils.generic import maybe_autocast, merge_with_config_defaults
3737
from ...utils.import_utils import is_flash_linear_attention_available
3838
from ...utils.output_capturing import capture_outputs
@@ -513,7 +513,7 @@ def __init__(self, config: OlmoHybridConfig, layer_idx: int):
513513
else FusedRMSNormGated(
514514
self.head_v_dim,
515515
eps=1e-5,
516-
device=torch.cuda.current_device(),
516+
device=agnostic.gpu.current_device(),
517517
dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(),
518518
)
519519
)

src/transformers/models/qwen3_5/modeling_qwen3_5.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
4545
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
4646
from ...processing_utils import Unpack
47-
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check
47+
from ...utils import TransformersKwargs, agnostic, auto_docstring, can_return_tuple, logging, torch_compilable_check
4848
from ...utils.generic import is_flash_attention_requested, maybe_autocast, merge_with_config_defaults
4949
from ...utils.import_utils import is_causal_conv1d_available, is_flash_linear_attention_available
5050
from ...utils.output_capturing import capture_outputs
@@ -395,7 +395,7 @@ def __init__(self, config: Qwen3_5Config, layer_idx: int):
395395
self.head_v_dim,
396396
eps=self.layer_norm_epsilon,
397397
activation=self.activation,
398-
device=torch.cuda.current_device(),
398+
device=agnostic.gpu.current_device(),
399399
dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(),
400400
)
401401
)

src/transformers/models/qwen3_next/modular_qwen3_next.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
2929
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
3030
from ...processing_utils import Unpack
31-
from ...utils import TransformersKwargs, auto_docstring, logging
31+
from ...utils import TransformersKwargs, agnostic, auto_docstring, logging
3232
from ...utils.generic import merge_with_config_defaults
3333
from ...utils.import_utils import (
3434
is_causal_conv1d_available,
@@ -380,7 +380,7 @@ def __init__(self, config: Qwen3NextConfig, layer_idx: int):
380380
self.head_v_dim,
381381
eps=self.layer_norm_epsilon,
382382
activation=self.activation,
383-
device=torch.cuda.current_device(),
383+
device=agnostic.gpu.current_device(),
384384
dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(),
385385
)
386386
)

src/transformers/utils/agnostic.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright 2026 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
GPU calls that are device-agnostic.
16+
"""
17+
18+
try:
19+
import torch
20+
except Exception:
21+
torch = None
22+
23+
24+
class AgnosticGPU:
25+
@staticmethod
26+
def configure() -> "AgnosticGPU":
27+
return (
28+
NoGPU()
29+
if torch is None
30+
else CUDAGPU()
31+
if torch.cuda.is_available()
32+
else XPUGPU()
33+
if (hasattr(torch, "xpu") and torch.xpu.is_available())
34+
else MPSGPU()
35+
if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available())
36+
else NoGPU()
37+
)
38+
39+
name: str
40+
41+
def is_accelerator_available(self) -> bool:
42+
return False
43+
44+
def current_device(self) -> int:
45+
return 0
46+
47+
def device_count(self) -> int:
48+
return 0
49+
50+
51+
class CUDAGPU(AgnosticGPU):
52+
def __init__(self):
53+
assert torch is not None
54+
self.name = "cuda"
55+
self.is_accelerator_available = torch.cuda.is_available
56+
self.current_device = torch.cuda.current_device
57+
self.device_count = torch.cuda.device_count
58+
59+
60+
class XPUGPU(AgnosticGPU):
61+
def __init__(self):
62+
assert torch is not None
63+
self.name = "xpu"
64+
self.is_accelerator_available = torch.xpu.is_available
65+
self.current_device = torch.xpu.current_device
66+
self.device_count = torch.xpu.device_count
67+
68+
69+
class MPSGPU(AgnosticGPU):
70+
def __init__(self):
71+
assert torch is not None
72+
self.name = "mps"
73+
self.is_accelerator_available = torch.mps.is_available
74+
# self.current_device = torch.mps.current_device
75+
self.device_count = torch.mps.device_count
76+
77+
78+
class NoGPU(AgnosticGPU):
79+
def __init__(self) -> None:
80+
self.name = "cpu"
81+
82+
83+
gpu = AgnosticGPU.configure()

src/transformers/utils/import_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
import packaging.version
3939
from packaging import version
4040

41-
from . import logging
41+
from . import agnostic, logging
4242

4343

4444
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -795,7 +795,11 @@ def is_mamba_2_ssm_available() -> bool:
795795
@lru_cache
796796
def is_flash_linear_attention_available():
797797
is_available, fla_version = _is_package_available("fla", return_version=True)
798-
return is_torch_cuda_available() and is_available and version.parse(fla_version) >= version.parse("0.2.2")
798+
return (
799+
agnostic.gpu.is_accelerator_available()
800+
and is_available
801+
and version.parse(fla_version) >= version.parse("0.2.2")
802+
)
799803

800804

801805
@lru_cache

0 commit comments

Comments
 (0)