Skip to content

Commit 9df7e71

Browse files
committed
Delete legacy DPO implementation
1 parent 4bcec6a commit 9df7e71

17 files changed

Lines changed: 42 additions & 1050 deletions

src/maxtext/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737

3838
from maxtext.configs import pyconfig
3939
from maxtext.models import models
40-
from maxtext.trainers.post_train.dpo import dpo_utils
4140
from maxtext.utils import maxtext_utils
4241
from maxtext.utils import model_creation_utils
4342

src/maxtext/common/metric_logger.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,8 @@ def _log_eval_metrics(self, metrics, step):
225225
f"avg_mtp_acceptance_rate={scalars['eval/avg_mtp_acceptance_rate_percent']:.2f}%",
226226
]
227227
)
228-
if self.config.use_dpo:
229-
log_parts.append(f"dpo_reward_accuracy={scalars['eval/dpo_reward_accuracy']:.3f}")
228+
if "eval/avg_dpo_reward_accuracy" in scalars:
229+
log_parts.append(f"dpo_reward_accuracy={scalars['eval/avg_dpo_reward_accuracy']:.3f}")
230230
max_logging.log(", ".join(log_parts))
231231

232232
def _log_running_eval_metrics(self, metrics, step):
@@ -421,10 +421,6 @@ def _accumulate_eval_metrics(self, metrics):
421421
scalar.get("evaluation/mtp_acceptance_rate_percent", 0.0)
422422
)
423423
self.cumulative_eval_metrics["scalar"]["eval/z_loss"] += float(scalar.get("evaluation/z_loss", 0.0))
424-
if self.config.use_dpo:
425-
self.cumulative_eval_metrics["scalar"]["eval/dpo_reward_accuracy"] += float(
426-
scalar.get("evaluation/dpo_reward_accuracy", 0.0)
427-
)
428424

429425
def record_train_metrics(self, metrics, step, step_time):
430426
"""Records training metrics for the current step."""
@@ -454,8 +450,7 @@ def _finalize_eval_metrics(self, train_step):
454450
cumulative["eval/avg_mtp_loss"] = cumulative["eval/mtp_loss"] / eval_step_count
455451
cumulative["eval/avg_mtp_acceptance_rate_percent"] = cumulative["eval/mtp_acceptance_rate_percent"] / eval_step_count
456452
cumulative["eval/avg_z_loss"] = cumulative["eval/z_loss"] / eval_step_count
457-
if self.config.use_dpo:
458-
cumulative["eval/dpo_reward_accuracy"] = cumulative["eval/dpo_reward_accuracy"] / eval_step_count
453+
459454
self.write_metrics(self.cumulative_eval_metrics, train_step, metric_type="eval")
460455
self._pending_eval_step_count = 0
461456
if self.config.target_eval_loss and eval_loss <= self.config.target_eval_loss:

src/maxtext/common/train_state_nnx.py

Lines changed: 5 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -30,31 +30,21 @@ class TrainStateNNX(nnx.Module):
3030
{"params": {...}, "opt_state": {}...}
3131
TrainStateNNX state pytree:
3232
{"model": {...}, "optimizer": {"opt_state": {...}}}
33-
34-
For DPO (Direct Preference Optimization), an optional `reference_model`
35-
carries a frozen copy of the same architecture used to compute reference
36-
log-probabilities. Only `model` is updated by `apply_gradients`; the
37-
reference is held alongside so it is sharded, jit-traced, and checkpointed
38-
with the rest of the train state.
3933
"""
4034

4135
def __init__(
4236
self,
4337
model: nnx.Module,
4438
optimizer: nnx.Optimizer | None,
45-
reference_model: nnx.Module | None = None,
4639
):
4740
self.model = model
4841
self.optimizer = optimizer
49-
if reference_model is not None:
50-
self.reference_model = reference_model
5142

5243
def apply_gradients(self, grads: Any):
5344
"""Mimics the Linen apply_gradients function.
5445
5546
Updates the optimizer state, applies updates to parameters, and increments
56-
the step counter. Only updates `self.model`; `self.reference_model` (if
57-
present) is left untouched.
47+
the step counter. Only updates `self.model`.
5848
"""
5949
if self.optimizer is None:
6050
raise RuntimeError(
@@ -88,9 +78,7 @@ def _cast_step(step, dtype):
8878
values.
8979
"""
9080
if isinstance(step, jax.ShapeDtypeStruct):
91-
return jax.ShapeDtypeStruct(
92-
step.shape, dtype, sharding=getattr(step, "sharding", None)
93-
)
81+
return jax.ShapeDtypeStruct(step.shape, dtype, sharding=getattr(step, "sharding", None))
9482
return jnp.asarray(step, dtype=dtype)
9583

