Skip to content

Commit b543cc8

Browse files
committed
swap dim
1 parent ad23558 commit b543cc8

2 files changed

Lines changed: 9 additions & 4 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -956,10 +956,15 @@ def log_loss_valid(_task_key="Default"):
956956
self.wrapper.train()
957957
self.t0 = time.time()
958958
self.total_train_time = 0.0
959-
for step_id in range(self.start_step, self.num_steps):
960-
step(step_id)
961-
if JIT:
962-
break
959+
try:
960+
torch.cuda.memory._record_memory_history()
961+
for step_id in range(self.start_step, self.num_steps):
962+
step(step_id)
963+
if JIT:
964+
break
965+
finally:
966+
torch.cuda.memory._dump_snapshot("mem.pickle")
967+
logging.warning("Memory snapshot dumped to mem.pickle")
963968

964969
if self.change_bias_after_training and (self.rank == 0 or dist.get_rank() == 0):
965970
if not self.multi_task:

mem.pickle

22.3 MB
Binary file not shown.

0 commit comments

Comments
 (0)