@@ -243,8 +243,9 @@ def parse_actions(self, output):
243243
244244 def resume_from_output_path (self ) -> None :
245245 sucs , spls , oss , nes , ndtw = [], [], [], [], []
246+ completed_episodes = set ()
246247 if self .rank != 0 :
247- return sucs , spls , oss , nes , ndtw
248+ return sucs , spls , oss , nes , ndtw , completed_episodes
248249
249250 # resume from previous results
250251 if os .path .exists (os .path .join (self .output_path , 'progress.json' )):
@@ -257,13 +258,14 @@ def resume_from_output_path(self) -> None:
257258 nes .append (res ['ne' ])
258259 if 'ndtw' in res :
259260 ndtw .append (res ['ndtw' ])
260- return sucs , spls , oss , nes , ndtw
261+ completed_episodes .add ((res ['scene_id' ], res ['episode_id' ]))
262+ return sucs , spls , oss , nes , ndtw , completed_episodes
261263
262264 def _run_eval_dual_system (self ) -> tuple : # noqa: C901
263265 self .model .eval ()
264266
265267 # resume from previous results
266- sucs , spls , oss , nes , ndtw = self .resume_from_output_path ()
268+ sucs , spls , oss , nes , ndtw , completed_episodes = self .resume_from_output_path ()
267269
268270 # Episode loop is now driven by env.reset() + env.is_running
269271 process_bar = tqdm .tqdm (total = len (self .env .episodes ), desc = f"Eval Epoch { self .epoch } Rank { self .rank } " )
@@ -281,6 +283,12 @@ def _run_eval_dual_system(self) -> tuple: # noqa: C901
281283 scene_id = episode .scene_id .split ('/' )[- 2 ]
282284 episode_id = int (episode .episode_id )
283285 episode_instruction = episode .instruction .instruction_text
286+
287+ # skip already completed episodes
288+ if (scene_id , episode_id ) in completed_episodes :
289+ process_bar .update (1 )
290+ continue
291+
284292 print ("episode start" , episode_instruction )
285293
286294 # save first frame per rank to validate sim quality
@@ -632,7 +640,7 @@ def _run_eval_system2(self) -> tuple:
632640 self .model .eval ()
633641
634642 # resume from previous results
635- sucs , spls , oss , nes , ndtw = self .resume_from_output_path ()
643+ sucs , spls , oss , nes , ndtw , _ = self .resume_from_output_path ()
636644
637645 # Episode loop is now driven by env.reset() + env.is_running
638646 process_bar = tqdm .tqdm (total = len (self .env .episodes ), desc = f"Eval Epoch { self .epoch } Rank { self .rank } " )
0 commit comments