9684

@@ -117,10 +105,7 @@ def _wrap_mu_nu_with_params(state):
117105
"""Wraps mu/nu under an inner 'params' key (the Linen collection)."""
118106
if not isinstance(state, dict):
119107
return state
120-
return {
121-
k: {"params": v} if k in ("mu", "nu") and isinstance(v, dict) else v
122-
for k, v in state.items()
123-
}
108+
return {k: {"params": v} if k in ("mu", "nu") and isinstance(v, dict) else v for k, v in state.items()}
124109

125110

126111
def _as_chain_index(key):
@@ -172,23 +157,14 @@ def _strip_mu_nu_params(state):
172157
if not isinstance(state, dict):
173158
return state
174159
return {
175-
k: (
176-
v["params"]
177-
if k in ("mu", "nu") and isinstance(v, dict) and "params" in v
178-
else v
179-
)
180-
for k, v in state.items()
160+
k: (v["params"] if k in ("mu", "nu") and isinstance(v, dict) and "params" in v else v) for k, v in state.items()
181161
}
182162

183163

184164
def _opt_state_from_linen(opt_state):
185165
"""Inverse of `_opt_state_to_linen`: Linen list-with-None -> NNX int-keyed dict."""
186166
if isinstance(opt_state, list):
187-
return {
188-
i: _strip_mu_nu_params(e)
189-
for i, e in enumerate(opt_state)
190-
if isinstance(e, dict)
191-
}
167+
return {i: _strip_mu_nu_params(e) for i, e in enumerate(opt_state) if isinstance(e, dict)}
192168
if not isinstance(opt_state, dict):
193169
return opt_state
194170
return {0: _strip_mu_nu_params(opt_state)}

