diff --git a/examples/regression_selection_synthetic.py b/examples/regression_selection_synthetic.py new file mode 100644 index 0000000..252305a --- /dev/null +++ b/examples/regression_selection_synthetic.py @@ -0,0 +1,44 @@ +# Copyright (c) 2023-present, SUSTech-ML. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + + +import torch +import torch.nn as nn +import torch.optim as optim + +from examples.regression_cqr_synthetic import prepare_dataset +from torchcp.regression.utils import build_regression_model +from torchcp.selection.score import RES +from torchcp.selection.selector import ConformalSelector +from torchcp.selection.testing_correction import BH_procedure + + +# get dataloader +train_loader, cal_loader, test_loader = prepare_dataset(train_ratio=0.4, cal_ratio=0.2, batch_size=128) +# build regression model +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +model = build_regression_model("NonLinearNet")(next(iter(train_loader))[0].shape[1], 1, 64, 0.5).to(device) + +# train model +epochs = 100 +criterion = nn.MSELoss() +lr = 0.01 +optimizer = optim.Adam(model.parameters(), lr=lr) + +for tmp_x, tmp_y in train_loader: + outputs = model(tmp_x.to(device)) + loss = criterion(outputs, tmp_y.reshape(-1, 1).to(device)) + optimizer.zero_grad() + loss.backward() + optimizer.step() + +# Conformal Selection +thresholds = torch.ones(len(test_loader.dataset)) * 5 + +selector = ConformalSelector(score_function=RES(), testing_correction=BH_procedure(), model=model) +selector.calibrate(cal_loader) +print(selector.select(test_loader, thresholds)) diff --git a/torchcp/classification/loss/__init__.py b/torchcp/classification/loss/__init__.py index c163bbc..83cb900 100644 --- a/torchcp/classification/loss/__init__.py +++ b/torchcp/classification/loss/__init__.py @@ -8,5 +8,4 @@ from .cd import CDLoss from .conftr import ConfTrLoss from .confts import ConfTSLoss -from .uncertainty_aware import UncertaintyAwareLoss from .scpo import SCPOLoss \ No newline at end of file diff --git a/torchcp/classification/trainer/__init__.py b/torchcp/classification/trainer/__init__.py index 52161e4..a3538ef 100644 --- a/torchcp/classification/trainer/__init__.py +++ b/torchcp/classification/trainer/__init__.py @@ -10,6 +10,5 @@ from .confts_trainer import ConfTSTrainer from .model_zoo import TemperatureScalingModel from .ts_trainer import TSTrainer -from .ua_trainer import UncertaintyAwareTrainer from .ordinal_trainer import OrdinalTrainer from .scpo_trainer import SCPOTrainer \ No newline at end of file diff --git a/torchcp/regression/predictor/__init__.py b/torchcp/regression/predictor/__init__.py index 7eb405f..77bcbc3 100644 --- a/torchcp/regression/predictor/__init__.py +++ b/torchcp/regression/predictor/__init__.py @@ -9,4 +9,4 @@ from .ensemble import EnsemblePredictor from .split import SplitPredictor from .agaci import AgACIPredictor -from .cpd import ConformalPredictiveDistribution \ No newline at end of file +from .cpd import ConformalPredictiveDistribution diff --git a/torchcp/regression/score/__init__.py b/torchcp/regression/score/__init__.py index a68921b..1519448 100644 --- a/torchcp/regression/score/__init__.py +++ b/torchcp/regression/score/__init__.py @@ -12,4 +12,4 @@ from .cqrm import CQRM from .cqrr import CQRR from .r2ccp import R2CCP -from .sign import Sign \ No newline at end of file +from .sign import Sign diff --git a/torchcp/selection/__init__.py b/torchcp/selection/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/torchcp/selection/score/__init__.py b/torchcp/selection/score/__init__.py new file mode 100644 index 0000000..bb9d168 --- /dev/null +++ b/torchcp/selection/score/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2023-present, SUSTech-ML. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + + +from .clip import CLIP +from .res import RES \ No newline at end of file diff --git a/torchcp/selection/score/clip.py b/torchcp/selection/score/clip.py new file mode 100644 index 0000000..889843a --- /dev/null +++ b/torchcp/selection/score/clip.py @@ -0,0 +1,21 @@ +# Copyright (c) 2023-present, SUSTech-ML. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + + +import torch +from torchcp.regression.score.base import BaseScore + + +class CLIP(BaseScore): + """ + CLIP score (Jin et al., 2023), only apply to binary classification. + paper: https://arxiv.org/pdf/2210.01408 + """ + def __call__(self, predicts, y_truth, M=100): + if len(predicts.shape) == 2: + predicts = predicts.squeeze().view(-1) + return M * torch.max(predicts, 0) - predicts diff --git a/torchcp/selection/score/res.py b/torchcp/selection/score/res.py new file mode 100644 index 0000000..cd1667b --- /dev/null +++ b/torchcp/selection/score/res.py @@ -0,0 +1,20 @@ +# Copyright (c) 2023-present, SUSTech-ML. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + + +from torchcp.regression.score.base import BaseScore + + +class RES(BaseScore): + """ + RES score (Jin et al., 2023) + paper: https://arxiv.org/pdf/2210.01408 + """ + def __call__(self, predicts, y_truth): + if len(predicts.shape) == 2: + predicts = predicts.squeeze().view(-1) + return y_truth - predicts diff --git a/torchcp/selection/selector/__init__.py b/torchcp/selection/selector/__init__.py new file mode 100644 index 0000000..556c204 --- /dev/null +++ b/torchcp/selection/selector/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2023-present, SUSTech-ML. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from .conformal_selector import ConformalSelector \ No newline at end of file diff --git a/torchcp/selection/selector/conformal_selector.py b/torchcp/selection/selector/conformal_selector.py new file mode 100644 index 0000000..81efaae --- /dev/null +++ b/torchcp/selection/selector/conformal_selector.py @@ -0,0 +1,100 @@ +# Copyright (c) 2023-present, SUSTech-ML. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + + +import torch + +from torchcp.regression.predictor.split import SplitPredictor +from torchcp.selection.utils.metrics import Metrics + + +class ConformalSelector(SplitPredictor): + """ + Conformal Selection: + a screening procedure that aims to select candidates whose unobserved outcomes exceed user-specified value. + + Args: + score_function (torchcp.regression.scores): A class that implements the score function. + model (torch.nn.Module): A PyTorch model capable of outputting quantile values. + The model should be an initialization model that has not been trained. + alpha (float, optional): The significance level. Default is 0.1. + device (torch.device, optional): The device on which the model is located. Default is None. + + Reference: + Paper: Selection by Prediction with Conformal p-values (Jin et al., 2023) + Link: https://arxiv.org/pdf/2210.01408 + Github: https://github.com/ying531/conformal-selection + """ + + def __init__(self, score_function, testing_correction, model, alpha=0.1, device=None): + super().__init__(score_function, model, alpha, device) + self.testing_correction = testing_correction + self._metric = Metrics() + + + def calibrate(self, cal_dataloader): + self._model.eval() + predicts_list, y_truth_list = [], [] + with torch.no_grad(): + for tmp_x, tmp_labels in cal_dataloader: + tmp_x, tmp_labels = tmp_x.to(self._device), tmp_labels.to(self._device) + tmp_predicts = self._model(tmp_x).detach() + predicts_list.append(tmp_predicts) + y_truth_list.append(tmp_labels) + + predicts = torch.cat(predicts_list).float().to(self._device) + y_truth = torch.cat(y_truth_list).to(self._device) + self.cal_scores = self.score_function(predicts, y_truth) + + + def select(self, data_loader, thresholds): + """ + Evaluate the performance of conformal selection on a test dataset by calculating false discovery proportion + (FDP) and power of the selection set. + + Args: + data_loader (DataLoader): The DataLoader providing the test data batches. + thresholds (torch.Tensor): A tensor of user-defined thresholds. + + Returns: + dict: A dictionary containing: + - "False discovery proportion": The FDP of the selection set. + - "Power": The power of the selection set. + + Example:: + + >>> eval_results = selector.evaluate(test_loader, thresholds) + >>> print(eval_results) + """ + self._model.eval() + y_truth_list = [] + predicts_list = [] + with torch.no_grad(): + for examples in data_loader: + tmp_x, tmp_labels = examples[0].to(self._device), examples[1].to(self._device) + tmp_predicts = self._model(tmp_x).detach() + predicts_list.append(tmp_predicts) + y_truth_list.append(tmp_labels) + predicts = torch.cat(predicts_list).float().to(self._device) + y_truth = torch.cat(y_truth_list).to(self._device) + scores = self.score_function(predicts, thresholds) + + n_cal, n_test = self.cal_scores.shape[0], scores.shape[0] + + # Compute p-values with tie-breaking + u = torch.rand(n_test) + count_less = (self.cal_scores.view(1, n_cal) < scores.view(n_test, 1)).sum(dim=1) + count_tie = (self.cal_scores.view(1, n_cal) == scores.view(n_test, 1)).sum(dim=1) + 1 + p_values = (count_less + count_tie * u) / (n_cal + 1) + + indices = self.testing_correction(p_values, self.alpha) + + # Evaluation + res_dict = {"false_discovery_proportion": self._metric("false_discovery_proportion")(y_truth, thresholds, + indices), + "power": self._metric("power")(y_truth, thresholds, indices)} + return res_dict diff --git a/torchcp/selection/testing_correction/__init__.py b/torchcp/selection/testing_correction/__init__.py new file mode 100644 index 0000000..a0f6b8e --- /dev/null +++ b/torchcp/selection/testing_correction/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2023-present, SUSTech-ML. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + + +from .base import Base +from .bh_procedure import BH_procedure \ No newline at end of file diff --git a/torchcp/selection/testing_correction/base.py b/torchcp/selection/testing_correction/base.py new file mode 100644 index 0000000..6f18b93 --- /dev/null +++ b/torchcp/selection/testing_correction/base.py @@ -0,0 +1,27 @@ +# Copyright (c) 2023-present, SUSTech-ML. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + + +from abc import ABCMeta, abstractmethod +import torch +from tqdm import tqdm + +from torchcp.utils.common import get_device + + +class Base(object): + """ + Abstract base class for all multiple testing correction algorithms. + """ + __metaclass__ = ABCMeta + + def __init__(self) -> None: + pass + + @abstractmethod + def __call__(self, p_values, alpha): + raise NotImplementedError diff --git a/torchcp/selection/testing_correction/bh_procedure.py b/torchcp/selection/testing_correction/bh_procedure.py new file mode 100644 index 0000000..66e1705 --- /dev/null +++ b/torchcp/selection/testing_correction/bh_procedure.py @@ -0,0 +1,49 @@ +# Copyright (c) 2023-present, SUSTech-ML. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + + +import torch + +from torchcp.regression.score.base import BaseScore + + +class BH_procedure(BaseScore): + """ + Benjamini-Hochberg (BH) procedure: + finds a p-value threshold from a list of p-values to determine which null hypotheses to reject, given a target + FDR level 'alpha'. + + References: + Paper: Controlling the False Discovery Rate: A Practical and Powerful Approach to Multiple Testing + (Benjamini and Hochberg, 1995) + Link: https://www.jstor.org/stable/2346101 + """ + def __init__(self): + super().__init__() + + def __call__(self, p_values, alpha): + """ + Apply the Benjamini-Hochberg procedure. + + Args: + p_values (torch.Tensor): A 1D tensor of p-values. + alpha (float): The desired False Discovery Rate (FDR) level (e.g., 0.1). + + Returns: + torch.Tensor: A 1D tensor of indices corresponding to the p-values (hypotheses) that are rejected. + """ + p_values_sorted, _ = torch.sort(p_values) + n_test = p_values_sorted.shape[0] + + k_range = torch.arange(1, n_test + 1, device=p_values_sorted.device) + thresholds = k_range * alpha / n_test + mask = p_values_sorted <= thresholds + k_star = torch.max(torch.where(mask, k_range, torch.zeros_like(k_range))) if mask.any() else 0 + threshold = (k_star * alpha / n_test) if k_star > 0 else 0 + indices = torch.nonzero(p_values <= threshold, as_tuple=False).squeeze() + + return indices diff --git a/torchcp/selection/utils/__init__.py b/torchcp/selection/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/torchcp/selection/utils/metrics.py b/torchcp/selection/utils/metrics.py new file mode 100644 index 0000000..5f71b09 --- /dev/null +++ b/torchcp/selection/utils/metrics.py @@ -0,0 +1,64 @@ +# Copyright (c) 2023-present, SUSTech-ML. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from typing import Any + +import torch + +from torchcp.utils.registry import Registry + +METRICS_REGISTRY_REGRESSION = Registry("METRICS") + + +@METRICS_REGISTRY_REGRESSION.register() +def false_discovery_proportion(y_truth, thresholds, indices): + """ + Conpute the false discovery proportion (the proportion of false discovery among all selected points) of the + selection set. + + Args: + y_truth (torch.Tensor): A tensor of ground truth values. + thresholds (torch.Tensor): Tensor of user-defined thresholds. + indices (torch.Tensor): A tensor containing the indices of selected points. + + Returns: + torch.Tensor: The false discovery proportion of the selection set. + """ + if indices.dim() == 0: + indices = indices.unsqueeze(0) + + false_positives = torch.sum(y_truth[indices] <= thresholds[indices]) + fdp = false_positives / indices.shape[-1] if indices.shape[-1] > 0 else torch.tensor(0.) + return fdp.item() + + +@METRICS_REGISTRY_REGRESSION.register() +def power(y_truth, thresholds, indices): + """ + Conpute the power (the proportion of desirable points that are correctly selected) of the selection set. + + Args: + y_truth (torch.Tensor): A tensor of ground truth values. + thresholds (torch.Tensor): Tensor of user-defined thresholds. + indices (torch.Tensor): A tensor containing the indices of selected points. + + Returns: + torch.Tensor: The power of the selection set. + """ + if indices.dim() == 0: + indices = indices.unsqueeze(0) + + true_positives = torch.sum(y_truth[indices] > thresholds[indices]) + power = true_positives / torch.sum(y_truth > thresholds) + return power.item() + + +class Metrics: + def __call__(self, metric) -> Any: + if metric not in METRICS_REGISTRY_REGRESSION.registered_names(): + raise NameError(f"The metric: {metric} is not defined in TorchCP.") + return METRICS_REGISTRY_REGRESSION.get(metric) \ No newline at end of file diff --git a/torchcp/utils/metrics.py b/torchcp/utils/metrics.py new file mode 100644 index 0000000..e69de29