We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 9788d66 commit 48e7594Copy full SHA for 48e7594
ajet/backbone/main_verl.py
@@ -27,6 +27,7 @@
27
from torch.utils.data import Dataset as TorchDataset
28
29
# Create training and validation datasets.
30
+from ajet.backbone.warm_up import warm_up_process
31
from ajet.task_reader import RouterTaskReader, task_to_standard_dataset
32
from ajet.utils.process_dataset import create_rl_sampler
33
from ajet.utils.core_env_vars import get_runtime_env
@@ -116,6 +117,7 @@ def run(self, config):
116
117
from loguru import logger
118
from omegaconf import OmegaConf
119
from verl.utils.fs import copy_to_local
120
+ warm_up_process(config)
121
122
logger.info(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}")
123
pprint(OmegaConf.to_container(config, resolve=True))
0 commit comments