Skip to content

Commit 0712fef

Browse files
KemingWukcz358
andauthored
Add Bagel Trainer and fix config, bagel data processor (#126)
* fix bagel trainer config processor * fix bagel trainer config processor * Add bagel fsdp trainer into register and in ema support trainer list --------- Co-authored-by: kcz358 <kaichenzhang358@outlook.com>
1 parent fd38011 commit 0712fef

5 files changed

Lines changed: 494 additions & 62 deletions

File tree

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
trainer_type: bagel_fsdp2_trainer
2+
3+
dataset_config:
4+
extra_kwargs: {}
5+
dataset_type: bagel_iterable
6+
dataset_format: yaml
7+
processor_config:
8+
processor_name: lmms-lab/BAGEL-7B-MoT-ver.LE
9+
processor_type: bagel
10+
# Align with original Bagel edit dataset:
11+
# - user/source image is encoded as BOTH VAE-condition (loss=0, timestep=-inf) and ViT-condition
12+
# - optional CFG-style conditional dropouts (set to original defaults; set to 0.0 to disable)
13+
extra_kwargs:
14+
user_image_as_vae_condition: true
15+
text_cond_dropout_prob: 0.0
16+
vit_cond_dropout_prob: 0.0
17+
vae_cond_dropout_prob: 0.0
18+
# Update this to your dataset path
19+
# Example dataset: https://huggingface.co/datasets/kcz358/bagel-example
20+
datasets:
21+
- path: /path/to/your/dataset/
22+
data_folder: /path/to/your/data_folder/
23+
data_type: arrow
24+
shuffle: true
25+
eval_dataset_path: null
26+
object_storage: none
27+
bucket_name: null
28+
packing: true
29+
packing_strategy: first_fit
30+
packing_length: 10240
31+
filter_overlong: true
32+
filter_overlong_workers: 8
33+
max_length: null
34+
video_backend: qwen_vl_utils
35+
36+
trainer_args:
37+
output_dir: ./output/bagel_training
38+
overwrite_output_dir: false
39+
do_train: true
40+
do_eval: false
41+
do_predict: false
42+
eval_strategy: 'no'
43+
prediction_loss_only: false
44+
per_device_train_batch_size: 2
45+
per_device_eval_batch_size: 8
46+
gradient_accumulation_steps: 1
47+
eval_accumulation_steps: null
48+
eval_delay: 0
49+
torch_empty_cache_steps: null
50+
learning_rate: 1.0e-06
51+
weight_decay: 0.0
52+
adam_beta1: 0.9
53+
adam_beta2: 0.999
54+
adam_epsilon: 1.0e-08
55+
max_grad_norm: 1.0
56+
num_train_epochs: 1
57+
max_steps: 1000
58+
lr_scheduler_type: cosine
59+
lr_scheduler_kwargs: {}
60+
warmup_ratio: 0.1
61+
warmup_steps: 20
62+
log_level: passive
63+
log_level_replica: warning
64+
log_on_each_node: true
65+
logging_dir: ./output/bagel_training/runs
66+
logging_strategy: steps
67+
logging_first_step: false
68+
logging_steps: 10
69+
logging_nan_inf_filter: true
70+
save_strategy: steps
71+
save_steps: 500
72+
save_total_limit: 2
73+
save_safetensors: true
74+
save_on_each_node: false
75+
save_only_model: false
76+
restore_callback_states_from_checkpoint: false
77+
no_cuda: false
78+
use_cpu: false
79+
use_mps_device: false
80+
seed: 42
81+
data_seed: null
82+
jit_mode_eval: false
83+
bf16: true
84+
fp16: false
85+
fp16_opt_level: O1
86+
half_precision_backend: auto
87+
bf16_full_eval: false
88+
fp16_full_eval: false
89+
tf32: null
90+
local_rank: 0
91+
ddp_backend: null
92+
tpu_num_cores: null
93+
tpu_metrics_debug: false
94+
debug: []
95+
dataloader_drop_last: false
96+
eval_steps: null
97+
dataloader_num_workers: 8
98+
dataloader_prefetch_factor: null
99+
past_index: -1
100+
run_name: bagel_training
101+
disable_tqdm: false
102+
remove_unused_columns: true
103+
label_names: null
104+
load_best_model_at_end: false
105+
metric_for_best_model: null
106+
greater_is_better: null
107+
ignore_data_skip: false
108+
fsdp: []
109+
fsdp_min_num_params: 0
110+
fsdp_config:
111+
transformer_layer_cls_to_wrap:
112+
- Qwen2MoTDecoderLayer
113+
reshard_after_forward: false
114+
min_num_params: 0
115+
xla: false
116+
xla_fsdp_v2: false
117+
xla_fsdp_grad_ckpt: false
118+
fsdp_transformer_layer_cls_to_wrap: null
119+
accelerator_config:
120+
split_batches: false
121+
dispatch_batches: null
122+
even_batches: true
123+
use_seedable_sampler: true
124+
non_blocking: false
125+
gradient_accumulation_kwargs: null
126+
parallelism_config: null
127+
deepspeed: null
128+
label_smoothing_factor: 0.0
129+
optim: adamw_torch_fused
130+
optim_args: null
131+
adafactor: false
132+
group_by_length: true
133+
length_column_name: length
134+
report_to:
135+
- wandb
136+
project: huggingface
137+
trackio_space_id: trackio
138+
ddp_find_unused_parameters: null
139+
ddp_bucket_cap_mb: null
140+
ddp_broadcast_buffers: null
141+
dataloader_pin_memory: true
142+
dataloader_persistent_workers: false
143+
skip_memory_metrics: true
144+
use_legacy_prediction_loop: false
145+
push_to_hub: false
146+
resume_from_checkpoint: null
147+
hub_model_id: null
148+
hub_strategy: every_save
149+
hub_token: <HUB_TOKEN>
150+
hub_private_repo: null
151+
hub_always_push: false
152+
hub_revision: null
153+
gradient_checkpointing: true
154+
gradient_checkpointing_kwargs: null
155+
include_inputs_for_metrics: false
156+
include_for_metrics: []
157+
eval_do_concat_batches: true
158+
fp16_backend: auto
159+
push_to_hub_model_id: null
160+
push_to_hub_organization: null
161+
mp_parameters: ''
162+
auto_find_batch_size: false
163+
full_determinism: false
164+
torchdynamo: null
165+
ray_scope: last
166+
ddp_timeout: 1800
167+
torch_compile: false
168+
torch_compile_backend: null
169+
torch_compile_mode: null
170+
include_tokens_per_second: false
171+
include_num_input_tokens_seen: 'no'
172+
neftune_noise_alpha: null
173+
optim_target_modules: null
174+
batch_eval_metrics: false
175+
eval_on_start: false
176+
use_liger_kernel: false
177+
liger_kernel_config: null
178+
eval_use_gather_object: false
179+
average_tokens_across_devices: true
180+
use_muon: false
181+
freeze_modules:
182+
- vae_model
183+
use_rmpad: false
184+
fsdp2: true
185+
sp_ulysses_degree: 1
186+
reduce_dtype: bfloat16
187+
output_dtype: bfloat16
188+
print_batch_input_steps: 5
189+
enable_profiler: false
190+
profiler_config:
191+
start_step: 1
192+
end_step: 3
193+
194+
model_config:
195+
extra_kwargs:
196+
visual_und: false # Enable/disable visual understanding
197+
load_from_pretrained_path: lmms-lab/BAGEL-7B-MoT-ver.LE
198+
load_from_config: null
199+
attn_implementation: flash_attention_2
200+
model_type: null
201+
torch_dtype: bfloat16
202+
overwrite_config: null
203+
# Optional: Enable Native Sparse Attention
204+
# monkey_patch_kwargs:
205+
# patch_type: ["nsa"]
206+
# block_size: 64
207+
# compress_type: "weightedpool"
208+
# kernel_size: 32
209+
# kernel_stride: 16
210+
# topk: 16
211+
# init_blocks: 1
212+
# local_blocks: 2
213+
# window_size: 512
214+
215+
extra_kwargs: null

0 commit comments

Comments
 (0)