Skip to content

Commit 0560bd7

Browse files
author
Francisco
committed
refactor(training): import BASE_DEFAULTS and PROFILES from projectdavid-common
Previously the training BASE_DEFAULTS and PROFILES dicts were duplicated across training_config_resolver.py (server-side) and unsloth_train.py (trainer-side TRAINER_FALLBACKS). Any change to a profile value had to be made in both files or drift was guaranteed. Hoisted into projectdavid_common.constants.training_profiles (>=0.71.0). Both consumers now import the same objects. TRAINER_FALLBACKS is now a thin derivation that adds target_modules on top of BASE_DEFAULTS (target_modules promotion to API surface is Phase 2 item 1).
1 parent f463cb3 commit 0560bd7

2 files changed

Lines changed: 10 additions & 62 deletions

File tree

src/api/training/services/training_config_resolver.py

Lines changed: 6 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -3,55 +3,9 @@
33

44
from typing import Any, Dict, Optional
55

6+
from projectdavid_common.constants import BASE_DEFAULTS, PROFILES
67
from projectdavid_common.schemas.training_schema import TrainingConfig, TrainingProfile
78

8-
# Canonical defaults. Represents the behaviour of the current codebase when
9-
# no config is supplied — the values currently hardcoded in unsloth_train.py
10-
# (SFTConfig + get_peft_model call sites). These are also the values baked
11-
# into PROFILES["standard"] for profile-scoped fields, so an empty config
12-
# reproduces the previous default-profile-standard behaviour.
13-
BASE_DEFAULTS: Dict[str, Any] = {
14-
# Profile-scoped (overridable by profile preset):
15-
"max_seq_length": 2048,
16-
"per_device_train_batch_size": 2,
17-
"gradient_accumulation_steps": 4,
18-
"max_steps": 60,
19-
"optim": "adamw_8bit",
20-
# SFTConfig-scoped:
21-
"learning_rate": 2e-4,
22-
"warmup_steps": 2,
23-
"weight_decay": 0.01,
24-
"lr_scheduler_type": "linear",
25-
"seed": 3407,
26-
"logging_steps": 50,
27-
"num_train_epochs": 3,
28-
# PEFT-scoped:
29-
"lora_r": 32,
30-
"lora_alpha": 32,
31-
"lora_dropout": 0.0,
32-
"bias": "none",
33-
}
34-
35-
# Must match PROFILES in unsloth_train.py. Kept duplicated for Phase 1;
36-
# Phase 2 cleanup should hoist this into a shared constants module imported
37-
# by both the resolver and the trainer.
38-
PROFILES: Dict[str, Dict[str, Any]] = {
39-
"laptop": {
40-
"max_seq_length": 1024,
41-
"per_device_train_batch_size": 1,
42-
"gradient_accumulation_steps": 8,
43-
"max_steps": 12500,
44-
"optim": "adamw_8bit",
45-
},
46-
"standard": {
47-
"max_seq_length": 2048,
48-
"per_device_train_batch_size": 2,
49-
"gradient_accumulation_steps": 4,
50-
"max_steps": 60,
51-
"optim": "adamw_8bit",
52-
},
53-
}
54-
559

5610
def resolve_training_config(user_config: Optional[TrainingConfig]) -> Dict[str, Any]:
5711
"""
@@ -67,6 +21,11 @@ def resolve_training_config(user_config: Optional[TrainingConfig]) -> Dict[str,
6721
6822
The returned dict is the complete execution plan. Worker and trainer
6923
read from it without further resolution logic.
24+
25+
BASE_DEFAULTS and PROFILES are the canonical dicts exported from
26+
projectdavid_common.constants — the trainer safety-net fallbacks in
27+
unsloth_train.py import the same objects, so there is no possible
28+
drift between resolver and trainer.
7029
"""
7130
resolved: Dict[str, Any] = dict(BASE_DEFAULTS)
7231

src/api/training/unsloth_train.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import unsloth # noqa: F401 — must precede trl/transformers/peft
1616
from datasets import load_dataset
17+
from projectdavid_common.constants import BASE_DEFAULTS
1718
from transformers import TrainerCallback
1819
from trl import SFTConfig, SFTTrainer
1920
from unsloth import FastLanguageModel, is_bfloat16_supported
@@ -25,22 +26,10 @@
2526
#
2627
# target_modules is fixed here (not exposed via the API in Phase 1) — Phase 2
2728
# will add base-model-aware validation before it becomes user-tunable.
29+
30+
2831
TRAINER_FALLBACKS = {
29-
"max_seq_length": 2048,
30-
"per_device_train_batch_size": 2,
31-
"gradient_accumulation_steps": 4,
32-
"max_steps": 60,
33-
"optim": "adamw_8bit",
34-
"learning_rate": 2e-4,
35-
"warmup_steps": 2,
36-
"weight_decay": 0.01,
37-
"lr_scheduler_type": "linear",
38-
"seed": 3407,
39-
"logging_steps": 50,
40-
"lora_r": 32,
41-
"lora_alpha": 32,
42-
"lora_dropout": 0.0,
43-
"bias": "none",
32+
**BASE_DEFAULTS,
4433
"target_modules": [
4534
"q_proj",
4635
"k_proj",

0 commit comments

Comments
 (0)