Skip to content

Commit 5f543e6

Browse files
committed
Better tensorboard plotter, training on demo works now
1 parent 56ebf24 commit 5f543e6

2 files changed

Lines changed: 8 additions & 6 deletions

File tree

train_kwcoco_demo.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,8 @@ TEST_FPATH=$BUNDLE_DPATH/vidshapes_rgb_test/data.kwcoco.json
108108
# Grab a checkpoint
109109
CKPT_FPATH=$(python -c "if 1:
110110
import pathlib
111-
ckpt_dpath = pathlib.Path('$TRAIN_DPATH') / 'train/kwcoco-demo/checkpoints'
112-
checkpoints = sorted(ckpt_dpath.glob('*'))
111+
root_dpath = pathlib.Path('$TRAIN_DPATH') / 'train/kwcoco-demo'
112+
checkpoints = sorted(root_dpath.glob('lightning_logs/*/checkpoints/*'))
113113
print(checkpoints[-1])
114114
")
115115
echo "CKPT_FPATH = $CKPT_FPATH"

yolo/utils/callbacks/tensorboard_plotter.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,8 @@ def _dump_measures(train_dpath, title='?name?', smoothing='auto', ignore_outlier
234234
else:
235235
smoothing_values = [smoothing]
236236

237-
plot_keys = [k for k in tb_data.keys() if '/' not in k]
237+
# plot_keys = [k for k in tb_data.keys() if '/' not in k]
238+
plot_keys = [k for k in tb_data.keys()]
238239
keys = set(tb_data.keys()).intersection(set(plot_keys))
239240
# no idea what hp metric is, but it doesn't seem important
240241
# keys = keys - {'hp_metric'}
@@ -243,7 +244,7 @@ def _dump_measures(train_dpath, title='?name?', smoothing='auto', ignore_outlier
243244
print('warning: no known keys to plot')
244245
print(f'available keys: {list(tb_data.keys())}')
245246

246-
USE_NEW_PLOT_PREF = 0
247+
USE_NEW_PLOT_PREF = 1
247248
if USE_NEW_PLOT_PREF:
248249
# TODO: finish this
249250
default_plot_preferences = kwutil.Yaml.loads(ub.codeblock(
@@ -419,7 +420,7 @@ def _dump_measures(train_dpath, title='?name?', smoothing='auto', ignore_outlier
419420
ax.set_title(title)
420421

421422
# png is smaller than jpg for this kind of plot
422-
fpath = out_dpath / (key + '.png')
423+
fpath = out_dpath / (key.replace('/', '-') + '.png')
423424
if verbose:
424425
print('Save plot: ' + str(fpath))
425426
ax.figure.savefig(fpath)
@@ -575,6 +576,7 @@ def main(cls, cmdline=1, **kwargs):
575576
if __name__ == '__main__':
576577
"""
577578
CommandLine:
578-
python -m callbacks.tensorboard_plotter .
579+
python -m yolo.utils.callbacks.tensorboard_plotter .
580+
python ~/code/YOLO-v9/yolo/utils/callbacks/tensorboard_plotter.py .
579581
"""
580582
TensorboardPlotterCLI.main()

0 commit comments

Comments
 (0)