diff --git a/src/decoupler/bm/_run.py b/src/decoupler/bm/_run.py index 7f4a35f9..9a330532 100644 --- a/src/decoupler/bm/_run.py +++ b/src/decoupler/bm/_run.py @@ -50,7 +50,9 @@ def _tensor_scores( def _tensor_truth(obs: pd.DataFrame, srcs: np.ndarray) -> pd.DataFrame: # Explode nested perturbs and pivot into mat - grts = obs.explode("source").pivot(columns="source", values="type_p").notna().astype(float).fillna(0.0) + grts = ( + obs.explode("source").pivot(columns="source", values="type_p").notna().astype(float).fillna(0.0).loc[obs.index] + ) miss_srcs = srcs[~np.isin(srcs, grts.columns)] miss_srcs = pd.DataFrame(0, index=grts.index, columns=miss_srcs) grts = pd.concat([grts, miss_srcs], axis=1) diff --git a/tests/bm/test_benchmark.py b/tests/bm/test_benchmark.py index ae5abd2d..14cf153d 100644 --- a/tests/bm/test_benchmark.py +++ b/tests/bm/test_benchmark.py @@ -13,7 +13,7 @@ [["auc"], None, "expr", False, 0.05, 5, False], [["auc", "fscore"], "group", "expr", False, 0.05, 5, False], [["auc", "fscore", "qrank"], None, "source", False, 0.05, 2, False], - [["auc", "fscore", "qrank"], "group", "source", False, 0.05, 1, False], + [["auc", "fscore", "qrank"], "class", "source", False, 0.05, 1, False], [["auc", "fscore", "qrank"], "bm_group", "expr", True, 0.05, 5, False], [["auc", "fscore", "qrank"], "source", "expr", True, 0.05, 5, False], ],