-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun.py
More file actions
82 lines (70 loc) · 3.01 KB
/
run.py
File metadata and controls
82 lines (70 loc) · 3.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from absl import app, flags
from egogym.evaluate import evaluate_policy
from egogym.utils import load_policy_from_file
FLAGS = flags.FLAGS
flags.DEFINE_string('policy', 'baselines/cap_policy.py', 'Path to policy file')
flags.DEFINE_string('task', 'pick', 'Task name (open, pick, close)')
flags.DEFINE_string('objects_set', None, 'Objects set (all_pick, lite_pick, diverse_pick, all_open, cabinet, drawer)')
flags.DEFINE_string('robot', 'cap', 'Robot embodiment')
flags.DEFINE_string('action_space', 'delta', 'Action space type')
flags.DEFINE_integer('num_objs', 1, 'Number of objects')
flags.DEFINE_integer('num_episodes', 1000, 'Number of evaluation episodes')
flags.DEFINE_integer('max_steps', 80, 'Maximum steps per episode')
flags.DEFINE_float('reward_threshold', 0.05, 'Reward threshold')
flags.DEFINE_integer('num_envs', 1, 'Number of parallel environments')
flags.DEFINE_integer('seed', 42, 'Random seed')
flags.DEFINE_boolean('render', False, 'Enable rendering')
flags.DEFINE_boolean('record', False, 'Enable recording')
flags.DEFINE_integer('render_freq', 0, 'Render/record frequency: 0=per env step, N=every N mujoco steps')
flags.DEFINE_string('render_size', '299,224', 'Render size as width,height')
flags.DEFINE_string('vlm', None, 'VLM wrapper to use (moondream, gemini, molmo, chatgpt)')
flags.DEFINE_string('logs_dir', "logs", 'Directory to save logs and recordings')
def parse_policy_config(argv):
config_overrides = {}
for arg in argv:
if '=' in arg and not arg.startswith('-'):
key, value = arg.split('=', 1)
if '.' in key:
prefix, actual_key = key.split('.', 1)
if prefix == 'policy':
try:
value = eval(value)
except:
pass
config_overrides[actual_key] = value
return config_overrides
def main(argv):
policy_config = parse_policy_config(argv)
policy = load_policy_from_file(FLAGS.policy, policy_config)
if not FLAGS.objects_set:
task_defaults = {
'open': 'all_open',
'pick': 'all_pick',
'close': "all_close",
}
objects_set = task_defaults.get(FLAGS.task, "all_pick")
else:
objects_set = FLAGS.objects_set
render_size = tuple(map(int, FLAGS.render_size.split(',')))
success_rate = evaluate_policy(
task_name=FLAGS.task,
policy=policy,
robot=FLAGS.robot,
action_space=FLAGS.action_space,
num_objs=FLAGS.num_objs,
num_episodes=FLAGS.num_episodes,
max_steps=FLAGS.max_steps,
reward_threshold=FLAGS.reward_threshold,
render=FLAGS.render,
record=FLAGS.record,
render_freq=FLAGS.render_freq,
render_size=render_size,
num_envs=FLAGS.num_envs,
seed=FLAGS.seed,
objects_set=objects_set,
use_unprivileged_vlm=FLAGS.vlm,
logs_dir=FLAGS.logs_dir
)
print(f"Success Rate: {success_rate:.2f}%")
if __name__ == '__main__':
app.run(main)