File tree Expand file tree Collapse file tree
backends/candle/src/layers Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -8,13 +8,17 @@ use candle_index_select_cu;
88#[ inline]
99#[ allow( dead_code) ]
1010pub fn index_select ( tensor : & Tensor , ids : & Tensor , dim : usize ) -> Result < Tensor > {
11- if cfg ! ( feature = "cuda" )
12- && matches ! ( tensor. dtype( ) , DType :: F16 | DType :: F32 )
13- && matches ! ( ids. dtype( ) , DType :: U32 )
11+ #[ cfg( feature = "cuda" ) ]
12+ {
13+ if matches ! ( tensor. dtype( ) , DType :: F16 | DType :: F32 ) && matches ! ( ids. dtype( ) , DType :: U32 ) {
14+ // NOTE: `candle-index-select-cu` supports f16/f32 data and u32 indices
15+ candle_index_select_cu:: index_select ( tensor, ids, dim)
16+ } else {
17+ tensor. index_select ( ids, dim)
18+ }
19+ }
20+ #[ cfg( not( feature = "cuda" ) ) ]
1421 {
15- // NOTE: `candle-index-select-cu` supports f16/f32 data and u32 indices
16- candle_index_select_cu:: index_select ( tensor, ids, dim)
17- } else {
1822 tensor. index_select ( ids, dim)
1923 }
2024}
You can’t perform that action at this time.
0 commit comments