@@ -17,7 +17,6 @@ def initialize_de_comparison(
1717 pred : pl .DataFrame ,
1818 target_col : str = "target" ,
1919 feature_col : str = "feature" ,
20- fold_change_col : str = "fold_change" ,
2120 log2_fold_change_col : str = "log2_fold_change" ,
2221 abs_log2_fold_change_col : str = "abs_log2_fold_change" ,
2322 pvalue_col : str = "p_value" ,
@@ -27,7 +26,6 @@ def initialize_de_comparison(
2726 DEResults ,
2827 target_col = target_col ,
2928 feature_col = feature_col ,
30- fold_change_col = fold_change_col ,
3129 log2_fold_change_col = log2_fold_change_col ,
3230 abs_log2_fold_change_col = abs_log2_fold_change_col ,
3331 )
@@ -47,7 +45,6 @@ class DEResults:
4745 # Column names configuration
4846 target_col : str = "target"
4947 feature_col : str = "feature"
50- fold_change_col : str = "fold_change"
5148 log2_fold_change_col : str = "log2_fold_change"
5249 abs_log2_fold_change_col : str = "abs_log2_fold_change"
5350 pvalue_col : str = "p_value"
@@ -58,7 +55,7 @@ def __post_init__(self) -> None:
5855 required_cols = {
5956 self .target_col ,
6057 self .feature_col ,
61- self .fold_change_col ,
58+ self .log2_fold_change_col ,
6259 self .pvalue_col ,
6360 self .fdr_col ,
6461 }
@@ -67,7 +64,6 @@ def __post_init__(self) -> None:
6764 raise ValueError (f"Missing required columns: { missing } " )
6865
6966 numeric_cols = [
70- self .fold_change_col ,
7167 self .pvalue_col ,
7268 self .fdr_col ,
7369 self .log2_fold_change_col ,
@@ -80,31 +76,32 @@ def __post_init__(self) -> None:
8076 ]
8177
8278 logger .info (f"Checking DE data integrity... ({ self .name } )" )
83- fc_num_null = self .data .filter (pl .col (self .fold_change_col ).is_null ()).height
84- fc_num_inf = self .data .filter (pl .col (self .fold_change_col ).is_infinite ()).height
85- fc_num_nan = self .data .filter (pl .col (self .fold_change_col ).is_nan ()).height
86- if fc_num_null > 0 :
79+ lfc_num_null = self .data .filter (
80+ pl .col (self .log2_fold_change_col ).is_null ()
81+ ).height
82+ lfc_num_inf = self .data .filter (
83+ pl .col (self .log2_fold_change_col ).is_infinite ()
84+ ).height
85+ lfc_num_nan = self .data .filter (
86+ pl .col (self .log2_fold_change_col ).is_nan ()
87+ ).height
88+ if lfc_num_null > 0 :
8789 logger .warning (
88- f"Identified { fc_num_null } missing fold change values ({ self .name } )"
90+ f"Identified { lfc_num_null } missing log2 fold change values ({ self .name } )"
8991 )
90- if fc_num_inf > 0 :
92+ if lfc_num_inf > 0 :
9193 logger .warning (
92- f"Identified { fc_num_inf } infinite fold change values ({ self .name } )"
94+ f"Identified { lfc_num_inf } infinite log2 fold change values ({ self .name } )"
9395 )
94- if fc_num_nan > 0 :
96+ if lfc_num_nan > 0 :
9597 logger .warning (
96- f"Identified { fc_num_nan } NaN fold change values ({ self .name } )"
98+ f"Identified { lfc_num_nan } NaN log2 fold change values ({ self .name } )"
9799 )
98100 logger .info (f"DE data integrity check complete. ({ self .name } )" )
99101
100- # Add log2 fold change columns if not present
101- if self .log2_fold_change_col not in self .data .columns :
102+ # Derive abs(log2_fold_change) if not already provided.
103+ if self .abs_log2_fold_change_col not in self .data .columns :
102104 self .data = self .data .with_columns (
103- pl .col (self .fold_change_col )
104- .log (base = 2 )
105- .alias (self .log2_fold_change_col )
106- .fill_nan (0.0 )
107- ).with_columns (
108105 pl .col (self .log2_fold_change_col )
109106 .abs ()
110107 .alias (self .abs_log2_fold_change_col )
@@ -153,7 +150,10 @@ def get_top_genes(
153150 # Set FDR threshold if not provided
154151 fdr_threshold = fdr_threshold if fdr_threshold is not None else 0.05
155152
156- descending = sort_by in {DESortBy .FOLD_CHANGE , DESortBy .ABS_FOLD_CHANGE }
153+ descending = sort_by in {
154+ DESortBy .LOG2_FOLD_CHANGE ,
155+ DESortBy .ABS_LOG2_FOLD_CHANGE ,
156+ }
157157
158158 # Create a rank matrix where each row is the ordinal rank of a gene and each column is a perturbation.
159159 # The rank is sensitive to the sort-by column and is computed post-filtering for FDR.
@@ -219,7 +219,7 @@ def compute_overlap(
219219 k : int | None ,
220220 metric : Literal ["overlap" , "precision" ] = "overlap" ,
221221 fdr_threshold : float | None = None ,
222- sort_by : DESortBy = DESortBy .ABS_FOLD_CHANGE ,
222+ sort_by : DESortBy = DESortBy .ABS_LOG2_FOLD_CHANGE ,
223223 ) -> dict [str , float ]:
224224 """
225225 Compute overlap metrics across perturbations.
0 commit comments