Skip to content

Commit ff444cc

Browse files
committed
optimize arguments in Model
1 parent 60daf1f commit ff444cc

4 files changed

Lines changed: 15 additions & 29 deletions

File tree

libmultilabel/nn/model.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,8 @@ class MultiLabelModel(L.LightningModule):
1515
1616
Args:
1717
num_classes (int): Total number of classes.
18-
learning_rate (float, optional): Learning rate for optimizer. Defaults to 0.0001.
1918
optimizer (str, optional): Optimizer name (i.e., sgd, adam, or adamw). Defaults to 'adam'.
20-
momentum (float, optional): Momentum factor for SGD only. Defaults to 0.9.
21-
weight_decay (int, optional): Weight decay factor. Defaults to 0.
19+
optimizer_config (dict, optional): Optimizer parameters. The keys in the dictionary should match the parameter names defined by PyTorch for the optimizer.
2220
metric_threshold (float, optional): The decision value threshold over which a label is predicted as positive. Defaults to 0.5.
2321
monitor_metrics (list, optional): Metrics to monitor while validating. Defaults to None.
2422
log_path (str): Path to a directory holding the log files and models.
@@ -30,10 +28,8 @@ class MultiLabelModel(L.LightningModule):
3028
def __init__(
3129
self,
3230
num_classes,
33-
learning_rate=0.0001,
3431
optimizer="adam",
35-
momentum=0.9,
36-
weight_decay=0,
32+
optimizer_config=None,
3733
lr_scheduler=None,
3834
scheduler_config=None,
3935
val_metric=None,
@@ -43,15 +39,13 @@ def __init__(
4339
multiclass=False,
4440
silent=False,
4541
save_k_predictions=0,
46-
**kwargs
42+
**kwargs,
4743
):
4844
super().__init__()
4945

5046
# optimizer
51-
self.learning_rate = learning_rate
5247
self.optimizer = optimizer
53-
self.momentum = momentum
54-
self.weight_decay = weight_decay
48+
self.optimizer_config = optimizer_config if optimizer_config is not None else {}
5549

5650
# lr_scheduler
5751
self.lr_scheduler = lr_scheduler
@@ -78,15 +72,13 @@ def configure_optimizers(self):
7872
parameters = [p for p in self.parameters() if p.requires_grad]
7973
optimizer_name = self.optimizer
8074
if optimizer_name == "sgd":
81-
optimizer = optim.SGD(
82-
parameters, self.learning_rate, momentum=self.momentum, weight_decay=self.weight_decay
83-
)
75+
optimizer = optim.SGD(parameters, **self.optimizer_config)
8476
elif optimizer_name == "adam":
85-
optimizer = optim.Adam(parameters, weight_decay=self.weight_decay, lr=self.learning_rate)
77+
optimizer = optim.Adam(parameters, **self.optimizer_config)
8678
elif optimizer_name == "adamw":
87-
optimizer = optim.AdamW(parameters, weight_decay=self.weight_decay, lr=self.learning_rate)
79+
optimizer = optim.AdamW(parameters, **self.optimizer_config)
8880
elif optimizer_name == "adamax":
89-
optimizer = optim.Adamax(parameters, weight_decay=self.weight_decay, lr=self.learning_rate)
81+
optimizer = optim.Adamax(parameters, **self.optimizer_config)
9082
else:
9183
raise RuntimeError("Unsupported optimizer: {self.optimizer}")
9284

libmultilabel/nn/nn_utils.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,8 @@ def init_model(
4141
embed_vecs=None,
4242
init_weight=None,
4343
log_path=None,
44-
learning_rate=0.0001,
4544
optimizer="adam",
46-
momentum=0.9,
47-
weight_decay=0,
45+
optimizer_config=None,
4846
lr_scheduler=None,
4947
scheduler_config=None,
5048
val_metric=None,
@@ -69,10 +67,8 @@ def init_model(
6967
For example, the `init_weight` of `torch.nn.init.kaiming_uniform_`
7068
is `kaiming_uniform`. Defaults to None.
7169
log_path (str): Path to a directory holding the log files and models.
72-
learning_rate (float, optional): Learning rate for optimizer. Defaults to 0.0001.
7370
optimizer (str, optional): Optimizer name (i.e., sgd, adam, or adamw). Defaults to 'adam'.
74-
momentum (float, optional): Momentum factor for SGD only. Defaults to 0.9.
75-
weight_decay (int, optional): Weight decay factor. Defaults to 0.
71+
optimizer_config (dict, optional): Optimizer parameters. The keys in the dictionary should match the parameter names defined by PyTorch for the optimizer.
7672
lr_scheduler (str, optional): Name of the learning rate scheduler. Defaults to None.
7773
scheduler_config (dict, optional): The configuration for learning rate scheduler. Defaults to None.
7874
val_metric (str, optional): The metric to select the best model for testing. Used by some of the schedulers. Defaults to None.
@@ -102,10 +98,8 @@ def init_model(
10298
word_dict=word_dict,
10399
network=network,
104100
log_path=log_path,
105-
learning_rate=learning_rate,
106101
optimizer=optimizer,
107-
momentum=momentum,
108-
weight_decay=weight_decay,
102+
optimizer_config=optimizer_config,
109103
lr_scheduler=lr_scheduler,
110104
scheduler_config=scheduler_config,
111105
val_metric=val_metric,

main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ def add_all_arguments(parser):
7777
"--optimizer",
7878
default="adam",
7979
choices=["adam", "adamw", "adamax", "sgd"],
80-
type=str.lower,
8180
help="Optimizer (default: %(default)s)",
8281
)
8382
parser.add_argument(
@@ -266,6 +265,9 @@ def get_config():
266265
args.early_stopping_metric = args.val_metric
267266
if not hasattr(args, "scheduler_config"):
268267
args.scheduler_config = None
268+
args.optimizer_config = {"lr": args.learning_rate, "weight_decay": args.weight_decay}
269+
if args.optimizer == "sgd":
270+
args.optimizer_config["momentum"] = args.momentum
269271
config = AttributeDict(vars(args))
270272

271273
config.run_name = "{}_{}_{}".format(

torch_trainer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,8 @@ def _setup_model(
189189
embed_vecs=embed_vecs,
190190
init_weight=self.config.init_weight,
191191
log_path=log_path,
192-
learning_rate=self.config.learning_rate,
193192
optimizer=self.config.optimizer,
194-
momentum=self.config.momentum,
195-
weight_decay=self.config.weight_decay,
193+
optimizer_config=self.config.optimizer_config,
196194
lr_scheduler=self.config.lr_scheduler,
197195
scheduler_config=self.config.scheduler_config,
198196
val_metric=self.config.val_metric,

0 commit comments

Comments
 (0)