@@ -60,20 +60,16 @@ def on_save(self, args, state, control, **kwargs):
6060
6161
6262def _make_dir (config ):
63- config .tensorboard_dir = config .tensorboard_dir % config .name
6463 config .tensorboard_dir = config .tensorboard_dir % config .name
6564 config .checkpoint_folder = config .checkpoint_folder % config .name
6665 config .log_dir = config .log_dir % config .name
6766 config .output_dir = config .output_dir % config .name
6867 if not os .path .exists (config .tensorboard_dir ):
6968 os .makedirs (config .tensorboard_dir , exist_ok = True )
70- os .makedirs (config .tensorboard_dir , exist_ok = True )
7169 if not os .path .exists (config .checkpoint_folder ):
7270 os .makedirs (config .checkpoint_folder , exist_ok = True )
73- os .makedirs (config .checkpoint_folder , exist_ok = True )
7471 if not os .path .exists (config .log_dir ):
7572 os .makedirs (config .log_dir , exist_ok = True )
76- os .makedirs (config .log_dir , exist_ok = True )
7773
7874
7975def main (config , model_class , model_config_class ):
@@ -98,14 +94,12 @@ def main(config, model_class, model_config_class):
9894 world_size = int (os .getenv ('WORLD_SIZE' , '1' ))
9995 rank = int (os .getenv ('RANK' , '0' ))
10096
101-
10297 # Set CUDA device for each process
10398 device_id = local_rank
10499 torch .cuda .set_device (device_id )
105100 device = torch .device (f'cuda:{ device_id } ' )
106101 print (f"World size: { world_size } , Local rank: { local_rank } , Global rank: { rank } " )
107102
108-
109103 # Initialize distributed training environment
110104 if world_size > 1 :
111105 try :
@@ -116,7 +110,6 @@ def main(config, model_class, model_config_class):
116110 print (f"Distributed initialization FAILED: { str (e )} " )
117111 world_size = 1
118112
119- print ("=" * 50 )
120113 print ("=" * 50 )
121114 print ("After distributed init:" )
122115 print (f"LOCAL_RANK: { local_rank } " )
@@ -146,13 +139,10 @@ def main(config, model_class, model_config_class):
146139 print (f"Buffer { name } is on wrong device { buffer .device } , should be moved to { device } " )
147140 buffer .data = buffer .data .to (device )
148141
149-
150142 # If distributed training, wrap the model with DDP
151143 if world_size > 1 :
152144 model = torch .nn .parallel .DistributedDataParallel (
153- model , device_ids = [local_rank ],
154- output_device = local_rank ,
155- find_unused_parameters = True
145+ model , device_ids = [local_rank ], output_device = local_rank , find_unused_parameters = True
156146 )
157147 # ------------ load logger ------------
158148 train_logger_filename = os .path .join (config .log_dir , 'train.log' )
@@ -162,15 +152,10 @@ def main(config, model_class, model_config_class):
162152 level = logging .INFO ,
163153 format_str = '%(asctime)-15s %(message)s' ,
164154 filename = train_logger_filename ,
165- name = 'train' ,
166- level = logging .INFO ,
167- format_str = '%(asctime)-15s %(message)s' ,
168- filename = train_logger_filename ,
169155 )
170156 else :
171157 # Other processes use console logging
172158 train_logger = MyLogger (name = 'train' , level = logging .INFO , format_str = '%(asctime)-15s %(message)s' )
173- train_logger = MyLogger (name = 'train' , level = logging .INFO , format_str = '%(asctime)-15s %(message)s' )
174159 transformers_logger = logging .getLogger ("transformers" )
175160 if transformers_logger .hasHandlers ():
176161 transformers_logger .handlers = []
@@ -180,18 +165,6 @@ def main(config, model_class, model_config_class):
180165
181166 # ------------ load dataset ------------
182167 if config .model_name == "navdp" :
183- train_dataset_data = NavDP_Base_Datset (
184- config .il .root_dir ,
185- config .il .dataset_navdp ,
186- config .il .memory_size ,
187- config .il .predict_size ,
188- config .il .batch_size ,
189- config .il .image_size ,
190- config .il .scene_scale ,
191- preload = config .il .preload ,
192- random_digit = config .il .random_digit ,
193- prior_sample = config .il .prior_sample ,
194- )
195168 train_dataset_data = NavDP_Base_Datset (
196169 config .il .root_dir ,
197170 config .il .dataset_navdp ,
@@ -239,7 +212,6 @@ def main(config, model_class, model_config_class):
239212 config .il .lerobot_features_dir ,
240213 dataset_data = train_dataset_data ,
241214 batch_size = config .il .batch_size ,
242- batch_size = config .il .batch_size ,
243215 )
244216 collate_fn = rdp_collate_fn (global_batch_size = global_batch_size )
245217 elif config .model_name == 'navdp' :
@@ -255,7 +227,6 @@ def main(config, model_class, model_config_class):
255227 deepspeed = '' ,
256228 gradient_checkpointing = False ,
257229 bf16 = False , # fp16=False,
258- bf16 = False , # fp16=False,
259230 tf32 = False ,
260231 per_device_train_batch_size = config .il .batch_size ,
261232 gradient_accumulation_steps = 1 ,
@@ -267,7 +238,6 @@ def main(config, model_class, model_config_class):
267238 logging_steps = 10.0 ,
268239 num_train_epochs = config .il .epochs ,
269240 save_strategy = 'epoch' , # no
270- save_strategy = 'epoch' , # no
271241 save_steps = config .il .save_interval_epochs ,
272242 save_total_limit = 8 ,
273243 report_to = config .il .report_to ,
@@ -279,7 +249,6 @@ def main(config, model_class, model_config_class):
279249 dataloader_drop_last = True ,
280250 disable_tqdm = True ,
281251 log_level = "info" ,
282- log_level = "info" ,
283252 )
284253
285254 # Create the trainer
@@ -299,17 +268,14 @@ def main(config, model_class, model_config_class):
299268 except Exception as e :
300269 import traceback
301270
302-
303271 print (f"Unhandled exception: { str (e )} " )
304272 print ("Stack trace:" )
305273 traceback .print_exc ()
306274
307-
308275 # If distributed environment, ensure all processes exit
309276 if dist .is_initialized ():
310277 dist .destroy_process_group ()
311278
312-
313279 raise
314280
315281
0 commit comments