Skip to content

Commit 8eb18c5

Browse files
committed
[Kernel][gfx1250] Fix FlyDSL MXScale import and arch gates
1 parent cd8989b commit 8eb18c5

3 files changed

Lines changed: 17 additions & 3 deletions

File tree

aiter/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def getLogger():
127127
gemm_mxfp8,
128128
gemm_mxa8w4,
129129
)
130-
except ImportError:
130+
except (ImportError, RuntimeError, OSError, KeyError):
131131
pass
132132

133133
# Import Triton-based communication primitives from ops.triton.comms (optional, only if Iris is available)

aiter/ops/flydsl/mxscale_gemm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from torch import Tensor
2424

2525
from aiter import logger
26-
from aiter.jit.utils.chip_info import get_gfx
26+
from aiter.jit.utils.chip_info import get_gfx_runtime
2727

2828
from .mxscale_layout import (
2929
SCALE_BLOCK,
@@ -273,7 +273,7 @@ def flydsl_mxscale_gemm(
273273
raise ValueError(
274274
f"data_format must be one of {_VALID_FORMATS}, got {data_format!r}"
275275
)
276-
cur_gfx = get_gfx()
276+
cur_gfx = get_gfx_runtime()
277277
if cur_gfx != _TARGET_GFX:
278278
raise RuntimeError(
279279
f"flydsl_mxscale_gemm requires {_TARGET_GFX}, current arch is {cur_gfx!r}"

aiter/ops/flydsl/test_flydsl_mxscale_gemm.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,20 @@ def test_format_named_wrappers_reject_data_format_kwarg():
334334
gemm_mxa8w4(a, b, s, s, data_format="a8w4")
335335

336336

337+
def test_runtime_arch_gate_ignores_gpu_archs(monkeypatch):
338+
from aiter.jit.utils import chip_info
339+
340+
monkeypatch.setenv("GPU_ARCHS", "gfx950")
341+
chip_info.get_gfx.cache_clear()
342+
chip_info.get_gfx_custom_op_core.cache_clear()
343+
344+
a = torch.zeros((32,), dtype=torch.uint8)
345+
b = torch.zeros((32,), dtype=torch.uint8)
346+
s = torch.zeros((1, 1), dtype=torch.uint8)
347+
with pytest.raises(ValueError, match="A and B must be 2-D"):
348+
flydsl_mxscale_gemm(a, b, s, s, data_format="fp8")
349+
350+
337351
# ---------------------------------------------------------------------------
338352
# GPU correctness tests (gfx1250 + flydsl required, gated above)
339353
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)