|
17 | 17 | """ |
18 | 18 |
|
19 | 19 | import datetime |
20 | | -import json |
21 | 20 | import logging |
22 | 21 | import os |
23 | | -import warnings |
24 | | -from dataclasses import asdict |
25 | | - |
26 | | -import psutil |
27 | | -import torch |
28 | | -import torch.distributed |
29 | | -import torch.distributed as dist |
30 | | -from codetiming import Timer |
| 22 | + |
31 | 23 | from omegaconf import DictConfig, OmegaConf, open_dict |
32 | | -from omegaconf.errors import ConfigAttributeError |
33 | | -from peft import LoraConfig, TaskType, get_peft_model |
34 | | -from safetensors.torch import save_file |
35 | 24 | from torch.distributed.device_mesh import init_device_mesh |
36 | | -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
37 | | -from torch.distributed.fsdp.api import FullStateDictConfig, ShardedStateDictConfig, StateDictType |
38 | 25 |
|
39 | 26 | try: |
40 | 27 | # for torch 2.5+ |
41 | | - from torch.distributed.tensor import DTensor |
| 28 | + pass |
42 | 29 | except ImportError: |
43 | | - from torch.distributed._tensor import DTensor |
| 30 | + pass |
44 | 31 |
|
45 | | -from verl import DataProto |
46 | | -from verl.models.transformers.monkey_patch import apply_monkey_patch |
47 | 32 | from verl.single_controller.base import Worker |
48 | | -from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register |
49 | | -from verl.utils import hf_processor, hf_tokenizer |
50 | | -from verl.utils.activation_offload import enable_activation_offloading |
| 33 | +from verl.single_controller.base.decorator import Dispatch, register |
51 | 34 | from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager |
52 | 35 | from verl.utils.config import omega_conf_to_dataclass |
53 | 36 | from verl.utils.device import ( |
54 | | - get_device_id, |
55 | 37 | get_device_name, |
56 | 38 | get_nccl_backend, |
57 | | - get_torch_device, |
58 | | - set_expandable_segments, |
59 | 39 | ) |
60 | 40 | from verl.utils.flops_counter import FlopsCounter |
61 | 41 | from verl.utils.fs import copy_to_local |
62 | 42 | from verl.utils.fsdp_utils import ( |
63 | | - CPUOffloadPolicy, |
64 | | - MixedPrecisionPolicy, |
65 | | - apply_fsdp2, |
66 | | - collect_lora_params, |
67 | | - fsdp2_load_full_state_dict, |
68 | 43 | fsdp_version, |
69 | | - get_fsdp_wrap_policy, |
70 | | - get_init_weight_context_manager, |
71 | | - get_shard_placement_fn, |
72 | | - init_fn, |
73 | | - layered_summon_lora_params, |
74 | | - load_fsdp_model_to_gpu, |
75 | | - load_fsdp_optimizer, |
76 | 44 | offload_fsdp_model_to_cpu, |
77 | 45 | offload_fsdp_optimizer, |
78 | | - replace_lora_wrapper, |
79 | 46 | ) |
80 | 47 | from verl.utils.import_utils import import_external_libs |
81 | 48 | from verl.utils.memory_utils import aggressive_empty_cache |
82 | | -from verl.utils.model import convert_weight_keys |
83 | | -from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage, simple_timer |
84 | | -from verl.utils.profiler.performance import reduce_timing, topk_reduce_ratio_min_max |
85 | | -from verl.utils.py_functional import convert_to_regular_types |
| 49 | +from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage |
86 | 50 |
|
87 | 51 | # QAT support |
88 | | -from verl.utils.qat import apply_qat, enable_qat_fuse |
89 | | -from verl.utils.ray_utils import get_event_loop |
90 | | -from verl.utils.transformers_compat import get_auto_model_for_vision2seq |
91 | | -from verl.workers.config import FSDPCriticConfig, FSDPEngineConfig, HFModelConfig, RolloutConfig |
92 | | -from verl.workers.config.optimizer import build_optimizer |
93 | | -from verl.workers.rollout import get_rollout_class |
| 52 | +from verl.workers.config import FSDPEngineConfig |
94 | 53 | from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager |
95 | 54 | from verl.workers.fsdp_workers import ActorRolloutRefWorker |
96 | 55 |
|
|
0 commit comments