diff --git a/lzero/entry/__init__.py b/lzero/entry/__init__.py index f17126527..47547a578 100644 --- a/lzero/entry/__init__.py +++ b/lzero/entry/__init__.py @@ -9,5 +9,6 @@ from .train_muzero_with_reward_model import train_muzero_with_reward_model from .train_rezero import train_rezero from .train_unizero import train_unizero +from .train_unizero_ppo import train_unizero_ppo from .train_unizero_segment import train_unizero_segment from .utils import * diff --git a/lzero/entry/train_unizero_ppo.py b/lzero/entry/train_unizero_ppo.py new file mode 100644 index 000000000..e564ec699 --- /dev/null +++ b/lzero/entry/train_unizero_ppo.py @@ -0,0 +1,179 @@ +import logging +import os +from functools import partial +from typing import Optional, Tuple, List, Dict, Any + +import torch +import wandb +import numpy as np +from ding.config import compile_config +from ding.envs import create_env_manager +from ding.envs import get_vec_env_setting +from ding.policy import create_policy +from ding.utils import get_rank, get_world_size, set_pkg_seed +from torch.utils.tensorboard import SummaryWriter +from ding.worker import BaseLearner +import torch.distributed as dist + +from lzero.worker.muzero_evaluator_ppo import MuZeroEvaluatorPPO as Evaluator +from lzero.worker.muzero_collector_ppo import MuZeroCollectorPPO + + +def train_unizero_ppo( + input_cfg: Tuple[dict, dict], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: int = int(1e10), +) -> None: + cfg, create_cfg = input_cfg + assert create_cfg.policy.type == 'unizero_ppo', "train_unizero_ppo expects policy type 'unizero_ppo'" + logging.info(f"Using policy type: {create_cfg.policy.type}") + + cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu' + logging.info(f"Device set to: {cfg.policy.device}") + + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + + collector_env.seed(cfg.seed) + evaluator_env.seed(cfg.seed, dynamic_seed=False) + set_pkg_seed(cfg.seed, use_cuda=torch.cuda.is_available()) + + rank = get_rank() + + if cfg.policy.use_wandb: + wandb.init( + project="LightZero", + config=cfg, + sync_tensorboard=False, + monitor_gym=False, + save_code=True, + ) + + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + logging.info("Policy created successfully!") + + # Load pretrained model if specified + if model_path is not None: + logging.info(f"Loading pretrained model from {model_path}...") + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + logging.info("Pretrained model loaded successfully!") + + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if rank == 0 else None + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + collector = MuZeroCollectorPPO( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=cfg.policy, + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=cfg.policy, + ) + + learner.call_hook('before_run') + if cfg.policy.use_wandb: + policy.set_train_iter_env_step(learner.train_iter, collector.envstep) + + if cfg.policy.multi_gpu: + world_size = get_world_size() + else: + world_size = 1 + + transition_buffer: List[Dict[str, Any]] = [] + + while True: + # eval_stop = False + # if (learner.train_iter == 0 or evaluator.should_eval(learner.train_iter)) and rank == 0: + # logging.info(f"Training iteration {learner.train_iter}: Starting evaluation...") + # eval_stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + # logging.info(f"Training iteration {learner.train_iter}: Evaluation completed, stop condition: {eval_stop}, current reward: {reward}") + # if cfg.policy.multi_gpu and world_size > 1: + # stop_tensor = torch.tensor([int(eval_stop)], device=cfg.policy.device if torch.cuda.is_available() else torch.device('cpu')) + # dist.broadcast(stop_tensor, src=0) + # eval_stop = bool(stop_tensor.item()) + # if eval_stop: + # logging.info("Stopping condition met, training ends!") + # break + + collect_kwargs = dict(temperature=1.0, epsilon=0.0) + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + logging.info(f"Rank {rank}, Training iteration {learner.train_iter}: New data collection completed!") + + transitions = new_data[0] + if transitions: + transition_buffer.extend(transitions) + + if len(transition_buffer) < cfg.policy.ppo.mini_batch_size: + continue + + if cfg.policy.ppo.get('advantage_normalization', True): + advantages = np.stack([item['advantage'] for item in transition_buffer]) + adv_mean = advantages.mean() + adv_std = advantages.std() + 1e-8 + for item in transition_buffer: + item['advantage'] = (item['advantage'] - adv_mean) / adv_std + + total_transitions = len(transition_buffer) + mini_batch_size = cfg.policy.ppo.mini_batch_size + for _ in range(cfg.policy.ppo.update_epochs): + permutation = np.random.permutation(total_transitions) + for start in range(0, total_transitions, mini_batch_size): + batch_indices = permutation[start:start + mini_batch_size] + if batch_indices.size == 0: + continue + + def stack(key: str) -> np.ndarray: + return np.stack([transition_buffer[i][key] for i in batch_indices]) + + batch_dict = dict( + prev_obs=stack('prev_obs'), + obs=stack('obs'), + action_mask=stack('action_mask'), + action=stack('action'), + old_log_prob=stack('old_log_prob'), + advantage=stack('advantage'), + return_=stack('return'), + prev_action=stack('prev_action'), + timestep=stack('timestep'), + ) + train_data = [batch_dict, None] + train_data.append(learner.train_iter) + learner.train(train_data, collector.envstep) + + transition_buffer.clear() + + if cfg.policy.multi_gpu and world_size > 1: + try: + dist.barrier() + except Exception as e: + logging.error(f'Rank {rank}: Synchronization barrier failed, error: {e}') + break + + if cfg.policy.use_wandb: + policy.set_train_iter_env_step(learner.train_iter, collector.envstep) + + if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: + logging.info("Reached max training condition") + break + + learner.call_hook('after_run') + collector.close() + evaluator.close() + if tb_logger is not None: + tb_logger.close() + if cfg.policy.use_wandb: + wandb.finish() diff --git a/lzero/model/unizero_world_models/world_model.py b/lzero/model/unizero_world_models/world_model.py index 7f1a0f68e..7f55f49e3 100644 --- a/lzero/model/unizero_world_models/world_model.py +++ b/lzero/model/unizero_world_models/world_model.py @@ -1596,7 +1596,96 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar latent_state_l2_norms=latent_state_l2_norms, ) + + def compute_loss_ppo( + self, + batch: Dict[str, torch.Tensor], + inverse_scalar_transform_handle, + clip_ratio: float, + value_coef: float, + entropy_coef: float, + ) -> Dict[str, torch.Tensor]: + """Compute PPO objectives (policy/value/entropy) for a mini-batch.""" + policy_logits = batch['policy_logits'] + action_mask = batch['action_mask'].bool() + actions = batch['actions'].long() + old_log_prob = batch['old_log_prob'].float() + advantages = batch['advantages'].float() + returns = batch['returns'].float() + + masked_logits = policy_logits.masked_fill(~action_mask, -1e9) + dist = Categorical(logits=masked_logits) + log_prob = dist.log_prob(actions) + entropy = dist.entropy() + + ratio = torch.exp(log_prob - old_log_prob) + surrogate1 = ratio * advantages + surrogate2 = torch.clamp(ratio, 1.0 - clip_ratio, 1.0 + clip_ratio) * advantages + policy_loss = -torch.min(surrogate1, surrogate2).mean() + + value_pred = inverse_scalar_transform_handle(batch['values']).squeeze(-1) + value_loss = torch.nn.functional.mse_loss(value_pred, returns) + + entropy_mean = entropy.mean() + entropy_loss = -entropy_mean + + loss_total = policy_loss + value_coef * value_loss + entropy_coef * entropy_loss + + return { + 'loss_total': loss_total, + 'loss_policy': policy_loss, + 'loss_value': value_loss, + 'loss_entropy': entropy_loss, + 'entropy_mean': entropy_mean, + 'ratio_mean': ratio.mean(), + 'advantage_mean': advantages.mean(), + 'return_mean': returns.mean(), + } + + def compute_loss_ppo( + self, + batch: Dict[str, torch.Tensor], + inverse_scalar_transform_handle, + clip_ratio: float, + value_coef: float, + entropy_coef: float, + ) -> Dict[str, torch.Tensor]: + """Compute PPO losses given policy logits and associated targets.""" + policy_logits = batch['policy_logits'] + action_mask = batch['action_mask'].bool() + actions = batch['actions'].long() + old_log_prob = batch['old_log_prob'].float() + advantages = batch['advantages'].float() + returns = batch['returns'].float() + + pred_values = inverse_scalar_transform_handle(batch['values']).squeeze(-1) + + masked_logits = policy_logits.masked_fill(~action_mask, -1e9) + dist = Categorical(logits=masked_logits) + log_prob = dist.log_prob(actions) + entropy = dist.entropy() + + ratio = torch.exp(log_prob - old_log_prob) + surrogate1 = ratio * advantages + surrogate2 = torch.clamp(ratio, 1.0 - clip_ratio, 1.0 + clip_ratio) * advantages + policy_loss = -torch.min(surrogate1, surrogate2).mean() + value_loss = F.mse_loss(pred_values, returns) + entropy_mean = entropy.mean() + entropy_loss = -entropy_mean + + loss_total = policy_loss + value_coef * value_loss + entropy_coef * entropy_loss + + return { + 'loss_total': loss_total, + 'loss_policy': policy_loss, + 'loss_value': value_loss, + 'loss_entropy': entropy_loss, + 'entropy_mean': entropy_mean, + 'ratio_mean': ratio.mean(), + 'advantage_mean': advantages.mean(), + 'return_mean': returns.mean(), + } # TODO: test correctness def _calculate_policy_loss_cont_simple(self, outputs, batch: dict): """ diff --git a/lzero/policy/unizero_ppo.py b/lzero/policy/unizero_ppo.py new file mode 100644 index 000000000..d48afb6ea --- /dev/null +++ b/lzero/policy/unizero_ppo.py @@ -0,0 +1,257 @@ +import copy +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +import wandb +from torch.distributions import Categorical + +from ding.utils import POLICY_REGISTRY + +from lzero.entry.utils import initialize_pad_batch +from lzero.policy import mz_network_output_unpack +from lzero.policy.unizero import UniZeroPolicy + + +@POLICY_REGISTRY.register('unizero_ppo') +class UniZeroPPOPolicy(UniZeroPolicy): + """UniZero policy variant that replaces MCTS-based improvement with PPO updates.""" + + config = copy.deepcopy(UniZeroPolicy.config) + config.update( + dict( + type='unizero_ppo', + ppo=dict( + rollout_length=64, + mini_batch_size=32, + update_epochs=4, + gamma=0.997, + gae_lambda=0.95, + clip_ratio=0.2, + value_coef=0.25, + entropy_coef=0.01, + advantage_normalization=True, + ), + ) + ) + + def _init_collect(self) -> None: + """Initialize structures used during data collection.""" + self._collect_model = self._model + env_num = self._cfg.collector_env_num + if self._cfg.model.model_type == 'conv': + self.last_batch_obs = torch.zeros( + [env_num, self._cfg.model.observation_shape[0], 64, 64], + device=self._cfg.device, + ) + else: + self.last_batch_obs = torch.full( + [env_num, self._cfg.model.observation_shape], + fill_value=self.pad_token_id, + device=self._cfg.device, + ) + self.last_batch_action = [-1 for _ in range(env_num)] + + def _init_eval(self) -> None: + """Evaluation reuses collect path (no MCTS search).""" + self._eval_model = self._model + env_num = self._cfg.evaluator_env_num + if self._cfg.model.model_type == 'conv': + self.last_batch_obs = torch.zeros( + [env_num, self._cfg.model.observation_shape[0], 64, 64], + device=self._cfg.device, + ) + else: + self.last_batch_obs = torch.full( + [env_num, self._cfg.model.observation_shape], + fill_value=self.pad_token_id, + device=self._cfg.device, + ) + self.last_batch_action = [-1 for _ in range(env_num)] + + def _forward_collect( + self, + data: torch.Tensor, + action_mask: List[np.ndarray], + ready_env_id: Optional[np.ndarray] = None, + timestep: Optional[List[int]] = None, + deterministic: bool = False, + **kwargs: Any, + ) -> Dict[int, Dict[str, Any]]: + """Sample actions directly from the policy head and expose statistics for PPO.""" + self._collect_model.eval() + + if ready_env_id is None: + ready_env_id = np.arange(data.shape[0]) + elif isinstance(ready_env_id, (list, tuple)): + ready_env_id = np.asarray(ready_env_id) + + if timestep is None: + timestep = [0 for _ in ready_env_id] + + ready_env_list = ready_env_id.tolist() + prev_obs_snapshot = torch.stack( + [self.last_batch_obs[env_id] for env_id in ready_env_list] + ).clone().to(self._cfg.device) + prev_action_snapshot = [self.last_batch_action[env_id] for env_id in ready_env_list] + + with torch.no_grad(): + network_output = self._collect_model.initial_inference( + prev_obs_snapshot, prev_action_snapshot, data, timestep + ) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + del latent_state_roots, reward_roots + pred_values = self.value_inverse_scalar_transform_handle(pred_values) + + outputs: Dict[int, Dict[str, Any]] = {} + batch_action: List[int] = [] + for idx, env_id in enumerate(ready_env_list): + logits = policy_logits[idx] + mask = torch.tensor(action_mask[idx], dtype=torch.bool, device=logits.device) + masked_logits = logits.masked_fill(~mask, -1e9) + dist = Categorical(logits=masked_logits) + action = torch.argmax(masked_logits, dim=-1) if deterministic else dist.sample() + log_prob = dist.log_prob(action) + entropy = dist.entropy() + + action_int = int(action.item()) + batch_action.append(action_int) + outputs[env_id] = dict( + action=action_int, + log_prob=float(log_prob.item()), + entropy=float(entropy.item()), + predicted_value=float(pred_values[idx].item()), + policy_logits=logits.detach().cpu(), + action_mask=np.asarray(action_mask[idx]).copy(), + obs=data[idx].detach().cpu(), + timestep=int(timestep[idx]), + ) + + for idx, env_id in enumerate(ready_env_list): + self.last_batch_obs[env_id] = data[idx].detach().clone() + self.last_batch_action[env_id] = batch_action[idx] + + return outputs + + def _forward_eval( + self, + data: torch.Tensor, + action_mask: List[np.ndarray], + to_play: Optional[List[int]] = None, + ready_env_id: Optional[np.ndarray] = None, + timestep: Optional[List[int]] = None, + **kwargs: Any, + ) -> Dict[int, Dict[str, Any]]: + return self._forward_collect( + data=data, + action_mask=action_mask, + ready_env_id=ready_env_id, + timestep=timestep, + deterministic=True, + ) + + def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, float]: + batch_dict, _, train_iter = data + + self._learn_model.train() + self._target_model.train() + device = next(self._learn_model.parameters()).device + + prev_obs = torch.as_tensor(batch_dict['prev_obs'], device=device) + obs = torch.as_tensor(batch_dict['obs'], device=device) + action_mask = torch.as_tensor(batch_dict['action_mask'], device=device).bool() + actions = torch.as_tensor(batch_dict['action'], device=device).long() + old_log_prob = torch.as_tensor(batch_dict['old_log_prob'], device=device).float() + advantages = torch.as_tensor(batch_dict['advantage'], device=device).float() + returns = torch.as_tensor(batch_dict['return_'], device=device).float() + prev_actions = [int(a) for a in batch_dict['prev_action']] + timesteps = batch_dict['timestep'].tolist() + + prev_obs = prev_obs.float() if prev_obs.is_floating_point() else prev_obs.long() + obs = obs.float() if obs.is_floating_point() else obs.long() + + network_output = self._learn_model.initial_inference(prev_obs, prev_actions, obs, timesteps) + _, _, pred_values, policy_logits = mz_network_output_unpack(network_output) + + loss_tensors = self._learn_model.world_model.compute_loss_ppo( + dict( + policy_logits=policy_logits, + values=pred_values, + action_mask=action_mask, + actions=actions, + old_log_prob=old_log_prob, + advantages=advantages, + returns=returns, + ), + inverse_scalar_transform_handle=self.value_inverse_scalar_transform_handle, + clip_ratio=self._cfg.ppo.clip_ratio, + value_coef=self._cfg.ppo.value_coef, + entropy_coef=self._cfg.ppo.entropy_coef, + ) + + total_loss = loss_tensors['loss_total'] + + if (train_iter % self.accumulation_steps) == 0: + self._optimizer_world_model.zero_grad() + + (total_loss / self.accumulation_steps).backward() + + total_grad_norm_before_clip = torch.tensor(0., device=device) + if (train_iter + 1) % self.accumulation_steps == 0: + total_grad_norm_before_clip = torch.nn.utils.clip_grad_norm_( + self._learn_model.world_model.parameters(), self._cfg.grad_clip_value + ) + if self._cfg.multi_gpu: + self.sync_gradients(self._learn_model) + self._optimizer_world_model.step() + if self.accumulation_steps > 1 and torch.cuda.is_available(): + torch.cuda.empty_cache() + + if self._cfg.cos_lr_scheduler or self._cfg.piecewise_decay_lr_scheduler: + self.lr_scheduler.step() + + self._target_model.update(self._learn_model.state_dict()) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + current_memory_allocated = torch.cuda.memory_allocated() / (1024 ** 3) + max_memory_allocated = torch.cuda.max_memory_allocated() / (1024 ** 3) + else: + current_memory_allocated = 0.0 + max_memory_allocated = 0.0 + + log_dict = { + 'loss_policy': loss_tensors['loss_policy'].item(), + 'loss_value': loss_tensors['loss_value'].item(), + 'loss_entropy': loss_tensors['loss_entropy'].item(), + 'loss_total': total_loss.item(), + 'ratio_mean': loss_tensors['ratio_mean'].item(), + 'advantage_mean': loss_tensors['advantage_mean'].item(), + 'return_mean': loss_tensors['return_mean'].item(), + 'entropy_mean': loss_tensors['entropy_mean'].item(), + 'cur_lr_world_model': self._optimizer_world_model.param_groups[0]['lr'], + 'total_grad_norm_before_clip': total_grad_norm_before_clip.item(), + 'Current_GPU': current_memory_allocated, + 'Max_GPU': max_memory_allocated, + 'train_iter': train_iter, + } + + if self._cfg.use_wandb: + wandb.log({'learner_step/' + k: v for k, v in log_dict.items()}, step=self.env_step) + wandb.log({'learner_iter_vs_env_step': self.train_iter}, step=self.env_step) + + return log_dict + + def reset(self, env_id: Optional[List[int]] = None) -> None: + """Reset cached context for specified environments.""" + if env_id is None: + self._reset_collect(reset_init_data=True) + return + + if isinstance(env_id, int): + env_id = [env_id] + for e_id in env_id: + self.last_batch_obs[e_id] = initialize_pad_batch( + self._cfg.model.observation_shape, 1, self._cfg.device, pad_token_id=getattr(self, 'pad_token_id', 0) + )[0] + self.last_batch_action[e_id] = -1 diff --git a/lzero/worker/__init__.py b/lzero/worker/__init__.py index ece5213be..dbb223c74 100644 --- a/lzero/worker/__init__.py +++ b/lzero/worker/__init__.py @@ -3,3 +3,5 @@ from .muzero_collector import MuZeroCollector from .muzero_segment_collector import MuZeroSegmentCollector from .muzero_evaluator import MuZeroEvaluator +from .muzero_collector_ppo import MuZeroCollectorPPO +from .muzero_evaluator_ppo import MuZeroEvaluatorPPO diff --git a/lzero/worker/muzero_collector_ppo.py b/lzero/worker/muzero_collector_ppo.py new file mode 100644 index 000000000..1afd9c402 --- /dev/null +++ b/lzero/worker/muzero_collector_ppo.py @@ -0,0 +1,318 @@ +import os +import time +from collections import deque, namedtuple +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +from ding.envs import BaseEnvManager +from ding.torch_utils import to_ndarray +from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, get_rank, get_world_size, allreduce_data +from ding.worker.collector.base_serial_collector import ISerialCollector + +from lzero.mcts.utils import prepare_observation + + +@SERIAL_COLLECTOR_REGISTRY.register('episode_muzero_ppo') +class MuZeroCollectorPPO(ISerialCollector): + """Collector that follows the original MuZeroCollector structure but gathers PPO rollouts.""" + + config = dict() + + def __init__( + self, + collect_print_freq: int = 100, + env: BaseEnvManager = None, + policy: namedtuple = None, + tb_logger: 'SummaryWriter' = None, # noqa + exp_name: Optional[str] = 'default_experiment', + instance_name: Optional[str] = 'collector', + policy_config: 'policy_config' = None, # noqa + ) -> None: + self._exp_name = exp_name + self._instance_name = instance_name + self._collect_print_freq = collect_print_freq + self._timer = EasyTimer() + self._end_flag = False + + self._rank = get_rank() + self._world_size = get_world_size() + if self._rank == 0: + if tb_logger is not None: + self._logger, _ = build_logger( + path='./{}/log/{}'.format(self._exp_name, self._instance_name), + name=self._instance_name, + need_tb=False + ) + self._tb_logger = tb_logger + else: + self._logger, self._tb_logger = build_logger( + path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name + ) + else: + self._logger, _ = build_logger( + path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False + ) + self._tb_logger = None + + self.policy_config = policy_config + self.rollout_length = self.policy_config.ppo.rollout_length + + self.reset(policy, env) + + def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: + if _env is not None: + self._env = _env + self._env.launch() + self._env_num = self._env.env_num + else: + self._env.reset() + + def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: + assert hasattr(self, '_env'), "please set env first" + if _policy is not None: + self._policy = _policy + self._default_n_episode = _policy.get_attribute('cfg').get('n_episode', None) + self._logger.debug( + 'Set default n_episode mode(n_episode({}), env_num({}))'.format(self._default_n_episode, self._env_num) + ) + self._policy.reset() + + def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: + if _env is not None: + self.reset_env(_env) + if _policy is not None: + self.reset_policy(_policy) + + self._env_info = {env_id: {'time': 0., 'step': 0} for env_id in range(self._env_num)} + self._episode_info: List[Dict[str, Any]] = [] + self._pending_buffers = {env_id: [] for env_id in range(self._env_num)} + self._obs_stacks = {env_id: deque([], maxlen=self.policy_config.model.frame_stack_num) for env_id in range(self._env_num)} + self._total_envstep_count = 0 + self._total_episode_count = 0 + self._total_duration = 0. + self._last_train_iter = 0 + self._end_flag = False + + def _reset_stat(self, env_id: int) -> None: + self._env_info[env_id] = {'time': 0., 'step': 0} + self._pending_buffers[env_id] = [] + self._obs_stacks[env_id] = deque(maxlen=self.policy_config.model.frame_stack_num) + + @property + def envstep(self) -> int: + return self._total_envstep_count + + def close(self) -> None: + if self._end_flag: + return + self._end_flag = True + self._env.close() + if self._tb_logger: + self._tb_logger.flush() + self._tb_logger.close() + + def __del__(self) -> None: + self.close() + + def collect(self, + train_iter: int = 0, + policy_kwargs: Optional[dict] = None) -> List[Any]: + if policy_kwargs is None: + policy_kwargs = {} + + env_num = self._env_num + target_episode = policy_kwargs.get('n_episode', self._default_n_episode) + target_episode = max(target_episode, env_num) if target_episode is not None else env_num + transitions: List[Dict[str, Any]] = [] + collected_step = 0 + collected_episode = 0 + collected_duration = 0.0 + + init_obs = self._env.ready_obs + retry_waiting_time = 0.01 + while len(init_obs.keys()) != self._env_num: + time.sleep(retry_waiting_time) + init_obs = self._env.ready_obs + + frame_stack = self.policy_config.model.frame_stack_num + obs_stacks = getattr(self, '_obs_stacks', None) + if obs_stacks is None or not obs_stacks: + obs_stacks = { + env_id: deque([to_ndarray(init_obs[env_id]['observation']) for _ in range(frame_stack)], maxlen=frame_stack) + for env_id in range(env_num) + } + self._obs_stacks = obs_stacks + last_prev_action = {} + for env_id in range(env_num): + buffer = self._pending_buffers.get(env_id, []) + if buffer: + last_prev_action[env_id] = int(buffer[-1]['action']) + else: + last_prev_action[env_id] = -1 + + while collected_episode < target_episode: + ready_obs = self._env.ready_obs + if not ready_obs: + time.sleep(0.001) + continue + + for env_id in ready_obs.keys(): + if len(obs_stacks[env_id]) < frame_stack: + obs_value = to_ndarray(ready_obs[env_id]['observation']) + obs_stacks[env_id] = deque([obs_value for _ in range(frame_stack)], maxlen=frame_stack) + + ready_env_list = list(ready_obs.keys()) + + stacked = [np.array(list(obs_stacks[env_id])) for env_id in ready_env_list] + action_masks = [to_ndarray(ready_obs[env_id]['action_mask']) for env_id in ready_env_list] + timesteps = [ready_obs[env_id].get('timestep', -1) for env_id in ready_env_list] + + stacked_np = prepare_observation(stacked, self.policy_config.model.model_type) + stacked_tensor = torch.from_numpy(stacked_np).to(self.policy_config.device) + + policy_output = self._policy.forward( + stacked_tensor, + action_mask=action_masks, + ready_env_id=ready_env_list, + timestep=timesteps, + ) + + actions = {env_id: policy_output[env_id]['action'] for env_id in ready_env_list} + + with self._timer: + timesteps_output = self._env.step(actions) + interaction_duration = self._timer.value / max(len(timesteps_output), 1) + + for env_id, timestep_data in timesteps_output.items(): + obs_dict = timestep_data.obs + reward = float(timestep_data.reward) + done = bool(timestep_data.done) + next_obs = to_ndarray(obs_dict['observation']) + + info = policy_output[env_id] + prev_stack = np.array(list(obs_stacks[env_id])) + prev_action_value = last_prev_action[env_id] + + obs_stacks[env_id].append(next_obs) + next_stack = np.array(list(obs_stacks[env_id])) + last_prev_action[env_id] = info['action'] + + step_record = dict( + prev_obs=prev_stack, + obs=next_stack, + action_mask=info['action_mask'], + action=np.array(info['action'], dtype=np.int64), + old_log_prob=np.array(info['log_prob'], dtype=np.float32), + value=np.array(info['predicted_value'], dtype=np.float32), + reward=np.array(reward, dtype=np.float32), + done=np.array(done, dtype=np.float32), + prev_action=np.array(prev_action_value, dtype=np.int64), + timestep=np.array(info['timestep'], dtype=np.int64), + ) + self._pending_buffers[env_id].append(step_record) + + self._env_info[env_id]['time'] += interaction_duration + self._env_info[env_id]['step'] += 1 + collected_step += 1 + + if done: + episode_transitions, episode_reward = self._finalize_episode(env_id) + transitions.extend(episode_transitions) + collected_episode += 1 + collected_duration += self._env_info[env_id]['time'] + self._episode_info.append({ + 'reward': episode_reward, + 'step': self._env_info[env_id]['step'], + 'time': self._env_info[env_id]['time'], + }) + + self._policy.reset([env_id]) + self._reset_stat(env_id) + if env_id in ready_env_list: + ready_env_list.remove(env_id) + + last_prev_action[env_id] = -1 + + self._obs_stacks = obs_stacks + + if self._world_size > 1: + collected_step = allreduce_data(collected_step, 'sum') + collected_episode = allreduce_data(collected_episode, 'sum') + collected_duration = allreduce_data(collected_duration, 'sum') + + self._total_envstep_count += collected_step + self._total_episode_count += collected_episode + self._total_duration += collected_duration + + self._output_log(train_iter) + + return [transitions, {}] + + def _finalize_episode(self, env_id: int) -> Tuple[List[Dict[str, Any]], float]: + buffer = self._pending_buffers[env_id] + if not buffer: + return [], 0.0 + + gamma = self.policy_config.ppo.gamma + gae_lambda = self.policy_config.ppo.gae_lambda + + rewards = np.array([step['reward'] for step in buffer], dtype=np.float32) + values = np.array([step['value'] for step in buffer], dtype=np.float32) + dones = np.array([step['done'] for step in buffer], dtype=np.float32) + + advantages = np.zeros_like(rewards) + returns = np.zeros_like(rewards) + gae = 0.0 + for t in reversed(range(len(buffer))): + next_value = 0.0 if t == len(buffer) - 1 or dones[t] else values[t + 1] + delta = rewards[t] + gamma * next_value * (1 - dones[t]) - values[t] + gae = delta + gamma * gae_lambda * (1 - dones[t]) * gae + advantages[t] = gae + returns[t] = gae + values[t] + + transitions: List[Dict[str, Any]] = [] + for t, step in enumerate(buffer): + transitions.append({ + 'prev_obs': step['prev_obs'], + 'obs': step['obs'], + 'action_mask': step['action_mask'], + 'action': step['action'], + 'old_log_prob': step['old_log_prob'], + 'advantage': advantages[t], + 'return': returns[t], + 'prev_action': step['prev_action'], + 'timestep': step['timestep'], + }) + + episode_reward = float(rewards.sum()) + self._pending_buffers[env_id] = [] + return transitions, episode_reward + + def _output_log(self, train_iter: int) -> None: + if self._rank != 0: + self._episode_info.clear() + return + if self._total_episode_count <= 0: + return + if (train_iter - self._last_train_iter) < self._collect_print_freq and train_iter != 0: + return + self._last_train_iter = train_iter + + reward_list = [info['reward'] for info in self._episode_info] + step_list = [info['step'] for info in self._episode_info] + time_list = [info['time'] for info in self._episode_info] + + avg_reward = float(np.mean(reward_list)) if reward_list else 0.0 + avg_steps = float(np.mean(step_list)) if step_list else 0.0 + avg_time = float(np.mean(time_list)) if time_list else 0.0 + + log_str = f"collect iter({train_iter}) envstep({self._total_envstep_count}) episode({self._total_episode_count}) " \ + f"avg_reward({avg_reward:.3f}) avg_step({avg_steps:.2f}) avg_time({avg_time:.2f})" + self._logger.info(log_str) + if self._tb_logger is not None: + self._tb_logger.add_scalar('collect/avg_reward', avg_reward, train_iter) + self._tb_logger.add_scalar('collect/avg_step', avg_steps, train_iter) + self._tb_logger.add_scalar('collect/avg_time', avg_time, train_iter) + + self._episode_info.clear() diff --git a/lzero/worker/muzero_evaluator_ppo.py b/lzero/worker/muzero_evaluator_ppo.py new file mode 100644 index 000000000..a83f84b6a --- /dev/null +++ b/lzero/worker/muzero_evaluator_ppo.py @@ -0,0 +1,223 @@ +import time +from collections import deque, namedtuple +from typing import Optional, Tuple, Dict, Any, List + +import numpy as np +import torch +import wandb +from ding.envs import BaseEnvManager +from ding.torch_utils import to_ndarray +from ding.utils import build_logger, EasyTimer, get_rank, get_world_size, broadcast_object_list +from ding.worker.collector.base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor + +from lzero.mcts.utils import prepare_observation + + +class MuZeroEvaluatorPPO(ISerialEvaluator): + config = dict( + eval_freq=50, + ) + + def __init__( + self, + eval_freq: int = 1000, + n_evaluator_episode: int = 3, + stop_value: int = 1e6, + env: BaseEnvManager = None, + policy: namedtuple = None, + tb_logger: 'SummaryWriter' = None, # noqa + exp_name: Optional[str] = 'default_experiment', + instance_name: Optional[str] = 'evaluator', + policy_config: 'policy_config' = None, # noqa + ) -> None: + self._eval_freq = eval_freq + self._default_n_episode = n_evaluator_episode + self._stop_value = stop_value + self._exp_name = exp_name + self._instance_name = instance_name + if get_rank() == 0: + if tb_logger is not None: + self._logger, _ = build_logger( + './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False + ) + self._tb_logger = tb_logger + else: + self._logger, self._tb_logger = build_logger( + './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name + ) + else: + self._logger, self._tb_logger = None, None + + self.policy_config = policy_config + self._timer = EasyTimer() + self.reset(policy, env) + + def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: + if _env is not None: + self._env = _env + self._env.launch() + self._env_num = self._env.env_num + else: + self._env.reset() + + def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: + assert hasattr(self, '_env'), "please set env first" + if _policy is not None: + self._policy = _policy + self._policy.reset() + + def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: + if _env is not None: + self.reset_env(_env) + if _policy is not None: + self.reset_policy(_policy) + self._last_eval_iter = 0 + self._max_episode_return = float('-inf') + self._end_flag = False + + def close(self) -> None: + if self._end_flag: + return + self._end_flag = True + self._env.close() + if self._tb_logger: + self._tb_logger.flush() + self._tb_logger.close() + + def __del__(self) -> None: + self.close() + + def should_eval(self, train_iter: int) -> bool: + if train_iter == self._last_eval_iter: + return False + if (train_iter - self._last_eval_iter) < self._eval_freq and train_iter != 0: + return False + self._last_eval_iter = train_iter + return True + + def eval( + self, + save_ckpt_fn: Optional[callable] = None, + train_iter: int = -1, + envstep: int = -1, + n_episode: Optional[int] = None, + return_trajectory: bool = False, + ) -> Tuple[bool, Dict[str, Any]]: + episode_info = None + stop_flag = False + if get_rank() == 0: + if n_episode is None: + n_episode = self._default_n_episode + self._env.reset() + self._policy.reset() + + monitor = VectorEvalMonitor(self._env.env_num, n_episode) + frame_stack = self.policy_config.model.frame_stack_num + obs_stacks = { + env_id: deque([to_ndarray(obs['observation']) for _ in range(frame_stack)], maxlen=frame_stack) + for env_id, obs in self._env.ready_obs.items() + } + + retry_wait = 0.01 + while len(obs_stacks) != self._env.env_num: + time.sleep(retry_wait) + obs_stacks = { + env_id: deque([to_ndarray(obs['observation']) for _ in range(frame_stack)], maxlen=frame_stack) + for env_id, obs in self._env.ready_obs.items() + } + + ready_env_id = set() + remain_episode = n_episode + episode_returns: List[float] = [] + + with self._timer: + while not monitor.is_finished(): + obs = self._env.ready_obs + new_available = set(obs.keys()).difference(ready_env_id) + ready_env_id = ready_env_id.union(set(list(new_available)[:remain_episode])) + remain_episode -= min(len(new_available), remain_episode) + if not ready_env_id: + time.sleep(0.01) + continue + + stacked = [np.array(list(obs_stacks[env_id])) for env_id in ready_env_id] + action_masks = [to_ndarray(obs[env_id]['action_mask']) for env_id in ready_env_id] + timesteps = [obs[env_id].get('timestep', -1) for env_id in ready_env_id] + + stacked_np = prepare_observation(stacked, self.policy_config.model.model_type) + stacked_tensor = torch.from_numpy(stacked_np).to(self.policy_config.device) + + policy_output = self._policy.forward( + stacked_tensor, + action_mask=action_masks, + ready_env_id=ready_env_id, + timestep=timesteps, + ) + actions = {env_id: policy_output[env_id]['action'] for env_id in ready_env_id} + + timesteps_output = self._env.step(actions) + + for env_id, timestep_data in timesteps_output.items(): + episode_obs = timestep_data.obs + reward = float(timestep_data.reward) + done = bool(timestep_data.done) + next_obs = to_ndarray(episode_obs['observation']) + + obs_stacks[env_id].append(next_obs) + + if done: + eval_reward = timestep_data.info.get('eval_episode_return', reward) + monitor.update_reward(env_id, eval_reward) + monitor.update_info(env_id, timestep_data.info.get('episode_info', {})) + episode_returns.append(eval_reward) + ready_env_id.remove(env_id) + remain_episode += 1 + if remain_episode > 0: + obs_reset = self._env.ready_obs + obs_stacks[env_id] = deque( + [to_ndarray(obs_reset[env_id]['observation']) for _ in range(frame_stack)], + maxlen=frame_stack + ) + self._policy.reset([env_id]) + else: + continue + + duration = self._timer.value + reward_mean = float(np.mean(episode_returns)) if episode_returns else 0.0 + reward_std = float(np.std(episode_returns)) if episode_returns else 0.0 + reward_max = float(np.max(episode_returns)) if episode_returns else 0.0 + reward_min = float(np.min(episode_returns)) if episode_returns else 0.0 + + info = dict( + train_iter=train_iter, + envstep_count=envstep, + episode_count=len(episode_returns), + evaluate_time=duration, + reward_mean=reward_mean, + reward_std=reward_std, + reward_max=reward_max, + reward_min=reward_min, + ) + if self._logger is not None: + self._logger.info(self._logger.get_tabulate_vars_hor(info)) + if self._tb_logger is not None: + for k, v in info.items(): + if isinstance(v, (int, float)): + self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) + self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep) + if getattr(self.policy_config, 'use_wandb', False) and isinstance(v, (int, float)): + wandb.log({'{}_step/'.format(self._instance_name) + k: v}, step=envstep) + + episode_info = info + if reward_mean > self._max_episode_return: + if save_ckpt_fn: + save_ckpt_fn('ckpt_best.pth.tar') + self._max_episode_return = reward_mean + stop_flag = reward_mean >= self._stop_value and train_iter > 0 + + if get_world_size() > 1: + objects = [stop_flag, episode_info] + broadcast_object_list(objects, src=0) + stop_flag, episode_info = objects + + return stop_flag, episode_info diff --git a/zoo/jericho/configs/jericho_unizero_ppo_config.py b/zoo/jericho/configs/jericho_unizero_ppo_config.py new file mode 100644 index 000000000..097aecb39 --- /dev/null +++ b/zoo/jericho/configs/jericho_unizero_ppo_config.py @@ -0,0 +1,241 @@ +import os +import argparse +from typing import Any, Dict + +from easydict import EasyDict + + +def main(env_id: str = 'detective.z5', seed: int = 0, max_env_step: int = int(1e6)) -> None: + """ + Main entry point for setting up environment configurations and launching training. + + Args: + env_id (str): Identifier of the environment, e.g., 'detective.z5'. + seed (int): Random seed used for reproducibility. + + Returns: + None + """ + env_id = 'detective.z5' + + collector_env_num: int = 4 # Number of collector environments + n_episode = int(collector_env_num) + batch_size=64 + + # ------------------------------------------------------------------ + # Base environment parameters (Note: these values might be adjusted for different env_id) + # ------------------------------------------------------------------ + # Define environment configurations + env_configurations = { + 'detective.z5': (12, 100), + 'omniquest.z5': (25, 100), + 'acorncourt.z5': (45, 50), + 'zork1.z5': (55, 500), + } + + # Set action_space_size and max_steps based on env_id + action_space_size, max_steps = env_configurations.get(env_id, (10, 50)) # Default values if env_id not found + + # ------------------------------------------------------------------ + # User frequently modified configurations + # ------------------------------------------------------------------ + evaluator_env_num: int = 1 # Number of evaluator environments + num_simulations: int = 4 # Number of simulations + + # Project training parameters + num_unroll_steps: int = 10 # Number of unroll steps (for rollout sequence expansion) + infer_context_length: int = 4 # Inference context length + + num_layers: int = 2 # Number of layers in the model + replay_ratio: float = 0.1 # Replay ratio for experience replay + embed_dim: int = 768 # Embedding dimension + + # Reanalysis (reanalyze) parameters: + # buffer_reanalyze_freq: Frequency of reanalysis (e.g., 1 means reanalyze once per epoch) + buffer_reanalyze_freq: float = 1 / 100000 + # reanalyze_batch_size: Number of sequences to reanalyze per reanalysis process + reanalyze_batch_size: int = 160 + # reanalyze_partition: Partition ratio from the replay buffer to use during reanalysis + reanalyze_partition: float = 0.75 + + # Model name or path - configurable according to the predefined model paths or names + encoder_option = 'legacy' # ['qwen', 'legacy']. Legacy uses the bge encoder + + if encoder_option == 'qwen': + model_name: str = 'Qwen/Qwen3-0.6B' + elif encoder_option == 'legacy': + model_name: str = 'BAAI/bge-base-en-v1.5' + else: + raise ValueError(f"Unsupported encoder option: {encoder_option}") + + # ------------------------------------------------------------------ + # TODO: Debug configuration - override some parameters for debugging purposes + # ------------------------------------------------------------------ + # max_env_step = int(2e5) + # batch_size = 10 + # num_simulations = 2 + # num_unroll_steps = 5 + # infer_context_length = 2 + # max_steps = 10 + # num_layers = 1 + # replay_ratio = 0.05 + # ------------------------------------------------------------------ + # Configuration dictionary for the Jericho Unizero environment and policy + # ------------------------------------------------------------------ + jericho_unizero_config: Dict[str, Any] = dict( + env=dict( + stop_value=int(1e6), + observation_shape=512, + max_steps=max_steps, + max_action_num=action_space_size, + tokenizer_path=model_name, + max_seq_len=512, + game_path=f"./zoo/jericho/envs/z-machine-games-master/jericho-game-suite/{env_id}", + for_unizero=True, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False), + ), + policy=dict( + type="unizero_ppo", + multi_gpu=False, + use_wandb=False, + learn=dict( + learner=dict( + hook=dict( + save_ckpt_after_iter=1000000, # To save memory, set a large value. If intermediate checkpoints are needed, reduce this value. + ), + ), + ), + accumulation_steps=1, # TODO: Accumulated gradient steps (currently default) + model=dict( + observation_shape=512, + action_space_size=action_space_size, + encoder_option=encoder_option, + encoder_url=model_name, + model_type="mlp", + continuous_action_space=False, + world_model_cfg=dict( + final_norm_option_in_obs_head='LayerNorm', + final_norm_option_in_encoder='LayerNorm', + predict_latent_loss_type='mse', + policy_entropy_weight=5e-2, + continuous_action_space=False, + max_blocks=num_unroll_steps, + # Note: Each timestep contains 2 tokens: observation and action. + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device="cuda", + action_space_size=action_space_size, + num_layers=num_layers, + num_heads=24, + embed_dim=embed_dim, + obs_type="text", + env_num=max(collector_env_num, evaluator_env_num), + decode_loss_mode=None, # Controls where to compute reconstruction loss: after_backbone, before_backbone, or None. + latent_recon_loss_weight=0.1 + ), + ), + update_per_collect=int(collector_env_num*max_steps*replay_ratio ), # Important for DDP + action_type="varied_action_space", + model_path=None, + num_unroll_steps=num_unroll_steps, + reanalyze_ratio=0, + replay_ratio=replay_ratio, + batch_size=batch_size, + learning_rate=0.0001, + cos_lr_scheduler=False, + fixed_temperature_value=0.25, + manual_temperature_decay=False, + num_simulations=num_simulations, + n_episode=n_episode, + train_start_after_envsteps=0, + replay_buffer_size=int(5e5), + eval_freq=int(3e4), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + buffer_reanalyze_freq=buffer_reanalyze_freq, + reanalyze_batch_size=reanalyze_batch_size, + reanalyze_partition=reanalyze_partition, + ppo=dict( + rollout_length=64, + mini_batch_size=32, + update_epochs=4, + gamma=0.997, + gae_lambda=0.95, + clip_ratio=0.2, + entropy_coef=0.01, + advantage_normalization=True, + ), + ), + ) + jericho_unizero_config = EasyDict(jericho_unizero_config) + + # ------------------------------------------------------------------ + # Create configuration for importing environment and policy modules + # ------------------------------------------------------------------ + jericho_unizero_create_config: Dict[str, Any] = dict( + env=dict( + type="jericho", + import_names=["zoo.jericho.envs.jericho_env"], + ), + # Use base env manager to avoid bugs present in subprocess env manager. + env_manager=dict(type="base"), + policy=dict( + type="unizero_ppo", + import_names=["lzero.policy.unizero_ppo"], + ), + ) + jericho_unizero_create_config = EasyDict(jericho_unizero_create_config) + + # ------------------------------------------------------------------ + # Combine configuration dictionaries and construct an experiment name + # ------------------------------------------------------------------ + main_config: EasyDict = jericho_unizero_config + create_config: EasyDict = jericho_unizero_create_config + + # Construct experiment name containing key parameters + main_config.exp_name = ( + f"data_lz/data_unizero_jericho_ppo_debug/bge-base-en-v1.5/{env_id}/uz_gpu_cen{collector_env_num}_rr{replay_ratio}_ftemp025_{env_id[:8]}_ms{max_steps}_ass-{action_space_size}_" + f"nlayer{num_layers}_embed{embed_dim}_Htrain{num_unroll_steps}-" + f"Hinfer{infer_context_length}_bs{batch_size}_seed{seed}" + ) + from lzero.entry import train_unizero_ppo + # Launch the training process + train_unizero_ppo( + [main_config, create_config], + seed=seed, + model_path=main_config.policy.model_path, + max_env_step=max_env_step, + ) + + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs. + Run the following command to launch the script: + torchrun --nproc_per_node=4 ./zoo/jericho/configs/jericho_unizero_ddp_config.py + """ + + parser = argparse.ArgumentParser(description='Process environment configuration and launch training.') + parser.add_argument( + '--env', + type=str, + help='Identifier of the environment, e.g., detective.z5 or zork1.z5', + default='detective.z5' + ) + parser.add_argument( + '--seed', + type=int, + help='Random seed for reproducibility', + default=0 + ) + args = parser.parse_args() + + # Disable tokenizer parallelism to prevent multi-process conflicts + os.environ['TOKENIZERS_PARALLELISM'] = 'false' + + # Start the main process with the provided arguments + main(args.env, args.seed)