@@ -429,6 +429,18 @@ AOTITorchError aoti_torch_mps_gather_qmv(
429429 ET_LOG (Error, " aoti_torch_mps_gather_qmv: w must be 3D [E, N, K_packed], got %d" , (int )w_tensor->dim ());
430430 return Error::InvalidArgument;
431431 }
432+ if (s_tensor->dim () != 3 ) {
433+ ET_LOG (Error, " aoti_torch_mps_gather_qmv: scales must be 3D [E, N, K/gs], got %d" , (int )s_tensor->dim ());
434+ return Error::InvalidArgument;
435+ }
436+ if (z_tensor->dim () != 3 ) {
437+ ET_LOG (Error, " aoti_torch_mps_gather_qmv: biases must be 3D [E, N, K/gs], got %d" , (int )z_tensor->dim ());
438+ return Error::InvalidArgument;
439+ }
440+ if (idx_tensor->dim () != 1 ) {
441+ ET_LOG (Error, " aoti_torch_mps_gather_qmv: expert_indices must be 1D [P], got %d" , (int )idx_tensor->dim ());
442+ return Error::InvalidArgument;
443+ }
432444
433445 int32_t P = static_cast <int32_t >(x_tensor->sizes ()[0 ]);
434446 int32_t K = static_cast <int32_t >(x_tensor->sizes ()[1 ]);
@@ -444,6 +456,13 @@ AOTITorchError aoti_torch_mps_gather_qmv(
444456 return Error::InvalidArgument;
445457 }
446458
459+ // Validate expert_indices size matches P
460+ if (idx_tensor->sizes ()[0 ] != P) {
461+ ET_LOG (Error, " aoti_torch_mps_gather_qmv: expert_indices size %d != P=%d" ,
462+ (int )idx_tensor->sizes ()[0 ], P);
463+ return Error::InvalidArgument;
464+ }
465+
447466 // Determine dtype
448467 int32_t dtype = static_cast <int32_t >(x_tensor->scalar_type ());
449468 size_t element_size;
0 commit comments