Skip to content

Commit 5b79394

Browse files
committed
pre commit fixes
1 parent 1d047f0 commit 5b79394

File tree

19 files changed

+85
-147
lines changed

19 files changed

+85
-147
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
## ✈️ News
2020

21-
- 2026.3.26 Upgrade verl backend to 0.7.1 to support more models and increase training speed!
21+
- 2026.3.26 Upgrade verl backend to 0.7.1 to support more models and increase training speed! All [benchmark](https://benchmark.agentjet.top/) verified.
2222
- 2026.3.19 Support for latest Qwen3.5 models is [in progress](https://github.com/modelscope/AgentJet/pull/16).
2323
- 2026.3.12 Tuning Original OpenClaw Agent without Editing Any Agent Code. [EN Blog](https://modelscope.github.io/AgentJet/en/example_openclaw/) / [ZH Blog](https://modelscope.github.io/AgentJet/en/example_openclaw.zh/).
2424
- 2026.3.09 Non-shared-parameter Multiagent Training. [EN Blog](https://modelscope.github.io/AgentJet/en/example_train_multi_model/) / [ZH Blog](http://modelscope.github.io/AgentJet/en/example_train_multi_model.zh/).

ajet/backbone/main_verl.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import hydra
2323
import ray
2424
from omegaconf import DictConfig, OmegaConf
25-
from verl.trainer.ppo.reward import load_reward_manager
2625
from verl.utils.device import is_cuda_available
2726
from verl.utils.dataset.rl_dataset import collate_fn
2827
from torch.utils.data import Dataset as TorchDataset
@@ -156,7 +155,6 @@ def run(self, config):
156155
from verl.workers.megatron_workers import (
157156
ActorRolloutRefWorker,
158157
AjetAsyncActorRolloutRefWorker,
159-
CriticWorker,
160158
)
161159

162160
actor_rollout_cls = AjetAsyncActorRolloutRefWorker

ajet/backbone/trainer_verl.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,10 @@
2323
import torch
2424
from beast_logger import print_dict
2525
from loguru import logger
26-
from omegaconf import OmegaConf
2726
from tqdm import tqdm
2827
from verl import DataProto
2928
from verl.experimental.dataset.sampler import AbstractCurriculumSampler
30-
from verl.experimental.agent_loop.agent_loop import AsyncLLMServerManager, AgentLoopWorker
31-
from verl.single_controller.ray import RayClassWithInitArgs
32-
from verl.single_controller.ray.base import create_colocated_worker_cls
29+
from verl.experimental.agent_loop.agent_loop import AsyncLLMServerManager
3330
from verl.trainer.config import AlgoConfig
3431
from verl.trainer.ppo import core_algos
3532
from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss
@@ -40,7 +37,6 @@
4037
)
4138
from verl.trainer.ppo.ray_trainer import (
4239
RayPPOTrainer,
43-
Role,
4440
apply_kl_penalty,
4541
compute_response_mask,
4642
)

ajet/backbone/verl/actor_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from verl.workers.config import FSDPActorConfig
2-
from dataclasses import dataclass, field
2+
from dataclasses import dataclass
33

44

55
@dataclass

ajet/backbone/verl/dp_actor.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,10 @@
2222
import os
2323

2424
import torch
25-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
26-
from torch.distributed.tensor import DTensor
2725

28-
import verl.utils.torch_functional as verl_F
2926
from verl import DataProto
3027
from verl.trainer.ppo.core_algos import agg_loss, get_policy_loss_fn, kl_penalty
3128
from verl.utils.device import get_device_id
32-
from verl.utils.fsdp_utils import FSDPModule, fsdp2_clip_grad_norm_
3329
from verl.utils.profiler import GPUMemoryLogger
3430
from verl.utils.py_functional import append_to_dict
3531
# ajet/backbone/verl/seqlen_balancing.py

ajet/backbone/verl/fsdp_workers.py

Lines changed: 6 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -17,80 +17,39 @@
1717
"""
1818

1919
import datetime
20-
import json
2120
import logging
2221
import os
23-
import warnings
24-
from dataclasses import asdict
25-
26-
import psutil
27-
import torch
28-
import torch.distributed
29-
import torch.distributed as dist
30-
from codetiming import Timer
22+
3123
from omegaconf import DictConfig, OmegaConf, open_dict
32-
from omegaconf.errors import ConfigAttributeError
33-
from peft import LoraConfig, TaskType, get_peft_model
34-
from safetensors.torch import save_file
3524
from torch.distributed.device_mesh import init_device_mesh
36-
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
37-
from torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType
3825

3926
try:
4027
# for torch 2.5+
41-
from torch.distributed.tensor import DTensor
28+
pass
4229
except ImportError:
43-
from torch.distributed._tensor import DTensor
30+
pass
4431

45-
from verl import DataProto
46-
from verl.models.transformers.monkey_patch import apply_monkey_patch
4732
from verl.single_controller.base import Worker
48-
from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register
49-
from verl.utils import hf_processor, hf_tokenizer
50-
from verl.utils.activation_offload import enable_activation_offloading
33+
from verl.single_controller.base.decorator import Dispatch, register
5134
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
5235
from verl.utils.config import omega_conf_to_dataclass
5336
from verl.utils.device import (
54-
get_device_id,
5537
get_device_name,
5638
get_nccl_backend,
57-
get_torch_device,
58-
set_expandable_segments,
5939
)
6040
from verl.utils.flops_counter import FlopsCounter
6141
from verl.utils.fs import copy_to_local
6242
from verl.utils.fsdp_utils import (
63-
CPUOffloadPolicy,
64-
MixedPrecisionPolicy,
65-
apply_fsdp2,
66-
collect_lora_params,
67-
fsdp2_load_full_state_dict,
6843
fsdp_version,
69-
get_fsdp_wrap_policy,
70-
get_init_weight_context_manager,
71-
get_shard_placement_fn,
72-
init_fn,
73-
layered_summon_lora_params,
74-
load_fsdp_model_to_gpu,
75-
load_fsdp_optimizer,
7644
offload_fsdp_model_to_cpu,
7745
offload_fsdp_optimizer,
78-
replace_lora_wrapper,
7946
)
8047
from verl.utils.import_utils import import_external_libs
8148
from verl.utils.memory_utils import aggressive_empty_cache
82-
from verl.utils.model import convert_weight_keys
83-
from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage, simple_timer
84-
from verl.utils.profiler.performance import reduce_timing, topk_reduce_ratio_min_max
85-
from verl.utils.py_functional import convert_to_regular_types
49+
from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage
8650

8751
# QAT support
88-
from verl.utils.qat import apply_qat, enable_qat_fuse
89-
from verl.utils.ray_utils import get_event_loop
90-
from verl.utils.transformers_compat import get_auto_model_for_vision2seq
91-
from verl.workers.config import FSDPCriticConfig, FSDPEngineConfig, HFModelConfig, RolloutConfig
92-
from verl.workers.config.optimizer import build_optimizer
93-
from verl.workers.rollout import get_rollout_class
52+
from verl.workers.config import FSDPEngineConfig
9453
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
9554
from verl.workers.fsdp_workers import ActorRolloutRefWorker
9655

ajet/default_config/verl/config_schema_rollout.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from verl.workers.config.rollout import MultiTurnConfig
2-
from dataclasses import dataclass, field
1+
from dataclasses import dataclass
32
from typing import Optional
43
from verl.base_config import BaseConfig
54

@@ -23,5 +22,3 @@ class AjetMultiTurnConfig(BaseConfig):
2322
tokenization_sanity_check_mode: str = "strict"
2423
format: str = "hermes"
2524
num_repeat_rollouts: Optional[int] = None
26-
27-

ajet/task_rollout/async_llm_bridge.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
1313
except:
1414
from vllm.tool_parsers.hermes_tool_parser import Hermes2ProToolParser # vllm 0.17.x moved this class elsewhere
15-
from vllm.outputs import RequestOutput as VerlVllmRequestOutput
1615
from verl.workers.rollout.replica import TokenOutput
1716
from agentscope.model import ChatResponse as AgentScopeChatResponse
1817
from openai.types.chat.chat_completion import ChatCompletion as OpenAIChatCompletion

ajet/task_runner/base_runner.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import asyncio
2-
import gc
32
from threading import Lock
43
from typing import Any, Callable, Union, Type
54
from multiprocessing import Process, Queue

docs/en/example_vibe_rl_who_is_spy.zh.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,4 +136,4 @@ task.task_id 有严重的问题,task_id应该是每个episode的随机数种
136136

137137
去SwanLab看看,不错,奖励平稳上升。
138138

139-
![alt text](https://img.alicdn.com/imgextra/i2/O1CN01qFvfeU20XTkCW2H89_!!6000000006859-2-tps-1994-522.png)
139+
![alt text](https://img.alicdn.com/imgextra/i2/O1CN01qFvfeU20XTkCW2H89_!!6000000006859-2-tps-1994-522.png)

0 commit comments

Comments
 (0)