Skip to content

Commit 87e50ab

Browse files
authored
[Fix] Update type hint for position_ids, update max_lengths type in LengthGroupedSampler, and handle fp32 lmhead hf weight (#1636)
* [Fix] Update type hint for position_ids in MultimodalTrainInfo and assert max_lengths type in LengthGroupedSampler - Changed position_ids to allow None in MultimodalTrainInfo. - Updated max_lengths assertion to include Column type in LengthGroupedSampler. - Adjusted reward assignment in JudgerController to use "score" instead of "weighted_score". - Enhanced tensor handling in BaseModel for FSDP with specific conditions for lm_head.weight. * fix lint * fix ut * fix ut
1 parent 59120b6 commit 87e50ab

5 files changed

Lines changed: 28 additions & 7 deletions

File tree

tests/ray/test_judger.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def test_gsm8k_multi_judger(self):
189189
pg = AutoCPUWorkers.build_placement_group(cpu_resources_config)
190190
judger_controller = JudgerController.remote(judger_cfg, pg)
191191
res3 = ray.get(judger_controller.run.remote(FAKE_JUDGER_INPUT_ITEM_MULTI_SOURCE))
192-
self.assertEqual(res3.reward["weighted_score"], 1.0) # weighted_score为固定字段,表示加权后的reward
192+
self.assertEqual(res3.reward["score"], 1.0)
193193

194194
def test_gsm8k_judger_score(self):
195195
"""Test the judger functionality with single and multiple data sources."""

xtuner/v1/data_proto/rl_data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ class RLUIDItem(BaseModel):
8686
class MultimodalTrainInfo(TypedDict):
8787
pixel_values: NotRequired[torch.Tensor | RayObjectRef | None] # type: ignore[valid-type]
8888
image_grid_thw: NotRequired[torch.Tensor]
89-
position_ids: NotRequired[torch.Tensor]
89+
position_ids: NotRequired[torch.Tensor | None]
9090

9191

9292
class RLDatasetItem(BaseModel):
@@ -206,7 +206,7 @@ class RLJudgerResponseItem(BaseModel):
206206
207207
Attributes:
208208
uid (Optional[int]): A unique ID to identify which input the result corresponds to.
209-
reward (Dict[str, Any]): A dictionary of reward scores, e.g., {"judger_type": reward_score, "weighted_scores": score}.
209+
reward (Dict[str, Any]): A dictionary of reward scores, e.g., {"score": score}.
210210
extra_info (Dict[str, Any]): Additional user-defined information.
211211
"""
212212

xtuner/v1/datasets/sampler.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515
from .packing import _LegacySoftPackDataset
1616

1717

18+
try:
19+
from datasets.arrow_dataset import Column
20+
except ImportError:
21+
Column = type("Column", (), {}) # Dummy type for backward compatibility
22+
1823
logger = get_logger()
1924

2025

@@ -221,7 +226,7 @@ def __init__(
221226
self.group_size = self.world_size
222227

223228
self.max_lengths = self.dataset.longest
224-
assert isinstance(self.max_lengths, (list, tuple))
229+
assert isinstance(self.max_lengths, (list, tuple, Column))
225230

226231
self.global_batch_size = global_batch_size
227232

xtuner/v1/model/base.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,8 +1017,24 @@ def _get_same_hf_param(
10171017
buffer_name_list.append(load_spec.hf_keys[0])
10181018
continue
10191019
local_tensor = param._local_tensor if isinstance(param, DTensor) else param
1020-
local_tensor = local_tensor.to(dtype=self._get_save_dtype(load_spec.hf_keys[0], torch.bfloat16))
1021-
tensor_size = self._get_tensor_size(param, dtype)
1020+
if (
1021+
self.fsdp_config is not None
1022+
and self.fsdp_config.fp32_lm_head
1023+
and load_spec.hf_keys[0] == "lm_head.weight"
1024+
):
1025+
logger.info(f"handling same hf param: {load_spec.hf_keys} separately")
1026+
lm_head_tensor_list = self._fsdp_foreach_allgather([local_tensor], [load_spec])
1027+
lm_head_tensor_list = [
1028+
self.param_to_safetensor(safetensor, name)
1029+
for safetensor, name in zip(lm_head_tensor_list, load_spec.hf_keys.copy())
1030+
]
1031+
lm_head_tensor_list = [t.to(device=device) for t in lm_head_tensor_list]
1032+
yield load_spec.hf_keys.copy(), lm_head_tensor_list
1033+
del lm_head_tensor_list, local_tensor
1034+
continue
1035+
else:
1036+
local_tensor = local_tensor.to(dtype=self._get_save_dtype(load_spec.hf_keys[0], torch.bfloat16))
1037+
tensor_size = self._get_tensor_size(param, dtype)
10221038
if safetensor_size + tensor_size > bucket_size and tensor_list:
10231039
if self.fsdp_mesh is not None:
10241040
gathered_tensor_list = self._fsdp_foreach_allgather(tensor_list, load_spec_list)

xtuner/v1/ray/judger/controller.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ async def run(
260260
for name, weight in data_source.items():
261261
if name in item.reward:
262262
final_reward += item.reward[name]["score"] * weight
263-
item.reward["weighted_score"] = final_reward
263+
item.reward["score"] = final_reward
264264
else:
265265
judger_response_item = await self._call_custom_reward_judger(
266266
self.reward_judger, self.reward_judger_names, group_data_item

0 commit comments

Comments
 (0)