-
Notifications
You must be signed in to change notification settings - Fork 826
Expand file tree
/
Copy path__init__.py
More file actions
101 lines (85 loc) · 2.7 KB
/
__init__.py
File metadata and controls
101 lines (85 loc) · 2.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
from dotenv import load_dotenv
load_dotenv()
if os.getenv("SUPPRESS_LITELLM_SERIALIZATION_WARNINGS", "1") == "1":
from art.utils.suppress_litellm_serialization_warnings import (
suppress_litellm_serialization_warnings,
)
suppress_litellm_serialization_warnings()
# Create a dummy GuidedDecodingParams class and inject it into vllm.sampling_params for trl compatibility
try:
import vllm.sampling_params
class GuidedDecodingParams:
"""Shim for vLLM 0.13+ where GuidedDecodingParams was removed."""
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
vllm.sampling_params.GuidedDecodingParams = GuidedDecodingParams # type: ignore
except ImportError:
pass # vllm not installed
# torch.cuda.MemPool doesn't currently support expandable_segments which is used in sleep mode
conf = os.getenv("PYTORCH_CUDA_ALLOC_CONF", "").split(",")
if "expandable_segments:True" in conf:
conf.remove("expandable_segments:True")
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = ",".join(conf)
# Import unsloth before transformers, peft, and trl to maximize Unsloth optimizations
if os.environ.get("IMPORT_UNSLOTH", "0") == "1":
import unsloth # noqa: F401
try:
import transformers
try:
from .transformers.patches import (
patch_apply_chat_template,
patch_preprocess_mask_arguments,
)
patch_preprocess_mask_arguments()
patch_apply_chat_template()
except Exception:
pass
except ImportError:
pass
from . import dev
from .auto_trajectory import auto_trajectory, capture_auto_trajectory
from .backend import Backend
from .batches import trajectory_group_batches
from .gather import gather_trajectories, gather_trajectory_groups
from .model import Model, TrainableModel
from .serverless import ServerlessBackend
from .trajectories import Trajectory, TrajectoryGroup
from .types import (
LocalTrainResult,
Messages,
MessagesAndChoices,
ServerlessTrainResult,
Tools,
TrainConfig,
TrainResult,
TrainSFTConfig,
)
from .utils import retry
from .yield_trajectory import capture_yielded_trajectory, yield_trajectory
__all__ = [
"dev",
"auto_trajectory",
"capture_auto_trajectory",
"gather_trajectories",
"gather_trajectory_groups",
"trajectory_group_batches",
"Backend",
"LocalTrainResult",
"ServerlessBackend",
"ServerlessTrainResult",
"Messages",
"MessagesAndChoices",
"Tools",
"Model",
"TrainableModel",
"retry",
"TrainSFTConfig",
"TrainConfig",
"TrainResult",
"Trajectory",
"TrajectoryGroup",
"capture_yielded_trajectory",
"yield_trajectory",
]