diff --git a/src/trainer/unlearn/ceu.py b/src/trainer/unlearn/ceu.py index 33da99c3c..18823f3b4 100644 --- a/src/trainer/unlearn/ceu.py +++ b/src/trainer/unlearn/ceu.py @@ -86,7 +86,7 @@ def __init__(self, ignore_first_n_answer_tokens=1, *args, **kwargs): super().__init__(*args, **kwargs) self.ignore_first_n_answer_tokens = ignore_first_n_answer_tokens - def compute_loss(self, model, inputs, return_outputs=False): + def compute_loss(self, model, inputs, return_outputs=False, **kwargs): forget_inputs = inputs["forget"] loss, outputs = compute_batch_ceu( model, diff --git a/src/trainer/unlearn/dpo.py b/src/trainer/unlearn/dpo.py index b64b474b4..e146cbad0 100644 --- a/src/trainer/unlearn/dpo.py +++ b/src/trainer/unlearn/dpo.py @@ -9,7 +9,7 @@ def __init__(self, beta=1.0, *args, **kwargs): if self.ref_model is None: self.ref_model = self._prepare_ref_model(self.model) - def compute_loss(self, model, inputs, return_outputs=False): + def compute_loss(self, model, inputs, return_outputs=False, **kwargs): forget_inputs = inputs["forget"]["original"] alternate_inputs = inputs["forget"]["alternate"] diff --git a/src/trainer/unlearn/grad_ascent.py b/src/trainer/unlearn/grad_ascent.py index eda8b4812..4e7630c02 100644 --- a/src/trainer/unlearn/grad_ascent.py +++ b/src/trainer/unlearn/grad_ascent.py @@ -2,7 +2,7 @@ class GradAscent(UnlearnTrainer): - def compute_loss(self, model, inputs, return_outputs=False): + def compute_loss(self, model, inputs, return_outputs=False, **kwargs): forget_inputs = inputs["forget"] forget_inputs = { "input_ids": forget_inputs["input_ids"], diff --git a/src/trainer/unlearn/grad_diff.py b/src/trainer/unlearn/grad_diff.py index bfecc19a2..1eb5a3013 100644 --- a/src/trainer/unlearn/grad_diff.py +++ b/src/trainer/unlearn/grad_diff.py @@ -38,7 +38,7 @@ def compute_retain_loss(self, model, retain_inputs): ) return retain_loss - def compute_loss(self, model, inputs, return_outputs=False): + def compute_loss(self, model, inputs, return_outputs=False, **kwargs): forget_inputs = inputs["forget"] forget_inputs = { "input_ids": forget_inputs["input_ids"], diff --git a/src/trainer/unlearn/npo.py b/src/trainer/unlearn/npo.py index 7c782d968..3a9a14fe5 100644 --- a/src/trainer/unlearn/npo.py +++ b/src/trainer/unlearn/npo.py @@ -9,7 +9,7 @@ def __init__(self, beta=1.0, *args, **kwargs): if self.ref_model is None: self.ref_model = self._prepare_ref_model(self.model) - def compute_loss(self, model, inputs, return_outputs=False): + def compute_loss(self, model, inputs, return_outputs=False, **kwargs): forget_inputs = inputs["forget"] forget_loss, forget_outputs = compute_dpo_loss( diff --git a/src/trainer/unlearn/pdu.py b/src/trainer/unlearn/pdu.py index e79bcc58b..9ec1d5800 100644 --- a/src/trainer/unlearn/pdu.py +++ b/src/trainer/unlearn/pdu.py @@ -102,7 +102,7 @@ def post_epoch_dual_param_update(self): ) self.log({"retain_preference": self.preferences[1]}) - def compute_loss(self, model, inputs, return_outputs=False): + def compute_loss(self, model, inputs, return_outputs=False, **kwargs): forget_inputs = inputs["forget"] forget_inputs = { "input_ids": forget_inputs["input_ids"], diff --git a/src/trainer/unlearn/rmu.py b/src/trainer/unlearn/rmu.py index d990d3a38..6091a92b1 100644 --- a/src/trainer/unlearn/rmu.py +++ b/src/trainer/unlearn/rmu.py @@ -136,7 +136,7 @@ def compute_retain_loss(self, model, retain_inputs): retain_loss = super().compute_retain_loss(model, retain_inputs) return retain_loss - def compute_loss(self, model, inputs, return_outputs=False): + def compute_loss(self, model, inputs, return_outputs=False, **kwargs): forget_inputs = inputs["forget"] forget_inputs = { "input_ids": forget_inputs["input_ids"], diff --git a/src/trainer/unlearn/satimp.py b/src/trainer/unlearn/satimp.py index b664390cd..d0dff9ab2 100644 --- a/src/trainer/unlearn/satimp.py +++ b/src/trainer/unlearn/satimp.py @@ -14,7 +14,7 @@ def __init__( if self.ref_model is None: self.ref_model = self._prepare_ref_model(self.model) - def compute_loss(self, model, inputs, return_outputs=False): + def compute_loss(self, model, inputs, return_outputs=False, **kwargs): forget_inputs = inputs["forget"] forget_inputs = { "input_ids": forget_inputs["input_ids"], diff --git a/src/trainer/unlearn/simnpo.py b/src/trainer/unlearn/simnpo.py index cb4f7f99c..5eed4d3e3 100644 --- a/src/trainer/unlearn/simnpo.py +++ b/src/trainer/unlearn/simnpo.py @@ -10,7 +10,7 @@ def __init__(self, delta=0.0, beta=1.0, *args, **kwargs): self.delta = delta self.beta = beta - def compute_loss(self, model, inputs, return_outputs=False): + def compute_loss(self, model, inputs, return_outputs=False, **kwargs): forget_inputs = inputs["forget"] forget_labels = forget_inputs["labels"] diff --git a/src/trainer/unlearn/undial.py b/src/trainer/unlearn/undial.py index e32147b30..653c1a293 100644 --- a/src/trainer/unlearn/undial.py +++ b/src/trainer/unlearn/undial.py @@ -9,7 +9,7 @@ def __init__(self, beta=1.0, *args, **kwargs): if self.ref_model is None: self.ref_model = self._prepare_ref_model(self.model) - def compute_loss(self, model, inputs, return_outputs=False): + def compute_loss(self, model, inputs, return_outputs=False, **kwargs): forget_inputs = inputs["forget"] forget_loss, forget_outputs = compute_undial_loss( model, self.ref_model, forget_inputs, self.beta diff --git a/src/trainer/unlearn/wga.py b/src/trainer/unlearn/wga.py index 08c4bf402..a371ae4c1 100644 --- a/src/trainer/unlearn/wga.py +++ b/src/trainer/unlearn/wga.py @@ -11,7 +11,7 @@ def __init__(self, beta=1.0, gamma=1.0, alpha=1.0, *args, **kwargs): if self.ref_model is None: self.ref_model = self._prepare_ref_model(self.model) - def compute_loss(self, model, inputs, return_outputs=False): + def compute_loss(self, model, inputs, return_outputs=False, **kwargs): forget_inputs = inputs["forget"] forget_inputs = { "input_ids": forget_inputs["input_ids"],