diff --git a/pyproject.toml b/pyproject.toml index 21bdbcd..36b8375 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "cell-eval" -version = "0.7.1" +version = "0.7.2" description = "Evaluation metrics for single-cell perturbation predictions" readme = "README.md" authors = [ diff --git a/src/cell_eval/_evaluator.py b/src/cell_eval/_evaluator.py index 30e4dd1..a7f777a 100644 --- a/src/cell_eval/_evaluator.py +++ b/src/cell_eval/_evaluator.py @@ -18,6 +18,19 @@ logger = logging.getLogger(__name__) +def _available_cpus() -> int: + """Return CPUs the current process is allowed to use. + + Uses ``os.sched_getaffinity`` on Linux so SLURM/cgroup/taskset limits are + respected; falls back to ``mp.cpu_count`` on macOS/Windows where that API + is unavailable (those platforms typically run locally without cgroup caps). + """ + try: + return len(os.sched_getaffinity(0)) + except AttributeError: + return mp.cpu_count() + + class MetricsEvaluator: """ Evaluates benchmarking metrics of a predicted and real anndata object. @@ -70,6 +83,9 @@ def __init__( # Enable a global string cache for categorical columns pl.enable_string_cache() + if num_threads == -1: + num_threads = _available_cpus() + if os.path.exists(outdir): logger.warning( f"Output directory {outdir} already exists, potential overwrite occurring" @@ -91,7 +107,7 @@ def __init__( anndata_pair=self.anndata_pair, de_pred=de_pred, de_real=de_real, - num_threads=num_threads if num_threads != -1 else mp.cpu_count(), + num_threads=num_threads, allow_discrete=allow_discrete, outdir=outdir, prefix=prefix,