Skip to content

Commit a3a42e4

Browse files
Update
[ghstack-poisoned]
1 parent 5707e2a commit a3a42e4

2 files changed

Lines changed: 6 additions & 8 deletions

File tree

backends/apple/metal/runtime/shims/utils.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ extern "C" {
2020
bool is_dtype_supported_in_et_metal(int32_t dtype) {
2121
switch (dtype) {
2222
case static_cast<int32_t>(SupportedDTypes::UINT8):
23+
case static_cast<int32_t>(SupportedDTypes::INT32):
2324
case static_cast<int32_t>(SupportedDTypes::INT64):
2425
case static_cast<int32_t>(SupportedDTypes::FLOAT32):
26+
case static_cast<int32_t>(SupportedDTypes::BOOL):
2527
case static_cast<int32_t>(SupportedDTypes::BFLOAT16):
2628
return true;
2729
default:
@@ -37,12 +39,8 @@ AOTITorchError validate_dtype(int32_t dtype) {
3739

3840
ET_LOG(
3941
Error,
40-
"Unsupported dtype: %d. Supported dtypes: %d (uint8), %d (int64), %d (float32), %d (bfloat16)",
41-
dtype,
42-
static_cast<int32_t>(SupportedDTypes::UINT8),
43-
static_cast<int32_t>(SupportedDTypes::INT64),
44-
static_cast<int32_t>(SupportedDTypes::FLOAT32),
45-
static_cast<int32_t>(SupportedDTypes::BFLOAT16));
42+
"Unsupported dtype: %d",
43+
dtype);
4644
return Error::InvalidArgument;
4745
}
4846

backends/apple/metal/runtime/shims/utils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ enum class SupportedDTypes : int32_t {
2222
UINT8 = 0, // PyTorch's uint8 dtype code
2323
// INT8 = 1, // PyTorch's int8 dtype code
2424
// INT16 = 2, // PyTorch's int16 dtype code
25-
// INT32 = 3, // PyTorch's int32 dtype code
25+
INT32 = 3, // PyTorch's int32 dtype code
2626
INT64 = 4, // PyTorch's int64 dtype code
2727
// FLOAT16 = 5, // PyTorch's float16 dtype code
2828
FLOAT32 = 6, // PyTorch's float32 dtype code
2929
// FLOAT64 = 7, // PyTorch's float64 dtype code
30-
// BOOL = 11, // PyTorch's bool dtype code
30+
BOOL = 11, // PyTorch's bool dtype code
3131
BFLOAT16 = 15 // PyTorch's bfloat16 dtype code
3232
};
3333

0 commit comments

Comments
 (0)