From 67cc8f5e3b7ccab9a1bc0738692c2a42cfc44dbb Mon Sep 17 00:00:00 2001 From: Puning97 <114408373+Puning97@users.noreply.github.com> Date: Mon, 11 May 2026 23:12:41 +0800 Subject: [PATCH 1/4] update es, satimp, and eua We update the following components: Extraction Strength (ES) metric for retain data. Hyperparameter default setting for SatImp New method EUA, which is accepted in ICML2026 --- configs/eval/tofu.yaml | 1 + .../retain_extraction_strength.yaml | 15 +++++ configs/trainer/EUA.yaml | 15 +++++ configs/trainer/SatImp.yaml | 2 +- src/evals/metrics/__init__.py | 2 + src/evals/metrics/memorization.py | 49 ++++++++++++++++ src/trainer/__init__.py | 2 + src/trainer/unlearn/eua.py | 36 ++++++++++++ src/trainer/utils.py | 56 +++++++++++++++++++ 9 files changed, 177 insertions(+), 1 deletion(-) create mode 100644 configs/eval/tofu_metrics/retain_extraction_strength.yaml create mode 100644 configs/trainer/EUA.yaml create mode 100644 src/trainer/unlearn/eua.py diff --git a/configs/eval/tofu.yaml b/configs/eval/tofu.yaml index 29e05e488..e1d4fd368 100644 --- a/configs/eval/tofu.yaml +++ b/configs/eval/tofu.yaml @@ -11,6 +11,7 @@ defaults: # include all defined metrics files - model_utility # populated in the metrics key as metrics.model_utility - privleak - extraction_strength + - retain_extraction_strength # - exact_memorization # - mia_min_k_plus_plus # - mia_min_k diff --git a/configs/eval/tofu_metrics/retain_extraction_strength.yaml b/configs/eval/tofu_metrics/retain_extraction_strength.yaml new file mode 100644 index 000000000..981851211 --- /dev/null +++ b/configs/eval/tofu_metrics/retain_extraction_strength.yaml @@ -0,0 +1,15 @@ +# @package eval.tofu.metrics.retain_extraction_strength +defaults: + - ../../data/datasets@datasets: TOFU_QA_retain_eval + - ../../collator@collators: DataCollatorForSupervisedDatasetwithIndex + # ^ get default dataset and generation config information + +handler: retain_extraction_strength +batch_size: ${eval.tofu.batch_size} + +datasets: + TOFU_QA_retain_eval: + args: + hf_args: + name: "retain_perturbed" + question_key: ${eval.tofu.question_key} \ No newline at end of file diff --git a/configs/trainer/EUA.yaml b/configs/trainer/EUA.yaml new file mode 100644 index 000000000..ad331b269 --- /dev/null +++ b/configs/trainer/EUA.yaml @@ -0,0 +1,15 @@ +defaults: + - GradDiff + +handler: EUA + +args: # HuggingFace TrainingArguments + learning_rate: 1e-5 + num_train_epochs: 10 + +method_args: + beta1: 0.1 + beta2: 1.0 + alpha: 1.0 #retain_loss + gamma: 0.05 #forget_loss + retain_loss_type: NLL \ No newline at end of file diff --git a/configs/trainer/SatImp.yaml b/configs/trainer/SatImp.yaml index f8d9c757b..3e27d8e14 100644 --- a/configs/trainer/SatImp.yaml +++ b/configs/trainer/SatImp.yaml @@ -9,7 +9,7 @@ args: # HuggingFace TrainingArguments method_args: beta1: 5.0 - beta2: 1.0 + beta2: 0.5 alpha: 1.0 gamma: 0.1 retain_loss_type: NLL \ No newline at end of file diff --git a/src/evals/metrics/__init__.py b/src/evals/metrics/__init__.py index 5afb04243..967e89a7c 100644 --- a/src/evals/metrics/__init__.py +++ b/src/evals/metrics/__init__.py @@ -7,6 +7,7 @@ rouge, truth_ratio, extraction_strength, + retain_extraction_strength, exact_memorization, ) from evals.metrics.privacy import ks_test, privleak, rel_diff @@ -62,6 +63,7 @@ def get_metrics(metric_cfgs: DictConfig, **kwargs): _register_metric(rel_diff) _register_metric(exact_memorization) _register_metric(extraction_strength) +_register_metric(retain_extraction_strength) # Register MIA metrics _register_metric(mia_loss) diff --git a/src/evals/metrics/memorization.py b/src/evals/metrics/memorization.py index c7bbe386c..9ddeb0a64 100644 --- a/src/evals/metrics/memorization.py +++ b/src/evals/metrics/memorization.py @@ -267,3 +267,52 @@ def _extraction_strength(model, batch): ) es_values = aggregate_to_1D(es_values) return {"agg_value": np.mean(es_values), "value_by_index": scores_by_index} + +@unlearning_metric(name="retain_extraction_strength") +def retain_extraction_strength(model, **kwargs): + data = kwargs["data"] + collator = kwargs["collators"] + batch_size = kwargs["batch_size"] + dataloader = DataLoader(data, batch_size=batch_size, collate_fn=collator) + + def _extraction_strength(model, batch): + log_probs_batch, labels_batch = tokenwise_vocab_logprobs( + model, batch, grad=False, return_labels=True + ) + es_batch = [] + for log_probs, labels in zip(log_probs_batch, labels_batch): + valid_len = len(labels) + preds = torch.argmax(log_probs, dim=-1) + for k in range(valid_len): + suff_preds = preds[k:] + suff_labels = labels[k:] + if torch.equal(suff_preds, suff_labels): + break + if valid_len == 0: + # Rarely, tokenization can result in a mismatch with no valid target + # tokens for loss computation (see preprocess_chat_instance() for + # reference). Since this condition makes no sense in terms of + # computing ES, we just choose to set ES=None + logger.warning( + "ES score for an instance is marked None, due to " + "tokenization issues that resulted in no valid target tokens." + ) + es_batch.append({"score": 0}) + else: + es_score = 1 - (k / valid_len) + es_batch.append({"score": es_score}) + return es_batch + + fun_args = {} + scores_by_index = run_batchwise_evals( + model, dataloader, _extraction_strength, fun_args, "Calculating ES" + ) + es_values = np.array( + [ + evals["score"] + for evals in scores_by_index.values() + if evals["score"] is not None + ] + ) + es_values = aggregate_to_1D(es_values) + return {"agg_value": np.mean(es_values), "value_by_index": scores_by_index} diff --git a/src/trainer/__init__.py b/src/trainer/__init__.py index 447b2d2dc..be7e5be37 100644 --- a/src/trainer/__init__.py +++ b/src/trainer/__init__.py @@ -15,6 +15,7 @@ from trainer.unlearn.satimp import SatImp from trainer.unlearn.wga import WGA from trainer.unlearn.pdu import PDU +from trainer.unlearn.eua import EUA import logging @@ -99,3 +100,4 @@ def load_trainer( _register_trainer(SatImp) _register_trainer(WGA) _register_trainer(PDU) +_register_trainer(EUA) diff --git a/src/trainer/unlearn/eua.py b/src/trainer/unlearn/eua.py new file mode 100644 index 000000000..cdab20c7a --- /dev/null +++ b/src/trainer/unlearn/eua.py @@ -0,0 +1,36 @@ +from trainer.unlearn.grad_diff import GradDiff +import torch +import torch.nn.functional as F +from trainer.utils import compute_eua_loss + +class EUA(GradDiff): + def __init__( + self, beta1=0.3, beta2=1.0, gamma=1.0, alpha=0.1, *args, **kwargs + ): # attention, satimp requires two beta!!!! + super().__init__(*args, **kwargs) + self.beta1 = beta1 + self.beta2 = beta2 + self.gamma = gamma + self.alpha = alpha + if self.ref_model is None: + self.ref_model = self._prepare_ref_model(self.model) + + def compute_loss(self, model, inputs, return_outputs=False): + forget_inputs = inputs["forget"] + forget_inputs = { + "input_ids": forget_inputs["input_ids"], + "attention_mask": forget_inputs["attention_mask"], + "labels": forget_inputs["labels"], + } + + retain_inputs = inputs["retain"] + retain_inputs = { + "input_ids": retain_inputs["input_ids"], + "attention_mask": retain_inputs["attention_mask"], + "labels": retain_inputs["labels"], + } + retain_loss = self.compute_retain_loss(model=model, retain_inputs=retain_inputs) + eua_loss, outputs = compute_eua_loss(model=model, forget_inputs=forget_inputs, retain_inputs=retain_inputs, beta1=self.beta1, beta2=self.beta2, ref_model=self.ref_model) + loss = self.gamma * eua_loss + self.alpha * retain_loss + + return (loss, outputs) if return_outputs else loss diff --git a/src/trainer/utils.py b/src/trainer/utils.py index 5bdb328f4..9da1342e5 100644 --- a/src/trainer/utils.py +++ b/src/trainer/utils.py @@ -132,3 +132,59 @@ def compute_satimp_loss(model, inputs, beta1, beta2): shift_labels.view(-1) != -100 ].mean() return forget_loss, outputs + +def compute_eua_loss(model, forget_inputs, retain_inputs,beta1, beta2, ref_model=None): + def get_preference_tensors(logits, ratio=0.1): + assert 0 < ratio < 1 + dim = logits.shape[1] + k = int(dim * ratio) + if k == 0: + raise ValueError("ratio too small, leading k=0.") + + # top ratio% + topk_values, topk_indices = torch.topk(logits, k, dim=1) + preference_positive = torch.zeros_like(logits) + preference_positive.scatter_(1, topk_indices, topk_values) + + # bottom ratio% + bottomk_values, bottomk_indices = torch.topk(-logits, k, dim=1) + preference_negative = torch.zeros_like(logits) + preference_negative.scatter_(1, bottomk_indices, logits.gather(1, bottomk_indices)) + + return preference_positive, preference_negative + #forget + outputs = model(**forget_inputs) + labels = forget_inputs["labels"] + labels = labels.to(outputs.logits.device) + + shift_logits = outputs.logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + en_out = -torch.logsumexp(shift_logits.view(-1, shift_logits.size(-1))/beta2, dim=1) + + #retain + retain_outputs = model(**retain_inputs) + retain_labels = retain_inputs["labels"] + retain_labels = retain_labels.to(retain_outputs.logits.device) + + shift_retain_logits = retain_outputs.logits[..., :-1, :].contiguous() + shift_retain_labels = retain_labels[..., 1:].contiguous() + en_in = -torch.logsumexp(shift_retain_logits.view(-1, shift_retain_logits.size(-1))/beta2, dim=1) + + with torch.no_grad(): + forget_outputs_oracle = ref_model(**forget_inputs) + retain_outputs_oracle = ref_model(**retain_inputs) + retain_logits_oracle = retain_outputs_oracle.logits[..., :-1, :].contiguous() + forget_logits_oracle = forget_outputs_oracle.logits[..., :-1, :].contiguous() + + forget_positive, forget_negative = get_preference_tensors(forget_logits_oracle.view(-1, forget_logits_oracle.size(-1)),ratio=beta1) + retain_positive, retain_negative = get_preference_tensors(retain_logits_oracle.view(-1, retain_logits_oracle.size(-1)),ratio=beta1) + + margin_out = -torch.logsumexp(forget_negative/beta2, dim=1) + margin_in = -torch.logsumexp(retain_positive/beta2, dim=1) + + eua_loss = (torch.pow(F.relu(en_in-margin_in), 2)[shift_retain_labels.view(-1) != -100].mean() + torch.pow(F.relu(margin_out-en_out), 2)[shift_labels.view(-1) != -100].mean()) + return eua_loss, outputs + + + + From 272e87b1712f1266acae58eb3246b7a75e9c82dd Mon Sep 17 00:00:00 2001 From: Puning97 <114408373+Puning97@users.noreply.github.com> Date: Fri, 15 May 2026 21:37:14 +0800 Subject: [PATCH 2/4] Update SatImp and ES for Retain --- README.md | 63 ++++++++++++++++-------------- configs/trainer/SatImp.yaml | 8 ++-- src/evals/metrics/memorization.py | 1 + src/trainer/__init__.py | 2 - src/trainer/unlearn/.DS_Store | Bin 0 -> 6148 bytes src/trainer/unlearn/eua.py | 36 ----------------- src/trainer/unlearn/satimp.py | 2 +- src/trainer/utils.py | 56 -------------------------- 8 files changed, 40 insertions(+), 128 deletions(-) create mode 100644 src/trainer/unlearn/.DS_Store delete mode 100644 src/trainer/unlearn/eua.py diff --git a/README.md b/README.md index 293bf34ca..1385f90a9 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@