Skip to content

Commit 8233127

Browse files
committed
rework float32_matmul_precision hack
1 parent a51b233 commit 8233127

4 files changed

Lines changed: 68 additions & 52 deletions

File tree

train_kwcoco_demo.sh

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,35 +79,36 @@ validation: $VALI_FPATH
7979
8080
$CLASS_YAML
8181
"
82-
8382
echo "$CONFIG_YAML" > "$DATASET_CONFIG_FPATH"
8483

85-
84+
TRAIN_DPATH="$BUNDLE_DPATH/kwcoco-demo-train-dir"
8685
# This might only work in development mode, otherwise we will get site packages
8786
# That still might be fine, but we do want to fix this to run anywhere.
8887
cd "$REPO_DPATH"
8988
LOG_BATCH_VIZ_TO_DISK=1 python -m yolo.lazy \
9089
task=train \
9190
dataset=kwcoco-demo \
91+
use_tensorboard=True \
9292
use_wandb=False \
93-
out_path="$BUNDLE_DPATH"/training \
93+
out_path="$TRAIN_DPATH" \
9494
name=kwcoco-demo \
9595
cpu_num=0 \
9696
device=0 \
9797
accelerator=auto \
9898
task.data.batch_size=2 \
9999
"image_size=[640, 640]" \
100-
task.optimizer.args.lr=0.0003
100+
task.optimizer.args.lr=0.03
101101

102102

103103
### show how to run inference
104104

