File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -449,6 +449,25 @@ def group_terms_polars(df: pl.DataFrame) -> pl.DataFrame:
449449 return df
450450
451451
452+ def maybe_group_terms_polars (df : pl .DataFrame ) -> pl .DataFrame :
453+ """
454+ Group terms only if there are duplicate (labels, vars) pairs.
455+
456+ This avoids the expensive group_by operation when terms already
457+ reference distinct variables (e.g. ``x - y`` has ``_term=2`` but
458+ no duplicates). When skipping, columns are reordered to match the
459+ output of ``group_terms_polars``.
460+ """
461+ varcols = [c for c in df .columns if c .startswith ("vars" )]
462+ keys = [c for c in ["labels" ] + varcols if c in df .columns ]
463+ key_count = df .select (pl .struct (keys ).n_unique ()).item ()
464+ if key_count < df .height :
465+ return group_terms_polars (df )
466+ # Match column order of group_terms (group-by keys, coeffs, rest)
467+ rest = [c for c in df .columns if c not in keys and c != "coeffs" ]
468+ return df .select (keys + ["coeffs" ] + rest )
469+
470+
452471def save_join (* dataarrays : DataArray , integer_dtype : bool = False ) -> Dataset :
453472 """
454473 Join multiple xarray Dataarray's to a Dataset and warn if coordinates are not equal.
Original file line number Diff line number Diff line change 4040 generate_indices_for_printout ,
4141 get_dims_with_index_levels ,
4242 get_label_position ,
43- group_terms_polars ,
4443 has_optimized_model ,
4544 iterate_slices ,
45+ maybe_group_terms_polars ,
4646 maybe_replace_signs ,
4747 print_coord ,
4848 print_single_constraint ,
@@ -622,7 +622,7 @@ def to_polars(self) -> pl.DataFrame:
622622
623623 long = filter_nulls_polars (long )
624624 if ds .sizes .get ("_term" , 1 ) > 1 :
625- long = group_terms_polars (long )
625+ long = maybe_group_terms_polars (long )
626626 check_has_nulls_polars (long , name = f"{ self .type } { self .name } " )
627627
628628 # Build short DataFrame (labels, rhs, sign) without xarray broadcast.
Original file line number Diff line number Diff line change 6060 has_optimized_model ,
6161 is_constant ,
6262 iterate_slices ,
63+ maybe_group_terms_polars ,
6364 print_coord ,
6465 print_single_expression ,
6566 to_dataframe ,
@@ -1463,13 +1464,7 @@ def to_polars(self) -> pl.DataFrame:
14631464
14641465 df = to_polars (self .data )
14651466 df = filter_nulls_polars (df )
1466- if df ["vars" ].n_unique () < df .height :
1467- df = group_terms_polars (df )
1468- else :
1469- # Match column order of group_terms (group-by keys, coeffs, rest)
1470- varcols = [c for c in df .columns if c .startswith ("vars" )]
1471- rest = [c for c in df .columns if c not in varcols and c != "coeffs" ]
1472- df = df .select (varcols + ["coeffs" ] + rest )
1467+ df = maybe_group_terms_polars (df )
14731468 check_has_nulls_polars (df , name = self .type )
14741469 return df
14751470
You can’t perform that action at this time.
0 commit comments