Skip to content

Commit aab2933

Browse files
committed
Runner script to sparify reward models for PointMaze
1 parent 89dcc75 commit aab2933

1 file changed

Lines changed: 65 additions & 0 deletions

File tree

runners/sparsify_point_maze.sh

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#!/usr/bin/env bash
2+
# Copyright 2020 Adam Gleave
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# Script to sparsify pretrained reward models generated by `transfer_point_maze.sh`
17+
18+
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
19+
. ${DIR}/common.sh
20+
21+
ENV_TRAIN="imitation/PointMazeLeftVel-v0"
22+
TRANSITION_P=0.05
23+
24+
if [[ ${fast} == "true" ]]; then
25+
# intended for debugging
26+
COMPARISON_TIMESTEPS="fast"
27+
PM_OUTPUT=${OUTPUT_ROOT}/transfer_point_maze_fast
28+
SPARSE_OUTPUT=${OUTPUT_ROOT}/sparse_point_maze_fast
29+
else
30+
COMPARISON_TIMESTEPS=""
31+
EVAL_TIMESTEPS=100000
32+
PM_OUTPUT=${OUTPUT_ROOT}/transfer_point_maze
33+
SPARSE_OUTPUT=${OUTPUT_ROOT}/sparse_point_maze
34+
fi
35+
36+
MIXED_POLICY_PATH=${TRANSITION_P}:random:dummy:ppo2:${PM_OUTPUT}/expert/train/policies/final
37+
for name in comparison_expert comparison_mixture comparison_random; do
38+
if [[ ${name} == "comparison_expert" ]]; then
39+
extra_flags="dataset_factory_kwargs.policy_type=ppo2 \
40+
dataset_factory_kwargs.policy_path=${PM_OUTPUT}/expert/train/policies/final"
41+
elif [[ ${name} == "comparison_mixture" ]]; then
42+
extra_flags="dataset_factory_kwargs.policy_type=mixture \
43+
dataset_factory_kwargs.policy_path=${MIXED_POLICY_PATH}"
44+
elif [[ ${name} == "comparison_random" ]]; then
45+
extra_flags=""
46+
else
47+
echo "BUG: unknown name ${name}"
48+
exit 1
49+
fi
50+
parallel --header : --results ${SPARSE_OUTPUT}/parallel/${name} \
51+
$(call_script "model_comparison" "with") \
52+
env_name=${ENV_TRAIN} ${extra_flags} \
53+
ellp_loss no_rescale target_reward_type=evaluating_rewards/Zero-v0 \
54+
seed={seed} source_reward_type={source_reward_type} \
55+
source_reward_path=${PM_OUTPUT}/reward/{source_reward_path}/{source_reward_suffix} \
56+
${COMPARISON_TIMESTEPS} log_dir=${SPARSE_OUTPUT}/${name}/{source_reward_path}/{seed} \
57+
::: source_reward_type evaluating_rewards/PointMazeGroundTruthWithCtrl-v0 \
58+
evaluating_rewards/PointMazeGroundTruthNoCtrl-v0 \
59+
evaluating_rewards/RewardModel-v0 evaluating_rewards/RewardModel-v0 \
60+
imitation/RewardNet_unshaped-v0 imitation/RewardNet_unshaped-v0 \
61+
:::+ source_reward_path withctrl noctrl preferences regress irl_state_only irl_state_action \
62+
:::+ source_reward_suffix dummy dummy model model checkpoints/final/discrim/reward_net \
63+
checkpoints/final/discrim/reward_net \
64+
::: seed 0 1 2
65+
done

0 commit comments

Comments
 (0)