Skip to content

Commit 3cfde35

Browse files
committed
Fix index_select feature gating
1 parent f6880c7 commit 3cfde35

1 file changed

Lines changed: 10 additions & 6 deletions

File tree

backends/candle/src/layers/index_select.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,17 @@ use candle_index_select_cu;
88
#[inline]
99
#[allow(dead_code)]
1010
pub fn index_select(tensor: &Tensor, ids: &Tensor, dim: usize) -> Result<Tensor> {
11-
if cfg!(feature = "cuda")
12-
&& matches!(tensor.dtype(), DType::F16 | DType::F32)
13-
&& matches!(ids.dtype(), DType::U32)
11+
#[cfg(feature = "cuda")]
12+
{
13+
if matches!(tensor.dtype(), DType::F16 | DType::F32) && matches!(ids.dtype(), DType::U32) {
14+
// NOTE: `candle-index-select-cu` supports f16/f32 data and u32 indices
15+
candle_index_select_cu::index_select(tensor, ids, dim)
16+
} else {
17+
tensor.index_select(ids, dim)
18+
}
19+
}
20+
#[cfg(not(feature = "cuda"))]
1421
{
15-
// NOTE: `candle-index-select-cu` supports f16/f32 data and u32 indices
16-
candle_index_select_cu::index_select(tensor, ids, dim)
17-
} else {
1822
tensor.index_select(ids, dim)
1923
}
2024
}

0 commit comments

Comments
 (0)