1+ import sys
12from enum import Enum
23from pathlib import Path
34from time import time
5+
46import numpy as np
7+
58from internnav .configs .evaluator import EvalCfg
69from internnav .evaluator .base import Evaluator
7- from internnav .evaluator .utils .common import set_seed_model , obs_to_image
10+ from internnav .evaluator .utils .common import set_seed_model
811from internnav .evaluator .utils .config import get_lmdb_path
912from internnav .evaluator .utils .data_collector import DataCollector
1013from internnav .evaluator .utils .dataset import ResultLogger , split_data
@@ -56,6 +59,9 @@ def __init__(self, config: EvalCfg):
5659
5760 # generate episode
5861 episodes = generate_episode (self .dataloader , config )
62+ if len (episodes ) == 0 :
63+ log .info ("No more episodes to evaluate. Episodes are saved in data/sample_episodes/" )
64+ sys .exit (0 )
5965 config .task .task_settings .update ({'episodes' : episodes })
6066 self .env_num = config .task .task_settings ['env_num' ]
6167 self .proc_num = (
@@ -211,7 +217,7 @@ def terminate_ops(self, obs_ls, reset_infos, terminated_ls):
211217
212218 # need this status to reset
213219 reset_env_ids = np .where (self .runner_status == runner_status_code .NOT_RESET )[0 ].tolist ()
214-
220+
215221 if len (reset_env_ids ) > 0 :
216222 log .debug (f'env{ reset_env_ids } : start new episode!' )
217223 obs , new_reset_infos = self .env .reset (reset_env_ids )
@@ -225,9 +231,7 @@ def terminate_ops(self, obs_ls, reset_infos, terminated_ls):
225231 self .runner_status [
226232 np .vectorize (lambda x : x )(reset_infos ) == None # noqa: E711
227233 ] = runner_status_code .TERMINATED
228- log .debug (
229- f'env{ np .vectorize (lambda x : x )(reset_infos ) == None } : states switch to TERMINATED.'
230- )
234+ log .debug (f'env{ np .vectorize (lambda x : x )(reset_infos ) == None } : states switch to TERMINATED.' )
231235 reset_infos = reset_infos .tolist ()
232236
233237 if np .logical_and .reduce (self .runner_status == runner_status_code .TERMINATED ):
@@ -241,8 +245,7 @@ def terminate_ops(self, obs_ls, reset_infos, terminated_ls):
241245 )
242246 if self .vis_output :
243247 self .visualize_util .trace_start (
244- trajectory_id = self .now_path_key (reset_info ),
245- reference_path = reset_info .data ['reference_path' ]
248+ trajectory_id = self .now_path_key (reset_info ), reference_path = reset_info .data ['reference_path' ]
246249 )
247250 return False , reset_infos
248251
@@ -258,8 +261,7 @@ def eval(self):
258261 )
259262 if self .vis_output :
260263 self .visualize_util .trace_start (
261- trajectory_id = self .now_path_key (info ),
262- reference_path = info .data ['reference_path' ]
264+ trajectory_id = self .now_path_key (info ), reference_path = info .data ['reference_path' ]
263265 )
264266 log .info ('start new episode!' )
265267
@@ -281,18 +283,16 @@ def eval(self):
281283 env_term , reset_info = self .terminate_ops (obs , reset_info , terminated )
282284 if env_term :
283285 break
284-
286+
285287 # save step obs
286288 if self .vis_output :
287289 for ob , info , act in zip (obs , reset_info , action ):
288- if info is None or not 'rgb' in ob or ob ['fail_reason' ]:
290+ if info is None or 'rgb' not in ob or ob ['fail_reason' ]:
289291 continue
290292 self .visualize_util .save_observation (
291- trajectory_id = self .now_path_key (info ),
292- obs = ob ,
293- action = act [self .robot_name ]
293+ trajectory_id = self .now_path_key (info ), obs = ob , action = act [self .robot_name ]
294294 )
295-
295+
296296 self .env .close ()
297297 progress_log_multi_util .report ()
298298
0 commit comments