Skip to content

Commit bf4f8a0

Browse files
authored
Add deterministic and random seed support in RLColocateTrainer (#1613)
* add deterministic and random seed support in RLColocateTrainer * fix test_producer.py mock for get_rollout_metadata
1 parent 65f1d77 commit bf4f8a0

File tree

7 files changed

+25
-4
lines changed

7 files changed

+25
-4
lines changed

tests/rl/test_producer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ async def test_sync_produce_strategy(self):
4949
mock_agent_loop = MagicMock()
5050
mock_agent_loop.rollout_ctl.continue_generation.remote = AsyncMock(return_value=None)
5151
mock_agent_loop.rollout_ctl.pause_generation.remote = AsyncMock(return_value=None)
52+
mock_agent_loop.rollout_ctl.get_rollout_metadata.remote = AsyncMock(return_value={"server_url_dict": {}})
5253

5354
async def mock_gen(rs):
5455
await asyncio.sleep(0.01 * rs[0].id)

xtuner/v1/rl/agent_loop/agent_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414

1515
class AgentLoopConfig(ABC, BaseModel):
16-
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) # TODO: extra="forbid"
16+
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
1717
hf_checkpoint: str
1818
sample_params: SampleParams
1919

xtuner/v1/rl/agent_loop/gsm8k_with_tool.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import re
44
from typing import cast
55

6-
from pydantic import BaseModel
6+
from pydantic import BaseModel, ConfigDict
77

88
from xtuner.v1.data_proto import RolloutState, SampleParams
99
from xtuner.v1.rl.agent_loop import AgentLoop, AgentLoopConfig
@@ -28,6 +28,8 @@ def build(self, rollout_controller, judger=None, logger=None) -> "GSM8KToolAgent
2828

2929

3030
class FunctionCall(BaseModel):
31+
model_config = ConfigDict(extra="forbid")
32+
3133
name: str
3234
arguments: dict
3335

xtuner/v1/rl/replay_buffer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import pandas as pd
99
import torch
10-
from pydantic import BaseModel
10+
from pydantic import BaseModel, ConfigDict
1111

1212
from xtuner.v1.data_proto.rl_data import RolloutState, Status, update_group_status
1313
from xtuner.v1.rl.utils import (
@@ -410,11 +410,15 @@ async def resume(self, path: str | Path) -> None:
410410

411411

412412
class SyncReplayBufferConfig(BaseModel):
413+
model_config = ConfigDict(extra="forbid")
414+
413415
def build(self):
414416
return ReplayBuffer(policy=FIFOReplayPolicy(), storage_backend=NaiveStorage())
415417

416418

417419
class AsyncReplayBufferConfig(BaseModel):
420+
model_config = ConfigDict(extra="forbid")
421+
418422
def build(self):
419423
policy = StalenessReplayPolicy()
420424
return ReplayBuffer(policy=policy, storage_backend=NaiveStorage())

xtuner/v1/train/rl_colocate_trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import ray
99
import torch
1010
from mmengine.dist import get_rank
11+
from mmengine.runner import set_random_seed
1112
from pydantic import BaseModel, ConfigDict
1213
from typing_extensions import Literal, TypedDict
1314

@@ -26,7 +27,7 @@
2627
from xtuner.v1.rl.trainer.worker import WorkerConfig, WorkerLogItem
2728
from xtuner.v1.rl.utils import AcceleratorResourcesConfig, AutoAcceleratorWorkers, asyncio_run
2829
from xtuner.v1.train.trainer import LoadCheckpointConfig, XTunerMeta
29-
from xtuner.v1.utils import get_logger, is_hf_model_path, timer
30+
from xtuner.v1.utils import get_logger, is_hf_model_path, set_deterministic, timer
3031
from xtuner.v1.utils.device import get_device, get_torch_device_module
3132

3233

@@ -283,6 +284,9 @@ def __init__(
283284
# self._total_epochs = total_epochs # TODO
284285
self._cur_step = 0
285286
self._global_train_step = 0
287+
self._seed = seed
288+
set_deterministic()
289+
set_random_seed(seed)
286290
self.global_batch_size = global_batch_size
287291

288292
# main components

xtuner/v1/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
get_padding_length,
1818
is_hf_model_path,
1919
record_git_info,
20+
set_deterministic,
2021
)
2122
from .pad import pad_to_max_length, pad_to_multiple_of
2223
from .profile import profile_time, profile_time_and_memory, timer, timer_logger
@@ -62,4 +63,5 @@
6263
"clean_param_name",
6364
"CacheDict",
6465
"CacheObj",
66+
"set_deterministic",
6567
]

xtuner/v1/utils/misc.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from types import FunctionType
1010
from typing import Annotated
1111

12+
import torch
1213
from huggingface_hub import constants
1314
from mmengine import is_installed
1415

@@ -24,6 +25,13 @@
2425
logger = get_logger()
2526
XTUNER_DETERMINISTIC = os.getenv("XTUNER_DETERMINISTIC") == "true"
2627

28+
29+
def set_deterministic():
30+
if XTUNER_DETERMINISTIC:
31+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
32+
torch.use_deterministic_algorithms(True, warn_only=True)
33+
34+
2735
# https://github.com/python/cpython/issues/82300#issuecomment-2169035092
2836
if sys.version_info >= (3, 13):
2937
SharedMemory = _mpshm.SharedMemory

0 commit comments

Comments
 (0)