@@ -613,6 +613,33 @@ def test_dask_reduce_axis_subset():
613613 )
614614
615615
616+ @pytest .mark .parametrize ("group_idx" , [[0 , 1 , 0 ], [0 , 0 , 1 ], [1 , 0 , 0 ], [1 , 1 , 0 ]])
617+ @pytest .mark .parametrize (
618+ "func" ,
619+ [
620+ # "first", "last",
621+ "nanfirst" ,
622+ "nanlast" ,
623+ ],
624+ )
625+ @pytest .mark .parametrize (
626+ "chunks" ,
627+ [
628+ None ,
629+ pytest .param (1 , marks = pytest .mark .skipif (not has_dask , reason = "no dask" )),
630+ pytest .param (2 , marks = pytest .mark .skipif (not has_dask , reason = "no dask" )),
631+ pytest .param (3 , marks = pytest .mark .skipif (not has_dask , reason = "no dask" )),
632+ ],
633+ )
634+ def test_first_last_useless (func , chunks , group_idx ):
635+ array = np .array ([[0 , 0 , 0 ], [0 , 0 , 0 ]], dtype = np .int8 )
636+ if chunks is not None :
637+ array = dask .array .from_array (array , chunks = chunks )
638+ actual , _ = groupby_reduce (array , np .array (group_idx ), func = func , engine = "numpy" )
639+ expected = np .array ([[0 , 0 ], [0 , 0 ]], dtype = np .int8 )
640+ assert_equal (actual , expected )
641+
642+
616643@pytest .mark .parametrize ("func" , ["first" , "last" , "nanfirst" , "nanlast" ])
617644@pytest .mark .parametrize ("axis" , [(0 , 1 )])
618645def test_first_last_disallowed (axis , func ):
@@ -1563,18 +1590,36 @@ def test_validate_reindex_map_reduce(
15631590 dask_expected , reindex , func , expected_groups , any_by_dask
15641591) -> None :
15651592 actual = _validate_reindex (
1566- reindex , func , "map-reduce" , expected_groups , any_by_dask , is_dask_array = True
1593+ reindex ,
1594+ func ,
1595+ "map-reduce" ,
1596+ expected_groups ,
1597+ any_by_dask ,
1598+ is_dask_array = True ,
1599+ array_dtype = np .dtype ("int32" ),
15671600 )
15681601 assert actual is dask_expected
15691602
15701603 # always reindex with all numpy inputs
15711604 actual = _validate_reindex (
1572- reindex , func , "map-reduce" , expected_groups , any_by_dask = False , is_dask_array = False
1605+ reindex ,
1606+ func ,
1607+ "map-reduce" ,
1608+ expected_groups ,
1609+ any_by_dask = False ,
1610+ is_dask_array = False ,
1611+ array_dtype = np .dtype ("int32" ),
15731612 )
15741613 assert actual
15751614
15761615 actual = _validate_reindex (
1577- True , func , "map-reduce" , expected_groups , any_by_dask = False , is_dask_array = False
1616+ True ,
1617+ func ,
1618+ "map-reduce" ,
1619+ expected_groups ,
1620+ any_by_dask = False ,
1621+ is_dask_array = False ,
1622+ array_dtype = np .dtype ("int32" ),
15781623 )
15791624 assert actual
15801625
@@ -1584,19 +1629,37 @@ def test_validate_reindex() -> None:
15841629 for method in methods :
15851630 with pytest .raises (NotImplementedError ):
15861631 _validate_reindex (
1587- True , "argmax" , method , expected_groups = None , any_by_dask = False , is_dask_array = True
1632+ True ,
1633+ "argmax" ,
1634+ method ,
1635+ expected_groups = None ,
1636+ any_by_dask = False ,
1637+ is_dask_array = True ,
1638+ array_dtype = np .dtype ("int32" ),
15881639 )
15891640
15901641 methods : list [T_Method ] = ["blockwise" , "cohorts" ]
15911642 for method in methods :
15921643 with pytest .raises (ValueError ):
15931644 _validate_reindex (
1594- True , "sum" , method , expected_groups = None , any_by_dask = False , is_dask_array = True
1645+ True ,
1646+ "sum" ,
1647+ method ,
1648+ expected_groups = None ,
1649+ any_by_dask = False ,
1650+ is_dask_array = True ,
1651+ array_dtype = np .dtype ("int32" ),
15951652 )
15961653
15971654 for func in ["sum" , "argmax" ]:
15981655 actual = _validate_reindex (
1599- None , func , method , expected_groups = None , any_by_dask = False , is_dask_array = True
1656+ None ,
1657+ func ,
1658+ method ,
1659+ expected_groups = None ,
1660+ any_by_dask = False ,
1661+ is_dask_array = True ,
1662+ array_dtype = np .dtype ("int32" ),
16001663 )
16011664 assert actual is False
16021665
@@ -1608,6 +1671,7 @@ def test_validate_reindex() -> None:
16081671 expected_groups = np .array ([1 , 2 , 3 ]),
16091672 any_by_dask = False ,
16101673 is_dask_array = True ,
1674+ array_dtype = np .dtype ("int32" ),
16111675 )
16121676
16131677 assert _validate_reindex (
@@ -1617,6 +1681,7 @@ def test_validate_reindex() -> None:
16171681 expected_groups = np .array ([1 , 2 , 3 ]),
16181682 any_by_dask = True ,
16191683 is_dask_array = True ,
1684+ array_dtype = np .dtype ("int32" ),
16201685 )
16211686 assert _validate_reindex (
16221687 None ,
@@ -1625,8 +1690,24 @@ def test_validate_reindex() -> None:
16251690 expected_groups = np .array ([1 , 2 , 3 ]),
16261691 any_by_dask = True ,
16271692 is_dask_array = True ,
1693+ array_dtype = np .dtype ("int32" ),
1694+ )
1695+
1696+ kwargs = dict (
1697+ method = "blockwise" ,
1698+ expected_groups = np .array ([1 , 2 , 3 ]),
1699+ any_by_dask = True ,
1700+ is_dask_array = True ,
16281701 )
16291702
1703+ for func in ["nanfirst" , "nanlast" ]:
1704+ assert not _validate_reindex (None , func , array_dtype = np .dtype ("int32" ), ** kwargs ) # type: ignore[arg-type]
1705+ assert _validate_reindex (None , func , array_dtype = np .dtype ("float32" ), ** kwargs ) # type: ignore[arg-type]
1706+
1707+ for func in ["first" , "last" ]:
1708+ assert not _validate_reindex (None , func , array_dtype = np .dtype ("int32" ), ** kwargs ) # type: ignore[arg-type]
1709+ assert not _validate_reindex (None , func , array_dtype = np .dtype ("float32" ), ** kwargs ) # type: ignore[arg-type]
1710+
16301711
16311712@requires_dask
16321713def test_1d_blockwise_sort_optimization ():
0 commit comments