|
| 1 | +#!/usr/bin/env bash |
| 2 | +# Copyright 2020 Adam Gleave |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +# Script to sparsify pretrained reward models generated by `transfer_point_maze.sh` |
| 17 | + |
| 18 | +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" |
| 19 | +. ${DIR}/common.sh |
| 20 | + |
| 21 | +ENV_TRAIN="imitation/PointMazeLeftVel-v0" |
| 22 | +TRANSITION_P=0.05 |
| 23 | + |
| 24 | +if [[ ${fast} == "true" ]]; then |
| 25 | + # intended for debugging |
| 26 | + COMPARISON_TIMESTEPS="fast" |
| 27 | + PM_OUTPUT=${OUTPUT_ROOT}/transfer_point_maze_fast |
| 28 | + SPARSE_OUTPUT=${OUTPUT_ROOT}/sparse_point_maze_fast |
| 29 | +else |
| 30 | + COMPARISON_TIMESTEPS="" |
| 31 | + EVAL_TIMESTEPS=100000 |
| 32 | + PM_OUTPUT=${OUTPUT_ROOT}/transfer_point_maze |
| 33 | + SPARSE_OUTPUT=${OUTPUT_ROOT}/sparse_point_maze |
| 34 | +fi |
| 35 | + |
| 36 | +MIXED_POLICY_PATH=${TRANSITION_P}:random:dummy:ppo2:${PM_OUTPUT}/expert/train/policies/final |
| 37 | +for name in comparison_expert comparison_mixture comparison_random; do |
| 38 | + if [[ ${name} == "comparison_expert" ]]; then |
| 39 | + extra_flags="dataset_factory_kwargs.policy_type=ppo2 \ |
| 40 | + dataset_factory_kwargs.policy_path=${PM_OUTPUT}/expert/train/policies/final" |
| 41 | + elif [[ ${name} == "comparison_mixture" ]]; then |
| 42 | + extra_flags="dataset_factory_kwargs.policy_type=mixture \ |
| 43 | + dataset_factory_kwargs.policy_path=${MIXED_POLICY_PATH}" |
| 44 | + elif [[ ${name} == "comparison_random" ]]; then |
| 45 | + extra_flags="" |
| 46 | + else |
| 47 | + echo "BUG: unknown name ${name}" |
| 48 | + exit 1 |
| 49 | + fi |
| 50 | + parallel --header : --results ${SPARSE_OUTPUT}/parallel/${name} \ |
| 51 | + $(call_script "model_comparison" "with") \ |
| 52 | + env_name=${ENV_TRAIN} ${extra_flags} \ |
| 53 | + ellp_loss no_rescale target_reward_type=evaluating_rewards/Zero-v0 \ |
| 54 | + seed={seed} source_reward_type={source_reward_type} \ |
| 55 | + source_reward_path=${PM_OUTPUT}/reward/{source_reward_path}/{source_reward_suffix} \ |
| 56 | + ${COMPARISON_TIMESTEPS} log_dir=${SPARSE_OUTPUT}/${name}/{source_reward_path}/{seed} \ |
| 57 | + ::: source_reward_type evaluating_rewards/PointMazeGroundTruthWithCtrl-v0 \ |
| 58 | + evaluating_rewards/PointMazeGroundTruthNoCtrl-v0 \ |
| 59 | + evaluating_rewards/RewardModel-v0 evaluating_rewards/RewardModel-v0 \ |
| 60 | + imitation/RewardNet_unshaped-v0 imitation/RewardNet_unshaped-v0 \ |
| 61 | + :::+ source_reward_path withctrl noctrl preferences regress irl_state_only irl_state_action \ |
| 62 | + :::+ source_reward_suffix dummy dummy model model checkpoints/final/discrim/reward_net \ |
| 63 | + checkpoints/final/discrim/reward_net \ |
| 64 | + ::: seed 0 1 2 |
| 65 | +done |
0 commit comments