Skip to content

Commit ad27a45

Browse files
Metal backend: Add INT32 and BOOL dtype support (#18874)
AOTI-generated code creates int32 tensors for topk indexing and bool tensors for attention masks. Add both to the SupportedDTypes enum and validation function.
1 parent 401ea8e commit ad27a45

2 files changed

Lines changed: 5 additions & 10 deletions

File tree

backends/apple/metal/runtime/shims/utils.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ extern "C" {
2020
bool is_dtype_supported_in_et_metal(int32_t dtype) {
2121
switch (dtype) {
2222
case static_cast<int32_t>(SupportedDTypes::UINT8):
23+
case static_cast<int32_t>(SupportedDTypes::INT32):
2324
case static_cast<int32_t>(SupportedDTypes::INT64):
2425
case static_cast<int32_t>(SupportedDTypes::FLOAT32):
26+
case static_cast<int32_t>(SupportedDTypes::BOOL):
2527
case static_cast<int32_t>(SupportedDTypes::BFLOAT16):
2628
return true;
2729
default:
@@ -35,14 +37,7 @@ AOTITorchError validate_dtype(int32_t dtype) {
3537
return Error::Ok;
3638
}
3739

38-
ET_LOG(
39-
Error,
40-
"Unsupported dtype: %d. Supported dtypes: %d (uint8), %d (int64), %d (float32), %d (bfloat16)",
41-
dtype,
42-
static_cast<int32_t>(SupportedDTypes::UINT8),
43-
static_cast<int32_t>(SupportedDTypes::INT64),
44-
static_cast<int32_t>(SupportedDTypes::FLOAT32),
45-
static_cast<int32_t>(SupportedDTypes::BFLOAT16));
40+
ET_LOG(Error, "Unsupported dtype: %d", dtype);
4641
return Error::InvalidArgument;
4742
}
4843

backends/apple/metal/runtime/shims/utils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ enum class SupportedDTypes : int32_t {
2222
UINT8 = 0, // PyTorch's uint8 dtype code
2323
// INT8 = 1, // PyTorch's int8 dtype code
2424
// INT16 = 2, // PyTorch's int16 dtype code
25-
// INT32 = 3, // PyTorch's int32 dtype code
25+
INT32 = 3, // PyTorch's int32 dtype code
2626
INT64 = 4, // PyTorch's int64 dtype code
2727
// FLOAT16 = 5, // PyTorch's float16 dtype code
2828
FLOAT32 = 6, // PyTorch's float32 dtype code
2929
// FLOAT64 = 7, // PyTorch's float64 dtype code
30-
// BOOL = 11, // PyTorch's bool dtype code
30+
BOOL = 11, // PyTorch's bool dtype code
3131
BFLOAT16 = 15 // PyTorch's bfloat16 dtype code
3232
};
3333

0 commit comments

Comments
 (0)