Skip to content

Commit ed5a141

Browse files
committed
fix empty episodes issue
1 parent 015efce commit ed5a141

6 files changed

Lines changed: 213 additions & 63 deletions

File tree

internnav/evaluator/vln_multi_evaluator.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
import sys
12
from enum import Enum
23
from pathlib import Path
34
from time import time
5+
46
import numpy as np
7+
58
from internnav.configs.evaluator import EvalCfg
69
from internnav.evaluator.base import Evaluator
710
from internnav.evaluator.utils.common import set_seed_model
@@ -28,22 +31,23 @@ def transform_action_batch(actions, flash=False):
2831
if 'ideal_flag' in action.keys():
2932
ideal_flag = action['ideal_flag']
3033
if flash:
31-
assert ideal_flag == True
34+
assert ideal_flag is True
3235
else:
3336
ideal_flag = False
3437
if not ideal_flag:
3538
transformed_actions.append({'h1': {'vln_dp_move_by_speed': action['action'][0]}})
3639
continue
3740
a = action['action']
38-
if a == 0 or a == [0] or a==[[0]]:
41+
if a == 0 or a == [0] or a == [[0]]:
3942
transformed_actions.append({'h1': {'stop': []}})
4043
elif a == -1 or a == [-1] or a == [[-1]]:
4144
transformed_actions.append({'h1': {'stand_still': []}})
4245
else:
4346
move = f"move_by_{'discrete' if not flash else 'flash'}"
44-
transformed_actions.append({'h1': {move: a}}) # discrete e.g. [3]
47+
transformed_actions.append({'h1': {move: a}}) # discrete e.g. [3]
4548
return transformed_actions
4649

