Skip to content

Commit a6241a7

Browse files
ronuchittomsilver
andauthored
avg_suc_time only cares about solve time, not policy execution time (#591)
Co-authored-by: Tom Silver <tomssilver@gmail.com>
1 parent 1bef203 commit a6241a7

4 files changed

Lines changed: 87 additions & 63 deletions

File tree

src/datasets/demo_only.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def create_demo_data(env: BaseEnv, train_tasks: List[Task]) -> Dataset:
3434
# get_last_plan(). We do this because we want to run the full plan.
3535
plan = oracle_approach.get_last_plan()
3636
# Stop run_policy() when OptionPlanExhausted() is hit.
37-
traj = utils.run_policy(
37+
traj, _ = utils.run_policy(
3838
utils.option_plan_to_policy(plan),
3939
env,
4040
"train",

src/main.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def _generate_interaction_results(
209209
for request in requests:
210210
monitor = TeacherInteractionMonitorWithVideo(env.render, request,
211211
teacher)
212-
traj = utils.run_policy(
212+
traj, _ = utils.run_policy(
213213
request.act_policy,
214214
env,
215215
"train",
@@ -240,7 +240,7 @@ def _run_testing(env: BaseEnv, approach: BaseApproach) -> Metrics:
240240
total_num_execution_failures = 0
241241
video_prefix = utils.get_config_path_str()
242242
for test_task_idx, task in enumerate(test_tasks):
243-
start = time.time()
243+
solve_start = time.time()
244244
try:
245245
policy = approach.solve(task, timeout=CFG.timeout)
246246
except (ApproachTimeout, ApproachFailure) as e:
@@ -253,20 +253,23 @@ def _run_testing(env: BaseEnv, approach: BaseApproach) -> Metrics:
253253
outfile = f"{video_prefix}__task{test_task_idx+1}_failure.mp4"
254254
utils.save_video(outfile, video)
255255
continue
256+
solve_time = time.time() - solve_start
256257
num_found_policy += 1
257258
try:
258259
if CFG.make_test_videos:
259260
monitor = utils.VideoMonitor(env.render)
260261
else:
261262
monitor = None
262-
traj = utils.run_policy(policy,
263-
env,
264-
"test",
265-
test_task_idx,
266-
task.goal_holds,
267-
max_num_steps=CFG.horizon,
268-
monitor=monitor)
263+
traj, execution_metrics = utils.run_policy(
264+
policy,
265+
env,
266+
"test",
267+
test_task_idx,
268+
task.goal_holds,
269+
max_num_steps=CFG.horizon,
270+
monitor=monitor)
269271
solved = task.goal_holds(traj.states[-1])
272+
solve_time += execution_metrics["policy_call_time"]
270273
except utils.EnvironmentFailure as e:
271274
logging.info(f"Task {test_task_idx+1} / {len(test_tasks)}: "
272275
f"Environment failed with error: {e}")
@@ -280,7 +283,7 @@ def _run_testing(env: BaseEnv, approach: BaseApproach) -> Metrics:
280283
if solved:
281284
logging.info(f"Task {test_task_idx+1} / {len(test_tasks)}: SOLVED")
282285
num_solved += 1
283-
total_suc_time += (time.time() - start)
286+
total_suc_time += solve_time
284287
else:
285288
logging.info(f"Task {test_task_idx+1} / {len(test_tasks)}: Policy "
286289
f"failed to reach goal")

src/utils.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import logging
1111
import os
1212
import subprocess
13+
import time
1314
from collections import defaultdict
1415
from dataclasses import dataclass, field
1516
from typing import TYPE_CHECKING, Any, Callable, Collection, Dict, FrozenSet, \
@@ -31,7 +32,7 @@
3132
from predicators.src.structs import NSRT, Action, Array, DummyOption, \
3233
EntToEntSub, GroundAtom, GroundAtomTrajectory, \
3334
GroundNSRTOrSTRIPSOperator, Image, LiftedAtom, LiftedOrGroundAtom, \
34-
LowLevelTrajectory, NSRTOrSTRIPSOperator, Object, OptionSpec, \
35+
LowLevelTrajectory, Metrics, NSRTOrSTRIPSOperator, Object, OptionSpec, \
3536
ParameterizedOption, Predicate, Segment, State, STRIPSOperator, Task, \
3637
Type, VarToObjSub, Video, _GroundNSRT, _GroundSTRIPSOperator, _Option, \
3738
_TypedEntity
@@ -443,15 +444,16 @@ def observe(self, state: State, action: Optional[Action]) -> None:
443444
raise NotImplementedError("Override me!")
444445

445446

446-
def run_policy(policy: Callable[[State], Action],
447-
env: BaseEnv,
448-
train_or_test: str,
449-
task_idx: int,
450-
termination_function: Callable[[State], bool],
451-
max_num_steps: int,
452-
exceptions_to_break_on: Optional[Set[
453-
TypingType[Exception]]] = None,
454-
monitor: Optional[Monitor] = None) -> LowLevelTrajectory:
447+
def run_policy(
448+
policy: Callable[[State], Action],
449+
env: BaseEnv,
450+
train_or_test: str,
451+
task_idx: int,
452+
termination_function: Callable[[State], bool],
453+
max_num_steps: int,
454+
exceptions_to_break_on: Optional[Set[TypingType[Exception]]] = None,
455+
monitor: Optional[Monitor] = None
456+
) -> Tuple[LowLevelTrajectory, Metrics]:
455457
"""Execute a policy starting from the initial state of a train or test task
456458
in the environment. The task's goal is not used.
457459
@@ -465,10 +467,14 @@ def run_policy(policy: Callable[[State], Action],
465467
state = env.reset(train_or_test, task_idx)
466468
states = [state]
467469
actions: List[Action] = []
470+
metrics: Metrics = defaultdict(float)
471+
metrics["policy_call_time"] = 0.0
468472
if not termination_function(state):
469473
for _ in range(max_num_steps):
470474
try:
475+
start_time = time.time()
471476
act = policy(state)
477+
metrics["policy_call_time"] += time.time() - start_time
472478
except Exception as e:
473479
if exceptions_to_break_on is not None and \
474480
type(e) in exceptions_to_break_on:
@@ -484,7 +490,7 @@ def run_policy(policy: Callable[[State], Action],
484490
if monitor is not None:
485491
monitor.observe(state, None)
486492
traj = LowLevelTrajectory(states, actions)
487-
return traj
493+
return traj, metrics
488494

489495

490496
def run_policy_with_simulator(
@@ -500,10 +506,10 @@ def run_policy_with_simulator(
500506
*** This function should not be used with any core code, because we want
501507
to avoid the assumption of a simulator when possible. ***
502508
503-
This is similar to run_policy, with two major differences:
509+
This is similar to run_policy, with three major differences:
504510
(1) The initial state `init_state` can be any state, not just the initial
505511
state of a train or test task. (2) A simulator (function that takes state
506-
as input) is assumed.
512+
as input) is assumed. (3) Metrics are not returned.
507513
508514
Note that the environment internal state is NOT updated.
509515

tests/test_utils.py

Lines changed: 54 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -240,21 +240,23 @@ def test_run_policy():
240240
env = CoverEnv()
241241
policy = lambda _: Action(env.action_space.sample())
242242
task = env.get_task("test", 0)
243-
traj = utils.run_policy(policy,
244-
env,
245-
"test",
246-
0,
247-
task.goal_holds,
248-
max_num_steps=5)
243+
traj, metrics = utils.run_policy(policy,
244+
env,
245+
"test",
246+
0,
247+
task.goal_holds,
248+
max_num_steps=5)
249249
assert not task.goal_holds(traj.states[-1])
250250
assert len(traj.states) == 6
251251
assert len(traj.actions) == 5
252-
traj2 = utils.run_policy(policy,
253-
env,
254-
"test",
255-
0,
256-
lambda s: True,
257-
max_num_steps=5)
252+
assert "policy_call_time" in metrics
253+
assert metrics["policy_call_time"] > 0.0
254+
traj2, _ = utils.run_policy(policy,
255+
env,
256+
"test",
257+
0,
258+
lambda s: True,
259+
max_num_steps=5)
258260
assert not task.goal_holds(traj2.states[-1])
259261
assert len(traj2.states) == 1
260262
assert len(traj2.actions) == 0
@@ -266,12 +268,12 @@ def _onestep_terminal(_):
266268
executed = True
267269
return terminate
268270

269-
traj3 = utils.run_policy(policy,
270-
env,
271-
"test",
272-
0,
273-
_onestep_terminal,
274-
max_num_steps=5)
271+
traj3, _ = utils.run_policy(policy,
272+
env,
273+
"test",
274+
0,
275+
_onestep_terminal,
276+
max_num_steps=5)
275277
assert not task.goal_holds(traj3.states[-1])
276278
assert len(traj3.states) == 2
277279
assert len(traj3.actions) == 1
@@ -288,15 +290,28 @@ def _policy(_):
288290
task.goal_holds,
289291
max_num_steps=5)
290292
assert "mock error" in str(e)
291-
traj4 = utils.run_policy(_policy,
292-
env,
293-
"test",
294-
0,
295-
task.goal_holds,
296-
max_num_steps=5,
297-
exceptions_to_break_on={ValueError})
293+
traj4, _ = utils.run_policy(_policy,
294+
env,
295+
"test",
296+
0,
297+
task.goal_holds,
298+
max_num_steps=5,
299+
exceptions_to_break_on={ValueError})
298300
assert len(traj4.states) == 1
299301

302+
# Test policy call time.
303+
def _policy(_):
304+
time.sleep(0.1)
305+
return Action(env.action_space.sample())
306+
307+
traj, metrics = utils.run_policy(_policy,
308+
env,
309+
"test",
310+
0,
311+
task.goal_holds,
312+
max_num_steps=3)
313+
assert metrics["policy_call_time"] >= 3 * 0.1
314+
300315

301316
def test_run_policy_with_simulator():
302317
"""Tests for run_policy_with_simulator()."""
@@ -1754,13 +1769,13 @@ def test_VideoMonitor():
17541769
monitor = utils.VideoMonitor(env.render)
17551770
policy = lambda _: Action(env.action_space.sample())
17561771
task = env.get_task("test", 0)
1757-
traj = utils.run_policy(policy,
1758-
env,
1759-
"test",
1760-
0,
1761-
task.goal_holds,
1762-
max_num_steps=2,
1763-
monitor=monitor)
1772+
traj, _ = utils.run_policy(policy,
1773+
env,
1774+
"test",
1775+
0,
1776+
task.goal_holds,
1777+
max_num_steps=2,
1778+
monitor=monitor)
17641779
assert not task.goal_holds(traj.states[-1])
17651780
assert len(traj.states) == 3
17661781
assert len(traj.actions) == 2
@@ -1774,13 +1789,13 @@ def test_SimulateVideoMonitor():
17741789
task = env.get_task("test", 0)
17751790
monitor = utils.SimulateVideoMonitor(task, env.render_state)
17761791
policy = lambda _: Action(env.action_space.sample())
1777-
traj = utils.run_policy(policy,
1778-
env,
1779-
"test",
1780-
0,
1781-
task.goal_holds,
1782-
max_num_steps=2,
1783-
monitor=monitor)
1792+
traj, _ = utils.run_policy(policy,
1793+
env,
1794+
"test",
1795+
0,
1796+
task.goal_holds,
1797+
max_num_steps=2,
1798+
monitor=monitor)
17841799
assert not task.goal_holds(traj.states[-1])
17851800
assert len(traj.states) == 3
17861801
assert len(traj.actions) == 2

0 commit comments

Comments
 (0)