105105
BUNDLE_DPATH=$HOME/demo-yolo-kwcoco-train
106+
TRAIN_DPATH="$BUNDLE_DPATH/kwcoco-demo-train-dir"
106107
TEST_FPATH=$BUNDLE_DPATH/vidshapes_rgb_test/data.kwcoco.json
107108
# Grab a checkpoint
108109
CKPT_FPATH=$(python -c "if 1:
109110
import pathlib
110-
ckpt_dpath = pathlib.Path('$BUNDLE_DPATH') / 'training/train/kwcoco-demo/checkpoints'
111+
ckpt_dpath = pathlib.Path('$TRAIN_DPATH') / 'train/kwcoco-demo/checkpoints'
111112
checkpoints = sorted(ckpt_dpath.glob('*'))
112113
print(checkpoints[-1])
113114
")
@@ -133,9 +134,11 @@ python yolo/lazy.py \
133134
### Show how to run validation
134135

135136
# Grab a checkpoint
137+
BUNDLE_DPATH=$HOME/demo-yolo-kwcoco-train
138+
TRAIN_DPATH="$BUNDLE_DPATH/kwcoco-demo-train-dir"
136139
CKPT_FPATH=$(python -c "if 1:
137140
import pathlib
138-
ckpt_dpath = pathlib.Path('$BUNDLE_DPATH') / 'training/train/kwcoco-demo/checkpoints'
141+
ckpt_dpath = pathlib.Path('$TRAIN_DPATH') / 'train/kwcoco-demo/checkpoints'
139142
checkpoints = sorted(ckpt_dpath.glob('*'))
140143
print(checkpoints[-1])
141144
")
@@ -146,7 +149,7 @@ LOG_BATCH_VIZ_TO_DISK=1 python -m yolo.lazy \
146149
task=validation \
147150
dataset=kwcoco-demo \
148151
use_wandb=False \
149-
out_path="$BUNDLE_DPATH"/training \
152+
out_path="$TRAIN_DPATH" \
150153
name=kwcoco-demo \
151154
cpu_num=0 \
152155
device=0 \

yolo/utils/callbacks.py

Lines changed: 0 additions & 41 deletions
This file was deleted.

yolo/utils/logging_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -339,21 +339,20 @@ def custom_wandb_log(string="", level=int, newline=True, repeat=True, prefix=Tru
339339
callbacks.append(YOLORichModelSummary())
340340

341341
if 1:
342-
from yolo.utils.callbacks import TorchGlobals
343-
callbacks.append(TorchGlobals(float32_matmul_precision='auto'))
342+
import lightning
344343
checkpoint_init_args = {
345344
'monitor': 'train_loss',
346345
'mode': 'min',
347346
'save_top_k': 5,
348347
'filename': '{epoch:04d}-{step:06d}-trainloss{train_loss:.3f}.ckpt',
349348
'save_last': True,
350349
}
351-
import lightning
352350
checkpointer = lightning.pytorch.callbacks.ModelCheckpoint(**checkpoint_init_args)
353351
callbacks.append(checkpointer)
354352

355353
callbacks.append(ImageLogger())
356354

355+
print(f'cfg.use_tensorboard={cfg.use_tensorboard}')
357356
if cfg.use_tensorboard:
358357
loggers.append(TensorBoardLogger(log_graph="all", save_dir=save_path))
359358
if cfg.use_wandb:

yolo/utils/trainer.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,18 @@ class YoloTrainer(lightning.Trainer):
1010

1111
def __init__(self, *args, **kwargs):
1212
super().__init__(*args, **kwargs)
13+
self._hacked_torch_global_callback = TorchGlobals(float32_matmul_precision='auto')
1314

14-
def _run_stage(self, *args, **kwargs):
15+
def _run(self, *args, **kwargs):
1516
# All I want is to print this directly before training starts.
1617
# Is that so hard to do?
1718
self._on_before_run()
19+
super()._run(*args, **kwargs)
20+
21+
def _run_stage(self, *args, **kwargs):
22+
# All I want is to print this directly before training starts.
23+
# Is that so hard to do?
24+
self._on_before_run_stage()
1825
super()._run_stage(*args, **kwargs)
1926

2027
@property
@@ -32,6 +39,12 @@ def log_dpath(self):
3239
return ub.Path(self.logger.log_dir)
3340

3441
def _on_before_run(self):
42+
"""
43+
Our custom "callback"
44+
"""
45+
self._hacked_torch_global_callback.before_setup_environment(self)
46+
47+
def _on_before_run_stage(self):
3548
"""
3649
Our custom "callback"
3750
"""
@@ -43,3 +56,45 @@ def _on_before_run_rank0(self):
4356
import rich
4457
dpath = self.log_dpath
4558
rich.print(f"Trainer log dpath:\n\n[link={dpath}]{dpath}[/link]\n")
59+
60+
61+
class TorchGlobals(lightning.pytorch.callbacks.Callback):
62+
"""
63+
Callback to setup torch globals.
64+
65+
Note: this needs to be called before the accelerators are setup, and
66+
existing callbacks don't have mechanisms for that, so we hack it in here.
67+
68+
Args:
69+
float32_matmul_precision (str):
70+
can be 'medium', 'high', 'default', or 'auto'.
71+
The 'default' value does not change any setting.
72+
The 'auto' value defaults to 'medium' if the training devices have
73+
ampere cores.
74+
"""
75+
76+
def __init__(self, float32_matmul_precision='default'):
77+
self.float32_matmul_precision = float32_matmul_precision
78+
79+
def before_setup_environment(self, trainer):
80+
import torch
81+
print('Setup Torch Globals')
82+
float32_matmul_precision = self.float32_matmul_precision
83+
if float32_matmul_precision == 'default':
84+
float32_matmul_precision = None
85+
elif float32_matmul_precision == 'auto':
86+
# Detect if we have Ampere tensor cores
87+
# Ampere (V8) and later leverage tensor cores, where medium
88+
# float32_matmul_precision becomes useful
89+
if torch.cuda.is_available():
90+
device_versions = [torch.cuda.get_device_capability(device_id)[0]
91+
for device_id in trainer.device_ids]
92+
if all(v >= 8 for v in device_versions):
93+
float32_matmul_precision = 'medium'
94+
else:
95+
float32_matmul_precision = None
96+
else:
97+
float32_matmul_precision = None
98+
if float32_matmul_precision is not None:
99+
print(f'Update: float32_matmul_precision={float32_matmul_precision}')
100+
torch.set_float32_matmul_precision(float32_matmul_precision)

0 commit comments

Comments
 (0)