Skip to content

Commit 2d18ab0

Browse files
committed
Dont override configs in cli
1 parent f573997 commit 2d18ab0

2 files changed

Lines changed: 42 additions & 19 deletions

File tree

configs/default.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ lora_dropout: 0
2929
encoder_base_model: "jhu-clsp/mmBERT-base"
3030
encoder_max_length: 8192
3131
encoder_batch_size: 2
32+
encoder_eval_batch_size: 1
3233
encoder_gradient_accumulation_steps: 8
34+
encoder_eval_accumulation_steps: 1
3335
encoder_learning_rate: 2.0e-5
3436
encoder_num_epochs: 5
3537
encoder_warmup_ratio: 0.1

squeez/encoder/train.py

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ def train(
3737
save_total_limit: int = 3,
3838
fp16: bool = False,
3939
bf16: bool = False,
40+
eval_batch_size: int | None = None,
41+
eval_accumulation_steps: int = 1,
4042
) -> None:
4143
"""Train the encoder line classifier."""
4244
import torch
@@ -81,6 +83,7 @@ def train(
8183

8284
# Resize embeddings for the new [LINE_SEP] token
8385
model.encoder.resize_token_embeddings(len(tokenizer))
86+
model.gradient_checkpointing_enable()
8487

8588
# Load datasets
8689
logger.info(f"Loading train data from {train_file}")
@@ -102,7 +105,7 @@ def train(
102105
output_dir=output_dir,
103106
num_train_epochs=num_epochs,
104107
per_device_train_batch_size=batch_size,
105-
per_device_eval_batch_size=max(1, batch_size // 2),
108+
per_device_eval_batch_size=eval_batch_size or max(1, batch_size // 2),
106109
gradient_accumulation_steps=gradient_accumulation_steps,
107110
learning_rate=learning_rate,
108111
weight_decay=weight_decay,
@@ -120,6 +123,8 @@ def train(
120123
metric_for_best_model="eval_loss" if eval_dataset else None,
121124
report_to="none",
122125
dataloader_num_workers=0,
126+
eval_accumulation_steps=eval_accumulation_steps,
127+
gradient_checkpointing=True,
123128
remove_unused_columns=False,
124129
)
125130

@@ -158,15 +163,17 @@ def build_parser(parser: argparse.ArgumentParser | None = None) -> argparse.Argu
158163

159164
parser.add_argument("--train-file", required=True, help="Path to encoder_train.jsonl")
160165
parser.add_argument("--eval-file", default=None, help="Path to encoder_dev.jsonl")
161-
parser.add_argument("--base-model", default="jhu-clsp/mmBERT-base", help="Base encoder model")
162-
parser.add_argument("--output-dir", default="output/squeez_encoder")
163-
parser.add_argument("--max-length", type=int, default=8192)
164-
parser.add_argument("--batch-size", type=int, default=16)
165-
parser.add_argument("--gradient-accumulation-steps", type=int, default=1)
166-
parser.add_argument("--learning-rate", type=float, default=2e-5)
167-
parser.add_argument("--num-epochs", type=int, default=5)
168-
parser.add_argument("--warmup-ratio", type=float, default=0.1)
169-
parser.add_argument("--weight-decay", type=float, default=0.01)
166+
parser.add_argument("--base-model", default=None, help="Base encoder model")
167+
parser.add_argument("--output-dir", default=None)
168+
parser.add_argument("--max-length", type=int, default=None)
169+
parser.add_argument("--batch-size", type=int, default=None)
170+
parser.add_argument("--eval-batch-size", type=int, default=None)
171+
parser.add_argument("--gradient-accumulation-steps", type=int, default=None)
172+
parser.add_argument("--eval-accumulation-steps", type=int, default=None)
173+
parser.add_argument("--learning-rate", type=float, default=None)
174+
parser.add_argument("--num-epochs", type=int, default=None)
175+
parser.add_argument("--warmup-ratio", type=float, default=None)
176+
parser.add_argument("--weight-decay", type=float, default=None)
170177
parser.add_argument("--eval-steps", type=int, default=200)
171178
parser.add_argument("--save-steps", type=int, default=200)
172179
parser.add_argument("--logging-steps", type=int, default=25)
@@ -184,23 +191,37 @@ def main(argv: list[str] | None = None) -> int:
184191
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
185192
)
186193

194+
import yaml
195+
196+
default_config_path = Path(__file__).parent.parent.parent / "configs" / "default.yaml"
197+
config = {}
198+
if default_config_path.exists():
199+
with open(default_config_path) as f:
200+
config = yaml.safe_load(f) or {}
201+
187202
train(
188203
train_file=args.train_file,
189204
eval_file=args.eval_file,
190-
base_model=args.base_model,
191-
output_dir=args.output_dir,
192-
max_length=args.max_length,
193-
batch_size=args.batch_size,
194-
gradient_accumulation_steps=args.gradient_accumulation_steps,
195-
learning_rate=args.learning_rate,
196-
num_epochs=args.num_epochs,
197-
warmup_ratio=args.warmup_ratio,
198-
weight_decay=args.weight_decay,
205+
base_model=args.base_model or config.get("encoder_base_model", "jhu-clsp/mmBERT-base"),
206+
output_dir=args.output_dir or config.get("encoder_output_dir", "output/squeez_encoder"),
207+
max_length=args.max_length or config.get("encoder_max_length", 8192),
208+
batch_size=args.batch_size or config.get("encoder_batch_size", 2),
209+
gradient_accumulation_steps=(
210+
args.gradient_accumulation_steps or config.get("encoder_gradient_accumulation_steps", 8)
211+
),
212+
learning_rate=args.learning_rate or config.get("encoder_learning_rate", 2e-5),
213+
num_epochs=args.num_epochs or config.get("encoder_num_epochs", 5),
214+
warmup_ratio=args.warmup_ratio or config.get("encoder_warmup_ratio", 0.1),
215+
weight_decay=args.weight_decay or config.get("weight_decay", 0.01),
199216
eval_steps=args.eval_steps,
200217
save_steps=args.save_steps,
201218
logging_steps=args.logging_steps,
202219
fp16=args.fp16,
203220
bf16=args.bf16,
221+
eval_batch_size=args.eval_batch_size or config.get("encoder_eval_batch_size", 1),
222+
eval_accumulation_steps=(
223+
args.eval_accumulation_steps or config.get("encoder_eval_accumulation_steps", 1)
224+
),
204225
)
205226
return 0
206227

0 commit comments

Comments
 (0)