@@ -87,6 +87,12 @@ def _get_group_id_driver(nditem: NdItem, a):
8787 a [i ] = g .get_group_id (0 )
8888
8989
90+ def _get_group_linear_id_driver (nditem : NdItem , a ):
91+ i = nditem .get_global_linear_id ()
92+ g = nditem .get_group ()
93+ a [i ] = g .get_group_linear_id ()
94+
95+
9096def _get_group_range_driver (nditem : NdItem , a ):
9197 i = nditem .get_global_id (0 )
9298 g = nditem .get_group ()
@@ -206,21 +212,29 @@ def test_no_item():
206212 )
207213
208214
209- def test_get_group_id ():
210- global_size = 100
211- group_size = 20
212- num_groups = global_size // group_size
215+ @pytest .mark .parametrize (
216+ "driver,rng" ,
217+ [
218+ (_get_group_id_driver , dpex .NdRange ((_SIZE ,), (_GROUP_SIZE ,))),
219+ (_get_group_linear_id_driver , dpex .NdRange ((_SIZE ,), (_GROUP_SIZE ,))),
220+ (
221+ _get_group_linear_id_driver ,
222+ dpex .NdRange ((1 , 1 , _SIZE ), (1 , 1 , _GROUP_SIZE )),
223+ ),
224+ ],
225+ )
226+ def test_get_group_id (driver , rng ):
227+ num_groups = _SIZE // _GROUP_SIZE
213228
214- a = dpnp .empty (global_size , dtype = dpnp .int32 )
215- ka = dpnp .empty (global_size , dtype = dpnp .int32 )
216- expected = np .empty (global_size , dtype = np .int32 )
217- ndrange = NdRange ((global_size ,), (group_size ,))
218- dpex_exp .call_kernel (dpex_exp .kernel (_get_group_id_driver ), ndrange , a )
219- kapi_call_kernel (_get_group_id_driver , ndrange , ka )
229+ a = dpnp .empty (_SIZE , dtype = dpnp .int32 )
230+ ka = dpnp .empty (_SIZE , dtype = dpnp .int32 )
231+ expected = np .empty (_SIZE , dtype = np .int32 )
232+ dpex_exp .call_kernel (dpex_exp .kernel (driver ), rng , a )
233+ kapi_call_kernel (driver , rng , ka )
220234
221235 for gid in range (num_groups ):
222- for lid in range (group_size ):
223- expected [gid * group_size + lid ] = gid
236+ for lid in range (_GROUP_SIZE ):
237+ expected [gid * _GROUP_SIZE + lid ] = gid
224238
225239 assert np .array_equal (a .asnumpy (), expected )
226240 assert np .array_equal (ka .asnumpy (), expected )
0 commit comments