You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: libmultilabel/nn/model.py
+8-16Lines changed: 8 additions & 16 deletions
Original file line number
Diff line number
Diff line change
@@ -15,10 +15,8 @@ class MultiLabelModel(L.LightningModule):
15
15
16
16
Args:
17
17
num_classes (int): Total number of classes.
18
-
learning_rate (float, optional): Learning rate for optimizer. Defaults to 0.0001.
19
18
optimizer (str, optional): Optimizer name (i.e., sgd, adam, or adamw). Defaults to 'adam'.
20
-
momentum (float, optional): Momentum factor for SGD only. Defaults to 0.9.
21
-
weight_decay (int, optional): Weight decay factor. Defaults to 0.
19
+
optimizer_config (dict, optional): Optimizer parameters. The keys in the dictionary should match the parameter names defined by PyTorch for the optimizer.
22
20
metric_threshold (float, optional): The decision value threshold over which a label is predicted as positive. Defaults to 0.5.
23
21
monitor_metrics (list, optional): Metrics to monitor while validating. Defaults to None.
24
22
log_path (str): Path to a directory holding the log files and models.
@@ -30,10 +28,8 @@ class MultiLabelModel(L.LightningModule):
Copy file name to clipboardExpand all lines: libmultilabel/nn/nn_utils.py
+3-9Lines changed: 3 additions & 9 deletions
Original file line number
Diff line number
Diff line change
@@ -41,10 +41,8 @@ def init_model(
41
41
embed_vecs=None,
42
42
init_weight=None,
43
43
log_path=None,
44
-
learning_rate=0.0001,
45
44
optimizer="adam",
46
-
momentum=0.9,
47
-
weight_decay=0,
45
+
optimizer_config=None,
48
46
lr_scheduler=None,
49
47
scheduler_config=None,
50
48
val_metric=None,
@@ -69,10 +67,8 @@ def init_model(
69
67
For example, the `init_weight` of `torch.nn.init.kaiming_uniform_`
70
68
is `kaiming_uniform`. Defaults to None.
71
69
log_path (str): Path to a directory holding the log files and models.
72
-
learning_rate (float, optional): Learning rate for optimizer. Defaults to 0.0001.
73
70
optimizer (str, optional): Optimizer name (i.e., sgd, adam, or adamw). Defaults to 'adam'.
74
-
momentum (float, optional): Momentum factor for SGD only. Defaults to 0.9.
75
-
weight_decay (int, optional): Weight decay factor. Defaults to 0.
71
+
optimizer_config (dict, optional): Optimizer parameters. The keys in the dictionary should match the parameter names defined by PyTorch for the optimizer.
76
72
lr_scheduler (str, optional): Name of the learning rate scheduler. Defaults to None.
77
73
scheduler_config (dict, optional): The configuration for learning rate scheduler. Defaults to None.
78
74
val_metric (str, optional): The metric to select the best model for testing. Used by some of the schedulers. Defaults to None.
0 commit comments