1010import logging
1111import os
1212import subprocess
13+ import time
1314from collections import defaultdict
1415from dataclasses import dataclass , field
1516from typing import TYPE_CHECKING , Any , Callable , Collection , Dict , FrozenSet , \
3132from predicators .src .structs import NSRT , Action , Array , DummyOption , \
3233 EntToEntSub , GroundAtom , GroundAtomTrajectory , \
3334 GroundNSRTOrSTRIPSOperator , Image , LiftedAtom , LiftedOrGroundAtom , \
34- LowLevelTrajectory , NSRTOrSTRIPSOperator , Object , OptionSpec , \
35+ LowLevelTrajectory , Metrics , NSRTOrSTRIPSOperator , Object , OptionSpec , \
3536 ParameterizedOption , Predicate , Segment , State , STRIPSOperator , Task , \
3637 Type , VarToObjSub , Video , _GroundNSRT , _GroundSTRIPSOperator , _Option , \
3738 _TypedEntity
@@ -443,15 +444,16 @@ def observe(self, state: State, action: Optional[Action]) -> None:
443444 raise NotImplementedError ("Override me!" )
444445
445446
446- def run_policy (policy : Callable [[State ], Action ],
447- env : BaseEnv ,
448- train_or_test : str ,
449- task_idx : int ,
450- termination_function : Callable [[State ], bool ],
451- max_num_steps : int ,
452- exceptions_to_break_on : Optional [Set [
453- TypingType [Exception ]]] = None ,
454- monitor : Optional [Monitor ] = None ) -> LowLevelTrajectory :
447+ def run_policy (
448+ policy : Callable [[State ], Action ],
449+ env : BaseEnv ,
450+ train_or_test : str ,
451+ task_idx : int ,
452+ termination_function : Callable [[State ], bool ],
453+ max_num_steps : int ,
454+ exceptions_to_break_on : Optional [Set [TypingType [Exception ]]] = None ,
455+ monitor : Optional [Monitor ] = None
456+ ) -> Tuple [LowLevelTrajectory , Metrics ]:
455457 """Execute a policy starting from the initial state of a train or test task
456458 in the environment. The task's goal is not used.
457459
@@ -465,10 +467,14 @@ def run_policy(policy: Callable[[State], Action],
465467 state = env .reset (train_or_test , task_idx )
466468 states = [state ]
467469 actions : List [Action ] = []
470+ metrics : Metrics = defaultdict (float )
471+ metrics ["policy_call_time" ] = 0.0
468472 if not termination_function (state ):
469473 for _ in range (max_num_steps ):
470474 try :
475+ start_time = time .time ()
471476 act = policy (state )
477+ metrics ["policy_call_time" ] += time .time () - start_time
472478 except Exception as e :
473479 if exceptions_to_break_on is not None and \
474480 type (e ) in exceptions_to_break_on :
@@ -484,7 +490,7 @@ def run_policy(policy: Callable[[State], Action],
484490 if monitor is not None :
485491 monitor .observe (state , None )
486492 traj = LowLevelTrajectory (states , actions )
487- return traj
493+ return traj , metrics
488494
489495
490496def run_policy_with_simulator (
@@ -500,10 +506,10 @@ def run_policy_with_simulator(
500506 *** This function should not be used with any core code, because we want
501507 to avoid the assumption of a simulator when possible. ***
502508
503- This is similar to run_policy, with two major differences:
509+ This is similar to run_policy, with three major differences:
504510 (1) The initial state `init_state` can be any state, not just the initial
505511 state of a train or test task. (2) A simulator (function that takes state
506- as input) is assumed.
512+ as input) is assumed. (3) Metrics are not returned.
507513
508514 Note that the environment internal state is NOT updated.
509515
0 commit comments