@@ -3220,9 +3220,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
32203220 const uint32_t D_lsb = D ^ (D & (D-1));
32213221 uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
32223222
3223- // Nvidia prefers shared memory use to load large tiles of K
3223+ // Nvidia prefers shared memory use to load large tiles of K.
3224+ // Switch to loading from global memory when it would use too much shared memory.
32243225 // AMD prefers loading K directly from global memory
3225- const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA ? 1 : 0;
3226+ const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ? 1 : 0;
32263227
32273228 return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem};
32283229 };
@@ -5590,9 +5591,9 @@ static void ggml_vk_instance_init() {
55905591 // Check if there are two physical devices corresponding to the same GPU
55915592 // This handles the case where the same GPU appears with different drivers (e.g., RADV + AMDVLK on Linux),
55925593 // see https://github.com/ggml-org/llama.cpp/pull/7582 for original deduplication.
5593- // However, for MoltenVK on macOS, multiple GPUs on the same card may report the same UUID ,
5594- // see https://github.com/KhronosGroup/MoltenVK/issues/2683. Until this is fixed, we'll only deduplicate
5595- // when drivers differ (same driver + same UUID = likely different GPUs)
5594+ // MoltenVK on macOS may report the same UUID for distinct GPUs on multi-GPU cards ,
5595+ // see https://github.com/KhronosGroup/MoltenVK/issues/2683. Skip when both old/new
5596+ // driver is MoltenVK
55965597 auto old_device = std::find_if(
55975598 vk_instance.device_indices.begin(),
55985599 vk_instance.device_indices.end(),
@@ -5609,11 +5610,9 @@ static void ggml_vk_instance_init() {
56095610 old_id.deviceLUIDValid && new_id.deviceLUIDValid &&
56105611 std::equal(std::begin(old_id.deviceLUID), std::end(old_id.deviceLUID), std::begin(new_id.deviceLUID))
56115612 );
5613+ bool both_molten_vk = (new_driver.driverID == vk::DriverId::eMoltenvk && old_driver.driverID == vk::DriverId::eMoltenvk);
56125614
5613- // Only deduplicate if same UUID AND different drivers
5614- // (same driver + same UUID on MoltenVK = likely different GPUs on multi-GPU card)
5615- bool different_driver = (old_driver.driverID != new_driver.driverID);
5616- return same_uuid && different_driver;
5615+ return same_uuid && !both_molten_vk;
56175616 }
56185617 );
56195618 if (old_device == vk_instance.device_indices.end()) {
@@ -8450,7 +8449,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
84508449 const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
84518450 const uint32_t sfsh = Bc * sfshstride * acctype;
84528451
8453- const bool k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA;
8452+ const bool k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ;
84548453 const uint32_t kshstride = (k_load_shmem ? hsk_pad : MatBr) / 4 + 2;
84558454 const uint32_t vsh_stride = MatBc / 4 * row_split;
84568455 const uint32_t ksh = ((kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)) * f16vec4;
0 commit comments