Skip to content

Commit 24aeb0f

Browse files
leofangemcastilloclaude
committed
Fix linter errors: unused import and missing blank line
- Replace try/import ml_dtypes pattern with try/except around numpy.dtype("bfloat16") to avoid unused import warning - Add blank line after docstring in test_torch_tensor_bridge_decorator Co-Authored-By: Emilio Castillo <ecastillo@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 8c20237 commit 24aeb0f

File tree

2 files changed

+14
-17
lines changed

2 files changed

+14
-17
lines changed

cuda_core/cuda/core/_tensor_bridge.pyx

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -133,27 +133,23 @@ cdef inline int check_aoti(AOTITorchError err, const char* name) except? -1:
133133
# ---------------------------------------------------------------------------
134134

135135
cdef dict _build_dtype_map():
136-
try:
137-
from ml_dtypes import bfloat16 as _bf16
138-
has_bfloat16 = True
139-
except ImportError:
140-
has_bfloat16 = False
141-
142136
cdef dict m = {
143-
aoti_torch_dtype_float16(): numpy.dtype(numpy.float16),
144-
aoti_torch_dtype_float32(): numpy.dtype(numpy.float32),
145-
aoti_torch_dtype_float64(): numpy.dtype(numpy.float64),
146-
aoti_torch_dtype_uint8(): numpy.dtype(numpy.uint8),
147-
aoti_torch_dtype_int8(): numpy.dtype(numpy.int8),
148-
aoti_torch_dtype_int16(): numpy.dtype(numpy.int16),
149-
aoti_torch_dtype_int32(): numpy.dtype(numpy.int32),
150-
aoti_torch_dtype_int64(): numpy.dtype(numpy.int64),
151-
aoti_torch_dtype_bool(): numpy.dtype(numpy.bool_),
152-
aoti_torch_dtype_complex64(): numpy.dtype(numpy.complex64),
137+
aoti_torch_dtype_float16(): numpy.dtype(numpy.float16),
138+
aoti_torch_dtype_float32(): numpy.dtype(numpy.float32),
139+
aoti_torch_dtype_float64(): numpy.dtype(numpy.float64),
140+
aoti_torch_dtype_uint8(): numpy.dtype(numpy.uint8),
141+
aoti_torch_dtype_int8(): numpy.dtype(numpy.int8),
142+
aoti_torch_dtype_int16(): numpy.dtype(numpy.int16),
143+
aoti_torch_dtype_int32(): numpy.dtype(numpy.int32),
144+
aoti_torch_dtype_int64(): numpy.dtype(numpy.int64),
145+
aoti_torch_dtype_bool(): numpy.dtype(numpy.bool_),
146+
aoti_torch_dtype_complex64(): numpy.dtype(numpy.complex64),
153147
aoti_torch_dtype_complex128(): numpy.dtype(numpy.complex128),
154148
}
155-
if has_bfloat16:
149+
try:
156150
m[aoti_torch_dtype_bfloat16()] = numpy.dtype("bfloat16")
151+
except TypeError:
152+
pass
157153
return m
158154

159155

cuda_core/tests/test_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,7 @@ def test_torch_tensor_bridge_cpu(init_cuda):
845845
@_torch_skip
846846
def test_torch_tensor_bridge_decorator(init_cuda):
847847
"""Verify tensor bridge works through the args_viewable_as_strided_memory decorator."""
848+
848849
@args_viewable_as_strided_memory((0,))
849850
def fn(tensor, stream):
850851
return tensor.view(stream.handle)

0 commit comments

Comments
 (0)