Skip to content

Commit 1e3297b

Browse files
committed
modify no_merge_train_val to retrain
1 parent f325f48 commit 1e3297b

1 file changed

Lines changed: 32 additions & 22 deletions

File tree

search_params.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -141,16 +141,16 @@ def init_search_algorithm(search_alg, metric=None, mode=None):
141141
logging.info(f"{search_alg} search is found, run BasicVariantGenerator().")
142142

143143

144-
def prepare_retrain_config(best_config, best_log_dir, merge_train_val):
144+
def prepare_retrain_config(best_config, best_log_dir, retrain):
145145
"""Prepare the configuration for re-training.
146146
147147
Args:
148148
best_config (AttributeDict): The best hyper-parameter configuration.
149149
best_log_dir (str): The directory of the best trial of the experiment.
150-
merge_train_val (bool): Whether to merge the training and validation data.
150+
retrain (bool): Whether to retrain the model with merged training and validation data.
151151
"""
152-
if merge_train_val:
153-
best_config.merge_train_val = True
152+
if retrain:
153+
best_config.retrain = True
154154

155155
log_path = os.path.join(best_log_dir, "logs.json")
156156
if os.path.isfile(log_path):
@@ -165,15 +165,15 @@ def prepare_retrain_config(best_config, best_log_dir, merge_train_val):
165165
optimal_idx = log_metric.argmax() if best_config.mode == "max" else log_metric.argmin()
166166
best_config.epochs = optimal_idx.item() + 1 # plus 1 for epochs
167167
else:
168-
best_config.merge_train_val = False
168+
best_config.retrain = False
169169

170170

171-
def load_static_data(config, merge_train_val=False):
171+
def load_static_data(config, retrain=False):
172172
"""Preload static data once for multiple trials.
173173
174174
Args:
175175
config (AttributeDict): Config of the experiment.
176-
merge_train_val (bool, optional): Whether to merge the training and validation data.
176+
retrain (bool): Whether to retrain the model with merged training and validation data.
177177
Defaults to False.
178178
179179
Returns:
@@ -184,7 +184,7 @@ def load_static_data(config, merge_train_val=False):
184184
test_data=config.test_file,
185185
val_data=config.val_file,
186186
val_size=config.val_size,
187-
merge_train_val=merge_train_val,
187+
merge_train_val=retrain,
188188
tokenize_text="lm_weight" not in config.network_config,
189189
remove_no_label_data=config.remove_no_label_data,
190190
)
@@ -205,31 +205,31 @@ def load_static_data(config, merge_train_val=False):
205205
}
206206

207207

208-
def retrain_best_model(exp_name, best_config, best_log_dir, merge_train_val):
208+
def retrain_best_model(exp_name, best_config, best_log_dir, retrain):
209209
"""Re-train the model with the best hyper-parameters.
210-
A new model is trained on the combined training and validation data if `merge_train_val` is True.
210+
A new model is trained on the combined training and validation data if `retrain` is True.
211211
If a test set is provided, it will be evaluated by the obtained model.
212212
213213
Args:
214214
exp_name (str): The directory to save trials generated by ray tune.
215215
best_config (AttributeDict): The best hyper-parameter configuration.
216216
best_log_dir (str): The directory of the best trial of the experiment.
217-
merge_train_val (bool): Whether to merge the training and validation data.
217+
retrain (bool): Whether to retrain the model with merged training and validation data.
218218
"""
219219
best_config.silent = False
220220
checkpoint_dir = os.path.join(best_config.result_dir, exp_name, "trial_best_params")
221221
os.makedirs(checkpoint_dir, exist_ok=True)
222-
with open(os.path.join(checkpoint_dir, "params.yml"), "w") as fp:
223-
yaml.dump(dict(best_config), fp)
224222
best_config.run_name = "_".join(exp_name.split("_")[:-1]) + "_best"
225223
best_config.checkpoint_dir = checkpoint_dir
226224
best_config.log_path = os.path.join(best_config.checkpoint_dir, "logs.json")
227-
prepare_retrain_config(best_config, best_log_dir, merge_train_val)
225+
prepare_retrain_config(best_config, best_log_dir, retrain)
228226
set_seed(seed=best_config.seed)
227+
with open(os.path.join(checkpoint_dir, "params.yml"), "w") as fp:
228+
yaml.dump(dict(best_config), fp)
229229

230-
data = load_static_data(best_config, merge_train_val=best_config.merge_train_val)
230+
data = load_static_data(best_config, retrain=best_config.retrain)
231231

232-
if merge_train_val:
232+
if retrain:
233233
logging.info(f"Re-training with best config: \n{best_config}")
234234
trainer = TorchTrainer(config=best_config, **data)
235235
trainer.train()
@@ -247,7 +247,7 @@ def retrain_best_model(exp_name, best_config, best_log_dir, merge_train_val):
247247

248248
if "test" in data["datasets"]:
249249
test_results = trainer.test()
250-
if merge_train_val:
250+
if retrain:
251251
logging.info(f"Test results after re-training: {test_results}")
252252
else:
253253
logging.info(f"Test results of best config: {test_results}")
@@ -260,8 +260,18 @@ def main():
260260
"--config",
261261
help="Path to configuration file (default: %(default)s). Please specify a config with all arguments in LibMultiLabel/main.py::get_config.",
262262
)
263-
parser.add_argument("--cpu_count", type=int, default=4, help="Number of CPU per trial (default: %(default)s)")
264-
parser.add_argument("--gpu_count", type=int, default=1, help="Number of GPU per trial (default: %(default)s)")
263+
parser.add_argument(
264+
"--cpu_count",
265+
type=int,
266+
default=4,
267+
help="Number of CPU per trial (default: %(default)s)",
268+
)
269+
parser.add_argument(
270+
"--gpu_count",
271+
type=int,
272+
default=1,
273+
help="Number of GPU per trial (default: %(default)s)",
274+
)
265275
parser.add_argument(
266276
"--num_samples",
267277
type=int,
@@ -275,9 +285,9 @@ def main():
275285
help="Search algorithms (default: %(default)s)",
276286
)
277287
parser.add_argument(
278-
"--no_merge_train_val",
288+
"--no_retrain",
279289
action="store_true",
280-
help="Do not add the validation set in re-training the final model after hyper-parameter search.",
290+
help="Do not retrain the model with validation set after hyperparameter search.",
281291
)
282292
args, _ = parser.parse_known_args()
283293

@@ -343,7 +353,7 @@ def main():
343353
# Save best model after parameter search.
344354
best_config = analysis.get_best_config(f"val_{config.val_metric}", config.mode, scope="all")
345355
best_log_dir = analysis.get_best_logdir(f"val_{config.val_metric}", config.mode, scope="all")
346-
retrain_best_model(exp_name, best_config, best_log_dir, merge_train_val=not config.no_merge_train_val)
356+
retrain_best_model(exp_name, best_config, best_log_dir, retrain=not config.no_retrain)
347357

348358

349359
if __name__ == "__main__":

0 commit comments

Comments
 (0)