136136 ZeroRedundancyOptimizer ,
137137)
138138from torch .nn .parallel import DistributedDataParallel as DDP
139+ from torch .profiler import (
140+ ProfilerActivity ,
141+ profile ,
142+ )
139143from torch .utils .data import (
140144 DataLoader ,
141145)
@@ -1346,14 +1350,27 @@ def run(self) -> None:
13461350
13471351 writer = SummaryWriter (log_dir = self .tensorboard_log_dir )
13481352 if self .enable_profiler or self .profiling :
1349- prof = torch .profiler .profile (
1350- schedule = torch .profiler .schedule (wait = 1 , warmup = 15 , active = 3 , repeat = 1 ),
1353+ # prof = torch.profiler.profile(
1354+ # schedule=torch.profiler.schedule(wait=1, warmup=15, active=3, repeat=1),
1355+ # on_trace_ready=torch.profiler.tensorboard_trace_handler(
1356+ # self.tensorboard_log_dir
1357+ # )
1358+ # if self.enable_profiler
1359+ # else None,
1360+ # record_shapes=True,
1361+ # with_stack=True,
1362+ # )
1363+ # prof.start()
1364+ prof = profile (
1365+ activities = [ProfilerActivity .CPU , ProfilerActivity .CUDA ],
1366+ schedule = torch .profiler .schedule (wait = 2 , warmup = 12 , active = 6 , repeat = 1 ),
13511367 on_trace_ready = torch .profiler .tensorboard_trace_handler (
13521368 self .tensorboard_log_dir
13531369 )
13541370 if self .enable_profiler
13551371 else None ,
13561372 record_shapes = True ,
1373+ profile_memory = True ,
13571374 with_stack = True ,
13581375 )
13591376 prof .start ()
@@ -1366,9 +1383,9 @@ def step(_step_id: int, task_key: str = "Default") -> None:
13661383 p = self .model_prob ,
13671384 )
13681385 task_key = self .model_keys [model_index ]
1369- # PyTorch Profiler
1370- if self .enable_profiler or self .profiling :
1371- prof .step ()
1386+ # # PyTorch Profiler
1387+ # if self.enable_profiler or self.profiling:
1388+ # prof.step()
13721389 cur_lr = self .lr_schedule .value (_step_id )
13731390 pref_lr = cur_lr
13741391 self .optimizer .zero_grad (set_to_none = True )
@@ -1829,7 +1846,13 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
18291846 self .last_display_step = 0
18301847
18311848 for step_id in range (self .start_step , self .num_steps ):
1849+ if self .enable_profiler or self .profiling :
1850+ if step_id >= 20 :
1851+ break
18321852 step (step_id )
1853+ # PyTorch Profiler
1854+ if self .enable_profiler or self .profiling :
1855+ prof .step ()
18331856 if JIT :
18341857 break
18351858
0 commit comments