@@ -10,11 +10,18 @@ class YoloTrainer(lightning.Trainer):
1010
1111 def __init__ (self , * args , ** kwargs ):
1212 super ().__init__ (* args , ** kwargs )
13+ self ._hacked_torch_global_callback = TorchGlobals (float32_matmul_precision = 'auto' )
1314
14- def _run_stage (self , * args , ** kwargs ):
15+ def _run (self , * args , ** kwargs ):
1516 # All I want is to print this directly before training starts.
1617 # Is that so hard to do?
1718 self ._on_before_run ()
19+ super ()._run (* args , ** kwargs )
20+
21+ def _run_stage (self , * args , ** kwargs ):
22+ # All I want is to print this directly before training starts.
23+ # Is that so hard to do?
24+ self ._on_before_run_stage ()
1825 super ()._run_stage (* args , ** kwargs )
1926
2027 @property
@@ -32,6 +39,12 @@ def log_dpath(self):
3239 return ub .Path (self .logger .log_dir )
3340
3441 def _on_before_run (self ):
42+ """
43+ Our custom "callback"
44+ """
45+ self ._hacked_torch_global_callback .before_setup_environment (self )
46+
47+ def _on_before_run_stage (self ):
3548 """
3649 Our custom "callback"
3750 """
@@ -43,3 +56,45 @@ def _on_before_run_rank0(self):
4356 import rich
4457 dpath = self .log_dpath
4558 rich .print (f"Trainer log dpath:\n \n [link={ dpath } ]{ dpath } [/link]\n " )
59+
60+
61+ class TorchGlobals (lightning .pytorch .callbacks .Callback ):
62+ """
63+ Callback to setup torch globals.
64+
65+ Note: this needs to be called before the accelerators are setup, and
66+ existing callbacks don't have mechanisms for that, so we hack it in here.
67+
68+ Args:
69+ float32_matmul_precision (str):
70+ can be 'medium', 'high', 'default', or 'auto'.
71+ The 'default' value does not change any setting.
72+ The 'auto' value defaults to 'medium' if the training devices have
73+ ampere cores.
74+ """
75+
76+ def __init__ (self , float32_matmul_precision = 'default' ):
77+ self .float32_matmul_precision = float32_matmul_precision
78+
79+ def before_setup_environment (self , trainer ):
80+ import torch
81+ print ('Setup Torch Globals' )
82+ float32_matmul_precision = self .float32_matmul_precision
83+ if float32_matmul_precision == 'default' :
84+ float32_matmul_precision = None
85+ elif float32_matmul_precision == 'auto' :
86+ # Detect if we have Ampere tensor cores
87+ # Ampere (V8) and later leverage tensor cores, where medium
88+ # float32_matmul_precision becomes useful
89+ if torch .cuda .is_available ():
90+ device_versions = [torch .cuda .get_device_capability (device_id )[0 ]
91+ for device_id in trainer .device_ids ]
92+ if all (v >= 8 for v in device_versions ):
93+ float32_matmul_precision = 'medium'
94+ else :
95+ float32_matmul_precision = None
96+ else :
97+ float32_matmul_precision = None
98+ if float32_matmul_precision is not None :
99+ print (f'Update: float32_matmul_precision={ float32_matmul_precision } ' )
100+ torch .set_float32_matmul_precision (float32_matmul_precision )
0 commit comments