Skip to content

Commit f4f616e

Browse files
Update
[ghstack-poisoned]
1 parent 4632a83 commit f4f616e

1 file changed

Lines changed: 19 additions & 0 deletions

File tree

backends/apple/metal/runtime/ops/op_gather_qmv.mm

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,18 @@ AOTITorchError aoti_torch_mps_gather_qmv(
429429
ET_LOG(Error, "aoti_torch_mps_gather_qmv: w must be 3D [E, N, K_packed], got %d", (int)w_tensor->dim());
430430
return Error::InvalidArgument;
431431
}
432+
if (s_tensor->dim() != 3) {
433+
ET_LOG(Error, "aoti_torch_mps_gather_qmv: scales must be 3D [E, N, K/gs], got %d", (int)s_tensor->dim());
434+
return Error::InvalidArgument;
435+
}
436+
if (z_tensor->dim() != 3) {
437+
ET_LOG(Error, "aoti_torch_mps_gather_qmv: biases must be 3D [E, N, K/gs], got %d", (int)z_tensor->dim());
438+
return Error::InvalidArgument;
439+
}
440+
if (idx_tensor->dim() != 1) {
441+
ET_LOG(Error, "aoti_torch_mps_gather_qmv: expert_indices must be 1D [P], got %d", (int)idx_tensor->dim());
442+
return Error::InvalidArgument;
443+
}
432444

433445
int32_t P = static_cast<int32_t>(x_tensor->sizes()[0]);
434446
int32_t K = static_cast<int32_t>(x_tensor->sizes()[1]);
@@ -444,6 +456,13 @@ AOTITorchError aoti_torch_mps_gather_qmv(
444456
return Error::InvalidArgument;
445457
}
446458

459+
// Validate expert_indices size matches P
460+
if (idx_tensor->sizes()[0] != P) {
461+
ET_LOG(Error, "aoti_torch_mps_gather_qmv: expert_indices size %d != P=%d",
462+
(int)idx_tensor->sizes()[0], P);
463+
return Error::InvalidArgument;
464+
}
465+
447466
// Determine dtype
448467
int32_t dtype = static_cast<int32_t>(x_tensor->scalar_type());
449468
size_t element_size;

0 commit comments

Comments
 (0)