|
17 | 17 |
|
18 | 18 | import pandas as pd |
19 | 19 |
|
| 20 | +import bigframes.constants as constants |
20 | 21 | import bigframes.core as core |
21 | 22 | import bigframes.core.blocks as blocks |
22 | 23 | import bigframes.core.ordering as ordering |
@@ -576,3 +577,53 @@ def align_columns( |
576 | 577 | left_final = left_block.select_columns(left_column_ids) |
577 | 578 | right_final = right_block.select_columns(right_column_ids) |
578 | 579 | return left_final, right_final |
| 580 | + |
| 581 | + |
| 582 | +def idxmin(block: blocks.Block) -> blocks.Block: |
| 583 | + return _idx_extrema(block, "min") |
| 584 | + |
| 585 | + |
| 586 | +def idxmax(block: blocks.Block) -> blocks.Block: |
| 587 | + return _idx_extrema(block, "max") |
| 588 | + |
| 589 | + |
| 590 | +def _idx_extrema( |
| 591 | + block: blocks.Block, min_or_max: typing.Literal["min", "max"] |
| 592 | +) -> blocks.Block: |
| 593 | + if len(block.index_columns) != 1: |
| 594 | + # TODO: Need support for tuple dtype |
| 595 | + raise NotImplementedError( |
| 596 | + f"idxmin not support for multi-index. {constants.FEEDBACK_LINK}" |
| 597 | + ) |
| 598 | + |
| 599 | + original_block = block |
| 600 | + result_cols = [] |
| 601 | + for value_col in original_block.value_columns: |
| 602 | + direction = ( |
| 603 | + ordering.OrderingDirection.ASC |
| 604 | + if min_or_max == "min" |
| 605 | + else ordering.OrderingDirection.DESC |
| 606 | + ) |
| 607 | + # Have to find the min for each |
| 608 | + order_refs = [ |
| 609 | + ordering.OrderingColumnReference(value_col, direction), |
| 610 | + *[ |
| 611 | + ordering.OrderingColumnReference(idx_col) |
| 612 | + for idx_col in original_block.index_columns |
| 613 | + ], |
| 614 | + ] |
| 615 | + window_spec = core.WindowSpec(ordering=order_refs) |
| 616 | + idx_col = original_block.index_columns[0] |
| 617 | + block, result_col = block.apply_window_op( |
| 618 | + idx_col, agg_ops.first_op, window_spec |
| 619 | + ) |
| 620 | + result_cols.append(result_col) |
| 621 | + |
| 622 | + block = block.select_columns(result_cols).with_column_labels( |
| 623 | + original_block.column_labels |
| 624 | + ) |
| 625 | + # Stack the entire column axis to produce single-column result |
| 626 | + # Assumption: uniform dtype for stackability |
| 627 | + return block.aggregate_all_and_stack( |
| 628 | + agg_ops.AnyValueOp(), dtype=block.dtypes[0] |
| 629 | + ).with_column_labels([original_block.index.name]) |
0 commit comments