Skip to content

Commit 38b5724

Browse files
committed
Fix inference checkpoint resume to skip completed episodes
1 parent 7a5c624 commit 38b5724

1 file changed

Lines changed: 12 additions & 4 deletions

File tree

internnav/habitat_extensions/vln/habitat_vln_evaluator.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,9 @@ def parse_actions(self, output):
243243

244244
def resume_from_output_path(self) -> None:
245245
sucs, spls, oss, nes, ndtw = [], [], [], [], []
246+
completed_episodes = set()
246247
if self.rank != 0:
247-
return sucs, spls, oss, nes, ndtw
248+
return sucs, spls, oss, nes, ndtw, completed_episodes
248249

249250
# resume from previous results
250251
if os.path.exists(os.path.join(self.output_path, 'progress.json')):
@@ -257,13 +258,14 @@ def resume_from_output_path(self) -> None:
257258
nes.append(res['ne'])
258259
if 'ndtw' in res:
259260
ndtw.append(res['ndtw'])
260-
return sucs, spls, oss, nes, ndtw
261+
completed_episodes.add((res['scene_id'], res['episode_id']))
262+
return sucs, spls, oss, nes, ndtw, completed_episodes
261263

262264
def _run_eval_dual_system(self) -> tuple: # noqa: C901
263265
self.model.eval()
264266

265267
# resume from previous results
266-
sucs, spls, oss, nes, ndtw = self.resume_from_output_path()
268+
sucs, spls, oss, nes, ndtw, completed_episodes = self.resume_from_output_path()
267269

268270
# Episode loop is now driven by env.reset() + env.is_running
269271
process_bar = tqdm.tqdm(total=len(self.env.episodes), desc=f"Eval Epoch {self.epoch} Rank {self.rank}")
@@ -281,6 +283,12 @@ def _run_eval_dual_system(self) -> tuple: # noqa: C901
281283
scene_id = episode.scene_id.split('/')[-2]
282284
episode_id = int(episode.episode_id)
283285
episode_instruction = episode.instruction.instruction_text
286+
287+
# skip already completed episodes
288+
if (scene_id, episode_id) in completed_episodes:
289+
process_bar.update(1)
290+
continue
291+
284292
print("episode start", episode_instruction)
285293

286294
# save first frame per rank to validate sim quality
@@ -632,7 +640,7 @@ def _run_eval_system2(self) -> tuple:
632640
self.model.eval()
633641

634642
# resume from previous results
635-
sucs, spls, oss, nes, ndtw = self.resume_from_output_path()
643+
sucs, spls, oss, nes, ndtw, _ = self.resume_from_output_path()
636644

637645
# Episode loop is now driven by env.reset() + env.is_running
638646
process_bar = tqdm.tqdm(total=len(self.env.episodes), desc=f"Eval Epoch {self.epoch} Rank {self.rank}")

0 commit comments

Comments
 (0)