Skip to content

Commit 2b103ed

Browse files
committed
add prof
1 parent d87256f commit 2b103ed

1 file changed

Lines changed: 28 additions & 5 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,10 @@
136136
ZeroRedundancyOptimizer,
137137
)
138138
from torch.nn.parallel import DistributedDataParallel as DDP
139+
from torch.profiler import (
140+
ProfilerActivity,
141+
profile,
142+
)
139143
from torch.utils.data import (
140144
DataLoader,
141145
)
@@ -1346,14 +1350,27 @@ def run(self) -> None:
13461350

13471351
writer = SummaryWriter(log_dir=self.tensorboard_log_dir)
13481352
if self.enable_profiler or self.profiling:
1349-
prof = torch.profiler.profile(
1350-
schedule=torch.profiler.schedule(wait=1, warmup=15, active=3, repeat=1),
1353+
# prof = torch.profiler.profile(
1354+
# schedule=torch.profiler.schedule(wait=1, warmup=15, active=3, repeat=1),
1355+
# on_trace_ready=torch.profiler.tensorboard_trace_handler(
1356+
# self.tensorboard_log_dir
1357+
# )
1358+
# if self.enable_profiler
1359+
# else None,
1360+
# record_shapes=True,
1361+
# with_stack=True,
1362+
# )
1363+
# prof.start()
1364+
prof = profile(
1365+
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
1366+
schedule=torch.profiler.schedule(wait=2, warmup=12, active=6, repeat=1),
13511367
on_trace_ready=torch.profiler.tensorboard_trace_handler(
13521368
self.tensorboard_log_dir
13531369
)
13541370
if self.enable_profiler
13551371
else None,
13561372
record_shapes=True,
1373+
profile_memory=True,
13571374
with_stack=True,
13581375
)
13591376
prof.start()
@@ -1366,9 +1383,9 @@ def step(_step_id: int, task_key: str = "Default") -> None:
13661383
p=self.model_prob,
13671384
)
13681385
task_key = self.model_keys[model_index]
1369-
# PyTorch Profiler
1370-
if self.enable_profiler or self.profiling:
1371-
prof.step()
1386+
# # PyTorch Profiler
1387+
# if self.enable_profiler or self.profiling:
1388+
# prof.step()
13721389
cur_lr = self.lr_schedule.value(_step_id)
13731390
pref_lr = cur_lr
13741391
self.optimizer.zero_grad(set_to_none=True)
@@ -1829,7 +1846,13 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
18291846
self.last_display_step = 0
18301847

18311848
for step_id in range(self.start_step, self.num_steps):
1849+
if self.enable_profiler or self.profiling:
1850+
if step_id >= 20:
1851+
break
18321852
step(step_id)
1853+
# PyTorch Profiler
1854+
if self.enable_profiler or self.profiling:
1855+
prof.step()
18331856
if JIT:
18341857
break
18351858

0 commit comments

Comments
 (0)