Skip to content

Commit 0122791

Browse files
committed
remove retrain config
1 parent 1e3297b commit 0122791

1 file changed

Lines changed: 6 additions & 6 deletions

File tree

search_params.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def prepare_retrain_config(best_config, best_log_dir, retrain):
150150
retrain (bool): Whether to retrain the model with merged training and validation data.
151151
"""
152152
if retrain:
153-
best_config.retrain = True
153+
best_config.merge_train_val = True
154154

155155
log_path = os.path.join(best_log_dir, "logs.json")
156156
if os.path.isfile(log_path):
@@ -165,15 +165,15 @@ def prepare_retrain_config(best_config, best_log_dir, retrain):
165165
optimal_idx = log_metric.argmax() if best_config.mode == "max" else log_metric.argmin()
166166
best_config.epochs = optimal_idx.item() + 1 # plus 1 for epochs
167167
else:
168-
best_config.retrain = False
168+
best_config.merge_train_val = False
169169

170170

171-
def load_static_data(config, retrain=False):
171+
def load_static_data(config, merge_train_val=False):
172172
"""Preload static data once for multiple trials.
173173
174174
Args:
175175
config (AttributeDict): Config of the experiment.
176-
retrain (bool): Whether to retrain the model with merged training and validation data.
176+
merge_train_val (bool, optional): Whether to merge the training and validation data.
177177
Defaults to False.
178178
179179
Returns:
@@ -184,7 +184,7 @@ def load_static_data(config, retrain=False):
184184
test_data=config.test_file,
185185
val_data=config.val_file,
186186
val_size=config.val_size,
187-
merge_train_val=retrain,
187+
merge_train_val=merge_train_val,
188188
tokenize_text="lm_weight" not in config.network_config,
189189
remove_no_label_data=config.remove_no_label_data,
190190
)
@@ -227,7 +227,7 @@ def retrain_best_model(exp_name, best_config, best_log_dir, retrain):
227227
with open(os.path.join(checkpoint_dir, "params.yml"), "w") as fp:
228228
yaml.dump(dict(best_config), fp)
229229

230-
data = load_static_data(best_config, retrain=best_config.retrain)
230+
data = load_static_data(best_config, merge_train_val=best_config.merge_train_val)
231231

232232
if retrain:
233233
logging.info(f"Re-training with best config: \n{best_config}")

0 commit comments

Comments
 (0)