50+
4751
@Evaluator.register('vln_multi')
4852
class VlnMultiEvaluator(Evaluator):
4953
def __init__(self, config: EvalCfg):
@@ -61,6 +65,9 @@ def __init__(self, config: EvalCfg):
6165
)
6266
# generate episode
6367
episodes = generate_episode(self.dataloader, config)
68+
if len(episodes) == 0:
69+
log.info("No more episodes to evaluate")
70+
sys.exit(0)
6471
config.task.task_settings.update({'episodes': episodes})
6572
self.env_num = config.task.task_settings['env_num']
6673
self.proc_num = (
@@ -88,7 +95,6 @@ def __init__(self, config: EvalCfg):
8895
self.data_collector = DataCollector(self.dataloader.lmdb_path)
8996
self.robot_flash = config.task.robot_flash
9097

91-
9298
@property
9399
def ignore_obs_attr(self):
94100
return [
@@ -223,15 +229,11 @@ def terminate_ops(self, obs_ls, reset_infos, terminated_ls):
223229
log.info(f'env{reset_env_ids}: states switch to WARM UP.')
224230
# modify original reset_info
225231
reset_infos = np.array(reset_infos)
226-
reset_infos[reset_env_ids] = (
227-
new_reset_infos if len(new_reset_infos) > 0 else None
228-
)
232+
reset_infos[reset_env_ids] = new_reset_infos if len(new_reset_infos) > 0 else None
229233
self.runner_status[
230234
np.vectorize(lambda x: x)(reset_infos) == None # noqa: E711
231235
] = runner_status_code.TERMINATED
232-
log.info(
233-
f'env{np.vectorize(lambda x: x)(reset_infos) == None}: states switch to TERMINATED.'
234-
)
236+
log.info(f'env{np.vectorize(lambda x: x)(reset_infos) == None}: states switch to TERMINATED.')
235237
reset_infos = reset_infos.tolist()
236238

237239
if np.logical_and.reduce(self.runner_status == runner_status_code.TERMINATED):

internnav/evaluator/vln_pe_evaluator.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import sys
12
from enum import Enum
23
from pathlib import Path
34
from time import time
5+
46
import numpy as np
7+
58
from internnav.configs.evaluator import EvalCfg
69
from internnav.evaluator.base import Evaluator
7-
from internnav.evaluator.utils.common import set_seed_model, obs_to_image
10+
from internnav.evaluator.utils.common import set_seed_model
811
from internnav.evaluator.utils.config import get_lmdb_path
912
from internnav.evaluator.utils.data_collector import DataCollector
1013
from internnav.evaluator.utils.dataset import ResultLogger, split_data
@@ -56,6 +59,9 @@ def __init__(self, config: EvalCfg):
5659

5760
# generate episode
5861
episodes = generate_episode(self.dataloader, config)
62+
if len(episodes) == 0:
63+
log.info("No more episodes to evaluate. Episodes are saved in data/sample_episodes/")
64+
sys.exit(0)
5965
config.task.task_settings.update({'episodes': episodes})
6066
self.env_num = config.task.task_settings['env_num']
6167
self.proc_num = (
@@ -211,7 +217,7 @@ def terminate_ops(self, obs_ls, reset_infos, terminated_ls):
211217

212218
# need this status to reset
213219
reset_env_ids = np.where(self.runner_status == runner_status_code.NOT_RESET)[0].tolist()
214-
220+
215221
if len(reset_env_ids) > 0:
216222
log.debug(f'env{reset_env_ids}: start new episode!')
217223
obs, new_reset_infos = self.env.reset(reset_env_ids)
@@ -225,9 +231,7 @@ def terminate_ops(self, obs_ls, reset_infos, terminated_ls):
225231
self.runner_status[
226232
np.vectorize(lambda x: x)(reset_infos) == None # noqa: E711
227233
] = runner_status_code.TERMINATED
228-
log.debug(
229-
f'env{np.vectorize(lambda x: x)(reset_infos) == None}: states switch to TERMINATED.'
230-
)
234+
log.debug(f'env{np.vectorize(lambda x: x)(reset_infos) == None}: states switch to TERMINATED.')
231235
reset_infos = reset_infos.tolist()
232236

233237
if np.logical_and.reduce(self.runner_status == runner_status_code.TERMINATED):
@@ -241,8 +245,7 @@ def terminate_ops(self, obs_ls, reset_infos, terminated_ls):
241245
)
242246
if self.vis_output:
243247
self.visualize_util.trace_start(
244-
trajectory_id=self.now_path_key(reset_info),
245-
reference_path=reset_info.data['reference_path']
248+
trajectory_id=self.now_path_key(reset_info), reference_path=reset_info.data['reference_path']
246249
)
247250
return False, reset_infos
248251

@@ -258,8 +261,7 @@ def eval(self):
258261
)
259262
if self.vis_output:
260263
self.visualize_util.trace_start(
261-
trajectory_id=self.now_path_key(info),
262-
reference_path=info.data['reference_path']
264+
trajectory_id=self.now_path_key(info), reference_path=info.data['reference_path']
263265
)
264266
log.info('start new episode!')
265267

@@ -281,18 +283,16 @@ def eval(self):
281283
env_term, reset_info = self.terminate_ops(obs, reset_info, terminated)
282284
if env_term:
283285
break
284-
286+
285287
# save step obs
286288
if self.vis_output:
287289
for ob, info, act in zip(obs, reset_info, action):
288-
if info is None or not 'rgb' in ob or ob['fail_reason']:
290+
if info is None or 'rgb' not in ob or ob['fail_reason']:
289291
continue
290292
self.visualize_util.save_observation(
291-
trajectory_id=self.now_path_key(info),
292-
obs=ob,
293-
action=act[self.robot_name]
293+
trajectory_id=self.now_path_key(info), obs=ob, action=act[self.robot_name]
294294
)
295-
295+
296296
self.env.close()
297297
progress_log_multi_util.report()
298298

tests/function_test/e2e_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,13 @@ def test_server():
5454
common_body(start_command)
5555

5656

57-
@pytest.mark.gpu
58-
def test_challenge():
59-
start_command = 'python ./tests/function_test/test_challenge.py'
57+
@pytest.mark.ray
58+
def test_evaluator():
59+
start_command = 'python ./tests/function_test/test_evaluator.py'
6060
common_body(start_command)
6161

6262

63-
@pytest.mark.ray
64-
def test_challenge_ray():
65-
start_command = 'python ./tests/function_test/test_challenge_ray.py'
63+
@pytest.mark.gpu
64+
def test_challenge():
65+
start_command = 'python ./tests/function_test/test_challenge.py'
6666
common_body(start_command)

tests/function_test/test_challenge.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
'''
66

77
import importlib.util
8-
import subprocess
98
import sys
109
import time
1110

1211
import numpy as np
12+
from test_server import start_server, stop_server
1313

1414
from internnav.configs.evaluator.default_config import get_config
1515
from internnav.evaluator import Evaluator
@@ -66,39 +66,32 @@ def load_eval_cfg(config_path, attr_name='eval_cfg'):
6666
evaluator.env.close()
6767

6868

69-
def start_server():
70-
server_cmd = [
71-
sys.executable,
72-
"internnav/agent/utils/server.py",
73-
"--config",
74-
"scripts/eval/configs/challenge_cfg.py",
75-
]
69+
def start_evaluator():
70+
from multiprocessing import get_context
7671

77-
proc = subprocess.Popen(
78-
server_cmd,
79-
stdout=None,
80-
stderr=None,
81-
)
82-
return proc
72+
ctx = get_context("spawn") # Use 'spawn' to avoid issues on some platforms
73+
p = ctx.Process(target=main)
74+
p.start()
75+
p.join()
76+
assert p.exitcode == 0
77+
print("Evaluator process completed successfully.")
8378

8479

8580
if __name__ == '__main__':
8681
try:
8782
proc = start_server()
8883
time.sleep(3)
89-
main()
84+
start_evaluator()
85+
9086
except Exception as e:
9187
print(f'exception is {e}')
9288
import traceback
9389

9490
traceback.print_exc()
9591
sys.exit(1)
92+
93+
except SystemExit as e:
94+
print(f"Caught SystemExit from env.close(): code={e.code}", flush=True)
95+
9696
finally:
97-
if proc and proc.poll() is None:
98-
print("Shutting down server...")
99-
proc.terminate()
100-
try:
101-
proc.wait(timeout=10)
102-
except subprocess.TimeoutExpired:
103-
print("Force killing server...")
104-
proc.kill()
97+
stop_server(proc)

0 commit comments

Comments
 (0)