Skip to content

Commit 8fa857a

Browse files
authored
Add kwarg verbose=False to train_ensemble() (#40)
* move results/ dir creation to save_results_dict() previously in example notebooks also delete runs/ dir creation, handled by TensorBoard automatically * add kwarg verbose=False to train_ensemble()
1 parent 92f083a commit 8fa857a

6 files changed

Lines changed: 7 additions & 30 deletions

File tree

aviary/utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ def train_ensemble(
293293
model_params: dict[str, Any],
294294
loss_dict: dict[str, Literal["L1", "L2", "CSE"]],
295295
patience: int = None,
296+
verbose: bool = False,
296297
) -> None:
297298
"""Convenience method to train multiple models in serial.
298299
@@ -313,6 +314,7 @@ def train_ensemble(
313314
to loss functions.
314315
patience (int, optional): Maximum number of epochs without improvement
315316
when early stopping. Defaults to None.
317+
verbose (bool, optional): Whether to show progress bars for each epoch.
316318
"""
317319
train_generator = DataLoader(train_set, **data_params)
318320
print(f"Training on {len(train_set):,} samples")
@@ -359,9 +361,7 @@ def train_ensemble(
359361

360362
if log:
361363
writer = SummaryWriter(
362-
log_dir=(
363-
f"runs/{model_name}/{model_name}-r{j}_{datetime.now():%d-%m-%Y_%H-%M-%S}"
364-
)
364+
f"runs/{model_name}/{model_name}-r{j}_{datetime.now():%d-%m-%Y_%H-%M-%S}"
365365
)
366366
else:
367367
writer = None
@@ -375,7 +375,7 @@ def train_ensemble(
375375
optimizer=None,
376376
normalizer_dict=normalizer_dict,
377377
action="val",
378-
verbose=True,
378+
verbose=verbose,
379379
)
380380

381381
val_score = {}
@@ -727,7 +727,7 @@ def save_results_dict(
727727
for col, data in results_dict[target_name].items():
728728

729729
# NOTE we save pre_logits rather than logits due to fact
730-
# that with the hetroskedastic setup we want to be able to
730+
# that with the heteroskedastic setup we want to be able to
731731
# sample from the Gaussian distributed pre_logits we parameterise.
732732
if "pre-logits" in col:
733733
for n_ens, y_pre_logit in enumerate(data):
@@ -760,6 +760,8 @@ def save_results_dict(
760760

761761
file_name = model_name.replace("/", "_")
762762

763+
os.makedirs("results", exist_ok=True)
764+
763765
csv_path = f"results/{file_name}.csv"
764766
df.to_csv(csv_path, index=False)
765767
print(f"\nSaved model predictions to '{csv_path}'")

examples/cgcnn-example.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -202,11 +202,6 @@ def main( # noqa: C901
202202
"n_hidden": n_hidden,
203203
}
204204

205-
if log:
206-
os.makedirs("runs/", exist_ok=True)
207-
208-
os.makedirs("results/", exist_ok=True)
209-
210205
if train:
211206
train_ensemble(
212207
model_class=CrystalGraphConvNet,

examples/colab/Roost.ipynb

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -283,8 +283,6 @@
283283
},
284284
"outputs": [],
285285
"source": [
286-
"import os \n",
287-
"import numpy as np\n",
288286
"import torch\n",
289287
"from sklearn.model_selection import train_test_split as split\n",
290288
"\n",
@@ -3874,9 +3872,6 @@
38743872
" \"out_hidden\": [64, 64],\n",
38753873
"}\n",
38763874
"\n",
3877-
"os.makedirs(f\"models/{model_name}\", exist_ok=True)\n",
3878-
"os.makedirs(f\"results/{model_name}\", exist_ok=True)\n",
3879-
"\n",
38803875
"train_ensemble(\n",
38813876
" model_class=Roost,\n",
38823877
" model_name=model_name,\n",

examples/colab/Wren.ipynb

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -283,8 +283,6 @@
283283
},
284284
"outputs": [],
285285
"source": [
286-
"import os \n",
287-
"import numpy as np\n",
288286
"import torch\n",
289287
"from sklearn.model_selection import train_test_split as split\n",
290288
"\n",
@@ -3879,9 +3877,6 @@
38793877
" \"out_hidden\": [64, 64],\n",
38803878
"}\n",
38813879
"\n",
3882-
"os.makedirs(f\"models/{model_name}\", exist_ok=True)\n",
3883-
"os.makedirs(f\"results/{model_name}\", exist_ok=True)\n",
3884-
"\n",
38853880
"train_ensemble(\n",
38863881
" model_class=Wren,\n",
38873882
" model_name=model_name,\n",

examples/roost-example.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,6 @@ def main( # noqa: C901
186186
"out_hidden": [256, 128, 64],
187187
}
188188

189-
if log:
190-
os.makedirs("runs/", exist_ok=True)
191-
192-
os.makedirs("results/", exist_ok=True)
193-
194189
# TODO dump all args/kwargs to a file for reproducibility.
195190

196191
if train:

examples/wren-example.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -199,11 +199,6 @@ def main( # noqa: C901
199199
"trunk_hidden": [128, 64],
200200
}
201201

202-
if log:
203-
os.makedirs("runs/", exist_ok=True)
204-
205-
os.makedirs("results/", exist_ok=True)
206-
207202
if train:
208203
train_ensemble(
209204
model_class=Wren,

0 commit comments

Comments
 (0)