|
| 1 | +import json |
| 2 | +from copy import deepcopy |
| 3 | +from os.path import join |
| 4 | +from pathlib import Path |
| 5 | + |
| 6 | +try: |
| 7 | + from typing import Literal |
| 8 | +except ImportError: |
| 9 | + from typing_extensions import Literal |
| 10 | + |
| 11 | +import torch |
| 12 | + |
| 13 | +from pydantic import Extra |
| 14 | + |
| 15 | +from sample_factory.algorithms.appo.actor_worker import transform_dict_observations |
| 16 | +from sample_factory.algorithms.appo.learner import LearnerWorker |
| 17 | +from sample_factory.algorithms.appo.model import create_actor_critic |
| 18 | +from sample_factory.algorithms.appo.model_utils import get_hidden_size |
| 19 | +from sample_factory.envs.create_env import create_env |
| 20 | +from sample_factory.utils.utils import AttrDict |
| 21 | + |
| 22 | +from agents.utils_agents import AlgoBase, run_algorithm |
| 23 | +from learning.epom_config import Environment |
| 24 | +from learning.grid_memory import MultipleGridMemory |
| 25 | +from pomapf_env.wrappers import MatrixObservationWrapper |
| 26 | + |
| 27 | +from train_epom import validate_config, register_custom_components |
| 28 | + |
| 29 | + |
| 30 | +class EpomConfig(AlgoBase, extra=Extra.forbid): |
| 31 | + name: Literal['EPOM'] = 'EPOM' |
| 32 | + path_to_weights: str = "weights/epom" |
| 33 | + |
| 34 | + |
| 35 | +class EPOM: |
| 36 | + def __init__(self, algo_cfg): |
| 37 | + self.algo_cfg: EpomConfig = algo_cfg |
| 38 | + |
| 39 | + path = algo_cfg.path_to_weights |
| 40 | + device = algo_cfg.device |
| 41 | + register_custom_components() |
| 42 | + |
| 43 | + self.path = path |
| 44 | + self.env = None |
| 45 | + config_path = join(path, 'cfg.json') |
| 46 | + with open(config_path, "r") as f: |
| 47 | + config = json.load(f) |
| 48 | + exp, flat_config = validate_config(config['full_config']) |
| 49 | + algo_cfg = flat_config |
| 50 | + |
| 51 | + env = create_env(algo_cfg.env, cfg=algo_cfg, env_config={}) |
| 52 | + actor_critic = create_actor_critic(algo_cfg, env.observation_space, env.action_space) |
| 53 | + env.close() |
| 54 | + |
| 55 | + if device == 'cpu' or not torch.cuda.is_available(): |
| 56 | + device = torch.device('cpu') |
| 57 | + else: |
| 58 | + device = torch.device('cuda') |
| 59 | + self.device = device |
| 60 | + |
| 61 | + actor_critic.model_to_device(device) |
| 62 | + policy_id = algo_cfg.policy_index |
| 63 | + checkpoints = join(path, f'checkpoint_p{policy_id}') |
| 64 | + checkpoints = LearnerWorker.get_checkpoints(checkpoints) |
| 65 | + checkpoint_dict = LearnerWorker.load_checkpoint(checkpoints, device) |
| 66 | + actor_critic.load_state_dict(checkpoint_dict['model']) |
| 67 | + |
| 68 | + self.ppo = actor_critic |
| 69 | + self.device = device |
| 70 | + self.cfg = algo_cfg |
| 71 | + |
| 72 | + self.rnn_states = None |
| 73 | + self.mgm = MultipleGridMemory() |
| 74 | + self._step = 0 |
| 75 | + |
| 76 | + def after_reset(self): |
| 77 | + torch.manual_seed(self.algo_cfg.seed) |
| 78 | + self.mgm.clear() |
| 79 | + self._step = 0 |
| 80 | + |
| 81 | + def get_additional_info(self): |
| 82 | + result = {"rl_used": 1.0, } |
| 83 | + return result |
| 84 | + |
| 85 | + def get_name(self): |
| 86 | + return Path(self.path).name |
| 87 | + |
| 88 | + def act(self, observations, rewards=None, dones=None, infos=None): |
| 89 | + observations = deepcopy(observations) |
| 90 | + if self.rnn_states is None or len(self.rnn_states) != len(observations): |
| 91 | + self.rnn_states = torch.zeros([len(observations), get_hidden_size(self.cfg)], dtype=torch.float32, |
| 92 | + device=self.device) |
| 93 | + env_cfg: Environment = Environment(**self.cfg.full_config['environment']) |
| 94 | + self.mgm.update(observations) |
| 95 | + gm_radius = env_cfg.grid_memory_obs_radius |
| 96 | + self.mgm.modify_observation(observations, obs_radius=gm_radius if gm_radius else env_cfg.grid_config.obs_radius) |
| 97 | + observations = MatrixObservationWrapper.to_matrix(observations) |
| 98 | + |
| 99 | + with torch.no_grad(): |
| 100 | + |
| 101 | + obs_torch = AttrDict(transform_dict_observations(observations)) |
| 102 | + for key, x in obs_torch.items(): |
| 103 | + obs_torch[key] = torch.from_numpy(x).to(self.device).float() |
| 104 | + policy_outputs = self.ppo(obs_torch, self.rnn_states, with_action_distribution=True) |
| 105 | + |
| 106 | + self.rnn_states = policy_outputs.rnn_states |
| 107 | + actions = policy_outputs.actions |
| 108 | + |
| 109 | + self._step += 1 |
| 110 | + result = actions.cpu().numpy() |
| 111 | + return result |
| 112 | + |
| 113 | + def clear_hidden(self, agent_idx): |
| 114 | + if self.rnn_states is not None: |
| 115 | + self.rnn_states[agent_idx] = torch.zeros([get_hidden_size(self.cfg)], dtype=torch.float32, |
| 116 | + device=self.device) |
| 117 | + |
| 118 | + def after_step(self, dones): |
| 119 | + for agent_idx, done_flag in enumerate(dones): |
| 120 | + if done_flag: |
| 121 | + self.clear_hidden(agent_idx) |
| 122 | + |
| 123 | + if all(dones): |
| 124 | + self.rnn_states = None |
| 125 | + self.mgm.clear() |
| 126 | + |
| 127 | + |
| 128 | +def example_epom(map_name='sc1-AcrosstheCape', max_episode_steps=512, seed=None, num_agents=64, main_dir='./', |
| 129 | + animate=False): |
| 130 | + algo = EPOM(EpomConfig(path_to_weights=str(main_dir / Path('weights/epom')))) |
| 131 | + return run_algorithm(algo, map_name, max_episode_steps, seed, num_agents, animate) |
| 132 | + |
| 133 | + |
| 134 | +if __name__ == '__main__': |
| 135 | + print(example_epom(main_dir='../')) |
0 commit comments