src/maxtext/input_pipeline/grain_data_processing.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -263,32 +263,6 @@ def pretrain_preprocessing_pipeline(
263263
return dataset
264264

265265

266-
def dpo_preprocessing_pipeline(
267-
dataset,
268-
config,
269-
data_columns,
270-
tokenize,
271-
grain_worker_count,
272-
grain_per_worker_buffer_size,
273-
):
274-
"""Use grain to pre-process the dataset and return iterators for dpo fine-tuning"""
275-
dataset = data_processing_utils.parse_and_keep_features(dataset, config, data_columns, tokenize)
276-
tokenizer_model, pad_id = data_processing_utils.get_tokenizer_and_pad_id(config)
277-
278-
if tokenize:
279-
dataset = dataset.map(grain_tokenizer.TokenizeAndTrim(data_columns, config.max_target_length, tokenizer_model))
280-
281-
batch_size = config.global_batch_size_to_load // jax.process_count()
282-
# DPO scores full sequences, so no shift.
283-
dataset = data_processing_utils.format_and_batch(
284-
dataset, config, batch_size, pad_id, data_columns, tokenizer_model, shift=False
285-
)
286-
dataset = data_processing_utils.apply_multiprocessing_and_prefetch(
287-
dataset, config, grain_worker_count, grain_per_worker_buffer_size
288-
)
289-
return dataset
290-
291-
292266
def _format_chat_template_grain(element, data_columns, tokenizer_model):
293267
"""Grain-compatible mapping function to format raw columns into conversational messages."""
294268
# Convert raw columns to conversational messages
@@ -376,8 +350,6 @@ def sft_preprocessing_pipeline(
376350

377351
def _get_pipeline_fn(config):
378352
"""Returns the appropriate preprocessing pipeline function based on config."""
379-
if config.use_dpo:
380-
return dpo_preprocessing_pipeline
381353
if config.use_sft:
382354
return sft_preprocessing_pipeline
383355
return pretrain_preprocessing_pipeline

src/maxtext/input_pipeline/tfds_data_processing.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ def preprocessing_pipeline(
9191
shift: bool = True,
9292
drop_remainder: bool = True,
9393
prefetch_size=tf.data.experimental.AUTOTUNE,
94-
use_dpo: bool = False,
9594
hf_access_token: str = "",
9695
):
9796
"""pipeline for preprocessing TFDS dataset."""
@@ -115,15 +114,11 @@ def preprocessing_pipeline(
115114
"Set tokenize_train_data or tokenize_eval_data to True if your dataset needs tokenization."
116115
)
117116

118-
if not use_dpo:
119-
assert len(data_column_names) == 1
120-
dataset = dataset.map(
121-
lambda x: input_pipeline_utils.normalize_features(x, data_column_names[0]), num_parallel_calls=AUTOTUNE
122-
)
123-
else:
124-
dataset = dataset.map(lambda x: {col: x[col] for col in data_column_names}, num_parallel_calls=AUTOTUNE)
125-
126-
data_column_names = data_column_names if use_dpo else ("inputs", "targets")
117+
assert len(data_column_names) == 1
118+
dataset = dataset.map(
119+
lambda x: input_pipeline_utils.normalize_features(x, data_column_names[0]), num_parallel_calls=AUTOTUNE
120+
)
121+
data_column_names = ("inputs", "targets")
127122

128123
tokenizer_model = input_pipeline_utils.get_tokenizer(tokenizer_path, tokenizer_type, add_bos, add_eos, hf_access_token)
129124
if tokenizer_model.pad_id is not None:
@@ -144,7 +139,7 @@ def preprocessing_pipeline(
144139
if max_target_length > 0:
145140
# in pre-training we can take upto max_length+1 because there would be truncation by
146141
# 1 token for both inputs and targets
147-
extra_tokens = 1 if not use_dpo else 0
142+
extra_tokens = 1
148143
dataset = dataset.map(
149144
lambda x: input_pipeline_utils.truncate_to_max_allowable_length(x, max_target_length + extra_tokens),
150145
num_parallel_calls=AUTOTUNE,
@@ -157,13 +152,13 @@ def preprocessing_pipeline(
157152
dataset = dataset.repeat(num_epochs)
158153

159154
# Shift inputs for teacher-forced training
160-
if shift and not use_dpo:
155+
if shift:
161156
dataset = dataset.map(
162157
input_pipeline_utils.shift_data_by_truncation, num_parallel_calls=tf.data.AUTOTUNE, deterministic=True
163158
)
164159

165160
# Perform greedy sequence packing and batching
166-
if pack_examples and not use_dpo:
161+
if pack_examples:
167162
dataset = sequence_packing.pack_dataset(dataset, max_target_length, pad_id)
168163
dataset = dataset.batch(global_batch_size // jax.process_count(), drop_remainder=drop_remainder)
169164
else:
@@ -223,7 +218,6 @@ def make_tfds_train_iterator(
223218
add_eos=config.add_eos,
224219
num_epochs=config.num_epoch,
225220
pack_examples=config.packing,
226-
use_dpo=config.use_dpo,
227221
hf_access_token=config.hf_access_token,
228222
)
229223
return multihost_dataloading.MultiHostDataLoadIterator(
@@ -248,7 +242,6 @@ def make_tfds_train_iterator(
248242
add_eos=config.add_eos,
249243
num_epochs=config.num_epoch,
250244
pack_examples=config.packing,
251-
use_dpo=config.use_dpo,
252245
hf_access_token=config.hf_access_token,
253246
)
254247
global_shape = (config.global_batch_size_to_load, config.max_target_length)
@@ -289,7 +282,6 @@ def make_tfds_eval_iterator(
289282
add_bos=config.add_bos,
290283
add_eos=config.add_eos,
291284
pack_examples=config.packing,
292-
use_dpo=config.use_dpo,
293285
hf_access_token=config.hf_access_token,
294286
)
295287
return multihost_dataloading.MultiHostDataLoadIterator(
@@ -317,7 +309,6 @@ def make_tfds_eval_iterator(
317309
add_bos=config.add_bos,
318310
add_eos=config.add_eos,
319311
pack_examples=config.packing,
320-
use_dpo=config.use_dpo,
321312
hf_access_token=config.hf_access_token,
322313
)
323314
global_shape = (config.global_batch_size_to_load_eval, config.max_target_length)

0 commit comments

Comments
 (0)