-
Notifications
You must be signed in to change notification settings - Fork 833
Expand file tree
/
Copy pathbackend.py
More file actions
1381 lines (1235 loc) · 55 KB
/
backend.py
File metadata and controls
1381 lines (1235 loc) · 55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import asyncio
import json
import logging
import math
import os
import shutil
import socket
import subprocess
import time
from types import TracebackType
from typing import AsyncIterator, Iterable, Literal, cast
import warnings
logger = logging.getLogger(__name__)
_AUTO_GPU_HOURLY_PRICING_USD = {
"H200": 3.0,
}
import aiohttp
import numpy as np
from openai import AsyncOpenAI
import polars as pl
import torch
from tqdm import auto as tqdm
from transformers import AutoImageProcessor, AutoTokenizer
from transformers.image_processing_utils import BaseImageProcessor
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from typing_extensions import Self
from art.utils.output_dirs import (
get_default_art_path,
get_model_dir,
get_output_dir_from_model_properties,
get_step_checkpoint_dir,
)
from art.utils.record_provenance import record_provenance
from art.utils.s3 import (
ExcludableOption,
pull_model_from_s3,
push_model_to_s3,
)
from mp_actors import close_proxy, move_to_child_process
from .. import dev
from .._backend_training import (
aggregate_rl_training_metrics,
build_rl_train_configs,
)
from ..backend import AnyTrainableModel, Backend
from ..costs import build_cost_calculator, get_model_pricing
from ..metrics_taxonomy import (
TRAIN_GRADIENT_STEPS_KEY,
build_training_summary_metrics,
summarize_trajectory_groups,
)
from ..model import Model, TrainableModel
from ..preprocessing.pack import (
PackedTensors,
packed_tensors_from_tokenized_results,
packed_tensors_to_dir,
plot_packed_tensors,
)
from ..preprocessing.tokenize import (
tokenize_sft_batch,
tokenize_trajectory_groups,
)
from ..trajectories import Trajectory, TrajectoryGroup
from ..types import LocalTrainResult, Message, TrainConfig, TrainSFTConfig
from ..utils import format_message, get_model_step
from .checkpoints import (
delete_checkpoints,
)
from .service import ModelService
class LocalBackend(Backend):
def __init__(
self,
*,
in_process: bool = False,
path: str | None = None,
gpu_cost_per_hour_usd: float | None = None,
) -> None:
"""
Initializes a local, directory-based Backend interface at the given path.
Note:
The local Backend uses Weights & Biases for training monitoring.
If you don't have a W&B account, you can create one at https://wandb.ai.
Args:
in_process: Whether to run the local service in-process.
path: The path to the local directory. Defaults to "{repo_root}/.art".
gpu_cost_per_hour_usd: Optional per-GPU hourly price override used for
automatic `costs/gpu` accounting on train steps. When unset,
ART auto-detects supported GPU types (H200 at $3/hr today) and
skips GPU cost logging for unknown devices instead of guessing.
"""
self._in_process = in_process
self._path = path or get_default_art_path()
self._gpu_cost_per_hour_usd = (
float(gpu_cost_per_hour_usd) if gpu_cost_per_hour_usd is not None else None
)
os.makedirs(self._path, exist_ok=True)
# Other initialization
self._services: dict[str, ModelService] = {}
self._tokenizers: dict[str, PreTrainedTokenizerBase] = {}
self._image_processors: dict[str, BaseImageProcessor | None] = {}
def supports_automatic_train_step_metrics(self) -> bool:
return True
def automatic_gpu_cost_per_hour_usd(self, model: Model) -> float | None:
per_gpu_cost = self._resolve_gpu_cost_per_hour_usd()
if per_gpu_cost is None:
return None
gpu_count = self._allocated_gpu_count(model)
if gpu_count <= 0:
return None
return per_gpu_cost * gpu_count
def _resolve_gpu_cost_per_hour_usd(self) -> float | None:
if self._gpu_cost_per_hour_usd is not None:
return self._gpu_cost_per_hour_usd
if not torch.cuda.is_available():
return None
num_visible_gpus = torch.cuda.device_count()
if num_visible_gpus <= 0:
return None
resolved_costs: list[float] = []
for index in range(num_visible_gpus):
device_name = torch.cuda.get_device_name(index).upper()
for gpu_name, hourly_cost in _AUTO_GPU_HOURLY_PRICING_USD.items():
if gpu_name in device_name:
resolved_costs.append(hourly_cost)
break
else:
return None
if not resolved_costs:
return None
if len(set(resolved_costs)) != 1:
return None
return resolved_costs[0]
def _allocated_gpu_count(self, model: Model) -> int:
if isinstance(model, TrainableModel) and model._internal_config is not None:
trainer_gpu_ids = set(model._internal_config.get("trainer_gpu_ids", []))
inference_gpu_ids = set(model._internal_config.get("inference_gpu_ids", []))
allocated_gpu_ids = trainer_gpu_ids | inference_gpu_ids
if allocated_gpu_ids:
return len(allocated_gpu_ids)
if not torch.cuda.is_available():
return 0
return torch.cuda.device_count()
def __enter__(self) -> Self:
return self
async def __aenter__(self) -> Self:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
tb: TracebackType | None,
) -> None:
self._close()
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
tb: TracebackType | None,
) -> None:
await self.close()
async def close(self) -> None:
"""
If running vLLM in a separate process, this will kill that process and close the communication threads.
"""
for service in self._services.values():
aclose = getattr(service, "aclose", None)
if aclose is None:
close = getattr(service, "close", None)
if close is not None:
close()
else:
await aclose()
close_proxy(service)
def _close(self) -> None:
for service in self._services.values():
close = getattr(service, "close", None)
if close is not None:
close()
close_proxy(service)
async def register(
self,
model: Model,
) -> None:
"""
Registers a model with the local Backend for logging and/or training.
Args:
model: An art.Model instance.
"""
# Ensure model state/logging uses the backend path
model.base_path = self._path
output_dir = get_model_dir(model=model, art_path=self._path)
os.makedirs(output_dir, exist_ok=True)
with open(f"{output_dir}/model.json", "w") as f:
json.dump(model.model_dump(), f)
# Auto-migrate any old JSONL trajectory files to Parquet
from art.utils.trajectory_migration import auto_migrate_on_register
auto_migrate_on_register(output_dir)
# Initialize wandb early if this is a trainable model
# (wandb initialization is now handled by the model's _get_wandb_run method)
if model.trainable and "WANDB_API_KEY" in os.environ:
_ = model._get_wandb_run()
if model.trainable:
trainable_model = cast(TrainableModel, model)
pricing = get_model_pricing(trainable_model.base_model)
if pricing is not None:
trainable_model.set_cost_calculator(build_cost_calculator(pricing))
def _model_inference_name(self, model: Model, step: int | None = None) -> str:
"""Return the inference name for a model checkpoint.
For LocalBackend with vLLM, the base model is served under its HF name,
and LoRA adapters are served as `model.name@step`.
Args:
model: The model.
step: If provided, returns name for specific checkpoint.
If None, returns name for latest checkpoint (step 0 initially).
"""
requested_step = step
if step is None and isinstance(model, TrainableModel):
from ..dev.validate import is_dedicated_mode
service = self._services.get(model.name)
if service is not None and is_dedicated_mode(
model._internal_config or dev.InternalModelConfig()
):
loaded_step = getattr(service, "_latest_step", None)
if isinstance(loaded_step, int):
step = loaded_step
if step is None:
# The checkpoint directory is written before dedicated-mode
# vLLM finishes reloading the new adapter.
step = self.__get_step(model)
name = f"{model.name}@{step}"
logger.debug(
f"[BACKEND] _model_inference_name: step_arg={requested_step} "
f"actual_step={step} -> {name}"
)
return name
async def _get_service(self, model: TrainableModel) -> ModelService:
from ..dev.get_model_config import get_model_config
from ..dev.validate import is_dedicated_mode, validate_dedicated_config
if model.name not in self._services:
config = get_model_config(
base_model=model.base_model,
output_dir=get_model_dir(model=model, art_path=self._path),
config=model._internal_config,
)
validate_dedicated_config(config)
dedicated = is_dedicated_mode(config)
is_tinker = config.get("tinker_args") is not None
if is_tinker:
from ..tinker.service import TinkerService
service_class = TinkerService
else:
from ..unsloth.service import UnslothService
service_class = UnslothService
# When moving the service to a child process, import unsloth
# early to maximize optimizations
os.environ["IMPORT_UNSLOTH"] = "1"
if dedicated:
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
str(g) for g in config["trainer_gpu_ids"]
)
self._services[model.name] = service_class(
model_name=model.name,
base_model=model.base_model,
config=config,
output_dir=get_model_dir(model=model, art_path=self._path),
)
if not dedicated and not self._in_process:
# Kill all "model-service" processes to free up GPU memory
subprocess.run(["pkill", "-9", "model-service"])
self._services[model.name] = move_to_child_process(
self._services[model.name],
process_name="tinker-service" if is_tinker else "model-service",
)
return self._services[model.name]
def _get_packed_tensors(
self,
model: AnyTrainableModel,
trajectory_groups: list[TrajectoryGroup],
advantage_balance: float,
allow_training_without_logprobs: bool,
scale_rewards: bool,
plot_tensors: bool,
) -> PackedTensors | None:
if model.base_model not in self._tokenizers:
self._tokenizers[model.base_model] = AutoTokenizer.from_pretrained(
model.base_model
)
if model.base_model not in self._image_processors:
try:
self._image_processors[model.base_model] = (
AutoImageProcessor.from_pretrained(model.base_model, use_fast=True)
)
except Exception:
self._image_processors[model.base_model] = None
tokenizer = self._tokenizers[model.base_model]
tokenized_results = list(
tokenize_trajectory_groups(
tokenizer,
trajectory_groups,
allow_training_without_logprobs,
scale_rewards,
image_processor=self._image_processors[model.base_model],
)
)
if not tokenized_results:
return None
max_tokens = max(len(result.token_ids) for result in tokenized_results)
# Round up max_tokens to the nearest multiple of 2048
sequence_length = math.ceil(max_tokens / 2048) * 2048
# Cap sequence length at the model's max sequence length
sequence_length = min(
sequence_length,
(model._internal_config or dev.InternalModelConfig())
.get("init_args", {})
.get("max_seq_length", 32_768),
)
packed_tensors = packed_tensors_from_tokenized_results(
tokenized_results,
sequence_length,
pad_token_id=tokenizer.eos_token_id,
advantage_balance=advantage_balance,
)
if (
not allow_training_without_logprobs
and np.isnan(packed_tensors["logprobs"]).all()
):
print(
"There are no assistant logprobs to train on. Did you forget to include at least one Choice in Trajectory.messages_and_choices?"
)
return None
if plot_tensors:
plot_packed_tensors(
packed_tensors, get_model_dir(model=model, art_path=self._path)
)
else:
print(
f"Packed {len(tokenized_results)} trajectories into {packed_tensors['tokens'].shape[0]} sequences of length {packed_tensors['tokens'].shape[1]}"
)
return packed_tensors
async def _get_step(self, model: AnyTrainableModel) -> int:
return self.__get_step(model)
def __get_step(self, model: Model) -> int:
if model.trainable:
model = cast(TrainableModel, model)
return get_model_step(model, self._path)
# Non-trainable models do not have checkpoints/steps; default to 0
return 0
async def _delete_checkpoint_files(
self,
model: AnyTrainableModel,
steps_to_keep: list[int],
) -> None:
"""Delete checkpoint files, keeping only the specified steps."""
output_dir = get_model_dir(model=model, art_path=self._path)
service = await self._get_service(model)
try:
from ..tinker.service import TinkerService
if isinstance(service, TinkerService):
await service.delete_checkpoints(steps_to_keep)
return
except ImportError:
pass
delete_checkpoints(output_dir, steps_to_keep)
async def _prepare_backend_for_training(
self,
model: AnyTrainableModel,
config: dev.OpenAIServerConfig | None = None,
) -> tuple[str, str]:
config_dict: dict = dict(config or {})
server_args = dict(config_dict.get("server_args", {}))
# Avoid binding collisions on busy hosts when no explicit port is provided.
if "port" not in server_args:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
server_args["port"] = s.getsockname()[1]
# Ensure the server and client share the same API key.
# If the caller did not supply one, generate a secure random key
# so the vLLM server is never exposed with a well-known credential.
if not server_args.get("api_key"):
server_args["api_key"] = dev._generate_api_key()
config_dict["server_args"] = server_args
resolved_config = cast(dev.OpenAIServerConfig, config_dict)
service = await self._get_service(model)
host, port = await service.start_openai_server(config=resolved_config)
base_url = f"http://{host}:{port}/v1"
api_key = server_args["api_key"]
def done_callback(_: asyncio.Task[None]) -> None:
close_proxy(self._services.pop(model.name))
asyncio.create_task(
self._monitor_openai_server(model, base_url, api_key)
).add_done_callback(done_callback)
return base_url, api_key
async def _monitor_openai_server(
self, model: AnyTrainableModel, base_url: str, api_key: str
) -> None:
model_name = model.name
openai_client = AsyncOpenAI(
base_url=base_url,
api_key=api_key,
)
consecutive_failures = 0
max_consecutive_failures = 3
async with aiohttp.ClientSession() as session:
while True:
# Wait 30 seconds before checking again
await asyncio.sleep(30)
try:
# If the server is sleeping, skip the check
if await self._services[model_name].vllm_engine_is_sleeping():
consecutive_failures = 0
continue
# Check the metrics with a timeout
async with session.get(
f"{base_url.split('/v1')[0]}/metrics",
timeout=aiohttp.ClientTimeout(total=10),
) as response:
metrics = await response.text()
# Parse Prometheus metrics for running requests
running_requests = 0
pending_requests = 0
for line in metrics.split("\n"):
if line.startswith("vllm:num_requests_running"):
running_requests = int(float(line.split()[1]))
elif line.startswith("vllm:num_requests_waiting"):
pending_requests = int(float(line.split()[1]))
# If there are no running or pending requests, send a health check
if running_requests == 0 and pending_requests == 0:
try:
# Send a health check with a short timeout
await openai_client.completions.create(
model=self._model_inference_name(model),
prompt="Hi",
max_tokens=1,
timeout=float(
os.environ.get("ART_SERVER_MONITOR_TIMEOUT", 5.0)
),
)
except Exception as e:
# If the server is sleeping, a failed health check is okay
if await self._services[
model_name
].vllm_engine_is_sleeping():
consecutive_failures = 0
continue
raise e
# Reset failure counter on success
consecutive_failures = 0
except Exception:
# If the server is sleeping during an exception, it's okay
try:
if await self._services[model_name].vllm_engine_is_sleeping():
consecutive_failures = 0
continue
except Exception:
pass # If we can't check sleeping status, count it as a failure
consecutive_failures += 1
if consecutive_failures >= max_consecutive_failures:
raise
# Otherwise, continue and try again
# Note: _log() method has been moved to the Model class (frontend)
def _trajectory_log(self, trajectory: Trajectory) -> str:
"""Format a trajectory into a readable log string."""
header = f"reward: {trajectory.reward} {' '.join(f'{k}: {v}' for k, v in trajectory.metrics.items())}\n\n"
formatted_messages = []
for message_or_choice in trajectory.messages_and_choices:
if isinstance(message_or_choice, dict):
message = message_or_choice
else:
message = cast(Message, message_or_choice.message.model_dump()) # ty:ignore[possibly-missing-attribute]
formatted_messages.append(format_message(message))
return header + "\n".join(formatted_messages)
async def train( # type: ignore[override]
self,
model: AnyTrainableModel,
trajectory_groups: Iterable[TrajectoryGroup],
*,
# Core training parameters
learning_rate: float = 5e-6,
loss_fn: Literal["cispo", "ppo"] = "cispo",
loss_fn_config: dict | None = None,
normalize_advantages: bool = True,
adam_params: object | None = None,
# KL-penalized advantage adjustment
kl_penalty_coef: float = 0.0,
kl_penalty_reference_step: int | None = None,
kl_ref_adapter_path: str | None = None,
epsilon: float | None = None,
epsilon_high: float | None = None,
# Advantage computation
advantage_balance: float = 0.0,
scale_rewards: bool = True,
# Importance sampling
importance_sampling_level: Literal[
"token", "sequence", "average", "geometric_average"
] = "token",
max_negative_advantage_importance_sampling_weight: float | None = None,
mask_prob_ratio: bool = False,
# Experimental parameters
kimi_k2_tau: float | None = None,
precalculate_logprobs: bool = False,
# LocalBackend-specific parameters
allow_training_without_logprobs: bool = False,
plot_tensors: bool = False,
truncated_importance_sampling: float | None = None,
scale_learning_rate_by_reward_std_dev: bool = False,
logprob_calculation_chunk_size: int = 1024,
num_trajectories_learning_rate_multiplier_power: float = 0.0,
# Checkpoint behavior
save_checkpoint: bool = True,
# Verbosity
verbose: bool = False,
) -> LocalTrainResult:
"""Train the model on the given trajectory groups.
This is the recommended way to train models. Unlike model.train(), this
method does NOT automatically log trajectories or metrics. Call model.log()
explicitly before and/or after training if you want to log data.
Args:
model: The trainable model to train.
trajectory_groups: Batches of trajectories to train on.
learning_rate: Learning rate for training. Defaults to 5e-6.
loss_fn: RL loss function. LocalBackend currently supports
"cispo" and "ppo".
loss_fn_config: Additional loss-function config. Not supported by
LocalBackend.
normalize_advantages: Whether to normalize advantages. LocalBackend
currently requires True.
adam_params: Custom optimizer params. Not supported by
LocalBackend.
kl_penalty_coef: Coefficient for KL-penalized advantage adjustment.
Tokens diverging more from the reference get reduced advantages.
Defaults to 0.0 (disabled).
kl_penalty_reference_step: Checkpoint step of the training model to
use as the KL reference. If None, uses the base model (LoRA
disabled) as reference.
kl_ref_adapter_path: Direct filesystem path to a LoRA adapter
checkpoint to use as the KL reference. Alternative to
kl_penalty_reference_step.
epsilon: Clip epsilon for importance sampling. Defaults based on loss_fn.
epsilon_high: Asymmetric upper clip bound. Defaults to epsilon.
advantage_balance: Balance between negative and positive advantages
in range [-1.0, 1.0]. Defaults to 0.0 (balanced).
scale_rewards: Whether to scale rewards by standard deviation.
Defaults to True.
importance_sampling_level: Level at which to compute importance
sampling weights. Defaults to "token".
max_negative_advantage_importance_sampling_weight: Maximum weight
for negative advantage samples.
mask_prob_ratio: Whether to mask probability ratios. Defaults to False.
kimi_k2_tau: Tau parameter for Kimi K2 algorithm.
precalculate_logprobs: Whether to precalculate logprobs.
allow_training_without_logprobs: Allow training even when no logprobs
are available. Defaults to False.
plot_tensors: Whether to plot training tensors for debugging.
Defaults to False.
truncated_importance_sampling: Truncation threshold for importance
sampling weights.
scale_learning_rate_by_reward_std_dev: Whether to scale learning rate
by reward standard deviation. Defaults to False.
logprob_calculation_chunk_size: Chunk size for logprob calculation.
Defaults to 1024.
num_trajectories_learning_rate_multiplier_power: Power for learning
rate multiplier based on number of trajectories.
save_checkpoint: Whether to save a checkpoint after training.
Defaults to True.
verbose: Whether to print verbose output. Defaults to False.
Returns:
LocalTrainResult with step number, training metrics, and checkpoint path.
Example:
# Before (deprecated):
await model.train(trajectory_groups, config=TrainConfig(learning_rate=5e-6))
# After (recommended):
await model.log(trajectory_groups, split="train")
result = await backend.train(model, trajectory_groups, learning_rate=5e-6)
# Optionally log training metrics:
# await model.log(metrics=result.metrics, step=result.step)
"""
groups_list = list(trajectory_groups)
if loss_fn not in {"cispo", "ppo"}:
raise ValueError("LocalBackend only supports loss_fn='cispo' or 'ppo'.")
if loss_fn_config is not None:
raise ValueError("LocalBackend requires loss_fn_config=None.")
if not normalize_advantages:
raise ValueError("LocalBackend requires normalize_advantages=True.")
if adam_params is not None:
raise ValueError("LocalBackend requires adam_params=None.")
resolved_kl_ref_adapter_path = kl_ref_adapter_path
if (
resolved_kl_ref_adapter_path is None
and kl_penalty_reference_step is not None
):
resolved_kl_ref_adapter_path = get_step_checkpoint_dir(
get_model_dir(model=model, art_path=self._path),
kl_penalty_reference_step,
)
config, dev_config = build_rl_train_configs(
learning_rate=learning_rate,
advantage_balance=advantage_balance,
scale_rewards=scale_rewards,
importance_sampling_level=importance_sampling_level,
mask_prob_ratio=mask_prob_ratio,
ppo=loss_fn == "ppo",
precalculate_logprobs=precalculate_logprobs,
epsilon=epsilon,
epsilon_high=epsilon_high,
max_negative_advantage_importance_sampling_weight=max_negative_advantage_importance_sampling_weight,
kimi_k2_tau=kimi_k2_tau,
kl_penalty_coef=kl_penalty_coef,
allow_training_without_logprobs=allow_training_without_logprobs,
plot_tensors=plot_tensors,
truncated_importance_sampling=truncated_importance_sampling,
scale_learning_rate_by_reward_std_dev=scale_learning_rate_by_reward_std_dev,
logprob_calculation_chunk_size=logprob_calculation_chunk_size,
num_trajectories_learning_rate_multiplier_power=num_trajectories_learning_rate_multiplier_power,
kl_ref_adapter_path=resolved_kl_ref_adapter_path,
)
# Collect metrics from training
training_metrics: list[dict[str, float]] = []
trainer_started = time.monotonic()
async for metrics in self._train_model(
model, groups_list, config, dev_config, verbose
):
training_metrics.append(metrics)
avg_metrics = aggregate_rl_training_metrics(
training_metrics=training_metrics,
trajectory_groups=groups_list,
trainer_started=trainer_started,
)
# Get step and checkpoint path
step = await self._get_step(model)
checkpoint_path: str | None = None
if save_checkpoint:
checkpoint_path = get_step_checkpoint_dir(
get_model_dir(model=model, art_path=self._path), step
)
if not os.path.exists(checkpoint_path):
checkpoint_path = None
# Record provenance on the latest W&B artifact
wandb_run = model._get_wandb_run()
if wandb_run is not None:
record_provenance(wandb_run, "local-rl")
return LocalTrainResult(
step=step,
metrics=avg_metrics,
checkpoint_path=checkpoint_path,
)
async def _train_model(
self,
model: TrainableModel,
trajectory_groups: list[TrajectoryGroup],
config: TrainConfig,
dev_config: dev.TrainConfig,
verbose: bool = False,
) -> AsyncIterator[dict[str, float]]:
if verbose:
print("Starting _train_model")
service = await self._get_service(model)
# Note: Logging is now handled by the frontend (Model.train() calls Model.log())
if verbose:
print("Packing tensors...")
summary = summarize_trajectory_groups(trajectory_groups)
base_metrics = build_training_summary_metrics(
summary,
include_trainable_groups=True,
)
packed_tensors = self._get_packed_tensors(
model,
trajectory_groups,
advantage_balance=dev_config.get("advantage_balance", 0.0),
allow_training_without_logprobs=dev_config.get(
"allow_training_without_logprobs", False
),
scale_rewards=dev_config.get("scale_rewards", True),
plot_tensors=dev_config.get("plot_tensors", False),
)
if packed_tensors is None:
print(
"Skipping tuning as there is no suitable data. "
"This can happen when all the trajectories in the same group "
"have the same reward and thus no advantage to train on."
)
# Still advance the step by renaming the checkpoint directory
current_step = self.__get_step(model)
next_step = current_step + 1
logger.info(
f"[BACKEND] _train_model SKIP: current_step={current_step} "
f"next_step={next_step} (all rewards equal)"
)
current_checkpoint_dir = get_step_checkpoint_dir(
get_model_dir(model=model, art_path=self._path), current_step
)
next_checkpoint_dir = get_step_checkpoint_dir(
get_model_dir(model=model, art_path=self._path), next_step
)
# If the current checkpoint exists, copy it to the next step
if os.path.exists(current_checkpoint_dir):
shutil.copytree(
current_checkpoint_dir,
next_checkpoint_dir,
dirs_exist_ok=True,
)
logger.info(
f"[BACKEND] _train_model SKIP: copied checkpoint "
f"{current_step} -> {next_step}, calling register_lora_for_step..."
)
try:
# Register the copied checkpoint as a new LoRA adapter
# so it's available for inference at the new step
if hasattr(service, "register_lora_for_step"):
await service.register_lora_for_step( # type: ignore[attr-defined]
next_step, next_checkpoint_dir
)
logger.info(
f"[BACKEND] _train_model SKIP: register_lora_for_step "
f"completed for step {next_step}"
)
except ModuleNotFoundError:
pass # Unsloth is not installed
# Yield metrics showing no groups were trainable
# (the frontend will handle logging)
yield {
**base_metrics,
"data/step_num_groups_trainable": 0.0,
"data/step_trainer_tokens": 0.0,
TRAIN_GRADIENT_STEPS_KEY: 0.0,
}
return
base_metrics["data/step_trainer_tokens"] = float(
packed_tensors["assistant_mask"].sum().item()
)
disk_packed_tensors = packed_tensors_to_dir(
packed_tensors, f"{get_model_dir(model=model, art_path=self._path)}/tensors"
)
# Note: scale_learning_rate_by_reward_std_dev is now handled by the frontend (Model.train())
grad_accumulation_sequences = max(
1, int(config.grad_accumulation_sequences or 1)
)
fallback_gradient_steps = math.ceil(
disk_packed_tensors["num_sequences"] / grad_accumulation_sequences
)
pbar = tqdm.tqdm(total=fallback_gradient_steps, desc="train")
reported_gradient_steps: int | None = None
async for result in service.train(
disk_packed_tensors, config, dev_config, verbose
):
raw_num_gradient_steps = result.pop(TRAIN_GRADIENT_STEPS_KEY, None)
if raw_num_gradient_steps is not None:
num_gradient_steps = int(raw_num_gradient_steps)
if reported_gradient_steps is None:
reported_gradient_steps = num_gradient_steps
if pbar.total != num_gradient_steps:
pbar.total = num_gradient_steps
pbar.refresh()
else:
assert num_gradient_steps == reported_gradient_steps, (
f"num_gradient_steps {num_gradient_steps} != reported_gradient_steps {reported_gradient_steps}"
)
else:
num_gradient_steps = reported_gradient_steps or fallback_gradient_steps
yield {
**base_metrics,
**result,
TRAIN_GRADIENT_STEPS_KEY: float(num_gradient_steps),
}
pbar.update(1)
pbar.set_postfix(result)
pbar.close()
# Note: Metrics logging is now handled by the frontend (Model.train())
if verbose:
print("_train_model complete")
# Note: _get_reward_std_dev_learning_rate_multiplier and _log_metrics
# have been moved to the Model class (frontend)
async def _train_sft(
self,
model: AnyTrainableModel,
trajectories: Iterable[Trajectory],
config: TrainSFTConfig,
dev_config: dev.TrainSFTConfig,
verbose: bool = False,
) -> AsyncIterator[dict[str, float]]:
"""Train the model using supervised fine-tuning.
Args:
model: The trainable model to fine-tune
trajectories: Iterable of Trajectory objects
config: SFT configuration with batch_size and learning rates.
If learning_rate is a list, streaming mode is used automatically.
dev_config: Developer configuration
verbose: Whether to print detailed logs
Yields:
Dictionary containing training metrics for each batch
"""
if verbose:
print("Starting _train_sft")
# Get tokenizer
if model.base_model not in self._tokenizers:
self._tokenizers[model.base_model] = AutoTokenizer.from_pretrained(
model.base_model
)
tokenizer = self._tokenizers[model.base_model]
from ..utils.sft import resolve_sft_batch_size
batch_size = resolve_sft_batch_size(
batch_size=config.batch_size,
default_batch_size=self._default_sft_batch_size(),
)
service_config = config.model_copy(update={"batch_size": batch_size})
# Auto-detect instruction/response parts from model
from ..utils.model_config import get_instruction_response_parts
instruction_part, response_part = get_instruction_response_parts(
model.base_model, tokenizer
)
if verbose:
print(f"Using instruction_part: {instruction_part!r}")
print(f"Using response_part: {response_part!r}")
import itertools
from typing import Iterator
from ..preprocessing.tokenize import SFTBatch
if isinstance(config.learning_rate, list):
learning_rates_iter: Iterator[float] = iter(config.learning_rate)
else:
learning_rates_iter = itertools.repeat(config.learning_rate)
# Build all batches in memory
trajectory_list = list(trajectories)
batches: list[SFTBatch] = []
for i in range(0, len(trajectory_list), batch_size):
batch_trajectories = trajectory_list[i : i + batch_size]
batches.append(
tokenize_sft_batch(
trajectory_batch=batch_trajectories,
learning_rate=next(learning_rates_iter),
tokenizer=tokenizer,
instruction_part=instruction_part,
response_part=response_part,
)
)
# Get the service and train
service = await self._get_service(model)
pbar = tqdm.tqdm(total=len(batches), desc="sft train")
total_trainable_tokens = sum(batch.num_trainable_tokens for batch in batches)
total_trajectories = len(trajectory_list)
batch_count = 0
async for result in service.train_sft(batches, service_config, verbose):
pbar.update(1)
pbar.set_postfix({"loss": f"{result.get('loss/train', 0):.4f}"})
batch_count += 1
yield {
**result,
"data/step_num_trajectories": float(total_trajectories),
"data/step_trainer_tokens": float(total_trainable_tokens),
TRAIN_GRADIENT_STEPS_KEY: float(len(batches)),
}
pbar.close()
if batch_count > 0 and total_trainable_tokens == 0:
print(
"WARNING: No trainable tokens found! "
"Check instruction_part and response_part settings."
)
if verbose:
print("_train_sft complete")
def _default_sft_batch_size(self) -> int:
return 2
# ------------------------------------------------------------------
# Experimental support for S3
# ------------------------------------------------------------------
async def _experimental_pull_model_checkpoint(
self,
model: "TrainableModel",
*,
step: int | Literal["latest"] | None = None,
local_path: str | None = None,
s3_bucket: str | None = None,
prefix: str | None = None,
verbose: bool = False,
) -> str:
"""Pull a model checkpoint to a local path.
For LocalBackend, this:
1. When step is "latest" or None, checks both local storage and S3 (if provided)
to find the latest checkpoint, preferring local if steps are equal
2. If checkpoint exists locally, uses it (optionally copying to local_path)
3. If checkpoint doesn't exist locally but s3_bucket is provided, pulls from S3
4. Returns the final checkpoint path
Args:
model: The model to pull checkpoint for.
step: The step to pull. Can be an int for a specific step,
or "latest" to pull the latest checkpoint. If None, pulls latest.
local_path: Custom directory to save/copy the checkpoint to.
If None, returns checkpoint from backend's default art path.
s3_bucket: S3 bucket to check/pull from. When step is "latest", both
local storage and S3 are checked to find the true latest.
prefix: S3 prefix.
verbose: Whether to print verbose output.
Returns:
Path to the local checkpoint directory.
"""
# Determine which step to use