diff --git a/xtuner/v1/rl/base/controller.py b/xtuner/v1/rl/base/controller.py index d5f10a15d..8f036ca99 100644 --- a/xtuner/v1/rl/base/controller.py +++ b/xtuner/v1/rl/base/controller.py @@ -8,6 +8,7 @@ from xtuner.v1.data_proto.sequence_context import SequenceContext from xtuner.v1.model.compose.base import BaseComposeConfig +from xtuner.v1.ray.utils import free_object_refs from xtuner.v1.train.trainer import LoadCheckpointConfig from xtuner.v1.utils import ray_method @@ -263,6 +264,13 @@ def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx: try: log_infos = ray.get(handles, timeout=TRAIN_RAY_GET_TIMEOUT) finally: + # free pixel values ref + free_pixel_value_refs: list[ray.ObjectRef] = [] + for data in packed_data_batches: + if data["seq_ctx"].pixel_values is not None: + free_pixel_value_refs.extend(data["seq_ctx"].pixel_values) + if len(free_pixel_value_refs) > 0: + free_object_refs(free_pixel_value_refs) del packed_data_batches return log_infos diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index 5bd6cbcc9..126f6ad1b 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -35,7 +35,6 @@ from xtuner.v1.model.utils.misc import ModelForwardExtraLogInfo from xtuner.v1.ray.base import SingleAcceleratorWorker from xtuner.v1.ray.config import RolloutConfig -from xtuner.v1.ray.utils import free_object_refs from xtuner.v1.rl.base.loss import BaseRLLossContext from xtuner.v1.train.trainer import LoadCheckpointConfig from xtuner.v1.utils import ( @@ -485,11 +484,7 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLo f"pixel_values should be list of tensor, got {type(pixel_values)}" ) pixel_value_refs = list(pixel_values) - try: - pixel_values = torch.cat(ray.get(pixel_value_refs), dim=0) - finally: - free_object_refs(pixel_value_refs) - + pixel_values = torch.cat(ray.get(pixel_value_refs), dim=0) seq_ctx.pixel_values = pixel_values rollout_routed_experts = seq_ctx.rollout_routed_experts diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index 62a5c7302..bd7081e19 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -120,7 +120,7 @@ def get_train_seq_ctx( ): seq_ctx = SequenceContext.from_input_ids((input_ids,), device="cpu") if multimodal_train_info and len(multimodal_train_info) > 0: - position_ids = multimodal_train_info.pop("position_ids") # (1,n) or (3,1,n) + position_ids = multimodal_train_info.get("position_ids") # (1,n) or (3,1,n) if position_ids is not None and len(position_ids.shape) == 3: # qwen3vl 需要特殊处理,其余的不需要额外处理 max_value = position_ids.max(dim=-1).values # (3,1) @@ -130,9 +130,8 @@ def get_train_seq_ctx( position_ids = torch.cat([position_ids, response_position_ids], dim=-1) seq_ctx.position_ids = position_ids # type: ignore[assignment] assert position_ids.size(-1) == input_ids.size(-1) - seq_ctx.pixel_values = multimodal_train_info.pop("pixel_values") - seq_ctx.image_grid_thw = multimodal_train_info.pop("image_grid_thw") - del multimodal_train_info + seq_ctx.pixel_values = multimodal_train_info.get("pixel_values") + seq_ctx.image_grid_thw = multimodal_train_info.get("image_grid_thw") return seq_ctx @@ -803,6 +802,8 @@ def _prepare_train_data(self, data_groups, pack_max_length, multimodal_train_inf seq_ctx.rollout_routed_experts = routed_experts # n,layer,expert data_batches.append(data_dict) + if multimodal_train_info is not None: + del multimodal_train_info random.shuffle(data_batches) rewards_t = torch.tensor(rewards_list).float() if rewards_list else torch.tensor([0.0]).float()