From ee2c76e165d09abd6ac928b30c416f2772205dc0 Mon Sep 17 00:00:00 2001 From: tAnGjIa520 <1157507000@qq.com> Date: Mon, 5 Jan 2026 14:44:12 +0800 Subject: [PATCH] efficientzerov2 --- .../ctree/ctree_efficientzero_v2/Makefile | 72 + .../ctree/ctree_efficientzero_v2/__init__.py | 3 + .../ctree/ctree_efficientzero_v2/ez_tree.pxd | 124 ++ .../ctree/ctree_efficientzero_v2/ez_tree.pyx | 158 ++ .../ctree_efficientzero_v2/lib/cminimax.cpp | 71 + .../ctree_efficientzero_v2/lib/cminimax.h | 45 + .../ctree_efficientzero_v2/lib/cnode.cpp | 1792 +++++++++++++++++ .../ctree/ctree_efficientzero_v2/lib/cnode.h | 134 ++ .../test_batch_sequential_halving | Bin 0 -> 88184 bytes .../test_batch_sequential_halving.py | 259 +++ lzero/mcts/tree_search/__init__.py | 2 +- lzero/mcts/tree_search/mcts_ctree.py | 332 ++- lzero/policy/efficientzero_v2.py | 861 ++++++++ .../config/atari_efficientzero_v2_config.py | 98 + .../cartpole_efficientzero_v2_config.py | 89 + 15 files changed, 4030 insertions(+), 10 deletions(-) create mode 100644 lzero/mcts/ctree/ctree_efficientzero_v2/Makefile create mode 100644 lzero/mcts/ctree/ctree_efficientzero_v2/__init__.py create mode 100644 lzero/mcts/ctree/ctree_efficientzero_v2/ez_tree.pxd create mode 100644 lzero/mcts/ctree/ctree_efficientzero_v2/ez_tree.pyx create mode 100644 lzero/mcts/ctree/ctree_efficientzero_v2/lib/cminimax.cpp create mode 100644 lzero/mcts/ctree/ctree_efficientzero_v2/lib/cminimax.h create mode 100644 lzero/mcts/ctree/ctree_efficientzero_v2/lib/cnode.cpp create mode 100644 lzero/mcts/ctree/ctree_efficientzero_v2/lib/cnode.h create mode 100755 lzero/mcts/ctree/ctree_efficientzero_v2/test_batch_sequential_halving create mode 100644 lzero/mcts/ctree/ctree_efficientzero_v2/test_batch_sequential_halving.py create mode 100644 lzero/policy/efficientzero_v2.py create mode 100644 zoo/atari/config/atari_efficientzero_v2_config.py create mode 100644 zoo/classic_control/cartpole/config/cartpole_efficientzero_v2_config.py diff --git a/lzero/mcts/ctree/ctree_efficientzero_v2/Makefile b/lzero/mcts/ctree/ctree_efficientzero_v2/Makefile new file mode 100644 index 000000000..0f1774f48 --- /dev/null +++ b/lzero/mcts/ctree/ctree_efficientzero_v2/Makefile @@ -0,0 +1,72 @@ +CXX = g++ +CXXFLAGS = -std=c++11 -Wall -Wextra -O2 +INCLUDES = -I. + +# Source files +SOURCES = lib/cnode.cpp lib/cminimax.cpp +TEST_SOURCE = test_cnode.cpp +TEST_SH_SOURCE = test_sequential_halving.cpp +TEST_BATCH_SH_SOURCE = test_batch_sequential_halving.cpp + +# Object files +OBJECTS = $(SOURCES:.cpp=.o) +TEST_OBJECTS = $(TEST_SOURCE:.cpp=.o) +TEST_SH_OBJECTS = $(TEST_SH_SOURCE:.cpp=.o) +TEST_BATCH_SH_OBJECTS = $(TEST_BATCH_SH_SOURCE:.cpp=.o) + +# Executables +TEST_EXECUTABLE = test_cnode +TEST_SH_EXECUTABLE = test_sequential_halving +TEST_BATCH_SH_EXECUTABLE = test_batch_sequential_halving + +.PHONY: all clean test test_sh test_batch_sh + +all: $(TEST_EXECUTABLE) + +$(TEST_EXECUTABLE): $(OBJECTS) $(TEST_OBJECTS) + @echo "Linking test executable..." + $(CXX) $(CXXFLAGS) -o $(TEST_EXECUTABLE) $(OBJECTS) $(TEST_OBJECTS) + @echo "Build successful!" + +lib/%.o: lib/%.cpp lib/%.h lib/cminimax.h + @echo "Compiling $<..." + $(CXX) $(CXXFLAGS) $(INCLUDES) -c $< -o $@ + +test_cnode.o: test_cnode.cpp + @echo "Compiling test_cnode.cpp..." + $(CXX) $(CXXFLAGS) $(INCLUDES) -c test_cnode.cpp -o test_cnode.o + +test_sequential_halving.o: test_sequential_halving.cpp + @echo "Compiling test_sequential_halving.cpp..." + $(CXX) $(CXXFLAGS) $(INCLUDES) -c test_sequential_halving.cpp -o test_sequential_halving.o + +$(TEST_SH_EXECUTABLE): $(OBJECTS) $(TEST_SH_OBJECTS) + @echo "Linking test_sequential_halving executable..." + $(CXX) $(CXXFLAGS) -o $(TEST_SH_EXECUTABLE) $(OBJECTS) $(TEST_SH_OBJECTS) + @echo "Build successful!" + +test_batch_sequential_halving.o: test_batch_sequential_halving.cpp + @echo "Compiling test_batch_sequential_halving.cpp..." + $(CXX) $(CXXFLAGS) $(INCLUDES) -c test_batch_sequential_halving.cpp -o test_batch_sequential_halving.o + +$(TEST_BATCH_SH_EXECUTABLE): $(OBJECTS) $(TEST_BATCH_SH_OBJECTS) + @echo "Linking test_batch_sequential_halving executable..." + $(CXX) $(CXXFLAGS) -o $(TEST_BATCH_SH_EXECUTABLE) $(OBJECTS) $(TEST_BATCH_SH_OBJECTS) + @echo "Build successful!" + +test: all + @echo "Running tests..." + @./$(TEST_EXECUTABLE) + +test_sh: $(TEST_SH_EXECUTABLE) + @echo "Running Sequential Halving tests..." + @./$(TEST_SH_EXECUTABLE) + +test_batch_sh: $(TEST_BATCH_SH_EXECUTABLE) + @echo "Running batch Sequential Halving tests..." + @./$(TEST_BATCH_SH_EXECUTABLE) + +clean: + @echo "Cleaning up..." + rm -f $(OBJECTS) $(TEST_OBJECTS) $(TEST_SH_OBJECTS) $(TEST_BATCH_SH_OBJECTS) $(TEST_EXECUTABLE) $(TEST_SH_EXECUTABLE) $(TEST_BATCH_SH_EXECUTABLE) + @echo "Clean complete!" diff --git a/lzero/mcts/ctree/ctree_efficientzero_v2/__init__.py b/lzero/mcts/ctree/ctree_efficientzero_v2/__init__.py new file mode 100644 index 000000000..a2ee6590c --- /dev/null +++ b/lzero/mcts/ctree/ctree_efficientzero_v2/__init__.py @@ -0,0 +1,3 @@ +from . import ez_tree + +__all__ = ['ez_tree'] diff --git a/lzero/mcts/ctree/ctree_efficientzero_v2/ez_tree.pxd b/lzero/mcts/ctree/ctree_efficientzero_v2/ez_tree.pxd new file mode 100644 index 000000000..b935d5b4e --- /dev/null +++ b/lzero/mcts/ctree/ctree_efficientzero_v2/ez_tree.pxd @@ -0,0 +1,124 @@ +# distutils:language=c++ +# cython:language_level=3 +from libcpp.vector cimport vector + + +cdef extern from "lib/cminimax.cpp": + pass + + +cdef extern from "lib/cminimax.h" namespace "tools": + cdef cppclass CMinMaxStats: + CMinMaxStats() except + + int c_visit + float c_scale + float maximum, minimum, value_delta_max + + void set_delta(float value_delta_max) + void set_static_val(float value_delta_max, int c_visit, float c_scale) + void update(float value) + void clear() + float normalize(float value) + + cdef cppclass CMinMaxStatsList: + CMinMaxStatsList() except + + CMinMaxStatsList(int num) except + + int num + vector[CMinMaxStats] stats_lst + + void set_delta(float value_delta_max) + void set_static_val(float value_delta_max, int c_visit, float c_scale) + +cdef extern from "lib/cnode.cpp": + pass + + +cdef extern from "lib/cnode.h" namespace "tree": + cdef cppclass CNode: + CNode() except + + CNode(float prior, vector[int] & legal_actions) except + + int visit_count, to_play, current_latent_state_index, batch_index, best_action + float value_prefixs, prior, value_sum, parent_value_prefix + + void expand(int to_play, int current_latent_state_index, int batch_index, float value_prefixs, + vector[float] policy_logits) + void add_exploration_noise(float exploration_fraction, vector[float] noises) + float compute_mean_q(int isRoot, float parent_q, float discount_factor) + + int expanded() + float value() + vector[int] get_trajectory() + vector[int] get_children_distribution() + CNode * get_child(int action) + + cdef cppclass CRoots: + CRoots() except + + CRoots(int root_num, vector[vector[int]] legal_actions_list) except + + int root_num + vector[CNode] roots + + void prepare(float root_noise_weight, const vector[vector[float]] & noises, + const vector[float] & value_prefixs, const vector[vector[float]] & policies, + vector[int] to_play_batch) + void prepare_no_noise(const vector[float] & value_prefixs, const vector[vector[float]] & policies, + vector[int] to_play_batch) + void clear() + vector[vector[int]] get_trajectories() + vector[vector[int]] get_distributions() + vector[float] get_values() + vector[vector[float]] get_root_policies(CMinMaxStatsList *min_max_stats_lst) + vector[int] get_best_actions() + # visualize related code + # CNode* get_root(int index) + + cdef cppclass CSearchResults: + CSearchResults() except + + CSearchResults(int num) except + + int num + vector[int] latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, search_lens + vector[int] virtual_to_play_batchs + vector[CNode *] nodes + + cdef void cbackpropagate(vector[CNode *] & search_path, CMinMaxStats & min_max_stats, + int to_play, float value, float discount_factor) + void cbatch_backpropagate(int current_latent_state_index, float discount_factor, vector[float] value_prefixs, + vector[float] values, vector[vector[float]] policies, + CMinMaxStatsList *min_max_stats_lst, CSearchResults & results, + vector[int] is_reset_list, vector[int] & to_play_batch) + void cbatch_backpropagate_with_reuse(int current_latent_state_index, float discount_factor, vector[float] value_prefixs, + vector[float] values, vector[vector[float]] policies, + CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, + vector[int] is_reset_list, vector[int] &to_play_batch, vector[int] &no_inference_lst, + vector[int] &reuse_lst, vector[float] &reuse_value_lst) + # ========== MuZero/UCB 风格的遍历(备份) ========== + void cbatch_traverse_ucb(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, + CMinMaxStatsList *min_max_stats_lst, CSearchResults & results, + vector[int] & virtual_to_play_batch) + void cbatch_traverse_with_reuse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, + CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, + vector[int] &virtual_to_play_batch, vector[int] &true_action, vector[float] &reuse_value) + + # ========== EfficientZero V2 风格的遍历(Sequential Halving 集成) ========== + void cbatch_traverse(CRoots *roots, CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, + int num_simulations, int simulation_idx, const vector[vector[float]]& gumble_noise, + int current_num_top_actions, vector[int] &virtual_to_play_batch) + + # ========== EfficientZero V2 Sequential Halving ========== + vector[int] c_batch_sequential_halving(CRoots *roots, vector[vector[float]] gumbel_noises, + CMinMaxStatsList *min_max_stats_lst, int current_phase, + int current_num_top_actions) + +cdef class MinMaxStatsList: + cdef CMinMaxStatsList *cmin_max_stats_lst + +cdef class ResultsWrapper: + cdef CSearchResults cresults + +cdef class Roots: + cdef readonly int root_num + cdef CRoots *roots + cdef readonly int num_actions + cdef public object legal_actions_list + +cdef class Node: + cdef CNode cnode diff --git a/lzero/mcts/ctree/ctree_efficientzero_v2/ez_tree.pyx b/lzero/mcts/ctree/ctree_efficientzero_v2/ez_tree.pyx new file mode 100644 index 000000000..1a589db0b --- /dev/null +++ b/lzero/mcts/ctree/ctree_efficientzero_v2/ez_tree.pyx @@ -0,0 +1,158 @@ +# distutils:language=c++ +# cython:language_level=3 +import cython +from libcpp.vector cimport vector + +cdef class MinMaxStatsList: + @cython.binding + def __cinit__(self, int num): + self.cmin_max_stats_lst = new CMinMaxStatsList(num) + + @cython.binding + def set_delta(self, float value_delta_max): + self.cmin_max_stats_lst[0].set_delta(value_delta_max) + + def __dealloc__(self): + del self.cmin_max_stats_lst + +cdef class ResultsWrapper: + @cython.binding + def __cinit__(self, int num): + self.cresults = CSearchResults(num) + + @cython.binding + def get_search_len(self): + return self.cresults.search_lens + +cdef class Roots: + @cython.binding + def __cinit__(self, int root_num, vector[vector[int]] legal_actions_list): + self.root_num = root_num + self.roots = new CRoots(root_num, legal_actions_list) + # Store legal_actions for access from Python + self.legal_actions_list = legal_actions_list + + @cython.binding + def prepare(self, float root_noise_weight, list noises, list value_prefix_pool, + list policy_logits_pool, vector[int] & to_play_batch): + self.roots[0].prepare(root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play_batch) + + @cython.binding + def prepare_no_noise(self, list value_prefix_pool, list policy_logits_pool, vector[int] & to_play_batch): + self.roots[0].prepare_no_noise(value_prefix_pool, policy_logits_pool, to_play_batch) + + @cython.binding + def get_trajectories(self): + return self.roots[0].get_trajectories() + + @cython.binding + def get_distributions(self): + return self.roots[0].get_distributions() + + @cython.binding + def get_values(self): + return self.roots[0].get_values() + + @cython.binding + def get_root_policies(self, MinMaxStatsList min_max_stats_lst): + return self.roots[0].get_root_policies(min_max_stats_lst.cmin_max_stats_lst) + + @cython.binding + def get_best_actions(self): + return self.roots[0].get_best_actions() + + # visualize related code + #def get_root(self, int index): + # return self.roots[index] + + @cython.binding + def clear(self): + self.roots[0].clear() + + @cython.binding + def get_legal_actions(self): + """Get the legal actions list""" + return list(self.legal_actions_list) + + def __dealloc__(self): + del self.roots + + @property + def num(self): + return self.root_num + +cdef class Node: + def __cinit__(self): + pass + + def __cinit__(self, float prior, vector[int] & legal_actions): + pass + + @cython.binding + def expand(self, int to_play, int current_latent_state_index, int batch_index, float value_prefix, + list policy_logits): + cdef vector[float] cpolicy = policy_logits + self.cnode.expand(to_play, current_latent_state_index, batch_index, value_prefix, cpolicy) + +@cython.binding +def batch_backpropagate(int current_latent_state_index, float discount_factor, list value_prefixs, list values, list policies, + MinMaxStatsList min_max_stats_lst, ResultsWrapper results, list is_reset_list, + list to_play_batch): + cdef int i + cdef vector[float] cvalue_prefixs = value_prefixs + cdef vector[float] cvalues = values + cdef vector[vector[float]] cpolicies = policies + + cbatch_backpropagate(current_latent_state_index, discount_factor, cvalue_prefixs, cvalues, cpolicies, + min_max_stats_lst.cmin_max_stats_lst, results.cresults, is_reset_list, to_play_batch) + +@cython.binding +def batch_backpropagate_with_reuse(int current_latent_state_index, float discount_factor, list value_prefixs, list values, list policies, + MinMaxStatsList min_max_stats_lst, ResultsWrapper results, list is_reset_list, + list to_play_batch, list no_inference_lst, list reuse_lst, list reuse_value_lst): + cdef int i + cdef vector[float] cvalue_prefixs = value_prefixs + cdef vector[float] cvalues = values + cdef vector[vector[float]] cpolicies = policies + cdef vector[float] creuse_value_lst = reuse_value_lst + + cbatch_backpropagate_with_reuse(current_latent_state_index, discount_factor, cvalue_prefixs, cvalues, cpolicies, + min_max_stats_lst.cmin_max_stats_lst, results.cresults, is_reset_list, to_play_batch, no_inference_lst, reuse_lst, creuse_value_lst) + +# ========== MuZero/UCB 风格的遍历(备份版本) ========== +@cython.binding +def batch_traverse_ucb(Roots roots, int pb_c_base, float pb_c_init, float discount_factor, MinMaxStatsList min_max_stats_lst, + ResultsWrapper results, list virtual_to_play_batch): + """MuZero/UCB 风格的批量树遍历(备份版本)""" + cbatch_traverse_ucb(roots.roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst.cmin_max_stats_lst, + results.cresults, virtual_to_play_batch) + + return results.cresults.latent_state_index_in_search_path, results.cresults.latent_state_index_in_batch, results.cresults.last_actions, results.cresults.virtual_to_play_batchs + +# ========== EfficientZero V2 风格的遍历(Sequential Halving 集成) ========== +@cython.binding +def batch_traverse(Roots roots, MinMaxStatsList min_max_stats_lst, ResultsWrapper results, + int num_simulations, int simulation_idx, list gumbel_noises, + int current_num_top_actions, list virtual_to_play_batch): + """EfficientZero V2 风格的批量树遍历,集成 Sequential Halving 逻辑""" + cdef vector[vector[float]] c_gumbel_noises = gumbel_noises + cbatch_traverse(roots.roots, min_max_stats_lst.cmin_max_stats_lst, results.cresults, + num_simulations, simulation_idx, c_gumbel_noises, + current_num_top_actions, virtual_to_play_batch) + + return results.cresults.latent_state_index_in_search_path, results.cresults.latent_state_index_in_batch, results.cresults.last_actions, results.cresults.virtual_to_play_batchs + +@cython.binding +def batch_traverse_with_reuse(Roots roots, int pb_c_base, float pb_c_init, float discount_factor, MinMaxStatsList min_max_stats_lst, + ResultsWrapper results, list virtual_to_play_batch, list true_action, list reuse_value): + cbatch_traverse_with_reuse(roots.roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst.cmin_max_stats_lst, results.cresults, + virtual_to_play_batch, true_action, reuse_value) + + return results.cresults.latent_state_index_in_search_path, results.cresults.latent_state_index_in_batch, results.cresults.last_actions, results.cresults.virtual_to_play_batchs + +@cython.binding +def batch_sequential_halving(Roots roots, list gumbel_noises, MinMaxStatsList min_max_stats_lst, + int current_phase, int current_num_top_actions): + cdef vector[vector[float]] c_gumbel_noises = gumbel_noises + return c_batch_sequential_halving(roots.roots, c_gumbel_noises, min_max_stats_lst.cmin_max_stats_lst, + current_phase, current_num_top_actions) diff --git a/lzero/mcts/ctree/ctree_efficientzero_v2/lib/cminimax.cpp b/lzero/mcts/ctree/ctree_efficientzero_v2/lib/cminimax.cpp new file mode 100644 index 000000000..b438ad51e --- /dev/null +++ b/lzero/mcts/ctree/ctree_efficientzero_v2/lib/cminimax.cpp @@ -0,0 +1,71 @@ +// C++11 + +#include "cminimax.h" +#include +#include + +namespace tools { + + // ========== CMinMaxStats Implementation ========== + CMinMaxStats::CMinMaxStats() { + this->maximum = FLOAT_MIN; + this->minimum = FLOAT_MAX; + this->value_delta_max = 0.01; + this->c_visit = 1; + this->c_scale = 1.0; + } + + CMinMaxStats::~CMinMaxStats() {} + + void CMinMaxStats::set_delta(float value_delta_max) { + this->value_delta_max = value_delta_max; + } + + void CMinMaxStats::set_static_val(float value_delta_max, int c_visit, float c_scale) { + this->value_delta_max = value_delta_max; + this->c_visit = c_visit; + this->c_scale = c_scale; + } + + void CMinMaxStats::update(float value) { + this->maximum = std::max(this->maximum, value); + this->minimum = std::min(this->minimum, value); + } + + void CMinMaxStats::clear() { + this->maximum = FLOAT_MIN; + this->minimum = FLOAT_MAX; + } + + float CMinMaxStats::normalize(float value) { + float delta = this->maximum - this->minimum; + delta = std::max(delta, this->value_delta_max); + return (value - this->minimum) / delta; + } + + // ========== CMinMaxStatsList Implementation ========== + CMinMaxStatsList::CMinMaxStatsList() { + this->num = 0; + } + + CMinMaxStatsList::CMinMaxStatsList(int num) { + this->num = num; + for (int i = 0; i < num; ++i) { + this->stats_lst.push_back(CMinMaxStats()); + } + } + + CMinMaxStatsList::~CMinMaxStatsList() {} + + void CMinMaxStatsList::set_delta(float value_delta_max) { + for (int i = 0; i < this->num; ++i) { + this->stats_lst[i].set_delta(value_delta_max); + } + } + + void CMinMaxStatsList::set_static_val(float value_delta_max, int c_visit, float c_scale) { + for (int i = 0; i < this->num; ++i) { + this->stats_lst[i].set_static_val(value_delta_max, c_visit, c_scale); + } + } +} diff --git a/lzero/mcts/ctree/ctree_efficientzero_v2/lib/cminimax.h b/lzero/mcts/ctree/ctree_efficientzero_v2/lib/cminimax.h new file mode 100644 index 000000000..4debf35c9 --- /dev/null +++ b/lzero/mcts/ctree/ctree_efficientzero_v2/lib/cminimax.h @@ -0,0 +1,45 @@ +// C++11 + +#ifndef CMINIMAX_H +#define CMINIMAX_H + +#include +#include + +const float FLOAT_MAX = 1000000.0; +const float FLOAT_MIN = -FLOAT_MAX; +const float EPSILON = 0.000001; + +namespace tools { + + class CMinMaxStats { + public: + int c_visit; + float c_scale; + float maximum, minimum, value_delta_max; + + CMinMaxStats(); + ~CMinMaxStats(); + + void set_delta(float value_delta_max); + void set_static_val(float value_delta_max, int c_visit, float c_scale); + void update(float value); + void clear(); + float normalize(float value); + }; + + class CMinMaxStatsList { + public: + int num; + std::vector stats_lst; + + CMinMaxStatsList(); + CMinMaxStatsList(int num); + ~CMinMaxStatsList(); + + void set_delta(float value_delta_max); + void set_static_val(float value_delta_max, int c_visit, float c_scale); + }; +} + +#endif diff --git a/lzero/mcts/ctree/ctree_efficientzero_v2/lib/cnode.cpp b/lzero/mcts/ctree/ctree_efficientzero_v2/lib/cnode.cpp new file mode 100644 index 000000000..4a2afc645 --- /dev/null +++ b/lzero/mcts/ctree/ctree_efficientzero_v2/lib/cnode.cpp @@ -0,0 +1,1792 @@ +// C++11 + +#include +#include "cnode.h" +#include "cminimax.h" +#include +#include +#include +#include + +#ifdef _WIN32 +#include "..\..\common_lib\utils.cpp" +#else +#include "../../common_lib/utils.cpp" +#endif + + +namespace tree +{ + + CSearchResults::CSearchResults() + { + /* + Overview: + Initialization of CSearchResults, the default result number is set to 0. + */ + this->num = 0; + } + + CSearchResults::CSearchResults(int num) + { + /* + Overview: + Initialization of CSearchResults with result number. + */ + this->num = num; + for (int i = 0; i < num; ++i) + { + this->search_paths.push_back(std::vector()); + } + } + + CSearchResults::~CSearchResults() {} + + //********************************************************* + + CNode::CNode() + { + /* + Overview: + Initialization of CNode. + */ + this->prior = 0; + this->legal_actions = legal_actions; + + this->is_reset = 0; + this->visit_count = 0; + this->value_sum = 0; + this->best_action = -1; + this->to_play = 0; + this->value_prefix = 0.0; + this->parent_value_prefix = 0.0; + } + + CNode::CNode(float prior, std::vector &legal_actions) + { + /* + Overview: + Initialization of CNode with prior value and legal actions. + Arguments: + - prior: the prior value of this node. + - legal_actions: a vector of legal actions of this node. + */ + this->prior = prior; + this->legal_actions = legal_actions; + + this->is_reset = 0; + this->visit_count = 0; + this->value_sum = 0; + this->best_action = -1; + this->to_play = 0; + this->value_prefix = 0.0; + this->parent_value_prefix = 0.0; + this->current_latent_state_index = -1; + this->batch_index = -1; + } + + CNode::~CNode() {} + + void CNode::expand(int to_play, int current_latent_state_index, int batch_index, float value_prefix, const std::vector &policy_logits) + { + /* + Overview: + Expand the child nodes of the current node. + Arguments: + - to_play: which player to play the game in the current node. + - current_latent_state_index: the x/first index of hidden state vector of the current node, i.e. the search depth. + - batch_index: the y/second index of hidden state vector of the current node, i.e. the index of batch root node, its maximum is ``batch_size``/``env_num``. + - value_prefix: the value prefix of the current node. + - policy_logits: the policy logit of the child nodes. + */ + this->to_play = to_play; + this->current_latent_state_index = current_latent_state_index; + this->batch_index = batch_index; + this->value_prefix = value_prefix; + + int action_num = policy_logits.size(); + if (this->legal_actions.size() == 0) + { + for (int i = 0; i < action_num; ++i) + { + this->legal_actions.push_back(i); + } + } + float temp_policy; + float policy_sum = 0.0; + + #ifdef _WIN32 + // 创建动态数组 + float* policy = new float[action_num]; + #else + float policy[action_num]; + #endif + + float policy_max = FLOAT_MIN; + for (auto a : this->legal_actions) + { + if (policy_max < policy_logits[a]) + { + policy_max = policy_logits[a]; + } + } + + for (auto a : this->legal_actions) + { + temp_policy = exp(policy_logits[a] - policy_max); + policy_sum += temp_policy; + policy[a] = temp_policy; + } + + float prior; + for (auto a : this->legal_actions) + { + prior = policy[a] / policy_sum; + std::vector tmp_empty; + this->children[a] = CNode(prior, tmp_empty); // only for muzero/efficient zero, not support alphazero + } + #ifdef _WIN32 + // 释放数组内存 + delete[] policy; + #else + #endif + } + + void CNode::add_exploration_noise(float exploration_fraction, const std::vector &noises) + { + /* + Overview: + Add a noise to the prior of the child nodes. + Arguments: + - exploration_fraction: the fraction to add noise. + - noises: the vector of noises added to each child node. + */ + float noise, prior; + for (int i = 0; i < this->legal_actions.size(); ++i) + { + noise = noises[i]; + CNode *child = this->get_child(this->legal_actions[i]); + + prior = child->prior; + child->prior = prior * (1 - exploration_fraction) + noise * exploration_fraction; + } + } + + float CNode::compute_mean_q(int isRoot, float parent_q, float discount_factor) + { + /* + Overview: + Compute the mean q value of the current node. + Arguments: + - isRoot: whether the current node is a root node. + - parent_q: the q value of the parent node. + - discount_factor: the discount_factor of reward. + */ + float total_unsigned_q = 0.0; + int total_visits = 0; + float parent_value_prefix = this->value_prefix; + for (auto a : this->legal_actions) + { + CNode *child = this->get_child(a); + if (child->visit_count > 0) + { + float true_reward = child->value_prefix - parent_value_prefix; + if (this->is_reset == 1) + { + true_reward = child->value_prefix; + } + float qsa = true_reward + discount_factor * child->value(); + total_unsigned_q += qsa; + total_visits += 1; + } + } + + float mean_q = 0.0; + if (isRoot && total_visits > 0) + { + mean_q = (total_unsigned_q) / (total_visits); + } + else + { + mean_q = (parent_q + total_unsigned_q) / (total_visits + 1); + } + return mean_q; + } + + int CNode::expanded() + { + /* + Overview: + Return whether the current node is expanded. + */ + return this->children.size() > 0; + } + + float CNode::value() + { + /* + Overview: + Return the estimated value of the current tree. + Current implementation: uses value_sum / visit_count (traditional MCTS style). + Returns: + The mean value of all backpropagations through this node. + */ + float true_value = 0.0; + if (this->visit_count == 0) + { + return true_value; + } + else + { + true_value = this->value_sum / this->visit_count; + return true_value; + } + } + + float CNode::value_v2() + { + /* + Overview: + EfficientZero V2 original style value estimation using estimated_value_lst. + This method is NOT USED in current implementation, kept for reference. + + Difference from value(): + - value(): uses value_sum / visit_count (O(1) memory, cumulative sum) + - value_v2(): uses sum(estimated_value_lst) / len(estimated_value_lst) (O(n) memory, keeps history) + + Mathematical equivalence: + Both methods produce the same average value. + + Why keep estimated_value_lst? + 1. Reanalyze: Can re-evaluate nodes with new network + 2. Uncertainty estimation: Can compute std(estimated_value_lst) + 3. Value distribution: Can analyze value distribution + 4. Advanced features: Ensemble learning, distributed training + + Current status: + UNUSED - The estimated_value_lst is initialized but not populated. + To enable: uncomment the push_back lines in cbackpropagate(). + */ + if (this->estimated_value_lst.empty()) + { + return 0.0; + } + + float sum = 0.0; + for (float v : this->estimated_value_lst) + { + sum += v; + } + return sum / this->estimated_value_lst.size(); + } + + std::vector CNode::get_trajectory() + { + /* + Overview: + Find the current best trajectory starts from the current node. + Returns: + - traj: a vector of node index, which is the current best trajectory from this node. + */ + std::vector traj; + + CNode *node = this; + int best_action = node->best_action; + while (best_action >= 0) + { + traj.push_back(best_action); + + node = node->get_child(best_action); + best_action = node->best_action; + } + return traj; + } + + std::vector CNode::get_children_distribution() + { + /* + Overview: + Get the distribution of child nodes in the format of visit_count. + Returns: + - distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]). + */ + std::vector distribution; + if (this->expanded()) + { + for (auto a : this->legal_actions) + { + CNode *child = this->get_child(a); + distribution.push_back(child->visit_count); + } + } + return distribution; + } + + CNode *CNode::get_child(int action) + { + /* + Overview: + Get the child node corresponding to the input action. + Arguments: + - action: the action to get child. + */ + return &(this->children[action]); + } + + //********************************************************* + + CRoots::CRoots() + { + /* + Overview: + The initialization of CRoots. + */ + this->root_num = 0; + } + + CRoots::CRoots(int root_num, std::vector > &legal_actions_list) + { + /* + Overview: + The initialization of CRoots with root num and legal action lists. + Arguments: + - root_num: the number of the current root. + - legal_action_list: the vector of the legal action of this root. + */ + this->root_num = root_num; + this->legal_actions_list = legal_actions_list; + + for (int i = 0; i < root_num; ++i) + { + this->roots.push_back(CNode(0, this->legal_actions_list[i])); + } + } + + CRoots::~CRoots() {} + + void CRoots::prepare(float root_noise_weight, const std::vector > &noises, const std::vector &value_prefixs, const std::vector > &policies, std::vector &to_play_batch) + { + /* + Overview: + Expand the roots and add noises. + Arguments: + - root_noise_weight: the exploration fraction of roots + - noises: the vector of noise add to the roots. + - value_prefixs: the vector of value prefixs of each root. + - policies: the vector of policy logits of each root. + - to_play_batch: the vector of the player side of each root. + */ + for (int i = 0; i < this->root_num; ++i) + { + this->roots[i].expand(to_play_batch[i], 0, i, value_prefixs[i], policies[i]); + this->roots[i].add_exploration_noise(root_noise_weight, noises[i]); + this->roots[i].visit_count += 1; + + // Initialize selected_children_idx with all legal actions for Sequential Halving + this->roots[i].selected_children_idx.clear(); + for (int action : this->roots[i].legal_actions) { + this->roots[i].selected_children_idx.push_back(action); + } + } + } + + void CRoots::prepare_no_noise(const std::vector &value_prefixs, const std::vector > &policies, std::vector &to_play_batch) + { + /* + Overview: + Expand the roots without noise. + Arguments: + - value_prefixs: the vector of value prefixs of each root. + - policies: the vector of policy logits of each root. + - to_play_batch: the vector of the player side of each root. + */ + for (int i = 0; i < this->root_num; ++i) + { + this->roots[i].expand(to_play_batch[i], 0, i, value_prefixs[i], policies[i]); + this->roots[i].visit_count += 1; + + // Initialize selected_children_idx with all legal actions for Sequential Halving + this->roots[i].selected_children_idx.clear(); + for (int action : this->roots[i].legal_actions) { + this->roots[i].selected_children_idx.push_back(action); + } + } + } + + void CRoots::clear() + { + /* + Overview: + Clear the roots vector. + */ + this->roots.clear(); + } + + std::vector > CRoots::get_trajectories() + { + /* + Overview: + Find the current best trajectory starts from each root. + Returns: + - traj: a vector of node index, which is the current best trajectory from each root. + */ + std::vector > trajs; + trajs.reserve(this->root_num); + + for (int i = 0; i < this->root_num; ++i) + { + trajs.push_back(this->roots[i].get_trajectory()); + } + return trajs; + } + + std::vector > CRoots::get_distributions() + { + /* + Overview: + Get the children distribution of each root. + Returns: + - distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]). + */ + std::vector > distributions; + distributions.reserve(this->root_num); + + for (int i = 0; i < this->root_num; ++i) + { + distributions.push_back(this->roots[i].get_children_distribution()); + } + return distributions; + } + + std::vector CRoots::get_values() + { + /* + Overview: + Return the estimated value of each root. + */ + std::vector values; + for (int i = 0; i < this->root_num; ++i) + { + values.push_back(this->roots[i].value()); + } + return values; + } + + std::vector> CRoots::get_root_policies(tools::CMinMaxStatsList *min_max_stats_lst) + { + /* + Overview: + Get improved policies for each root based on MCTS search results. + The improved policy is computed as softmax(prior + transformed_Q). + Arguments: + - min_max_stats_lst: min-max statistics for Q value normalization. + Returns: + - policies: vector of policy distributions for each root. + */ + std::vector> policies; + policies.reserve(this->root_num); + + for (int i = 0; i < this->root_num; ++i) + { + // Get transformed completed Q values (with Sigma transform) + std::vector transformed_completed_Qs = + get_transformed_completed_Qs(&(this->roots[i]), min_max_stats_lst->stats_lst[i], 0); + + // Get improved policy based on Q values: softmax(prior + transformed_Q) + std::vector improved_policy = + this->roots[i].get_improved_policy(transformed_completed_Qs); + + policies.push_back(improved_policy); + } + return policies; + } + + std::vector CRoots::get_best_actions() + { + /* + Overview: + Get the best action for each root after Sequential Halving. + The best action is the first element in selected_children_idx, + which has the highest score (Gumbel noise + prior + transformed Q). + Returns: + - best_actions: vector of best action indices for each root. + */ + std::vector best_actions(this->root_num, -1); + + for (int i = 0; i < this->root_num; ++i) + { + // selected_children_idx[0] is the action with highest score after Sequential Halving + best_actions[i] = this->roots[i].selected_children_idx[0]; + } + return best_actions; + } + + //********************************************************* + // + void update_tree_q(CNode *root, tools::CMinMaxStats &min_max_stats, float discount_factor, int players) + { + /* + Overview: + Update the q value of the root and its child nodes. + Arguments: + - root: the root that update q value from. + - min_max_stats: a tool used to min-max normalize the q value. + - discount_factor: the discount factor of reward. + - players: the number of players. + */ + std::stack node_stack; + node_stack.push(root); + float parent_value_prefix = 0.0; + int is_reset = 0; + while (node_stack.size() > 0) + { + CNode *node = node_stack.top(); + node_stack.pop(); + + if (node != root) + { + // NOTE: in self-play-mode, value_prefix is not calculated according to the perspective of current player of node, + // but treated as 1 player, just for obtaining the true reward in the perspective of current player of node. + // true_reward = node.value_prefix - (- parent_value_prefix) + float true_reward = node->value_prefix - node->parent_value_prefix; + + if (is_reset == 1) + { + true_reward = node->value_prefix; + } + float qsa; + if (players == 1) + { + qsa = true_reward + discount_factor * node->value(); + } + else if (players == 2) + { + // TODO(pu): why only the last reward multiply the discount_factor? + qsa = true_reward + discount_factor * (-1) * node->value(); + } + + min_max_stats.update(qsa); + } + + for (auto a : node->legal_actions) + { + CNode *child = node->get_child(a); + if (child->expanded()) + { + child->parent_value_prefix = node->value_prefix; + node_stack.push(child); + } + } + + is_reset = node->is_reset; + } + } + + void cbackpropagate(std::vector &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount_factor) + { + /* + Overview: + Update the value sum and visit count of nodes along the search path. + Arguments: + - search_path: a vector of nodes on the search path. + - min_max_stats: a tool used to min-max normalize the q value. + - to_play: which player to play the game in the current node. + - value: the value to propagate along the search path. + - discount_factor: the discount factor of reward. + */ + assert(to_play == -1 || to_play == 1 || to_play == 2); + if (to_play == -1) + { + // for play-with-bot-mode + float bootstrap_value = value; + int path_len = search_path.size(); + for (int i = path_len - 1; i >= 0; --i) + { + CNode *node = search_path[i]; + + // ========== Current Implementation: value_sum (Traditional MCTS) ========== + node->value_sum += bootstrap_value; + node->visit_count += 1; + + // ========== Alternative Implementation: estimated_value_lst (EfficientZero V2 Original) ========== + // Uncomment the line below to use EfficientZero V2 original style: + // (node->estimated_value_lst).push_back(bootstrap_value); + // node->visit_count += 1; + // + // Then in value calculation, use: + // node->value_v2() instead of node->value() + // + // Difference: + // - value_sum: O(1) memory, stores cumulative sum + // - estimated_value_lst: O(n) memory, stores all values (enables reanalyze, uncertainty estimation) + // - Result: Both produce the same average value mathematically + // ============================================================================ + + float parent_value_prefix = 0.0; + int is_reset = 0; + if (i >= 1) + { + CNode *parent = search_path[i - 1]; + parent_value_prefix = parent->value_prefix; + is_reset = parent->is_reset; + } + + float true_reward = node->value_prefix - parent_value_prefix; + min_max_stats.update(true_reward + discount_factor * node->value()); + + if (is_reset == 1) + { + // parent is reset + true_reward = node->value_prefix; + } + + bootstrap_value = true_reward + discount_factor * bootstrap_value; + } + } + else + { + // for self-play-mode + float bootstrap_value = value; + int path_len = search_path.size(); + for (int i = path_len - 1; i >= 0; --i) + { + CNode *node = search_path[i]; + + // ========== Current Implementation: value_sum (Traditional MCTS) ========== + if (node->to_play == to_play) + { + node->value_sum += bootstrap_value; + } + else + { + node->value_sum += -bootstrap_value; + } + node->visit_count += 1; + + // ========== Alternative Implementation: estimated_value_lst (EfficientZero V2 Original) ========== + // For self-play mode with estimated_value_lst: + // if (node->to_play == to_play) + // { + // (node->estimated_value_lst).push_back(bootstrap_value); + // } + // else + // { + // (node->estimated_value_lst).push_back(-bootstrap_value); + // } + // node->visit_count += 1; + // ============================================================================ + + float parent_value_prefix = 0.0; + int is_reset = 0; + if (i >= 1) + { + CNode *parent = search_path[i - 1]; + parent_value_prefix = parent->value_prefix; + is_reset = parent->is_reset; + } + + // NOTE: in self-play-mode, value_prefix is not calculated according to the perspective of current player of node, + // but treated as 1 player, just for obtaining the true reward in the perspective of current player of node. + float true_reward = node->value_prefix - parent_value_prefix; + + min_max_stats.update(true_reward + discount_factor * node->value()); + + if (is_reset == 1) + { + // parent is reset + true_reward = node->value_prefix; + } + if (node->to_play == to_play) + { + bootstrap_value = -true_reward + discount_factor * bootstrap_value; + } + else + { + bootstrap_value = true_reward + discount_factor * bootstrap_value; + } + } + } + } + + void cbatch_backpropagate(int current_latent_state_index, float discount_factor, const std::vector &value_prefixs, const std::vector &values, const std::vector > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector is_reset_list, std::vector &to_play_batch) + { + /* + Overview: + Expand the nodes along the search path and update the infos. + Arguments: + - current_latent_state_index: The index of latent state of the leaf node in the search path. + - discount_factor: the discount factor of reward. + - value_prefixs: the value prefixs of nodes along the search path. + - values: the values to propagate along the search path. + - policies: the policy logits of nodes along the search path. + - min_max_stats: a tool used to min-max normalize the q value. + - results: the search results. + - is_reset_list: the vector of is_reset nodes along the search path, where is_reset represents for whether the parent value prefix needs to be reset. + - to_play_batch: the batch of which player is playing on this node. + */ + for (int i = 0; i < results.num; ++i) + { + results.nodes[i]->expand(to_play_batch[i], current_latent_state_index, i, value_prefixs[i], policies[i]); + // reset + results.nodes[i]->is_reset = is_reset_list[i]; + + cbackpropagate(results.search_paths[i], min_max_stats_lst->stats_lst[i], to_play_batch[i], values[i], discount_factor); + } + } + + void cbatch_backpropagate_with_reuse(int current_latent_state_index, float discount_factor, const std::vector &value_prefixs, const std::vector &values, const std::vector > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector is_reset_list, std::vector &to_play_batch, std::vector &no_inference_lst, std::vector &reuse_lst, std::vector &reuse_value_lst) + { + /* + Overview: + Expand the nodes along the search path and update the infos. + Arguments: + - current_latent_state_index: The index of latent state of the leaf node in the search path. + - discount_factor: the discount factor of reward. + - value_prefixs: the value prefixs of nodes along the search path. + - values: the values to propagate along the search path. + - policies: the policy logits of nodes along the search path. + - min_max_stats: a tool used to min-max normalize the q value. + - results: the search results. + - to_play_batch: the batch of which player is playing on this node. + - no_inference_lst: the list of the nodes which does not need to expand. + - reuse_lst: the list of the nodes which should use reuse-value to backpropagate. + - reuse_value_lst: the list of the reuse-value. + */ + int count_a = 0; + int count_b = 0; + int count_c = 0; + float value_propagate = 0; + for (int i = 0; i < results.num; ++i) + { + if (i == no_inference_lst[count_a]) + { + count_a = count_a + 1; + value_propagate = reuse_value_lst[i]; + } + else + { + results.nodes[i]->expand(to_play_batch[i], current_latent_state_index, count_b, value_prefixs[count_b], policies[count_b]); + if (i == reuse_lst[count_c]) + { + value_propagate = reuse_value_lst[i]; + count_c = count_c + 1; + } + else + { + value_propagate = values[count_b]; + } + count_b = count_b + 1; + } + results.nodes[i]->is_reset = is_reset_list[i]; + cbackpropagate(results.search_paths[i], min_max_stats_lst->stats_lst[i], to_play_batch[i], value_propagate, discount_factor); + } + } + + int cselect_child(CNode *root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players) + { + /* + Overview: + Select the child node of the roots according to ucb scores. + Arguments: + - root: the roots to select the child node. + - min_max_stats: a tool used to min-max normalize the score. + - pb_c_base: constants c2 in muzero. + - pb_c_init: constants c1 in muzero. + - disount_factor: the discount factor of reward. + - mean_q: the mean q value of the parent node. + - players: the number of players. + Returns: + - action: the action to select. + */ + float max_score = FLOAT_MIN; + const float epsilon = 0.000001; + std::vector max_index_lst; + for (auto a : root->legal_actions) + { + CNode *child = root->get_child(a); + float temp_score = cucb_score(child, min_max_stats, mean_q, root->is_reset, root->visit_count - 1, root->value_prefix, pb_c_base, pb_c_init, discount_factor, players); + + if (max_score < temp_score) + { + max_score = temp_score; + + max_index_lst.clear(); + max_index_lst.push_back(a); + } + else if (temp_score >= max_score - epsilon) + { + max_index_lst.push_back(a); + } + } + + int action = 0; + if (max_index_lst.size() > 0) + { + int rand_index = rand() % max_index_lst.size(); + action = max_index_lst[rand_index]; + } + return action; + } + + int cselect_root_child(CNode *root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players, int true_action, float reuse_value) + { + /* + Overview: + Select the child node of the roots according to ucb scores. + Arguments: + - root: the roots to select the child node. + - min_max_stats: a tool used to min-max normalize the score. + - pb_c_base: constants c2 in muzero. + - pb_c_init: constants c1 in muzero. + - disount_factor: the discount factor of reward. + - mean_q: the mean q value of the parent node. + - players: the number of players. + - true_action: the action chosen in the trajectory. + - reuse_value: the value obtained from the search of the next state in the trajectory. + Returns: + - action: the action to select. + */ + + float max_score = FLOAT_MIN; + const float epsilon = 0.000001; + std::vector max_index_lst; + for (auto a : root->legal_actions) + { + + CNode *child = root->get_child(a); + float temp_score = 0.0; + if (a == true_action) + { + temp_score = carm_score(child, min_max_stats, mean_q, root->is_reset, reuse_value, root->visit_count - 1, root->value_prefix, pb_c_base, pb_c_init, discount_factor, players); + } + else + { + temp_score = cucb_score(child, min_max_stats, mean_q, root->is_reset, root->visit_count - 1, root->value_prefix, pb_c_base, pb_c_init, discount_factor, players); + } + + if (max_score < temp_score) + { + max_score = temp_score; + + max_index_lst.clear(); + max_index_lst.push_back(a); + } + else if (temp_score >= max_score - epsilon) + { + max_index_lst.push_back(a); + } + } + + int action = 0; + if (max_index_lst.size() > 0) + { + int rand_index = rand() % max_index_lst.size(); + action = max_index_lst[rand_index]; + } + // printf("select root child ends"); + return action; + } + + float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, int is_reset, float total_children_visit_counts, float parent_value_prefix, float pb_c_base, float pb_c_init, float discount_factor, int players) + { + /* + Overview: + Compute the ucb score of the child. + Arguments: + - child: the child node to compute ucb score. + - min_max_stats: a tool used to min-max normalize the score. + - parent_mean_q: the mean q value of the parent node. + - is_reset: whether the value prefix needs to be reset. + - total_children_visit_counts: the total visit counts of the child nodes of the parent node. + - parent_value_prefix: the value prefix of parent node. + - pb_c_base: constants c2 in muzero. + - pb_c_init: constants c1 in muzero. + - disount_factor: the discount factor of reward. + - players: the number of players. + Returns: + - ucb_value: the ucb score of the child. + */ + float pb_c = 0.0, prior_score = 0.0, value_score = 0.0; + pb_c = log((total_children_visit_counts + pb_c_base + 1) / pb_c_base) + pb_c_init; + pb_c *= (sqrt(total_children_visit_counts) / (child->visit_count + 1)); + + prior_score = pb_c * child->prior; + if (child->visit_count == 0) + { + value_score = parent_mean_q; + } + else + { + float true_reward = child->value_prefix - parent_value_prefix; + if (is_reset == 1) + { + true_reward = child->value_prefix; + } + + if (players == 1) + { + value_score = true_reward + discount_factor * child->value(); + } + else if (players == 2) + { + value_score = true_reward + discount_factor * (-child->value()); + } + } + + value_score = min_max_stats.normalize(value_score); + + if (value_score < 0) + { + value_score = 0; + } + else if (value_score > 1) + { + value_score = 1; + } + + return prior_score + value_score; // ucb_value + } + + float carm_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, int is_reset, float reuse_value, float total_children_visit_counts, float parent_value_prefix, float pb_c_base, float pb_c_init, float discount_factor, int players) + { + /* + Overview: + Compute the ucb score of the child. + Arguments: + - child: the child node to compute ucb score. + - min_max_stats: a tool used to min-max normalize the score. + - parent_mean_q: the mean q value of the parent node. + - is_reset: whether the value prefix needs to be reset. + - total_children_visit_counts: the total visit counts of the child nodes of the parent node. + - parent_value_prefix: the value prefix of parent node. + - pb_c_base: constants c2 in muzero. + - pb_c_init: constants c1 in muzero. + - disount_factor: the discount factor of reward. + - players: the number of players. + Returns: + - ucb_value: the ucb score of the child. + */ + float pb_c = 0.0, prior_score = 0.0, value_score = 0.0; + pb_c = log((total_children_visit_counts + pb_c_base + 1) / pb_c_base) + pb_c_init; + pb_c *= (sqrt(total_children_visit_counts) / (child->visit_count + 1)); + + prior_score = pb_c * child->prior; + if (child->visit_count == 0) + { + value_score = parent_mean_q; + } + else + { + float true_reward = child->value_prefix - parent_value_prefix; + if (is_reset == 1) + { + true_reward = child->value_prefix; + } + + if (players == 1) + { + value_score = true_reward + discount_factor * reuse_value; + } + else if (players == 2) + { + value_score = true_reward + discount_factor * (-reuse_value); + } + } + + value_score = min_max_stats.normalize(value_score); + + if (value_score < 0) + { + value_score = 0; + } + else if (value_score > 1) + { + value_score = 1; + } + + float ucb_value = 0.0; + if (child->visit_count == 0) + { + ucb_value = prior_score + value_score; + } + else + { + ucb_value = value_score; + } + // printf("carmscore ends"); + return ucb_value; + } + + /** + * ===================================================================== + * 【EfficientZero V2】动作选择函数 + * ===================================================================== + * 根据当前节点类型和搜索进度采用不同的动作选择策略: + * 1. 根节点:Sequential Halving + 等量访问(Gumbel-Max 采样初始化) + * 2. 非根节点:改进策略 + 访问计数平衡 + * + * 参数: + * - node: 当前树节点(可能是根节点或内部节点) + * - min_max_stats: 最小最大值统计,用于Q值归一化 + * - num_simulations: 总模拟数(序列削减的总步数) + * - simulation_idx: 当前模拟索引(序列削减的当前阶段,从0开始) + * - gumble_noise: 该样本的Gumbel噪声向量(长度=动作数) + * - current_num_top_actions: 当前保留的顶部候选动作数量 + * + * 返回:选中的动作索引 + * ===================================================================== + */ + int select_action(CNode* node, tools::CMinMaxStats &min_max_stats, + int num_simulations, int simulation_idx, + const std::vector& gumble_noise, + int current_num_top_actions){ + + int action = -1; + int num_actions = node->legal_actions.size(); + + // Check if this is a root node by examining if selected_children_idx is populated + // (Root nodes go through Sequential Ha2lving and have selected_children_idx set) + bool is_root = !node->selected_children_idx.empty() || node->visit_count == 0; + + if(is_root){ + // ============================================================ + // 【根节点处理】Sequential Halving + 等量访问策略 + // ============================================================ + + if(simulation_idx == 0){ + // 【第一阶段 (simulation_idx == 0)】 + // 目的:基于Gumbel-Max采样初始化顶部候选动作集合 + + // 1. 获取根节点所有子节点的先验概率(来自网络) + std::vector children_prior = node->get_children_priors(); + + // 2. 计算初始得分 = Gumbel噪声 + log(先验) + // 这实现了Gumbel-Max采样:max_a[g_a + log(p_a)] + // Gumbel噪声用于探索,先验用于指导 + std::vector children_scores; + for(int a = 0; a < num_actions; ++a){ + // g_a: Gumbel噪声(从高斯分布生成) + // p_a: 网络输出的先验策略 + children_scores.push_back(gumble_noise[a] + children_prior[a]); + } + + // 3. 对得分进行降序排序,获取排序后的动作索引 + std::vector idx(children_scores.size()); + std::iota(idx.begin(), idx.end(), 0); // idx初始化为[0, 1, 2, ...] + std::sort(idx.begin(), idx.end(), + [&children_scores](size_t i1, size_t i2) { + return children_scores[i1] > children_scores[i2]; + }); + + // 4. 保存前K个最佳动作到节点中(作为候选集合) + // 【关键】这些候选在后续迭代中会被等量访问 + // 实现了序列削减的"筛选"阶段 + node->selected_children_idx.clear(); + for(int a = 0; a < current_num_top_actions; ++a){ + node->selected_children_idx.push_back(idx[a]); + } + } + + // 【后续阶段 (simulation_idx > 0)】 + // 使用等量访问策略:从候选集合中轮流选择动作 + // 目的:平衡评估各个候选动作,同时让高价值候选更早收敛 + action = node->do_equal_visit(num_simulations); + } + else{ + // ============================================================ + // 【非根节点处理】改进策略 + 访问计数平衡 + // ============================================================ + + // 1. 计算所有子节点的【变换后完整Q值】 + // 完整Q = R + γ*V (当前节点的累积奖励 + 折扣未来值) + // 通过min-max统计进行归一化到[0,1]范围内 + std::vector transformed_completed_Qs = + get_transformed_completed_Qs(node, min_max_stats, 0); + + // 2. 计算【改进策略】= softmax(transformed_Q) + // 这是MCTS搜索中学到的改进策略,优于网络初始策略 + std::vector improved_policy = + node->get_improved_policy(transformed_completed_Qs); + + // 3. 获取原始网络先验策略(用于对比或调试) + std::vector ori_policy = node->get_children_priors(); + + // 4. 获取每个子节点的访问次数(反映已探索程度) + std::vector children_visits = node->get_children_visits(); + + // 5. 计算每个动作的【选择得分】 + // 得分 = + // 其中:访问惩罚 = 子节点访问次数 / (1 + 父节点访问次数) + // + // 设计原理: + // - 改进策略项:偏向高Q值的动作(利用) + // - 访问惩罚项:偏向访问较少的动作(探索) + // - 除以父节点访问次数:动态平衡,搜索越深惩罚越大 + std::vector children_scores(num_actions, 0.0); + for(int a = 0; a < num_actions; ++a){ + float visit_penalty = children_visits[a] / (1.0f + float(node->visit_count)); + float score = improved_policy[a] - visit_penalty; + children_scores[a] = score; + } + + // 6. 选择得分最高的动作(贪心选择) + action = argmax(children_scores); + } + + return action; + } + + // ========== MuZero/UCB 风格的批量遍历(备份版本) ========== + void cbatch_traverse_ucb(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &virtual_to_play_batch) + { + /* + Overview: + Search node path from the roots using UCB/PUCT selection (MuZero style). + Arguments: + - roots: the roots that search from. + - pb_c_base: constants c2 in muzero. + - pb_c_init: constants c1 in muzero. + - disount_factor: the discount factor of reward. + - min_max_stats: a tool used to min-max normalize the score. + - results: the search results. + - virtual_to_play_batch: the batch of which player is playing on this node. + */ + // set seed + get_time_and_set_rand_seed(); + + int last_action = -1; + float parent_q = 0.0; + results.search_lens = std::vector(); + + int players = 0; + int largest_element = *max_element(virtual_to_play_batch.begin(), virtual_to_play_batch.end()); // 0 or 2 + if (largest_element == -1) + { + players = 1; + } + else + { + players = 2; + } + + for (int i = 0; i < results.num; ++i) + { + CNode *node = &(roots->roots[i]); + int is_root = 1; + int search_len = 0; + results.search_paths[i].push_back(node); + + while (node->expanded()) + { + float mean_q = node->compute_mean_q(is_root, parent_q, discount_factor); + is_root = 0; + parent_q = mean_q; + + int action = cselect_child(node, min_max_stats_lst->stats_lst[i], pb_c_base, pb_c_init, discount_factor, mean_q, players); + if (players > 1) + { + assert(virtual_to_play_batch[i] == 1 || virtual_to_play_batch[i] == 2); + if (virtual_to_play_batch[i] == 1) + { + virtual_to_play_batch[i] = 2; + } + else + { + virtual_to_play_batch[i] = 1; + } + } + + node->best_action = action; + // next + node = node->get_child(action); + last_action = action; + results.search_paths[i].push_back(node); + search_len += 1; + } + + CNode *parent = results.search_paths[i][results.search_paths[i].size() - 2]; + + results.latent_state_index_in_search_path.push_back(parent->current_latent_state_index); + results.latent_state_index_in_batch.push_back(parent->batch_index); + + results.last_actions.push_back(last_action); + results.search_lens.push_back(search_len); + results.nodes.push_back(node); + results.virtual_to_play_batchs.push_back(virtual_to_play_batch[i]); + } + } + + /** + * ===================================================================== + * 【EfficientZero V2 核心】批量树遍历函数 - cbatch_traverse + * ===================================================================== + * 该函数是 EfficientZero V2 MCTS 搜索的核心遍历算法,集成了 Sequential Halving: + * 1. 对批次中的每个样本进行树遍历 + * 2. 使用 select_action() 进行动作选择(集成 Sequential Halving) + * 3. 遍历直至到达叶节点(未扩展的节点) + * 4. 记录搜索路径、动作序列、叶节点等信息供后续回传使用 + * + * 参数说明: + * roots - 批量样本的根节点集合(CRoots对象数组) + * min_max_stats_lst - 批量min-max统计对象集合(用于Q值归一化) + * results - 搜索结果缓冲区(保存搜索路径、动作、叶节点等) + * num_simulations - 总模拟次数(序列削减的总步数) + * simulation_idx - 当前模拟索引(当前处于第几次迭代) + * gumble_noise - 批量Gumbel噪声矩阵(大小:batch_size × num_actions) + * current_num_top_actions - 当前保留的顶部候选动作数(随削减递减) + * virtual_to_play_batch - 虚拟玩家批次(用于两人游戏,单人游戏时为[-1,-1,...]) + * ===================================================================== + */ + void cbatch_traverse(CRoots *roots, + tools::CMinMaxStatsList *min_max_stats_lst, + CSearchResults &results, + int num_simulations, + int simulation_idx, + const std::vector>& gumble_noise, + int current_num_top_actions, + std::vector &virtual_to_play_batch) { + + // 初始化结果容器 + int last_action = -1; + results.search_lens = std::vector(); + + // 判断游戏类型(单人/双人) + int players = 0; + int largest_element = *max_element(virtual_to_play_batch.begin(), virtual_to_play_batch.end()); + if (largest_element == -1) { + players = 1; // 单人游戏 + } else { + players = 2; // 双人游戏 + } + + // 对批次中每个样本独立遍历 + for (int i = 0; i < results.num; ++i) { + CNode *node = &(roots->roots[i]); + int search_len = 0; + results.search_paths[i].push_back(node); + + // 从根节点遍历到叶节点 + while (node->expanded()) { + // 【关键】使用 select_action 替代 cselect_child + // 该函数内部集成了 Sequential Halving 逻辑: + // - 根节点:在 simulation_idx==0 时初始化候选集,后续等量访问 + // - 非根节点:使用改进策略 + 访问计数平衡 + int action = select_action( + node, + min_max_stats_lst->stats_lst[i], + num_simulations, // Sequential Halving 参数 + simulation_idx, // Sequential Halving 参数 + gumble_noise[i], // Sequential Halving 参数 + current_num_top_actions // Sequential Halving 参数 + ); + + // 【两人游戏兼容】切换玩家 + if (players > 1) { + assert(virtual_to_play_batch[i] == 1 || virtual_to_play_batch[i] == 2); + if (virtual_to_play_batch[i] == 1) { + virtual_to_play_batch[i] = 2; + } else { + virtual_to_play_batch[i] = 1; + } + } + + node->best_action = action; + node = node->get_child(action); + last_action = action; + results.search_paths[i].push_back(node); + search_len += 1; + } + + // 记录叶节点父节点的信息 + CNode *parent = results.search_paths[i][results.search_paths[i].size() - 2]; + + // 记录父节点的隐状态索引位置(后续用于神经网络推理定位隐状态) + results.latent_state_index_in_search_path.push_back(parent->current_latent_state_index); + results.latent_state_index_in_batch.push_back(parent->batch_index); + + // 记录此次搜索的其他信息 + results.last_actions.push_back(last_action); + results.search_lens.push_back(search_len); + results.nodes.push_back(node); + results.virtual_to_play_batchs.push_back(virtual_to_play_batch[i]); + } + } + + void cbatch_traverse_with_reuse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &virtual_to_play_batch, std::vector &true_action, std::vector &reuse_value) + { + /* + Overview: + Search node path from the roots. + Arguments: + - roots: the roots that search from. + - pb_c_base: constants c2 in muzero. + - pb_c_init: constants c1 in muzero. + - disount_factor: the discount factor of reward. + - min_max_stats: a tool used to min-max normalize the score. + - results: the search results. + - virtual_to_play_batch: the batch of which player is playing on this node. + - true_action: the action chosen in the trajectory. + - reuse_value: the value obtained from the search of the next state in the trajectory. + */ + // set seed + get_time_and_set_rand_seed(); + + int last_action = -1; + float parent_q = 0.0; + results.search_lens = std::vector(); + + int players = 0; + int largest_element = *max_element(virtual_to_play_batch.begin(), virtual_to_play_batch.end()); // 0 or 2 + if (largest_element == -1) + { + players = 1; + } + else + { + players = 2; + } + + for (int i = 0; i < results.num; ++i) + { + CNode *node = &(roots->roots[i]); + int is_root = 1; + int search_len = 0; + results.search_paths[i].push_back(node); + + while (node->expanded()) + { + float mean_q = node->compute_mean_q(is_root, parent_q, discount_factor); + parent_q = mean_q; + + int action = 0; + if (is_root) + { + action = cselect_root_child(node, min_max_stats_lst->stats_lst[i], pb_c_base, pb_c_init, discount_factor, mean_q, players, true_action[i], reuse_value[i]); + } + else + { + action = cselect_child(node, min_max_stats_lst->stats_lst[i], pb_c_base, pb_c_init, discount_factor, mean_q, players); + } + + if (players > 1) + { + assert(virtual_to_play_batch[i] == 1 || virtual_to_play_batch[i] == 2); + if (virtual_to_play_batch[i] == 1) + { + virtual_to_play_batch[i] = 2; + } + else + { + virtual_to_play_batch[i] = 1; + } + } + + node->best_action = action; + // next + node = node->get_child(action); + last_action = action; + results.search_paths[i].push_back(node); + search_len += 1; + + if(is_root && action == true_action[i]) + { + break; + } + + is_root = 0; + } + + if (node->expanded()) + { + results.latent_state_index_in_search_path.push_back(-1); + results.latent_state_index_in_batch.push_back(i); + + results.last_actions.push_back(last_action); + results.search_lens.push_back(search_len); + results.nodes.push_back(node); + results.virtual_to_play_batchs.push_back(virtual_to_play_batch[i]); + } + else + { + CNode *parent = results.search_paths[i][results.search_paths[i].size() - 2]; + + results.latent_state_index_in_search_path.push_back(parent->current_latent_state_index); + results.latent_state_index_in_batch.push_back(parent->batch_index); + + results.last_actions.push_back(last_action); + results.search_lens.push_back(search_len); + results.nodes.push_back(node); + results.virtual_to_play_batchs.push_back(virtual_to_play_batch[i]); + } + } + } + + // ==================== EfficientZero V2 Sequential Halving 实现 ==================== + + // Softmax 函数:将 logits 转换为概率分布 + std::vector softmax(const std::vector &logits){ + std::vector policy(logits.size(), 0.0); + + // 数值稳定:减去最大值防止溢出 + float max_logit = -1e9; + for(size_t a = 0; a < logits.size(); ++a){ + if(logits[a] > max_logit) max_logit = logits[a]; + } + + for(size_t a = 0; a < logits.size(); ++a){ + policy[a] = exp(logits[a] - max_logit); + } + + // 归一化 + float policy_sum = 0.0; + for(size_t a = 0; a < policy.size(); ++a){ + policy_sum += policy[a]; + } + for(size_t a = 0; a < policy.size(); ++a){ + policy[a] = policy[a] / policy_sum; + } + return policy; + } + + // 获取子节点的先验概率 + std::vector CNode::get_children_priors(){ + std::vector priors(this->legal_actions.size(), 0.0); + for(size_t i = 0; i < this->legal_actions.size(); ++i){ + int action = this->legal_actions[i]; + priors[i] = this->children[action].prior; + } + return priors; + } + + /** + * ===================================================================== + * 获取子节点访问次数向量 + * ===================================================================== + * 将该节点的所有子节点的访问次数合并为一个向量 + * 向量大小与合法动作数量相同 + * + * 返回:visit_count 向量,索引对应合法动作顺序 + * ===================================================================== + */ + std::vector CNode::get_children_visits(){ + // 创建与合法动作数量相同大小的向量,初始化为0 + std::vector visits(this->legal_actions.size(), 0); + + // 遍历所有合法动作 + for(size_t i = 0; i < this->legal_actions.size(); ++i){ + // 获取该动作的索引 + int action = this->legal_actions[i]; + + // 从子节点中获取访问次数 + visits[i] = this->children[action].visit_count; + } + + return visits; + } + + // 获取该节点的奖励 + float CNode::get_reward(){ + if(this->is_reset){ + // 如果是重置点,直接返回 value_prefix + return this->value_prefix; + } else { + // 否则返回与父节点的差值 + return this->value_prefix - this->parent_value_prefix; + } + } + + // 获取单个动作的 Q 值:Q(s,a) = R(a) + γ*V(s') + float CNode::get_qsa(int action){ + CNode* child = &(this->children[action]); + // Q(s,a) = R(a) + γ*V(s') + float qsa = child->get_reward() + this->discount * child->value(); + return qsa; + } + + // 计算混合值估计(用于未访问子节点的乐观估计) + float CNode::get_v_mix(){ + // 1. 获取当前节点的策略分布 + std::vector priors = this->get_children_priors(); + std::vector pi_lst = softmax(priors); + + float pi_sum = 0.0; // 已访问动作的策略概率和 + float pi_qsa_sum = 0.0; // π(a) * Q(a) 的加权和 + + // 2. 遍历所有合法动作,只统计已扩展的子节点 + for(size_t i = 0; i < this->legal_actions.size(); ++i){ + int action = this->legal_actions[i]; + if(this->children[action].expanded()){ + // 子节点已访问:累加策略概率 + pi_sum += pi_lst[i]; + // 累加加权 Q 值 + pi_qsa_sum += pi_lst[i] * this->get_qsa(action); + } + } + + // 3. 计算 v_mix + float v_mix = 0.0; + const float EPSILON = 0.000001; + + if(pi_sum < EPSILON) { + // 没有子节点被访问,直接用节点值 + v_mix = this->value(); + } + else{ + // v_mix = (1/(1+N)) * [V + N * Σ(π*Q) / Σ(π)] + v_mix = (1.0 / (1.0 + this->visit_count)) * (this->value() + this->visit_count * pi_qsa_sum / pi_sum); + } + + return v_mix; + } + + // 获取完整的Q值(包括已访问和未访问节点的估计) + std::vector CNode::get_completed_Q(tools::CMinMaxStats &min_max_stats, int to_normalize){ + // 1. 创建返回值向量 + std::vector completed_Qs(this->legal_actions.size(), 0.0); + + // 2. 计算混合值(用于未访问的子节点) + float v_mix = this->get_v_mix(); + + // 3. 对每个动作计算完整的 Q 值 + for(size_t i = 0; i < this->legal_actions.size(); ++i){ + int action = this->legal_actions[i]; + float Q = 0.0; + + // 判断子节点是否已扩展(访问过) + if(this->children[action].expanded()){ + // 子节点已访问:Q = R + γ*V + Q = this->get_qsa(action); + } + else { + // 子节点未访问:用 v_mix 乐观估计 + Q = v_mix; + } + + // 4. 根据 to_normalize 标志选择归一化方式 + if (to_normalize == 1) { + // 模式1:使用 min-max 统计进行归一化 + completed_Qs[i] = min_max_stats.normalize(Q); + if (completed_Qs[i] < 0.0) completed_Qs[i] = 0.0; + if (completed_Qs[i] > 1.0) completed_Qs[i] = 1.0; + } + else { + completed_Qs[i] = Q; + } + } + + // 5. 如果选择最终归一化模式(to_normalize == 2) + if (to_normalize == 2){ + float v_max = -1e9; + float v_min = 1e9; + for(float q : completed_Qs){ + if(q > v_max) v_max = q; + if(q < v_min) v_min = q; + } + + if(v_max > v_min){ + for(size_t i = 0; i < completed_Qs.size(); ++i){ + completed_Qs[i] = (completed_Qs[i] - v_min) / (v_max - v_min); + } + } + } + + return completed_Qs; + } + + // 获取改进后的策略(基于 Q 值) + std::vector CNode::get_improved_policy(const std::vector &transformed_completed_Qs){ + // 获取子节点的先验 + std::vector logits = this->get_children_priors(); + + // EfficientZero V2 改进:logits = prior + transformed_Q + for(size_t i = 0; i < logits.size(); ++i){ + logits[i] = logits[i] + transformed_completed_Qs[i]; + } + + // 转换为概率分布 + std::vector policy = softmax(logits); + return policy; + } + + /** + * ===================================================================== + * 【EfficientZero V2 根节点策略】等量访问函数 - do_equal_visit + * ===================================================================== + * 在 Sequential Halving 中,根节点采用等量访问策略来评估候选动作。 + * 该函数从已筛选的候选集合中选择访问次数最少的动作。 + * + * 算法原理: + * 1. 遍历 selected_children_idx 中所有候选动作 + * 2. 找到访问次数最少的动作 + * 3. 返回该动作的索引 + * + * 目的:确保在同一阶段中各个候选动作被公平地评估 + * + * 参数: + * num_simulations - 总模拟数(用于初始化最小值) + * + * 返回:访问次数最少的候选动作索引 + * ===================================================================== + */ + int CNode::do_equal_visit(int num_simulations){ + // 初始化最小访问次数为一个很大的值(num_simulations + 1) + // 这样确保任何实际的访问次数都会更小 + int min_visit_count = num_simulations + 1; + int action = -1; // 初始化为-1,表示未选择 + + // 遍历所有候选动作(在 selected_children_idx 中) + for(int selected_child_idx : this->selected_children_idx){ + // 获取该候选动作的访问次数 + int visit_count = (this->get_child(selected_child_idx))->visit_count; + + // 如果该动作的访问次数更少,更新最小访问次数和选中的动作 + if(visit_count < min_visit_count){ + action = selected_child_idx; + min_visit_count = visit_count; + } + } + + // 返回选中的动作(访问次数最少的候选) + return action; + } + + // ==================== 辅助函数 ==================== + + // 获取向量中的最大浮点数 + float max_float(const std::vector &arr){ + if(arr.empty()) return -1e9; + float max_val = arr[0]; + for(size_t i = 1; i < arr.size(); ++i){ + if(arr[i] > max_val) max_val = arr[i]; + } + return max_val; + } + + // 获取向量中的最小浮点数 + float min_float(const std::vector &arr){ + if(arr.empty()) return 1e9; + float min_val = arr[0]; + for(size_t i = 1; i < arr.size(); ++i){ + if(arr[i] < min_val) min_val = arr[i]; + } + return min_val; + } + + // 获取向量中的最大整数 + int max_int(const std::vector &arr){ + if(arr.empty()) return -1e9; + int max_val = arr[0]; + for(size_t i = 1; i < arr.size(); ++i){ + if(arr[i] > max_val) max_val = arr[i]; + } + return max_val; + } + + // 求浮点数向量的和 + float sum_float(const std::vector &arr){ + float res = 0.0; + for(float a : arr) res += a; + return res; + } + + // 求整数向量的和 + int sum_int(const std::vector &arr){ + int res = 0; + for(int a : arr) res += a; + return res; + } + + // 获取最大值的索引 + int argmax(const std::vector &arr){ + if(arr.empty()) return -1; + int index = 0; + float max_val = arr[0]; + for(size_t i = 1; i < arr.size(); ++i){ + if(arr[i] > max_val){ + max_val = arr[i]; + index = i; + } + } + return index; + } + + // 获取转换后的完整 Q 值(带 Sigma 变换) + std::vector get_transformed_completed_Qs(CNode* node, tools::CMinMaxStats &min_max_stats, int final){ + // 1. 获取完整Q值(根据 final 参数选择归一化模式) + int to_normalize = (final == 0) ? 1 : 2; + std::vector completed_Qs = node->get_completed_Q(min_max_stats, to_normalize); + + // 2. 计算最大访问数 + int max_child_visit_count = 0; + for(int action : node->legal_actions){ + if(node->children.count(action) > 0){ + int visit_count = node->children[action].visit_count; + if(visit_count > max_child_visit_count){ + max_child_visit_count = visit_count; + } + } + } + + // 3. Sigma 变换(缩放):Q' = (c_visit + max_visit) * c_scale * Q + for(size_t i = 0; i < completed_Qs.size(); ++i){ + completed_Qs[i] = (min_max_stats.c_visit + max_child_visit_count) * min_max_stats.c_scale * completed_Qs[i]; + } + + return completed_Qs; + } + + // Sequential Halving:逐步淘汰差的动作 + int sequential_halving(CNode* root, const std::vector& gumbel_noise, + tools::CMinMaxStats &min_max_stats, int current_phase, int current_num_top_actions){ + // 1. 获取子节点的先验概率和转换后的Q值 + std::vector children_prior = root->get_children_priors(); + std::vector transformed_completed_Qs = get_transformed_completed_Qs(root, min_max_stats, 0); + + // 获取当前已选的动作列表 + std::vector selected_children_idx = root->selected_children_idx; + std::vector children_scores; + + // 2. 计算分数:gumbel噪声 + 先验 + 转换后的Q值 + // 直接用 action 作为索引(action ∈ [0, num_actions-1]) + for(int action : selected_children_idx){ + float score = gumbel_noise[action] + children_prior[action] + transformed_completed_Qs[action]; + children_scores.push_back(score); + } + + // 3. 创建索引数组并初始化为 [0, 1, 2, ...] + std::vector idx(children_scores.size()); + std::iota(idx.begin(), idx.end(), 0); + + // 4. 按分数从高到低排序索引 + std::sort(idx.begin(), idx.end(), + [&children_scores](size_t index_1, size_t index_2) { + return children_scores[index_1] > children_scores[index_2]; + }); + + // 5. 清空已选动作,只保留分数最高的 top-m 个 + root->selected_children_idx.clear(); + int keep_count = std::min(current_num_top_actions, (int)selected_children_idx.size()); + for(int i = 0; i < keep_count; ++i){ + root->selected_children_idx.push_back(selected_children_idx[idx[i]]); + } + + // 6. 返回分数最高的动作 + int best_action = root->selected_children_idx[0]; + return best_action; + } + + // 批量 Sequential Halving:对多个搜索树进行动作选择 + std::vector c_batch_sequential_halving(CRoots *roots, const std::vector>& gumbel_noises, + tools::CMinMaxStatsList *min_max_stats_lst, + int current_phase, int current_num_top_actions){ + std::vector best_actions(roots->root_num, -1); + + for(int i = 0; i < roots->root_num; ++i){ + int action = sequential_halving(&(roots->roots[i]), gumbel_noises[i], + min_max_stats_lst->stats_lst[i], current_phase, current_num_top_actions); + best_actions[i] = action; + } + + return best_actions; + } +} \ No newline at end of file diff --git a/lzero/mcts/ctree/ctree_efficientzero_v2/lib/cnode.h b/lzero/mcts/ctree/ctree_efficientzero_v2/lib/cnode.h new file mode 100644 index 000000000..491e93001 --- /dev/null +++ b/lzero/mcts/ctree/ctree_efficientzero_v2/lib/cnode.h @@ -0,0 +1,134 @@ +// C++11 + +#ifndef CNODE_H +#define CNODE_H + +#include "cminimax.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +const int DEBUG_MODE = 0; + +namespace tree { + class CNode { + public: + int visit_count, to_play, current_latent_state_index, batch_index, best_action, is_reset; + float value_prefix, prior, value_sum; + float parent_value_prefix; + std::vector children_index; + std::map children; + + std::vector legal_actions; + + // ========== V2 新增字段 ========== + std::vector selected_children_idx; // Sequential Halving 选中的动作 + std::vector estimated_value_lst; // 多值估计列表(EfficientZero V2 原版风格,未启用) + float discount; // 折扣因子(V2 用) + + CNode(); + CNode(float prior, std::vector &legal_actions); + ~CNode(); + + void expand(int to_play, int current_latent_state_index, int batch_index, float value_prefix, const std::vector &policy_logits); + void add_exploration_noise(float exploration_fraction, const std::vector &noises); + float compute_mean_q(int isRoot, float parent_q, float discount_factor); + void print_out(); + + int expanded(); + + float value(); + float value_v2(); // EfficientZero V2 原版风格的值估计(基于 estimated_value_lst,未启用) + + std::vector get_trajectory(); + std::vector get_children_distribution(); + CNode* get_child(int action); + + // ========== EfficientZero V2 新增方法 ========== + std::vector get_children_priors(); + std::vector get_children_visits(); // 获取子节点访问次数 + float get_reward(); + float get_qsa(int action); + float get_v_mix(); + std::vector get_completed_Q(tools::CMinMaxStats &min_max_stats, int to_normalize); + std::vector get_improved_policy(const std::vector &transformed_completed_Qs); + int do_equal_visit(int num_simulations); // Sequential Halving 等量访问策略 + }; + + class CRoots{ + public: + int root_num; + std::vector roots; + std::vector > legal_actions_list; + + CRoots(); + CRoots(int root_num, std::vector > &legal_actions_list); + ~CRoots(); + + void prepare(float root_noise_weight, const std::vector > &noises, const std::vector &value_prefixs, const std::vector > &policies, std::vector &to_play_batch); + void prepare_no_noise(const std::vector &value_prefixs, const std::vector > &policies, std::vector &to_play_batch); + void clear(); + std::vector > get_trajectories(); + std::vector > get_distributions(); + std::vector get_values(); + std::vector > get_root_policies(tools::CMinMaxStatsList *min_max_stats_lst); + std::vector get_best_actions(); + CNode* get_root(int index); + }; + + class CSearchResults{ + public: + int num; + std::vector latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, search_lens; + std::vector virtual_to_play_batchs; + std::vector nodes; + std::vector > search_paths; + + CSearchResults(); + CSearchResults(int num); + ~CSearchResults(); + + }; + + + //********************************************************* + void update_tree_q(CNode* root, tools::CMinMaxStats &min_max_stats, float discount_factor, int players); + void cbackpropagate(std::vector &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount_factor); + void cbatch_backpropagate(int current_latent_state_index, float discount_factor, const std::vector &value_prefixs, const std::vector &values, const std::vector > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector is_reset_list, std::vector &to_play_batch); + void cbatch_backpropagate_with_reuse(int current_latent_state_index, float discount_factor, const std::vector &value_prefixs, const std::vector &values, const std::vector > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector is_reset_list, std::vector &to_play_batch, std::vector &no_inference_lst, std::vector &reuse_lst, std::vector &reuse_value_lst); + int cselect_child(CNode* root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players); + int cselect_root_child(CNode *root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players, int true_action, float reuse_value); + float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, int is_reset, float total_children_visit_counts, float parent_value_prefix, float pb_c_base, float pb_c_init, float discount_factor, int players); + float carm_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, int is_reset, float reuse_value, float total_children_visit_counts, float parent_value_prefix, float pb_c_base, float pb_c_init, float discount_factor, int players); + + // ========== MuZero/UCB 风格的遍历(备份) ========== + void cbatch_traverse_ucb(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &virtual_to_play_batch); + void cbatch_traverse_with_reuse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &virtual_to_play_batch, std::vector &true_action, std::vector &reuse_value); + + // ========== EfficientZero V2 风格的遍历(Sequential Halving 集成) ========== + void cbatch_traverse(CRoots *roots, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, + int num_simulations, int simulation_idx, const std::vector>& gumble_noise, + int current_num_top_actions, std::vector &virtual_to_play_batch); + + // ========== EfficientZero V2 Sequential Halving 相关 ========== + std::vector softmax(const std::vector &logits); + std::vector get_transformed_completed_Qs(CNode* node, tools::CMinMaxStats &min_max_stats, int final); + int sequential_halving(CNode* root, const std::vector& gumbel_noise, tools::CMinMaxStats &min_max_stats, int current_phase, int current_num_top_actions); + std::vector c_batch_sequential_halving(CRoots *roots, const std::vector>& gumbel_noises, tools::CMinMaxStatsList *min_max_stats_lst, int current_phase, int current_num_top_actions); + + // 辅助函数 + float max_float(const std::vector &arr); + float min_float(const std::vector &arr); + int max_int(const std::vector &arr); + float sum_float(const std::vector &arr); + int sum_int(const std::vector &arr); + int argmax(const std::vector &arr); +} + +#endif \ No newline at end of file diff --git a/lzero/mcts/ctree/ctree_efficientzero_v2/test_batch_sequential_halving b/lzero/mcts/ctree/ctree_efficientzero_v2/test_batch_sequential_halving new file mode 100755 index 0000000000000000000000000000000000000000..b308acee9814e63dc66649a2d581f8ae7ae6da9d GIT binary patch literal 88184 zcmeFa3wTu3)jvFeL`Mmppixt$8ttePKu8pnDN+-XkaO?^QUQ%}2^fOVB1D*hs9b`X z5s$~wXtnxQt<-93Yi+Idf>tphK=2Nz6|W$Q>~TOM2u4Kn|NZtpXD$KSzJ34i_k7>? zI6UN>eP4U+wbx#2?X~w=Qf`jyk&)qXe?2`v@yO5!6;tdaG*l(?6nlnxdV7jJ`JPif zM<6{D|J{7upPPCh5kL3oRY6ZrJXtP}FNb*cRnZ>z>9)=%%N_2wzK=?ApPn2QX1OfB z=JOAJ`SNe#`PT1L*nOs!lkKf{@^^chXRCPknOcsxsn2H1c=kkopR`xy zl=E*=^>lw;;56Vq-S&Qpa>UEO`*YCgx!7s%aO2SI^xJ)^QSx{O&6;_|d4mVdns)ZA znX^L;XEzKz@9gsi51cn=;5iZm!j+Hzq^XM0<2@Lb7REW%d3rReuv;I~a-DSMpLNo0 zC-gn*!X=+BD0`%I@VI`rys&5`^AK*9VIKZipEKl%upe*Cqww)pyk~kC$6o>dZ>W3i zs<{#KKT5R_>qUHzv3|U=N=~hbRgwW z@VtMR`cFShJsS=K=jg-a|M6ks#~h}f!H3CTco_I+947wHhl#)9F!eVb2A+(=)N{#U z^51Zn{Id@eKmIWAWFIE}*N3TR_F>?O9VUP1Fyqeol#l=2f1e+wo(~Tb|H)zEr=enJ z$w_Mh;2*Nx0f&hvyFU~?e~WfcPper4TJbm?++zJRDtx4q=MLoW<2l(=oKFbAPt?QX zsj0d0syVZ3<^`wD4c63nYK(D}HPh)28%$!y?yPWN&m7F=dc5dBOb+dzqsZfFy z96EE(JX!x>WA@BoSyB0XRnPfu@2;3St!C=1S#xTUP*dA5wdM-oYIaTS)L`ulWSusv z$|-SPYKhvZ^MWjJWkW+vecjx7b7oJSH8Xfk&HO@DJu=Uk2OLbj%0Yrr3*d`tXH1<7 z)J&ZjoQKYpUyQhO(fOIRZf;{-&A7sv^0DXDR9!Uo+?rXkwsVH2Ha&Ij+^N^p%&v0> zhTsAJ=hlaU5=m+JD>}EPQXvp<%nZujBlD2yvqJM`IJlTK55xKIz*ajBkPQK39E;ft zu96LR!*Hp@utenCI!QmW3xjLtgn|x&>p%`JdecU9THU`G(Ro+T4NmvWnseoJPhCU( zbWhFr*$ZZ(wu|fLg{}fY%Vwv~Z2p6k3;SEGu-1B;4OCMOt33!M}2lvY=lj4*172A-Rm0)Ej0|9aworhM|P2V|Th zJV(k{74G5jdXDlOosugP`FkOp;mL+(px!bN?ui&MB>Xb)?`Et4E%nZYO+=Ob191+SsQz z{_=B(f|(x82?q;qhlt(7a}tF*{9O6pmAki$Q1K0peDB6r^;hwY-QpXa_#3;$uP;&g zo4ds?ck1`%s|CQV-wm%nRmD?=^s6f5-%Tm;7@)*&OGt3C!zXbN_<{Qye}o*UEe5w_e+Uie~RiC*F*KD5;vv9V;~d1Eh+IB+Qd(! z#5+SM(>hY(*|z)ZOo=B>6F;ZFgvVW@ndVK2cgKeDIVtg&0Eu5-O8n7@SdS+^B|bYP z-k%bWiIMmfroyUeMSkrrNrNu690pg_~w-OlTzaEPKiG`C4PBIylbmq z^1~_dr=;Y6EG7Qbl=u}X@u#K4x1_|Mo)W(S~r@8d`S8HYCC{u0M(;%`t_QiK@e`yFo{}W*vK+yk0n1&DZKM|(E1N~2g zY3M-z6JZ)S(Emi3h7I&T5vD-{{ZE8x$Uy%SVHz;d|3sLE3-mt`rojUJPlRcxK>rhA z8r;zTM3@E(^gj`%p#uH?*6BZ};m&VK4P_ol4+uvNq=y%E3x~Reuj&@A>lU8WEj+$k zxUyTgyjys9xA2f|;j_Di&*&EZVYl${-NL=Qg|oVazx^g1AA7rncXkVJ>lXf?TllSR z;a9tb*K`X%-!1%PxA3Fg!Vh%|-`6dCXSZg z>Q(Y`6{->yun6n+b$%6XG1}TCmU6zO)iBNAHsk-Gt-5!e~b)$K(N^rUyaX-!^yJAh!iD)$-49ubX+h<9Yfbu(1*%PLf* zLeHzvP!;-%3gxTNZ&b*uLcc)Bw6+Tytu`A8C>JslkP$?**nB|cxe}3LH~vU&ZLJY! z5>(_Gji+hP@)cG_b_UPY?J{43Zcp@8o6#v=Bk=zGJl!6qH??`i87LgzDQ-c9;tzN? zQMeDDtmG(E!io)RwRlygZ5MUO$h4g-(11v{_LE52r<^<~`pi48)NQ9#Z0ZBD}9nSjyovd4dBB*5bY z2o+UEJ_&l+Y0an#B2oO>%^6OY)Yd0QYX(C!X}qs9YoPb zg0#?=Bd&xuglYdg$Fys`m4R)+Nodqqy1E4D-jKO-`K5Z$H00QfW8IwC8+uuv9S(PVQBqhu8st2`iYaPPz0S^e(^K;_MER-< z8yV{b!A5qr2cB|;xDRnrh)tx`G?oo-?+K+O6?%+)e)mze2Oc^^xpTA$TXJYv2+J_xTk?No#E#=TVdZi zL=?4%5vYR0ba@of(IF?JJOS+QM3<`QH-D$1bDih`D*AXQ_qY30CC97iv;L@Zf9gbk zM~cMo*Q06A3xxy#p$3NZ_OP#!FZhYID!*V$nuDdas&0?G*e` zvVsCXpE=Hx0x-k*&XZk5Z_pH4;k4r~UI=}~>Z3QcX6n`$udPy|RTulA=l zTcn*@?6G?2O{>uEPxGC2n_4o^MZ(t9DceOox>H(@UX!N>{tGEPXe(Tqqls%D66CZ%V3^k!cnI$&D+W#eT&UAIhMwYbYE{1^*c`;*|6`TU|E zX#zE`KRB^a^hRQN(I-Xk#mA_A?$P=jp*O7s<`#JOXgOI8f1r+WS(y(51m?gxG9aHZ zv5!KhK5Riwa7t>qrfDsQ1tfLThur$HjqxE|uQNv}cP;8tpy)G@%fO(E? zi2MbZc?%E#9Wk7PFzOCHh_KPLI?rf2P^d+Uk>Lg6nH1iECx3kJRIl>g^+gb`?7_f9 ztI!+M=hy8Zmi=~B0@T6M1b7)okps?}K&>n-IksW7FwB~DdmPbU=9?lmUCHc2e)C)A z-X$($6PV~%Ta3Qf*NlO7hGesDO~k_S^bM+=li3cI-SZJHdQZ|BaMnUx-Hk%PRU2@h zbpVrpk5-)}@DvwCfP}UT;^B>=_nKAya9aLXoct^B)T0f#;>2W`*8!;C;W3ZxR)DU% zk@>>D9CWiA`Q(T=dZ8%fc+M+#oDD1FO&RAF)Bckz@4|cm+OEvlCQKkhyEJRtZvc9+ zdVdG+h8;?dx-l%26Z#!mrfkwbqz3u-iXM~Wd@Z?ubcpp3akfkB`9cBTO9DTn*@=GN ziT-jK`}CZzQ1DDG4$1o)+HT+v91xZm35KA9n;a+?Df6 zS>UT(3Kw%aPzS*2OX{vFx4tkfEPmsBIVAL;tm?%G1ED&h=yk)}%->5}C$p(Zur~?r8aUdLpwg7%#mVFKH z&(a!%^d3pvorP!zcYk6I4i0j3$O$R2n}ui!PDE&>{e(nlWn`Zgz6vdtG%hBXMGDOM z5~UA*u29-nLaF$%5}__e5A(%z5;f3L&4HW2QCxv8NkUbWoJPMGI<%_jS4<$`qf~=F>=|4^(G@w-hN-nfbw%;|iM}c6U16ceZebz0j$kb?tW9z~;Q}&} zBTM{==ceKJxId=sPZ^*shaC5}-a*9Nb+r%|m9UB6+8&EjN)i{5jW z1@bGG#geCTS=fN57`((;7BWL`==MZRd6b=lb>#(4C-K&ft}aQ3L88_N)jKqS zB~6Y09g-%*+Zd)4K^vB(I+RB@PH?(0pFu2fw_ML|3>oA^*ErD&7|mI&m`@bSPX?yh zGXNg(L0@$_{h||nqZ7TH(HPZyIdU05 z>q<~iQA^Q2$jl(D#}LII-&dr(w>a-9v=_h7;jgZodWS3l3gjeBlU21_*RxJr<om3qS9#3^dS+?#B$pv*JMVJsy${wTnmq7(G{@0@V7({v}wyM zMgb9YU;!{7G3S_pS-u&;HKPnYLzJQuq}dtF;3b1^fb|Z1 zP)<*T!StdQ@SfgR0Dz>v#7Pf^f&Z6f zJ&Z{KhS3TcE^}#nCTalWpfxOE-q1ft*?unWKe&M%0x;l>B3lOj1!GL@9r1n_CErxW zDCH2y+T@M>0EKvHB638B+~nZ$OrhrVPZm4ukYsh+b&3=2MJ9&(AWXP&IQvgId?G-wqg}LwY$N{enTT zxJMQtLv%=v6TOhp6sf?%ObS@v=}~2@D0zlrhq$g%=HGu09vDB<&F^HkHkUlZwXH-l z2vN!0p!-6C;aWWaH^cmqMd!%0ux}`zNaNE@%X89-oisk(w0y?F>T{y_{ig)_In|1^ z1Kr7(=#U8xl)sZ1&?^Nc2v0HVk`=5^{N5yV3Y)1=R;9G&LRrmL0VSUsdYpmkjflipvJ(0hD~Y zX)ctFP8y$XS`x|$gbB6FRwOPB1?3MNC@*Iagfo}q8XfW@Oh*D=>O@aa(O>+RiXPxZ zm#OFuC;E65jq+`Gple4sxlU2Jy6#eicYUIoI#Na7jjU{HlM~%>Ey6A$B&fee3a8TV zph+`@TtlHc%rb~spNN?&^LB}!fz>kQu|!I>leak$Gg8LTzG^x(*C*2YIk~G6F?mkz zYH^ILoA%XACoV_)cp=-PebwVaiBV8ce)ut=B%jQ6a{0y4@-1Osg|M>1$u(5`8M*@V z=ZiNd@hvQkantnvnbJC^MgESerNXpy^(GL}gh?e6OD2^}E~#lN>WZzF)=IhF-yyqC z#_`@#7$rV-Q?LSyKTKoQ2^B#LKmR4nlIH@5*x&G`N8Y<8OMJN%aZC8;SA;Iw`z(y% zv$gD5bjV%5;A8j-Jj6ZBGaq@x)hgq6zeL6dRL0pV<8YJ}bD8mn$m7`G4eQO=9axi6 zg4I~lQi6lhgV(2L{uDeoCG+L!sVmb{PfSleHa)mKJ+q$P?r+jlA4>$Gsl6HdIuUGR zaB+HYYZ4%eDad&yJ@aY=9l3^?Xb@ho8AFP> z_{vA5*x|iDWKi0lkHu>&3(-{m1J~!%sp^bY)vG8{Ha-fS9=%gg6bb zuFYD6d>NBa){|P^vEenMAOIiHRdT6e?J2p$u>Jw_dDEvEhV@p-I2b&P)*Tm)*R5BL zjMZa}4F}A?r&{zmlrikQ_|a~EEEtGKWa0ODk>f zo)Rr5`w#Kc0lyv{g*;G6*XgYS+1BaS27N=P5qM)!P`6f@(ag8?Xy)sNUC2V2f7J%H zE_hkD{3#qn$>Y(?uXO7Rv>W<}_A{etXAj6HY}?HE3~lRym$b;wfIU5G9)*6fBj?DL z*T=>p+Ud7R%VT=&3ZH?mwD9kNq>@XC(S6*IFsuXP$Li4fGriaGP|l%ZxUXnEQqr<(;A4%LVu$xQwmbrnx`j;nRpI-r-j=I zzq6NO?ZrOIH>N$`YuW>K`$F9=&8f7$tSs1}7XbaM7YsM;Qq;amFW7Gc-qe=ej>46J z&x6O9)>ewI*M7(uH_XU5U|O|_rd@iJ#;K=KB&jl>Qf=82h@z1Arw>rtT9Z;CXYokJ zVk=)RqL(3HWbDwb{%8OaQ%WP#o85Nw$j0Cg2~Hfj2{DSBFnXHyRcJ>79}OPwF$1tc z{S?Ct+1|v4LcwnFEDQAfQF^_J{g0xZu_?%oNea~yKlBHy%P-%@6g&-onmN4-I`!d= zNYieKW6kAkdz-ZJUCDUEdKLRsK=2+(8j+1!_)=ss0F+^MR7R%*u`LF0Z#3=7Fzkj- zv!Dai&MAP2RrVy*hn#WGW9ID#^H^`#CvZwcGco+!EHI-UOoRE0wLvdw%RVLmgH|oL z*09bqtV_Ixt#_K%TGRSC-izaIj|F+Q$zpdR~==Rh`IT*q5*| zWQm33v{LxfA;1SdT<~WXpQaVVnhrpz$NYtRk=7<6!9gg@L_m305W3{1*m@)z)-k%3 zQmV0%R&@B_iRVyN>hcAmQPta6^! z&U1?MoZ&p{oo9o3E^FiyKfQg~jp~iB&B<3t*Ybzh<;mBFomaJi37(UfKkN^ufNB4p z`N3**`!B4-2 zRmgEvIZ1``Ra&+RNe7PO$R}g_H(`1w;APao@GA~)zaYCH+gOcK282*zN!9$gj9G|X zPMP7iGKk$L$fLnC;WPm=@CRZzv)dD(QZMq0<)m3~y9Q)MuF$&%XvNRsM@h)0Aig_1 zTz2Fp_*ifhzYYsa4hm@D_bKgzz_9gDjFw<|hjlJ|pNxh7g58jZX{3egB=@O64%m2P zmLjJhCijdy;9Co~0YTxi-WR$ze07+Rd?IzHPgHaWp9@@Z8F-Kyw7woVm) zq?7D*o;l7l&w1uMPrvglbe=<5Q&uG29r+SonWwceA z(_X47mLdzbUg4T;t%>~?djG%EKde3Nr4Gw0ZBWoOHS&p{-rmH?%pYEpEU)Qq70KA; z$=8RIua7ye?WKMBChp#s=(_51N>^QEq<4AWTP|G_30OOtmdx;Iceb?2^~FAtRz=WX zA3R^gf>)6b6XUu)+T>P+wEyOM;c{uUzfAt@@V6SjZqLcr1M#5GrLQD^-uMQoa19(c z;-Q#f&w^Rrb^Ci7Nd`AT7!3(H9y~2ko<2HjT8Ngva^*$g9V!N-0kmnw6MD*KeQBF)$B^xO5ZY0RDV~#DvLPxKgE-OO zOQ9558#MoV@>0XvOP?a^Tegq~;?4Fob!N;zs%d|h7P+|xsKFi&epu%<3f?qY5A-aL z27O&kUxS%egLRgJ8{T*}UMg$9Gg`OzjNOc`Lf#r>kNa__G6qi^l~-n8k&E%u5b+6y zZVQI)ea%m21NxTjq6HickMm{Yuj$LA(O7JRnWLoXKnC}9Fy=Bk2U$W{2wem~FU$1> zF3QxxFkvC%=NTxq0dl0b1gulV67o@YUn*%6@@3vTk#~=aHx}EzH`arDC1y;2Pr=o<-s}Qb` znu5w@wZzbOzz|?fAQP}+fMwb0vDkYU4`q?q+mYemtTaD?rs`D^OQ^FnAAXUc? z$VB{QStm9E0l$SIYs4bhrw(w{jt;q&dm+)`w<1bxd>z$De&NtGC$;s~SnFj<3TU&; zm#g`)8BqX(_mzOw#o;3}W0JC@R>Z!>~NUva0( zY-BU0FJJR_I0FN6QGiVRB7TqJGsyQM^96nR$OLdFe}PPJdq4vnF%GW8RiB{547x71 z0>1!8Mi3X8e_gC6hb}2RN;)QgN$?dp2Ct5_-bjprSMviw)vNJ8>nzzXV4K(<$9a#D zz<5kXYM+U|T>J$@qCvHc41jtkV1Pz}B;Y|C_?fosb2SPZrp@&Vekam7 z&XJZ-VM)_kI-9+R7H7YWJS3kIt>1^}lPPKF4d8ME{>wg9uSyP=H1#_3pe6nn(km$f zS+SoC$ku88uVZs1`MdK0y=D4BIW%9##>=6hBilk>rl@>P4hmk``}Wwk5S3B}#L>?Z zyOn9ubDn9>_L}_%8i6(1GTsNMtA!C`rdxX<6nTLjL@F;#NW)tAJ46sk`(Ekl8XX<% zHz;UZFq}e~u7oEC7N)gYI4*O2Mkmj9qxIV!;w)rmGS;s}SV}_P$^tAF-}ML*!M)Fb zLEW@EFN)?qE%CuULXbjDD=174tP1TY+LA;rOr>^*x{UbHi%b6Vt73|0II@TXI+0e?F7WtTtM6`<^#iad}K*0PBe zzP6P9Q3$jvu*#pP2i6CR^@8=t+e5cn_13T9W7?THf3V!jbJhtYW?#=uRUiz+(~3>^ zx$(66aYsGD`ez)nehr`m&({mqmDw3IsbIV%`1ua_3ZHiF(vraX zGR?#!!On1?WA4r}tEH`Eh3CAw)$uPeKZuVbb^B7exPcymeVuWX;|3&!_N4T;Ert43 zzC?N0xmmoA)OX5yCW?Q_2Ho-PhM$T<;D3PZ0BTe91OG555&dOaMGJf348td^7ui&Q zKnpmLBm0xkV4<-z?Z`jKaMEq&fzFhg(B1#}|8?dM>wHL2x8VGNr^Ncwj7mEPoK>*h z@M-)fYTeP>;VS{ILbtH?ksb2AePpM6T7O%PCnj(SY$Ky}TLyKziSmu^IAh-cE2JarWD+LnPD zmn7_q?6)^*bfe?|6SCxNZDLHj>Nd5`ndo-&^yZ{K@D?@d7XLUm^YG zjMg!Zn=pXMlOyDLNB|oVHvp%} z{?Y>fm1BXIn&Z+PcXI@2EPJ7teUzds1gO3#Xb4MN9+x%v$RZP_931}Lmn!$A+I^Yg zzRYl6>fM(H_hqOTZ(`kQc+7jKNdjzaajg+gAqK*M9-RbBTzMXJE+*-E^ir=LZB$ut ztdWsbv;j2`_8AP;^9}tUM+eFCP3WRNJ@B>`evR^r{WEL|edd*^I$W-Df0v`mIIO{Fk z-0bmHpku%S8{@{%F>&yUVsvdki3DmCiR9Y6VI%UF#1q6564+b}W(^yezvu)Gq969p zViR=e7dB7vVR?*>^y^J4n zfe+>#D-9k^YrJCXDgYLUFZi5p3^xLn1bJ!^YX17t)ubv+RF{vI)WjY|A!ptJ?-ERK zypbloFzfXwwofo``{>dOQH_&4x20noffG7fd^3{U)mSXC5!b2P4|k#|aS36;Mj$Dm zK_*0-U4ojSXE6^49mIk25%M5A*p!-dp<>K2sfATzk*0nTI~Ki1ItCo^$hA?swD%J@ z_k)=mKxQ!5MAq%FNHPa0h|Cwl*yMyZmB`hqa&c;pXD&`{ssO1q-G@8}VKQ{AgT*8Y z+r%@iNv^D6qPSdbAl}pEho&8crd?5=GKoPr@Dm$W(V#XtAbJRSO2SmP{$~11N~;W2?+)sSY*`(#s}8 z=p8>BeEVPV1FjpAADnzB&OW~rownB6mkiaTll=;`2H4ZlB1HupTwAL}t|n%z&$#se z?#Gxkvk-3XwvceW&rhuPH7~f-v?ewPQwr<7#Fw*&4_QU+s??sBMpB1Iny3uGpQ#)( zXYsCObo}QA4ZQxPaCwbS5U`p5oH=|2*aF7B!g%otxn^7xqwHBAtg<#k&+NaXvdI~ue#M}T9e6cM8<-uKzC49 zgGJ?$PZkeAVOoy*y+S#Zh##7qndbxXx;r!B)K(6ug_aUo0FxZHD1&UGUDCU7OcGnq!>ye z(ycp(FVd266mrp$F#rvZiuQFZ8Sn_>6yvUw9@wOXzd?3TU(r4>?eE;vsstE|>Y45} zqanTMJrn#d+YG#}h2b>^CgHIM-wzQPJYm3nvMsBF(k6$uHw6Xo1<71NNH|Yz>v_ZK zAt%UVev-dpLUja>C_(I^(;R7!f&$EfkljQL-20M1z;ObmQP@L@%Qz9>IMXK+RKJe? z7+|CHOuzmr@(2!J`WPiAhcrmJ=)x2Rwj`2bC!(a=R_tg5#A!eQj2Rb0-V^1|CTG+S zFuM-Pf2Gy$VD;A8Mj`(=Omf}M^{I0k1(099?h!mqx6e)J)tT3WU@~t2d1p!s2kZ$U ztV8f(*e=fCxfX{PWQp?uO}AdAGrs95;5I&gkbajs->Xgg%6f2I%miAD@mHE8v8<&E z%Vjvjf00~Id6S8W?lN=fM}o-z3N?rulNq=xh!lfwp0+;eEve0;1g@Xq_Ao zpLW9o>@*gmLLFRsXfTT>xrXw+SPQR2nso}#^`bC(!Xy-%=F2piR)eoh^Le$%QIa+% zq7X#dLh?s!oSiJ5eHw#V5XTXhyY<#@dgx#6(Q7xB+0*WX3h7+4QGtzBj<=qmYAtdD z!D(7uUFK?(lCrVB0eDLNSub5FW&FNASa!ARgR*fK;@}_!tpqey%`{3kN+LDRl?4a* z)M&*5>B;?&96JMX=z-k95uNCPyWIFOX;Ms5fXCr)*1jU?JV)=SmvM~$VZx7DUPVC` zY6!>*?a5s)(qYPitd4wu{GvEUCrjaAIQ8%?XxCOQ8DcQ8XkS+asksW2N!b9xLYc1` z2tlRrnZQmlfPY>FLVW4WYP+Y3KJa|4AT z#&%GOX@H>?vNARf*%Gw5GGDM*yK{}CLAxhJy@YN_^GKWATa*Vf<@GxNvz`E$QJ&F+ zTv};uF3b2T#^`1xCIZgYaef|i5H|_6% zokr}Dw|f9E2Zfpw|-kSe1y{Dfnc3E85- zIYqC~K142`FjFYBu^V@V?R@P(XY`46gy7o6k$l&aWEVZDU;(MMq&BjolcM|K{n1H z#Qq2za-OwQ&Gp1Sk_~k)s>z}YWl_1+#0p}av?xTcn~~N^c%Wor$p+3LLm)7E^zV&m zA}#k*gSzG3Bv)%3FeP8CPs`h{d(Ag7FS5-75ajv^k`rBLcdK>;veIk_%hq-IX<6Uw zo^?Mkz^qS0L-(wIC9dD}TnIgN1p4OR%sy zDI7uL~*Q%BAMmD}FZQe`!a;{8o0PG8&RTOB#~4(~xvT zS;CTpMO#^t@MDdLC>>eyHx!Lu0b!c|7fRIr3e$YB!#0 zR>7gK4@L})i0EUoS-dg{9yH0MxhY9NddhTKhE}k&PjTlG?V{vu;AfCjz{z%jA?MO@ zyssZ3n)WwVt=_N1JLhX+eNXts!}%bMoT#@nEuMjO_ek^$kMw<3E&L#3 z?aB4AxH`q>@-AlJCAufJi2Jaat7%zsh;{&HRo7gCfych)i zT!ZVtO7hBv8Ipcp)2*+WX}b~FOZpk@SR!FgCjC&m$bf(c`k~#%TR99XTW|z4NCPW; ziB)`{lFKBh2rehYS zNJkh~iNF8c9`U>%4Ez7KJ;F~e#WmkZvqzi;Ch6)RZQ^f_0we#XJ>qTTbqQwzxRLul z3GZp>NT(<112)6aPx8n%XoY~}hrgze;wdc4_&d=zAdTX{z+6?O7<43672SQQk}s$W zd-qUP24cB^FIc*&Vo&5W?=#Dn-K_K!TEMBN6ypD2vaS?8MQ)Zrza4=keN^7C%UwIT z+%FO7dWtk6!p60zSPEuh@is|!7hpBe^>LvCk2QCy+BaFaU2&> z74PHh;6`=|3v8iW-^AE6-R98!X};{(f%NoqkscF>quNtI)|i35xs?1{Omb%*fclPM zz0OT@6bDwI2chA6Ss8wWuo}S3DADd*jiqKI zA-RzW(Xfv!zXY~8W-2bRkA%^`+^JSD@^$dab;HrYTm(W+BPgx&gQ%wz}xQ{Xl zYAuY=013=OxFBA&@Nda@kWy7hp%_-R*U)wAbnRi%JMRmQrY1ONy~3W2bJMw0gMmFHE$!vjf-?cWnWk zB8LQCN&MMT?ZSd8<)X@HgC9U~NKNZGHC8=W$PvcysbPB@17PCBk)wjyiWumosovm? zpN$h!f$V--m|wkxaZf78k(rpZmDb)$*rtIfFjvx+!RzqQfqbLqGSyD8SGTJ>jA09) zuzdjh@l2UggbZr|#$tgVX;9bRmuarDhBQCYijf9o?LDTc$(Jl!jPz=xPr*$=Dt$Tn znM|)X?58;~jA2hV;@THT6xE~-nH^(KfYNcgRt22M>ngXl}C_Kbj5`DmMo%MkP_A zO`Lifs>6jeL{V%Uj4d+Sh`!K>2$6x**j6A$2NoHlNAK`0$5Sjj6)C@SQb<}NgcyfD zRFUXdFXpXROhf{}QnVfMu}pMqHj?s1iAqwW9pMejY=Dym$wS5IY~-}mMsgSvmN6J~ z7tG7^5xZCRsT-JEPC?2uOi8I1M4KblA_XL$6O*TckV+aRP^2W8<3Z*?K;|lUFgBE3 z0!K>Z`gK7{t{$?b-q+)rnP+=r4S<^5Pu`I!v-u z!R$Z1UfM0KY+uSkjWrpobYhc69(l$!1heOVU=^4 zH>^qAmz3Bp9M(`@^;q-xrBMb?LP}C=2vLYLtRDQ#LpoyU=yo&ekLUzY*<MAImdt=;T(&bR}q22 z(lKrXe0W!dK9$F1#jNEh936gNB`CW2eil^v`biI_>+AAz0Cg0Gd*DNL>wR2nubvWrmZ~khAjf4V6PUQ zL>VB3alu1=6{2AD7yIuc(^>5laOzOyAR@zqS&Q^zX~^*P9I?ubPx;=!cU}S1$QS%7 z17AYjW=`@ZobIHTsANoWe`^2fFS-U8<{i0Ek_|TT z6mrK@zE#OL$qWb3P4sDpFG1B3Lr+FY)(IxT`a+K-+e>u%LcKp+w%|QPM}7L6EPD%_ zTONMB1~2ZJB*;KQ+)E@kcoR3n&lzC$%P6; z?w%HCp%+_J@5Fl18ta(?59r-0TSFcJ@9+(wU6t0_XD>w8X$h;Rv|WG=;jh8!WZfpj zb>1|hJ>Q}?bzBJQt=*_+Y!u^Pz@RPtjyRJoWA}+GO!P?^1Dk82{66u6=N<64HC@jQ z`HPb981_$avLBZajO--VpWy`*+6}j$GBqD`>k#)-7)@(2qke*_LNDm{FDQZN!+z0> zNAO79k}p6e8Ehnu?rLA>>g(~>83?Aeqct888yada-uTSMg`!Mw}TY@>#cYC`y=N=d#GAcuJ}93{~EJf@+GX}NvvxTN<{JlJ8=HM{T2WCFUU4i`Nt&a z(cz2WI4=3pdY)c;1lbZMEHv=ga2Sk;iR9^; zXkV!{cQ^fl@oV8nn3b1v?!`4V<<{%sCP0PjXYvdS%#$JzxRtzA5Gy1uGL+fhUD+j{ zFvN>z@0ONr=~p4|76dJo^S0+kWC6Fh1R_x8J$m$l2_Wyxa;!tq-0LOvL}XQvQ4PtU~z}?%?^&SI!x=ecrWaPI7{I1SpKkdj>-#+kp9@TJJKKf zi(r53jx0P=r$;S#B2Cn}^7^gNAoRcwO%Cf8DpTt)n6Kkk-PA5u$sEzS9$>?-`gt)B zK&KgGA;!?2WkRklxx~crOqi!0xW}WY1#+8=v%##=+OGF@F z2HGq9&a+THaRlg8m{PHKdb*DBHw(6Kg@E(KgGff*R;eEc+!oIq6CE%gr;oTzWwdVV z881lAk57fj=R7RUX+arGX0ZM|;adrYh*QLau;-bNWxa(78J~+kwV?0zbvgR5>+(Wi zz2C`OZn{L+2jl%JjJF;3v0sT*U5jwj71NyEPg&RCI*$-!~=byTl?{rq#%JRkte9nX7W}gij>rggw}&mSJ)g z+>-J-Jh#OAs(_ghW48@ZFbZ_Ax!~YRSgpL3{RjB;_Hxuu>Mi#=Pf5*Mgjcr|y%@hW z!GCd~f*+q*$v5$tl~b@^k^D%@<97n1oFRd2p%}9ynCIO{p(uo=OY`-ffE6!=u?xQ= zgC#qvKnAh#4&;JAy4Qo4A&hB3jCcfZ?z}k$%u2nt$NDAS*Tjb4-I-71MP_t?-xZdm zPp3Tk02=z4SNguK!w_J?1Bwyq4vfu+Fo(^6&0rfxr@dJAz7p3+=LakvRZa9(@tn92IK~`p?#&*ZJApFMo*HaI0KLn|O#N zRryl%>$=p%bW|nxD`7uuFIH2Y8bKlAd78w4$(y4PJVEgz$UD5j1UkT%Dv4Z<3quzm zmI#u@19yHaD5w7dA`4bL2`H}xk;Bx{7pmGg6l44C5ZVwtO9qR9?r!iLLxlWy;NjID z3Z6Ic=91;UyOQq*Qr`O{*nfNM#FY0RrgK>6bG*;P+NRc2xhr33tsV$uTy3Uyb_8p-P(nBz2Hr%aFG*^Ajytyzy&U_bv$$< z@nuCD_z1kCg^$W)axOQadQT|9yY+!#y`Q*84jO||%aiPo)D5gTP?wwouC1YuP3y^K z)B`q*1EkNhX487~Zag3Ypf8U;j8{A({OpaxSIydU_mskzz$>yQ$^_BpcIL{tleu{| z1bel}(kBz&cB)>F;xYj?m{6M@jnE`$+Ok>GmTl>ltOjZF!$(%6kpM(T>PR3gNO(@x z9I3~{r9T52h;BzEHri!Ll9_Hl$sdCq$=-t(Pz zUEGiTTMq6FG*^sR98OgZN4y-Q!N4IGP|kf45c>|b`C~;o0EuqhDfl33+(-i%Cw+UV zT8zLOpe00@!AK~Q_VZP;*(ZFBK%9XMog4Xte%XvCf8=&EoqmgLU$YUe-#95`Ukr}2 zk^IHcB9Y*Zr%DI^LAq|=IYTybv#K0;XwYp1Kdx|RLu8%!bVD&9z7DKxvzQ3iLGn~s z9{u8lCrQa&;;vg{-$L&K#r3Q)iVu#Uw3z)E5~86S8Sxk+Cdr8L%Z+GQn&o+5H-vLo z%dtE>Y}n7qNCW$Q1)G3LP8uD@&mog+A}|FwUWDtobHFJo5w{+3>h3!|@P@Wbtuhf7 zG6L5nR)XZ!>E=~}oPIF+}^$qW5rgsA7e5>>IYktD~5r1d%*ObAvIR~Hd9jGGUI zsIJ%v{NR+<>kf*ZzF8(FKmla|dr=AV-pw^J-F$;hZC* zDX~57ZmYDC*f*B7)YLKM7^Q6fD97zkKwix--vm@ zI%CT@v6{@>PETwzU*+Yy3v3K#5B}MSKU~tlWn#D8mVCc26V085J#jC#uPUv#Mnwm_ zMT}u*RT}D`IKa<#m%M=U0Vgtlggy9yqK&5QF*1ax>Eu{LUtuMfAXJ^~LX@HAa0Qtu z;t^~*5-yT=ks;}UyT~TUz~l?Qyj712j2iU7tHFr|6gphA)QAio+^R?mm$Y)pZ9iP% z5d5)IUkNULv}G)WrNR2Jr^!4*ujotLWgEC%6y?Nku+;PFhnGkfki;{}17~Ts@`@6| zoV1#UG^l~8r&L76O$qs#gOzqm#!%uPqB5=ez<;Xjfq|fUSLTCWp6OHA-wDWo{i649 z8HA*JG!4~c36$;uAF!ik7B2Ub2+z|UWdG5L$N{qd=pcKO*u9VV#DUn1&zgDwpEoId z-sIv_zCizlpV(Xw>1)AhhV@DkpFaTx9DF{`B9-tP9O2Z56M$G4Wi&STefqG+B|=}; zqtD7lAjYt?7=slUrhq$hYCXT#|qM6gML-Y`FOLqq|<6}j^u!~B! zXlTtnGeGgAWW&C*5&1^h7i@)Pz$!Z2?Zhh7 zn-eHPf9`U!6 z@=!Xfkot)<`PjJb*dD|qD`y$jfn?xz@0}+n?8uz|9SQ&79CC9z-EsRw12U<*0 zt}l`G>np~qUsl4TC=a48dlZ10ybPrRyy#o~*+>4Uex^J)3#<2ER>4v~f3Z$B42uvZ z46Q-fFEI<+bXmw5q(OFVYJqO!qdH9GTtQA)a-JaWl zX9u30ctXB>r_~^ zbgcpx1t{Zj`~@KQ?sI_LXRepXRr2J)_|^UyetSw4<$s7@T_M;D1VdL!BDg>CatRS! z`cDw7=%PD*eL%>;_#KU|#{ZU%-;)qW{NC%}7cvTe>G&;2rhkrKTvz1ccMWLABk?=5 z2>88h=|9Eq|33YKCnYGDsz?fo{Y3FtqPX~((ByX0VENWEJc1r5?p9Vv<$#P&u{80KBZ^B3fPqsd=9Pw*uWh+A)~% zn{W0gY;BJ4k}CMsH-w>_kH~St*u{(oFMtgYtO)uw7N4dT@K5kiYZUmW4*od`{8I=2 z90mTlDfB+hyq-f6f1p=|ACrI1(eizP20xOK?4CJJQ)j-u!!| z4g=>?wg$NJMS6Zm!1{Awg z*X1{j;UjE~#k+GEPw7wui@CW4A~IplJ~cxrT9IG{D{5ke8}jiqbnxK$bqjjgT|T}XahzZOR6gc6y8nau z)o>P2`(F#X&#!TqReN}=9n_5+Zqpj}F>$Zle}h6)f;A6x0mVA7!O%=-k|o-`Ezl}L zTPv-P%EK+8Q;J#wFPU(kGquWflGI6;)=N<3%qSF?z?FHy6R@E!K19IUw8=!h#ouEp0dE~{KciX{0R;jt9X~t zX*!614w@?q%*bCfg?C*xHRO3rt-MXw$~MuEi)Qh)6*Vgp_UA>oXAx`yGw5fm;nDf1 zVHj#SMb>a#o(J=C6KZKe4e{q7-!v`pdqPF3GHj$ccRa@)gkV7WWdc6!QnV2_`n1UV z`jx*0Td{drU#4&IDuscTWR@g>b+cu^q?`S>3K*II29e3|q~*eY~( z0k(Gc0tS0g6*iZ3GvgzcYcJKwC~zje8vmyt*9;lP6d5T#$YRQ>B)(Gjb5I&ZRyq5~ z+~;xkamLR;+7$4)Dlq67xXc|M+jaG*x?Wb@UfQ6ZO^tlwr-gaGC%IQ3UAMPz$pBS% z22=W5Z(6@?LI2wsETyF6Z@NOci;BGS6mKXM~W{;LlUXyNk+%t83)jxg~* zBLyCtH3B4Qxw358jMJL^B-0Lt?HN@ZuwirIVHWTI&~iLdzjxr&)GqHLMH&+y;}bvi zJ*;lHdLIeI)hqLoxU!ea$ltP{ZrN-Km{5-=nRYO2KUApZ~g`l zJGV_$^4nr`^)qds&$PakcvJTk)iIm;4ur9Ebqn&)?{=&64(DI`aN?s8pocf~zT5eU~`(Y9hCcf}1ie;b1~$vAh=2 z>O%?9nhC~OLhaC>BCScdUw}r>-x8Xiy&{uU!9^xxHTOC|*`~$rNB3iaL$~E+KJ0d^ z0A=(Q?f^-97916IZDcD9$i=h-d+B5te-x6>JEvlK-VbLRHN@@lquiHM#2bjl$GfN- z52mTi`MAX__>{yy)9nF1@mJ)~t++m{uPFeQ_FSi=Qp#vbcLs z7&aO5L=FnF=7dM=W@HgRLp|al{Ey7Q?_s@&#J31v6RafTgawc)o_k}(yc1=09R;8I*!!g4#TnT%gOG`0I~Hd z4)}KQGvGt)<)8J4A=auH_?UlAksxjsPdRx%zY~!!BE>;!+^M)N{-zt#Tf~sgYWkqC zc;5+RB9P!GQUURJ8ZT_G$T|A$A6-D1VTH#Rd=dW*N)Byl`?KA@j|E;_k<+I5E1R{> zmofM=CbRqZ2bqq&8=Q5ifSXdy80=Z};ai_EUWhSD{_oy)iU+FPAHZr*gCaSXqHxZb z95?{sveAk7Zbc1vCBBK%JwN0cF9`xqNg;Sj0#UBJX%MRY2X0qlM0>?P#Jy3fU&tlt zlknRoMm8?~5w>!@O`l;~DED?f3LX+2cbkv$YahsfEx8bocrUxGPv##8%~j>kM*b=s z<cM?jEz)pC9y|=KqY9J>y&~mUK%8(5kz3Je+ojK8zsw=s*zMrC z^8J}tUs1mlOu%Q_-g2Q+p6X!Ii13s=my5qJi|*0p91WKw*q*b`2je*(HqV{G^QFs)aOq-> zhxf&D=rN>PpAkE(Q8)!8LfiCc`3?!7=usJ+w#?6+eV;tGv7IYQbmR`6>oTK*BWw*; z)XI!6McL1&72>`(?pR%lfLzBMjUIyR!E2V1XN>2-ulQR6Ucyxzl0F7V*2ld-j2ZVU zt#;f1mM>SM%IKB3v6E3keiaS{YqeE8{XR7j?`4NsaP zRvwMgdi3}+_^^v%8eSh`0M*6Ti$b6FfKR91kG7dN$K7Ong ziLfTeNgw-W(Y{Y2f~+8&C17)~467JaPVoa&&tqLMjNsS-xH7G29PO+8R$O)l*mb}zrVzK$TnFUwI5Ss_jgYn0s2Rc6Cl#1_vZ$J2Ulsf!=a8TBHxcy*P^Q3eDs=Slnt78X5 zR|4R)A%)7e;0h8V_9ieB_KFfTF*-T`RNqscg~Pe2pFX5@l;y z*46X*C9^c}68>BNc74bFcj}Wf$RS3|BQ5e5fbQafHU(VdIpDAp=_2y1H|mIOpC8!9 z*|$G~W_iX6fsR+Fv_g>OM(rN^Z9~r|Rfm0Og5yETkz_#G3~Xu1qfmOrOLVFSBkKs#awZRnpmFge z%JX0#*j%fb(JC|JbJAM)00y25thHb9y?=jy4=58ovEa3{2;&52`!3$#}x&XA}Bk!@#6LtVSvF|6n z+Jp~PVpWVyK@P}%*J$GLD?m4}S00F6fW=Xqi(t|*g+mAv+*;^$@JLX49X2Gf2lyoD zE`xM;ItnO;&@PU^JGD}1Azi{kK;*JL8*XTb-7V%LDjs*(hZxH4VG2#bQ_XLAFdGcQ znUY```mi~miCfS!T$Se4`=5d{AM7V)Y5EzSU|lNFv@Mv=mEt_KM@q@SC)6vkyus^3 z9wWjL6r5j&OR-+C$F`FQm)n2)!kfIA=O<2^ASzeYfW*bE$ zZ;|3PL_wZCm$&Ued*el}3#i6Ij*7kUQIfu`WPUAiw~Nf%9P70$b$OEd;`jan{{w%4 z{*y?H|6~7;|I^;LfJaqb`=2uphyG!p^@L`_Hr5(qFRp!lFC!^{bp zGMO1?CJ?Ocr5B}WgXZ5_rGmZD`oODgW7SHlR$AX)R9da&R#fV>Ew$K7|9`9W*6ZBg z+I#JpGbb~MkM@4w|I7Z8S?B!D-fQo@_S$Q&z0W!O2reJwc>(WO=-R}Kl$~?x%h@eL z0pN_ohk&)dV*UYozcv4S&3p{m^K8%8%s1T<}JA{sF=Da)EOC%h^YUXx5ElA+b1j8&D+MCQy{ zYlSn3)La>%X-OhtO+IVlR3no}#L{!-G%b$C7l-;fGNDWwXf&RgYIMaCAw{suU?yQE zW1;oNnP(c){Klo1B54kAp5rA3EgXu+6B#3%vQP%YRt6&$O0Bk1My5LyH+n*SW;%L_ zRfQa6D6B^&+-+u3p|w^jZB=pM+Y^aQdUBCc7OhTOG@U_YPTOIHQsM4)E8QCtHMNy% zE^2+&#C$oTsZ4JuW^y&nRUwq+!stp8fT-`C0Lc zLZYRq>apmm&xPaAjp}eRsY#WYGbd%Gt<+j8cWy35;X|=lB5X$ExRenMSERio*2TS6 z>E_;WcGa5rGx%RZ_v#*;wP^;czy_u^a_HkL(3+0!Gkbapuv71JE!LVwLvW5!gYb}EFg;xFj!S!Knf z{EW1fft4{Xu~G@6Dq>}xHm!kgdz416Tagm0rT4PyBBu*N&lkh1;1Os83 z5nz$d;bI##MpbuobvLPiY(HfHC_c4xla$>YV->}bl{KPiqlPPLZ-?4kJ2Z+`G*Y4X z8Y&^;a~ml6PhV*SQ;F5+Tj}UpE2nT(i(#0eNi>U4z4Rn4uWwS|^(6$Z%4_Fp#iEc{ zh*Z}|_lCpBtgBb-L#}C5x5ilP2>rt#)lGAi#%t3=u$2dO62_g?UsDB}Y6x2??E0A; z-!{Hrbcs?ogw>DmReP(ajn!7%N@385P#cqelLUq8r{^Gn!Ah>pF&uq{GNUlfn3r#> z^c+KjWyeSpjboG%5k$kvzOAW-nPaT6k{Kgc3lXz`6ZA1A`Xe}!GN^7@dtr?j6>5+~ z*@m3XQyoe)iL4C9+*D?ck#E*lH+|jE>J35aB&n6eFUmw&h+aRPn?@Uh#=K5qUFfT1 z2Q&>&7@SEDFW_m2v!34FcBX!^`;56BP1^*$J-i^BP2#w0VK%!Jv=VfPC^ijEj@Gnp z&<8+wg4SdC=ktrR*{S2;1Ay+s(&P1@{U6h`-DjqidU0zH0dHXAuX)24y$1KkK3#4YCQK?gpmX|+Kt zh2cq;1E7PTC!LecZa!JlYCvzsqi0ne+3cIRt-k{FPT-qCgTQy-2KW=8dqG#Ao=@QM zgD0rIcpAp=0tZckR>iZ~*NEatt~%@v)CU>_ZNYOJ5zsO`&A1(OE9mo}JMb`P1LU+B z55x9@_Tf3;y>t$G5OfEqh8tlAK#vFAjK{g#K=*=P2D%gP5gY`qzZ7yf2IcxvHk$!m z@lV<8F3>*E=Rx;^9ssS!BYf|GZU(Kug9T53P5?a&S`WJ7E0AM3%DEA8K)W3V-Hi5n z0`IoI1G*D03L9v*tyiGEL3dt>c%TQcsYM&w?{Gil2HJNO;)B)?KrU#<1E5cUuJ{J( zi*`JK2NMTC>$jnPXve16p+#(xF|u ze}HtLm3MfHx= z8P)$W$PIMoINt&K1n5rCGVBG` z{VE>r2d(;#Z1z&n&7k*yu6PaCPepvt2GGNxeV}`>G`|;g^8x4^HdyQa18~sFgWwC= z0=f^h5A+b7|0$cDi+)1^%_xF?NxP=^Rp=*=9bHEEfeA;JTA!88PNW#kqoFyPHUa;2 z_}|b38PU?PuVTKha^c6y*OhJ1&OT|*+S ze*7Oseyzw~oA0Z*+7<9sUgZw>j4hs4Usb=i#aFXw)B<1q#?h`Hjq%kq`l{ynjCsDw zdA^EyzOwn{uBF4kHn4gP|A+8c!Cj!*BI1iQ%y%B~Z6>~r#J9!Q;JVCL-#Cm<7p?sZqRAMQW6a20sep}oqOTPzY+C=;|j&dy?0JKfNtI=1v$&LDL^ti4a zrEAWo>&c}H*weKj?pCCuXY59UYG_|>Gst+0N3>bK zzRkW)_n(T^7fDHnDd?93+k}RX-ga+;rk$+xd@8PyPIvZXv%doMTnFg{d=;BW_cpp0 z`3%?HvXqTNKqNcT)xl=8KgGWo2!!C?BpPk^*UW!IgHn1PF{nEceCO(E6)+2u~@ zDlmT5C1B?`-4@dAero4Uq}v-wx7P_RMnlVfr$rFA1?k$5?qzB((RUY89~14B*Ux6^ zW6i#x`wRM*VhY+^?)n8TAP-01ZiQ-U_4p@$<`InBDe4EvXA8Aczq^^r*F@!;?+X$S z>Up^mwv=n6QwNc@ERAjYIjvo%L>Yx0XreOC_jS7dyJ!&ul2<+ai8lBNw4<~Zd=LE$ zx<5trEmkdR)62kX0KNu>AQ1@L3T!*D^KICjz~~(&JGKj09sCf|MGBPmd0Wy9V9wiB4im}r0aKac!4n_W(rqo1$OsZFb0ZA81YAl*Tv`!%J5Jc{X!>-&y& zfx)46xeRFs?$2h=%eBkWobC+k>p|C7i?$25@gDHH`GKS6wYey-1+rYHz}MUP;8FAX zT2Wrjl2I$D;CP$m@CNaVPZPR`BwaA1~Tz*cb%+xDjJz(7hh4 zs5{w*qQ-Iqd`?08_cdg*x8T_GG#or%Fh39hqqZ^1Fn+>c`>^FEUNP|M{^^I8r?77* z`&j7i8&vj_Pu+f>dD~P&sja}drh#N^rJ?kBTFWvqZ7 zJ^}vdJ2_nxzJg=i5q1pYER79CoCS>8#OosR@!`{cSiP#`r2QUZd}I1NK1g z2B0giU`+H(0S_@g<$X(sjalH^qr1;_p=E~~?1km-&lT^Q&@TgdC6U)Cl=l+K_Z3)U zFxxnK)2M#$7SC1gt6dGiGZ@VF3@UUi;~(qfY5*F#Qrl`|D#mChMfM-McHa_{hBXkvpVrTNAhJ`F2-2ao1hDxOhFrR z|4JhLwcw9whUPJde=yfxxg$=_CfHv6W1%TCJ=!!@|;qlE_T0nKzBb>yfX91 zUo!Bo;UI`SF|E}+&tdZG8z|66lA#B1#aFP_0h?-*A8f!dUwff%zwUm{UDWE$r#xHm zPx>9)1Rs*>u^P zb2eK`I&hrKZqAGw~4dhz~9lZc~iZSA8EF3s=fcnV_#CZpKihgnz;-HT2 zr=M&^`mxAE^pg#gN4}qsF6_khJFm!Q?j3-Tm*ybpf7{l;v2k`QaSa zg-=tSgTr*80$uh^tQ$QHyDN1;*hz;UFZ$8~-+;^g9ZykR5P4EPsA?IkFYTs!+>J^V z>!0>|pwC}kydF87cpf}Qp&p`r(%=b0Veh1E$oDYf?EYppdoAh2c?J23^@RMKgzZAo zy^^v;Q5;hbX%||cRy$^Avsj+eJSR||Nqd%Z{6n0>*J2G5{+Z{mkh-XAUVg&9<@Ine zjWf4WUm9+Va&_GWo~_9DT*`M3csj-ts_ScrV{FT2XH#A8QTYlRDC#=Qt_EE%7R|MZ z^&h5Z3i1^4T}OEa3+jqls#Xuz=^*mF3CFa4;Q9g<5MTkSaJmfF_QP+UP1oR#Y$Dqz ze3g~L8>5l=;m6Qtad{AVy@|N{Xly=CmYvrYFy3tRxPOO%m|Zv$3f1>{q#1-CNnfq> zOb1wU{}rnh!h19#99unZAZMF@DF3uZhwRE7yeRyzMb=z9vVK@Qn~k0`Wvt| zNafDT^lT&Gt+DME`lkF<_xZ)U)`D^a{HDS+dUaMd`$HUieh+>apKPv+Xs;mRc(4xq zHvDMM+lV99N}9%vYQEZim1n^gZ~v%88%MjJ^o+qmB4+Kx_cEkufIR%f_m4<3c2orz zyZT3M@m}S*+D&!YjA_iBw`Q~JA&aB25$6?ePR*0%4uNCII`4KqN-@1{#2oEu}#5HUnz|RxJ>TZ3os3OxO|N zB*P(K*I*v>b&_GpexLL8^y2N>D*F8?(2*AK*TBnjA*6TYv1x%X>3X(kw`bo!27Xu0 zV7qw&_<1YPTdDv76DyinUumuIUsf_p^sxh%!jpe`;Lf;Rfyen|e?4`o; zig}${+jU*(C`)sB8xZ??fbi|8d>29A4uGF1-|7!nKC%J*xF^=~H0HEaJ_N!Bfkl9w zE)Z!0?g6#|*oR${BK{igyLoWGQ0OmCNjK}DM>Tb@PmsUoemHZ&F0>X|lWZaXVxe!K z%)RrtQkFsb-kZVmGC$tSAzO1EWQ9cb!4;Raew1Zz>u}qEOB!Yxqqgeaf79v6vk(jj zxm)+XT!_5n@?P(zYu-VfBmJxH{f&%tED++g?O6SoN*qu3KB&_KhOx47-j8eJflCPS zPSrpD4PtYP?tN~w{;=zx-FSx?SHHw*PSq=Fya9IuqJ+F%TCYqiugVfWs?WI5CCW^h zqg>@@U4YMQpiF;|Wjay9A4`C4cX{t{Yqz>3E0LI`rDN6Rc`Mq zZvE%f1h;!2OZT5W-rKy|kk|V&ulqr&eXjYzXPSOGRB($|dqwyD#iP9~j-PVT@w=}5 z`d7UA?>u90qIB#8ecS@?x%#DEUHdxemBJhKnJ3`18rrO1^qE7tKA?Mluj@b2y*S?F z^4{md4MFcNm-`1U?~^X=*HqHq(@(B2*=;lQs)%>HR|n$(ul|}Xumw8F$b2_@sub)278+Cu$dx_TIy-C;Fy|||n@lJ4GK_$F_e(vJ3aw0U_ zg#U(|#^H|A@;{E#UVzNUX>W|a+KrPN$05NZ_+|Z$?pAQzNvqG ziu=B?SLhq3xSt(&i|fWI?prFJ*MB?Nz47B0@0{$u;kfB6_fzH>t7d${ye&VMGEW5*U%dhy+F?Fd~5w35-Z! zL;@oc7?Hq;1V$tkOV2nJIu{faQa*VR2K+-(uzFYjg(bU@EZQ@}ES z2nYD43hp$Bu%EB1dtVc=%u4~cQbeHjSPHmJQ4lzpw)0c6d@gmBFW?rC2XF~rbA43G2O#-Khr}@HC}J3U}`XxrF#FrWrkif(SoL?IY!m; zRlV^{ui^Js*Hza{^Y@BV|CeiPt83~eF7kA*svjGCk?!{v zi+?I`7w*^`nkOTQcX)94N5&hPC4QGA+C9ja;%fm3h<9mlxR>#uxQT~%2)X*VxMciY z3uS!qUI-3PM!zFI4HKA;BwB#+0SCU4@%@b7!twu!@lB^o0JnZbxQTItZulY4{Yna# zdt|=bIsQZmYdv0x|BUfV7{@IIQO>y>QStvJeA4!n!;?TjnC za|h%1FwR8#H^zU>fGQ7eGRSshT$N`c3>V4S_gPsX#XrUPM8*}LA;#x2uIxHlE#kAB z@pD-|7chQ@gU{2958f>!(z7xY`k{T*pnrbH>YpaB;LPz>D>Z>zU6y=C9=Q5aVi`o#No% z?%@A=3I1bY*Qxx$A4vv%%x4MXD>z@e??{23Tcr4ETzBT%;m9|J_{HkAj``GZf0x`f zZ9U_i4!oan!*Imf-p0zWb5C#X~nz&{T854K3;QNFNJ;3`(0 z_Hz886%teWxxGaE*MXk~{sTXi@#)??h2`)+C|_}d8|UIZEzsK-A5tBH_2hAdUq-2M zA0GCO!YO$6j`-|ZBF}MeNrW24hkE5XJ-13>9dIho;BE3;yjOw42Y^>-CuxSEdXT*S zF!R~JQl8T@`y!y|ly5~up6Z-zG8{rYZa6G##rv%W_-hXS^sL{plwyxe!qw8IGyl$I z5*P1?;gFu?qkIps{ePWv+9Gg1v5_Bee5X8r0i5Kp;oCB@cqash)yEbu|0027T)S7s z$8aFRYQ_h<0{!t|Uag@>5WA=WE4i`r`RBdYHZ8wPp? zg2HPYzk*wmv@i3YEu(8^L&StP+$sT8-){jI?YL3m^n5Ob*Es&bkLCGJNwjfr zIEYX35eb~m_*~#fGAwin9PwM_B`&VEk@*Dgm-uo{d0PoSKL&oR>m<$Te?89p2hWoX z&SyS;$%rp;{3hT+4ofBQO~%h-yoUXp%NhRy<4(KrRe^IP z8@Unqc$B|_^&i8i2oIKs|10Jnyh6sKXNf5UPAD$tuLyj&oWIR@MO3DzcNr+WtniEF zIlUV|;jfH0u>Gl)ur?lgNAhv%|Cxllj@8tBk)vx(9Dk7Q63mDQD+R7%<>@8Bsa>4< z@NLB>BjYMN^Kc11Z*lxxAsPRZEH^(KBr3o0Ly0T@=NyH9OXA}BR2<&S_=ar~7te^| zu(7hZ{5ybC`&K&Ica4Y-`+ue6bCV?6Cg4F5cI{ z;cLui&yOTNS`sZQ;`51(RHEWS|GD2i$=BL|lYE9)|JxXkaQwkLB=8pFZ*lxX_e$V1 zjDMnp{hw3wEuS@KG{j{wJPVJl|G^9)E?%Km9LYu|vA%6MUxM^5FNOKQkA*+n!2P;{@ryYA&|(=+`5_6$wG|R) zcG|VT$-i>?Ki}c_&U|-?_$O;C#&d>}o%Sejp;tV}PV z3k5FY+J7?~KiMZy@!k;*|BG>FeaDTL{GI;ic;FR~n^O;`0jKg0vLRIZe+%OcH^_XK zOQJm~`16U4yrKBqEhN|EU!CRYF6Pt8`NBOG;h@5~KeGAIj-4p;HMoPC9DfFIl4rvs5}3w#gz@Cf z5>R^kG~-S?|0m|(nU?WuS)M1sQ78UG_e)%jze|7@)1S@~c#`=v+$i}d{of5-)m>hZ zzkdL3*a1F1Ecl~84oHCBBcd>|s(Abc;M86X8$>jX-Xo;2p7ELs<@pm zb+-0&gj#z72(g|e9W~|*(@Z4O{uyQ~BAekvPf{dny8xK&XS{<2S3L=~Sq?##Y%(Dna!$ zV~NBEY`y>0+M@Afh3n1sRoHyjvRb1ZnYv^snrdAXm8!!E5`ea%!NsXD}qP?-oe>Ps9$_%i;G|iD*}$EVv0}F)6h%bYMl+ z3NNj#9xiEucA?h7w4pe@K_K>Fhf~2W?l6V7`47(gRGI>}81dI?w0-uGcYY@axz375 zVy$5)us_@#O3~mF&7@ny@VLoaf*MD|W&*a<3iY%K2RqO)GpE(Gt2N86uwZmS!%QQc z>BZN>!&<~jS*y`?F%ii0nBiC=Zly5=i6qR`vBWA&C?c3lq|H!opO&+0ZF+JIv34|LKw6)G_ zTCu{MUR{^-7iXuj%f9&BL4on(p;Kr_?d+TvCBJSU)F@>^Mxl7vT2;c-V7k@bVLCI- zkIo$h_Jy1}IC9PpDF~;Kt_9Y<*Q^f)OIRu4*5&HwuPgla2kDTS9NK&yGI6!YrqqDa z)kPdjb_RbSn$NMsB=Um{V>zaI8!ehtw)cnf1c)`fYVk#i)}VdzAf_oK4F62DN3u8F zjXoP*Lp?z>2VDxRWieZDB15@ieh{xRmst5gMO;PAY+Y*N%OugbiLoI++ws?gL#ZA! zjZX<+E_DJdDs5d?}K?CkuryiG71&pFnv`Wi( z5G|qHm^Z(-YDdMKhb4|>kC3?vrYmlS=?g3k=+cezT1|iTbjrBcnnW`hUtF^dB_$kf zm>*&fb0$7VAZu;bR@YX~a4=}{lfjTn0fO%=vLu+!71$B3e5rzgi8N6ef34Y!(IJ&s zFZ7HYo&c74A`)wbjd93REUv)r6YKm%wp&GojU}7+7Il|h)GVM@zwZ{j+II!Qo%*V4Zd@fXwEf( z67cZi3H|A@df++~YJ?d{C-CtYjMr99J!aDvjE*#-NHv_5PIP4uwLo3+!;Vb~!^<-} zXOt+lT%8d*$R=;hbQN5nqhKDiwV(qM(}1=T)byu4*hg_Vt-be0V8A zy(6>UTuX~w7++^}#b6}^!;FN;{~t2>KsgE={%9OuBB2l1lo*_y=1i2-DV9Q+O;__y zF*`pBYb)l>Ikk|M);?vXOQp={Z1PegRxA?=*hk(1mr}Med0vi8>ZO}(Maa0A_=Hm! zV+Xz|Vv-A(YuV|-3Cf#};c8Z(Mrw2_HbWQEuZ0tB8#RQ37q`}#vQrl5Ms@9MN0Dp& zp-99;xnl_|V$;H~_+S#&`-V??k0iH(CR0V8$zNpxPi9+$tLR8H&r`EKT4+-!J}W9tiV@i>0`}sKdJ22R5!b zN`!tsljacegFz|iH`bY>j=`9lenvQk&%3EHHn$$|A(yW)Ay{nE+7C|$Y>T*9hD)P_ zM_za(%{r12vsVy_1q*tOhT979NUoXfE(!d=-temT>jzS!cVuE|=*ZDbHN(rWmOsNB zKCs5u^*HTmfmz|UQ-l31{(M}ZhlG%isf=NPkrU2%$;pv$9^TujL9~LYQNd6VvEJWM zkuUIYi)q(4`+$rVeUG)JLQ$yxS4(aY*+Azlr{P%_b@@J+d}y@5d?(z!rqCC_fI%G( zJ$BjRCN*8BaTj;=5^_C)yTtomy$I5cMfN)6482+m5q6`2+n|f3)eyxfbT_K~eiEVG z<6+0H*o0uFT>O_~Yra7X#BTTP+3bnY`zTi23mdb<&K3nItLSvb+l zV~As!M>ZfER;eYkvC?-WKi#C6y>y&%nl&tEG26v@T`mS^R9k#Jn#G<-pdUNnfm zUx&q3HZ=k?qC&HBmCaQsC$#)5UDZ)^VslkGoo|T(l`P)5NC4D5iW^0`aukC`or&qS zt&0`pt7^m(Sk=w2${tOm;T@)}x>ndq(zw!NBQq&l%J;u8Qegd{97XPxlf+9+IgoXx^+eddpp|-BjpOW{(g)1doCAeA{tllTOBB8vOOYRRG-yYojEOd#)0uH)l z(bXZ__<*JEtW}^_lCtIObU|8&bcJ2U9H=W_p;?YTGW>RP zE`omsPfb);F1WMzk@R1i`V=^6v*ja5_U*=m96;1kn1w9L}}PqgVcvSl!FJ zv9?Zl#76{3-SYi~I}6Oz5odtui&Gg}2ZEJ|C05%2(WCNQmFS9lV;^A2v~b25r%M!Vj4Z#DWc=6-Q8MsWPKc2NSaB>EA(A^TB#Ay_{7Au zlobn60zOK{GFr8mj8)^XdUXP)bda{fS~ZqZa4HrCt5XSagR0u<=9{qH5pYvy$R-@d zvKy@r;lWcfa)PV8KP!2tV`LuciGsak2$xL5edZooWR=goe%N2Uo}$sC$NU(rR_T@; z6;HK>MRS66q*LjC$wPO0Q@ouG`b+ z-e2Pv;f!2Ld7#oMn6_miOy#fMqcb=?@wbNt`w5Oe!}-@ZaFu?F16S "m from lzero.mcts.ctree.ctree_muzero import mz_tree as ctree return ctree.Roots(active_collect_env_num, legal_actions) - # @profile + #@profile def search( self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int, - List[Any]], timestep: Union[int, List[Any]] - ) -> dict: + List[Any]], timestep: Union[int, List[Any]]=None, task_id=None + ) -> None: """ Overview: Perform Monte Carlo Tree Search (MCTS) for a batch of root nodes in parallel. @@ -137,7 +139,15 @@ def search( for ix, iy in zip(latent_state_index_in_search_path, latent_state_index_in_batch): latent_states.append(latent_state_batch_in_search_path[ix][iy]) - latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device) + try: + latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device) + except Exception as e: + print("="*20) + print(e) + print("roots:", roots, "latent_state_roots:", latent_state_roots) + print ("latent_state_roots.shape:", latent_state_roots.shape) + + # TODO: .long() is only for discrete action last_actions = torch.from_numpy(np.asarray(last_actions)).to(self._cfg.device).long() @@ -154,7 +164,23 @@ def search( # search_depth is used for rope in UniZero search_depth = results.get_search_len() # print(f'simulation_index:{simulation_index}, search_depth:{search_depth}, latent_state_index_in_search_path:{latent_state_index_in_search_path}') - network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth, timestep) + if timestep is None: + # for UniZero + if task_id is not None: + # multi task setting + network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth, task_id=task_id) + else: + # single task setting + network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth) + else: + # for UniZero + if task_id is not None: + # multi task setting + # network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth, timestep, task_id=task_id) + network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth, task_id=task_id) + else: + # single task setting + network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth, timestep) network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) @@ -245,10 +271,10 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "m from lzero.mcts.ctree.ctree_muzero import mz_tree as ctree return ctree.Roots(active_collect_env_num, legal_actions) - # @profile + # #@profile def search( self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int, - List[Any]] + List[Any]], task_id=None ) -> None: """ Overview: @@ -318,6 +344,13 @@ def search( """ network_output = model.recurrent_inference(latent_states, last_actions) # for classic muzero + if task_id is not None: + # multi task setting + network_output = model.recurrent_inference(latent_states, last_actions, task_id=task_id) + else: + # single task setting + network_output = model.recurrent_inference(latent_states, last_actions) + network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) network_output.value = to_detach_cpu_numpy(self.value_inverse_scalar_transform_handle(network_output.value)) @@ -516,7 +549,7 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "e """ return tree_muzero.Roots(active_collect_env_num, legal_actions) - # @profile + # #@profile def search( self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], world_model_latent_history_roots: List[Any], to_play_batch: Union[int, List[Any]], ready_env_id=None, @@ -715,6 +748,8 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "e ..note:: The initialization is achieved by the ``Roots`` class from the ``ctree_efficientzero`` module. """ + # import pudb;pudb.set_trace() + return tree_efficientzero.Roots(active_collect_env_num, legal_actions) def search( @@ -735,6 +770,7 @@ def search( .. note:: The core functions ``batch_traverse`` and ``batch_backpropagate`` are implemented in C++. """ + with torch.no_grad(): model.eval() @@ -977,6 +1013,284 @@ def search_with_reuse( return length, average_infer +class EZV2MCTSCtree(object): + """ + Overview: + The C++ implementation of MCTS (batch format) for EfficientZero V2. \ + It completes the ``roots``and ``search`` methods by calling functions in module ``ctree_efficientzero_v2``, \ + which are implemented in C++. + Interfaces: + ``__init__``, ``roots``, ``search`` + + ..note:: + The benefit of searching for a batch of nodes at the same time is that \ + it can be parallelized during model inference, thus saving time. + + EfficientZero V2 uses Sequential Halving at the root node to progressively eliminate low-scoring actions. + """ + + config = dict( + # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of the search tree. + root_dirichlet_alpha=0.3, + # (float) The noise weight at the root node of the search tree. + root_noise_weight=0.25, + # (int) The base constant used in the PUCT formula for balancing exploration and exploitation during tree search. + pb_c_base=19652, + # (float) The initialization constant used in the PUCT formula for balancing exploration and exploitation during tree search. + pb_c_init=1.25, + # (float) The maximum change in value allowed during the backup step of the search tree update. + value_delta_max=0.01, + # (int) The initial number of top actions to consider in Sequential Halving. If None, equals num_actions. + num_top_actions=None, + ) + + @classmethod + def default_config(cls: type) -> EasyDict: + """ + Overview: + A class method that returns a default configuration in the form of an EasyDict object. + Returns: + - cfg (:obj:`EasyDict`): The dict of the default configuration. + """ + # Create a deep copy of the `config` attribute of the class. + cfg = EasyDict(copy.deepcopy(cls.config)) + # Add a new attribute `cfg_type` to the `cfg` object. + cfg.cfg_type = cls.__name__ + 'Dict' + return cfg + + def __init__(self, cfg: EasyDict = None) -> None: + """ + Overview: + Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key + in the default configuration, the user-provided value will override the default configuration. Otherwise, + the default configuration will be used. + Arguments: + - cfg (:obj:`EasyDict`): The configuration passed in by the user. + """ + # Get the default configuration. + default_config = self.default_config() + # Update the default configuration with the values provided by the user in ``cfg``. + default_config.update(cfg) + self._cfg = default_config + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) + self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) + self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) + + # EfficientZero V2: Sequential Halving phase management + self.num_top_actions = self._cfg.num_top_actions + + @classmethod + def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "ez_ctree.Roots": + """ + Overview: + Initializes a batch of roots to search parallelly later. + Arguments: + - root_num (:obj:`int`): the number of the roots in a batch. + - legal_action_list (:obj:`List[Any]`): the vector of the legal actions for the roots. + + ..note:: + The initialization is achieved by the ``Roots`` class from the ``ctree_efficientzero_v2`` module. + """ + + return tree_efficientzero_v2.Roots(active_collect_env_num, legal_actions) + + def search( + self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], + reward_hidden_state_roots: List[Any], to_play_batch: Union[int, List[Any]], + gumbel_noises + ) -> Tuple[List[List[float]], List[int]]: + """ + Overview: + Do MCTS for a batch of roots. Parallel in model inference. \ + Use C++ to implement the tree search. + Arguments: + - roots (:obj:`Any`): a batch of expanded root nodes. + - latent_state_roots (:obj:`list`): the hidden states of the roots. + - reward_hidden_state_roots (:obj:`list`): the value prefix hidden states in LSTM of the roots. + - model (:obj:`torch.nn.Module`): The model used for inference. + - to_play (:obj:`list`): the to_play list used in in self-play-mode board games. + + .. note:: + The core functions ``batch_traverse`` and ``batch_backpropagate`` are implemented in C++. + """ + with torch.no_grad(): + model.eval() + + + + # preparation some constant + batch_size = roots.num + pb_c_base, pb_c_init, discount_factor = self._cfg.pb_c_base, self._cfg.pb_c_init, self._cfg.discount_factor + + # the data storage of latent states: storing the latent state of all the nodes in one search. + latent_state_batch_in_search_path = [latent_state_roots] + # the data storage of value prefix hidden states in LSTM + # print(f"reward_hidden_state_roots[0]={reward_hidden_state_roots[0]}") + # print(f"reward_hidden_state_roots[1]={reward_hidden_state_roots[1]}") + reward_hidden_state_c_batch = [reward_hidden_state_roots[0]] + reward_hidden_state_h_batch = [reward_hidden_state_roots[1]] + + # minimax value storage + min_max_stats_lst = tree_efficientzero_v2.MinMaxStatsList(batch_size) + min_max_stats_lst.set_delta(self._cfg.value_delta_max) + + # ========== EfficientZero V2: Calculate Sequential Halving parameters ========== + legal_actions_list = roots.get_legal_actions() + num_actions = len(legal_actions_list[0]) if legal_actions_list and len(legal_actions_list) > 0 else 0 + num_top_actions = self.num_top_actions if self.num_top_actions is not None else num_actions + if num_top_actions > 0: + num_phases = int(np.ceil(np.log2(num_top_actions))) + sims_per_phase = max(1, self._cfg.num_simulations // num_phases) + else: + num_phases = 1 + sims_per_phase = self._cfg.num_simulations + + + + for simulation_index in range(self._cfg.num_simulations): + # In each simulation, we expanded a new node, so in one search, we have ``num_simulations`` num of nodes at most. + + latent_states = [] + hidden_states_c_reward = [] + hidden_states_h_reward = [] + + # prepare a result wrapper to transport results between python and c++ parts + results = tree_efficientzero_v2.ResultsWrapper(num=batch_size) + + # ========== EfficientZero V2: Sequential Halving at the root ========== + # Calculate current phase and number of top actions to keep based on simulation index + current_phase = min(simulation_index // sims_per_phase, num_phases - 1) + current_num_top_actions = max(1, num_top_actions // (2 ** current_phase)) + + # Call Sequential Halving to prune low-scoring root actions + tree_efficientzero_v2.batch_sequential_halving( + roots, gumbel_noises, min_max_stats_lst, current_phase, current_num_top_actions + ) + + # latent_state_index_in_search_path: the first index of leaf node states in latent_state_batch_in_search_path, i.e. is current_latent_state_index in one the search. + # latent_state_index_in_batch: the second index of leaf node states in latent_state_batch_in_search_path, i.e. the index in the batch, whose maximum is ``batch_size``. + # e.g. the latent state of the leaf node in (x, y) is latent_state_batch_in_search_path[x, y], where x is current_latent_state_index, y is batch_index. + # The index of value prefix hidden state of the leaf node is in the same manner. + """ + MCTS stage 1: Selection + Each simulation starts from the internal root state s0, and finishes when the simulation reaches a leaf node s_l. + Sequential Halving has already pruned low-scoring root actions before selection begins. + """ + if self._cfg.env_type == 'not_board_games': + latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, virtual_to_play_batch = tree_efficientzero_v2.batch_traverse( + roots, + min_max_stats_lst, + results, + self._cfg.num_simulations, + simulation_index, + gumbel_noises, + current_num_top_actions, + to_play_batch + ) + else: + # the ``to_play_batch`` is only used in board games, here we need to deepcopy it to avoid changing the original data. + latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, virtual_to_play_batch = tree_efficientzero_v2.batch_traverse( + roots, + min_max_stats_lst, + results, + self._cfg.num_simulations, + simulation_index, + gumbel_noises, + current_num_top_actions, + copy.deepcopy(to_play_batch) + ) + # obtain the search horizon for leaf nodes + search_lens = results.get_search_len() + + # Debug: print batch_traverse return values + print(f"[DEBUG sim {simulation_index}] batch_traverse returned:") + print(f" - latent_state_index_in_search_path: {latent_state_index_in_search_path}") + print(f" - latent_state_index_in_batch: {latent_state_index_in_batch}") + print(f" - last_actions: {last_actions}") + print(f" - search_lens: {search_lens}") + print(f" - latent_state_batch_in_search_path length: {len(latent_state_batch_in_search_path)}") + if len(latent_state_batch_in_search_path) > 0: + print(f" - latent_state_batch_in_search_path[0] type: {type(latent_state_batch_in_search_path[0])}") + if hasattr(latent_state_batch_in_search_path[0], 'shape'): + print(f" - latent_state_batch_in_search_path[0] shape: {latent_state_batch_in_search_path[0].shape}") + + # obtain the latent state for leaf node + for ix, iy in zip(latent_state_index_in_search_path, latent_state_index_in_batch): + print(f" - Accessing latent_state_batch_in_search_path[{ix}][{iy}]") + if ix >= len(latent_state_batch_in_search_path): + print(f" ERROR: ix={ix} out of range, len={len(latent_state_batch_in_search_path)}") + latent_states.append(latent_state_batch_in_search_path[ix][iy]) + hidden_states_c_reward.append(reward_hidden_state_c_batch[ix][0][iy]) + hidden_states_h_reward.append(reward_hidden_state_h_batch[ix][0][iy]) + + latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device) + hidden_states_c_reward = torch.from_numpy(np.asarray(hidden_states_c_reward)).to(self._cfg.device + ).unsqueeze(0) + hidden_states_h_reward = torch.from_numpy(np.asarray(hidden_states_h_reward)).to(self._cfg.device + ).unsqueeze(0) + # .long() is only for discrete action + last_actions = torch.from_numpy(np.asarray(last_actions)).to(self._cfg.device).long() + """ + MCTS stage 2: Expansion + At the final time-step l of the simulation, the next_latent_state and reward/value_prefix are computed by the dynamics function. + Then we calculate the policy_logits and value for the leaf node (next_latent_state) by the prediction function. (aka. evaluation) + MCTS stage 3: Backup + At the end of the simulation, the statistics along the trajectory are updated. + """ + network_output = model.recurrent_inference( + latent_states, (hidden_states_c_reward, hidden_states_h_reward), last_actions + ) + + network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) + network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) + network_output.value = to_detach_cpu_numpy(self.value_inverse_scalar_transform_handle(network_output.value)) + network_output.value_prefix = to_detach_cpu_numpy( + self.value_inverse_scalar_transform_handle(network_output.value_prefix)) + + network_output.reward_hidden_state = ( + network_output.reward_hidden_state[0].detach().cpu().numpy(), + network_output.reward_hidden_state[1].detach().cpu().numpy() + ) + + latent_state_batch_in_search_path.append(network_output.latent_state) + # tolist() is to be compatible with cpp datatype. + value_prefix_batch = network_output.value_prefix.reshape(-1).tolist() + value_batch = network_output.value.reshape(-1).tolist() + policy_logits_batch = network_output.policy_logits.tolist() + + reward_latent_state_batch = network_output.reward_hidden_state + # reset the hidden states in LSTM every ``lstm_horizon_len`` steps in one search. + # which enable the model only need to predict the value prefix in a range (e.g.: [s0,...,s5]) + assert self._cfg.lstm_horizon_len > 0 + reset_idx = (np.array(search_lens) % self._cfg.lstm_horizon_len == 0) + assert len(reset_idx) == batch_size + reward_latent_state_batch[0][:, reset_idx, :] = 0 + reward_latent_state_batch[1][:, reset_idx, :] = 0 + is_reset_list = reset_idx.astype(np.int32).tolist() + reward_hidden_state_c_batch.append(reward_latent_state_batch[0]) + reward_hidden_state_h_batch.append(reward_latent_state_batch[1]) + + # In ``batch_backpropagate()``, we first expand the leaf node using ``the policy_logits`` and + # ``reward`` predicted by the model, then perform backpropagation along the search path to update the + # statistics. + + # NOTE: simulation_index + 1 is very important, which is the depth of the current leaf node. + current_latent_state_index = simulation_index + 1 + tree_efficientzero_v2.batch_backpropagate( + current_latent_state_index, discount_factor, value_prefix_batch, value_batch, policy_logits_batch, + min_max_stats_lst, results, is_reset_list, virtual_to_play_batch + ) + + # ========== EfficientZero V2: Return improved policies and best actions ========== + improved_policies = roots.get_root_policies(min_max_stats_lst) + best_actions = roots.get_best_actions() + # Convert to numpy arrays for easier manipulation + improved_policies = [np.array(p) for p in improved_policies] + return improved_policies, best_actions + + + class GumbelMuZeroMCTSCtree(object): """ Overview: diff --git a/lzero/policy/efficientzero_v2.py b/lzero/policy/efficientzero_v2.py new file mode 100644 index 000000000..fb8af3701 --- /dev/null +++ b/lzero/policy/efficientzero_v2.py @@ -0,0 +1,861 @@ +import copy +from typing import List, Dict, Any, Tuple, Union + +import numpy as np +import torch +import torch.optim as optim +from ding.model import model_wrap +from ding.torch_utils import to_tensor +from ding.utils import POLICY_REGISTRY +from torch.distributions import Categorical +from torch.nn import L1Loss + +from lzero.mcts import EZV2MCTSCtree as MCTSCtree +from lzero.mcts import EfficientZeroMCTSPtree as MCTSPtree +from lzero.model import ImageTransforms +from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ + DiscreteSupport, select_action, to_torch_float_tensor, ez_network_output_unpack, negative_cosine_similarity, \ + prepare_obs, \ + configure_optimizers +from lzero.policy.muzero import MuZeroPolicy + + +@POLICY_REGISTRY.register('efficientzero_v2') +class EfficientZeroV2Policy(MuZeroPolicy): + """ + Overview: + The policy class for EfficientZero proposed in the paper https://arxiv.org/abs/2111.00210. + """ + + # The default_config for EfficientZero policy. + config = dict( + model=dict( + # (str) The model type. For 1-dimensional vector obs, we use mlp model. For 3-dimensional image obs, we use conv model. + model_type='conv', # options={'mlp', 'conv'} + # (bool) If True, the action space of the environment is continuous, otherwise discrete. + continuous_action_space=False, + # (tuple) The stacked obs shape. + # observation_shape=(1, 96, 96), # if frame_stack_num=1 + observation_shape=(4, 96, 96), # if frame_stack_num=4 + # (bool) Whether to use the self-supervised learning loss. + self_supervised_learning_loss=True, + # (bool) Whether to use discrete support to represent categorical distribution for value/reward/value_prefix. + categorical_distribution=True, + # (int) The image channel in image observation. + image_channel=1, + # (int) The number of frames to stack together. + frame_stack_num=1, + # (tuple) The range of supports used in categorical distribution. + # These variables are only effective when ``model.categorical_distribution=True``. + reward_support_range=(-300., 301., 1.), + value_support_range=(-300., 301., 1.), + # (int) The hidden size in LSTM. + lstm_hidden_size=512, + # (bool) whether to learn bias in the last linear layer in value and policy head. + bias=True, + # (str) The type of action encoding. Options are ['one_hot', 'not_one_hot']. Default to 'one_hot'. + discrete_action_encoding_type='one_hot', + # (bool) whether to use res connection in dynamics. + res_connection_in_dynamics=True, + # (str) The type of normalization in MuZero model. Options are ['BN', 'LN']. Default to 'LN'. + norm_type='BN', + ), + # ****** common ****** + # (bool) Whether to use multi-gpu training. + multi_gpu=False, + # (bool) Whether to enable the sampled-based algorithm (e.g. Sampled EfficientZero) + # this variable is used in ``collector``. + sampled_algo=False, + # (bool) Whether to enable the gumbel-based algorithm (e.g. Gumbel Muzero) + gumbel_algo=True, + # (bool) Whether to use C++ MCTS in policy. If False, use Python implementation. + mcts_ctree=True, + # (bool) Whether to use cuda for network. + cuda=True, + # (int) The number of environments used in collecting data. + collector_env_num=8, + # (int) The number of environments used in evaluating policy. + evaluator_env_num=3, + # (str) The type of environment. The options are ['not_board_games', 'board_games']. + env_type='not_board_games', + # (str) The type of action space. Options are ['fixed_action_space', 'varied_action_space']. + action_type='fixed_action_space', + # (str) The type of battle mode. The options are ['play_with_bot_mode', 'self_play_mode']. + battle_mode='play_with_bot_mode', + # (bool) Whether to monitor extra statistics in tensorboard. + monitor_extra_statistics=True, + # (int) The transition number of one ``GameSegment``. + game_segment_length=200, + # (bool): Indicates whether to perform an offline evaluation of the checkpoint (ckpt). + # If set to True, the checkpoint will be evaluated after the training process is complete. + # IMPORTANT: Setting eval_offline to True requires configuring the saving of checkpoints to align with the evaluation frequency. + # This is done by setting the parameter learn.learner.hook.save_ckpt_after_iter to the same value as eval_freq in the train_muzero.py automatically. + eval_offline=False, + + # ****** observation ****** + # (bool) Whether to transform image to string to save memory. + transform2string=False, + # (bool) Whether to use gray scale image. + gray_scale=False, + # (bool) Whether to use data augmentation. + use_augmentation=False, + # (list) The style of augmentation. + augmentation=['shift', 'intensity'], + + # ******* learn ****** + # (bool) Whether to ignore the done flag in the training data. Typically, this value is set to False. + # However, for some environments with a fixed episode length, to ensure the accuracy of Q-value calculations, + # we should set it to True to avoid the influence of the done flag. + ignore_done=False, + # (int) How many updates(iterations) to train after collector's one collection. + # Bigger "update_per_collect" means bigger off-policy. + # collect data -> update policy-> collect data -> ... + # For different env, we have different episode_length, + # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor + # if we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.replay_ratio automatically. + update_per_collect=None, + # (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None. + replay_ratio=0.25, + # (int) Minibatch size for one gradient descent. + batch_size=256, + # (str) Optimizer for training policy network. ['SGD', 'Adam', 'AdamW'] + optim_type='SGD', + # (float) Learning rate for training policy network. Initial lr for manually decay schedule. + learning_rate=0.2, + # (int) Frequency of target network update. + target_update_freq=100, + # (float) Weight decay for training policy network. + weight_decay=1e-4, + # (float) One-order Momentum in optimizer, which stabilizes the training process (gradient direction). + momentum=0.9, + # (float) The maximum constraint value of gradient norm clipping. + grad_clip_value=10, + # (int) The number of episodes in each collecting stage. + n_episode=8, + # (float) the number of simulations in MCTS. + num_simulations=50, + # (float) Discount factor (gamma) for returns. + discount_factor=0.997, + # (int) The number of steps for calculating target q_value. + td_steps=5, + # (int) The number of unroll steps in dynamics network. + num_unroll_steps=5, + # (int) reset the hidden states in LSTM every ``lstm_horizon_len`` horizon steps. + lstm_horizon_len=5, + # (float) The weight of reward loss. + reward_loss_weight=1, + # (float) The weight of value loss. + value_loss_weight=0.25, + # (float) The weight of policy loss. + policy_loss_weight=1, + # (float) The weight of ssl (self-supervised learning) loss. + ssl_loss_weight=2, + # (bool) Whether to use piecewise constant learning rate decay. + # i.e. lr: 0.2 -> 0.02 -> 0.002 + piecewise_decay_lr_scheduler=True, + # (int) The number of final training iterations to control lr decay, which is only used for manually decay. + threshold_training_steps_for_final_lr=int(5e4), + # (int) The number of final training iterations to control temperature, which is only used for manually decay. + threshold_training_steps_for_final_temperature=int(1e5), + # (bool) Whether to use manually decayed temperature. + # i.e. temperature: 1 -> 0.5 -> 0.25 + manual_temperature_decay=False, + # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. + # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. + fixed_temperature_value=0.25, + # (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048. + use_ture_chance_label_in_chance_encoder=False, + # (bool) Whether to add noise to roots during reanalyze process. + reanalyze_noise=True, + # (bool) Whether to reuse the root value between batch searches. + reuse_search=False, + # (bool) whether to use the pure policy to collect data. If False, use the MCTS guided with policy. + collect_with_pure_policy=False, + + # ****** Priority ****** + # (bool) Whether to use priority when sampling training data from the buffer. + use_priority=False, + # (float) The degree of prioritization to use. A value of 0 means no prioritization, + # while a value of 1 means full prioritization. + priority_prob_alpha=0.6, + # (float) The degree of correction to use. A value of 0 means no correction, + # while a value of 1 means full correction. + priority_prob_beta=0.4, + + # ****** UCB ****** + # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of the search tree. + root_dirichlet_alpha=0.3, + # (float) The noise weight at the root node of the search tree. + root_noise_weight=0.25, + + # ****** Explore by random collect ****** + # (int) The number of episodes to collect data randomly before training. + random_collect_episode_num=0, + + # ****** Explore by eps greedy ****** + eps=dict( + # (bool) Whether to use eps greedy exploration in collecting data. + eps_greedy_exploration_in_collect=False, + # (str) The type of decaying epsilon. Options are 'linear', 'exp'. + type='linear', + # (float) The start value of eps. + start=1., + # (float) The end value of eps. + end=0.05, + # (int) The decay steps from start to end eps. + decay=int(1e5), + ), + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default model setting. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): model name and model import_names. + - model_type (:obj:`str`): The model type used in this algorithm, which is registered in ModelRegistry. + - import_names (:obj:`List[str]`): The model class path list used in this algorithm. + .. note:: + The user can define and use customized network model but must obey the same interface definition indicated \ + by import_names path. For EfficientZero, ``lzero.model.efficientzero_model.EfficientZeroModel`` + """ + if self._cfg.model.model_type == "conv": + return 'EfficientZeroModel', ['lzero.model.efficientzero_model'] + elif self._cfg.model.model_type == "mlp": + return 'EfficientZeroModelMLP', ['lzero.model.efficientzero_model_mlp'] + else: + raise ValueError("model type {} is not supported".format(self._cfg.model.model_type)) + + def _init_learn(self) -> None: + """ + Overview: + Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. + """ + assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW'], self._cfg.optim_type + if self._cfg.optim_type == 'SGD': + self._optimizer = optim.SGD( + self._model.parameters(), + lr=self._cfg.learning_rate, + momentum=self._cfg.momentum, + weight_decay=self._cfg.weight_decay, + ) + + elif self._cfg.optim_type == 'Adam': + self._optimizer = optim.Adam( + self._model.parameters(), + lr=self._cfg.learning_rate, + weight_decay=self._cfg.weight_decay, + ) + elif self._cfg.optim_type == 'AdamW': + self._optimizer = configure_optimizers( + model=self._model, + weight_decay=self._cfg.weight_decay, + learning_rate=self._cfg.learning_rate, + device_type=self._cfg.device + ) + + if self._cfg.piecewise_decay_lr_scheduler: + from torch.optim.lr_scheduler import LambdaLR + max_step = self._cfg.threshold_training_steps_for_final_lr + # NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr. + lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa + self.lr_scheduler = LambdaLR(self._optimizer, lr_lambda=lr_lambda) + + # use model_wrapper for specialized demands of different modes + self._target_model = copy.deepcopy(self._model) + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='assign', + update_kwargs={'freq': self._cfg.target_update_freq} + ) + self._learn_model = self._model + + if self._cfg.use_augmentation: + self.image_transforms = ImageTransforms( + self._cfg.augmentation, + image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) + ) + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) + assert self.value_support.size == self._learn_model.value_support_size # if these assertions fails, somebody introduced... + assert self.reward_support.size == self._learn_model.reward_support_size # ...incoherence between policy and model + self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) + self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) + + def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: + """ + Overview: + The forward function for learning policy in learn mode, which is the core of the learning process. \ + The data is sampled from replay buffer. \ + The loss is calculated by the loss function and the loss is backpropagated to update the model. + Arguments: + - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. \ + The first tensor is the current_batch, the second tensor is the target_batch. + Returns: + - info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ + current learning loss and learning statistics. + """ + # import pudb;pudb.set_trace() + self._learn_model.train() + self._target_model.train() + + current_batch, target_batch = data + obs_batch_ori, action_batch, mask_batch, indices, weights, make_time = current_batch + # EfficientZero V2: unpack search_values for value mixing + target_value_prefix, target_value, target_policy, target_search_value = target_batch + + obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) + + # do augmentations + if self._cfg.use_augmentation: + obs_batch = self.image_transforms.transform(obs_batch) + obs_target_batch = self.image_transforms.transform(obs_target_batch) + + # shape: (batch_size, num_unroll_steps, action_dim) + # NOTE: .long(), in discrete action space. + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1).long() + data_list = [ + mask_batch, + target_value_prefix.astype('float32'), + target_value.astype('float32'), target_policy, weights, + target_search_value.astype('float32') # EfficientZero V2: search values + ] + [mask_batch, target_value_prefix, target_value, target_policy, + weights, target_search_value] = to_torch_float_tensor(data_list, self._cfg.device) + + target_value_prefix = target_value_prefix.view(self._cfg.batch_size, -1) + target_value = target_value.view(self._cfg.batch_size, -1) + target_search_value = target_search_value.view(self._cfg.batch_size, -1) # EfficientZero V2 + assert obs_batch.size(0) == self._cfg.batch_size == target_value_prefix.size(0) + + # ============================================================== + # EfficientZero V2: Value Target Mixing + # ============================================================== + # Get value target type from config + value_target_type = getattr(self._cfg, 'value_target', 'bootstrap') # 'bootstrap', 'search', 'mixed' + + if value_target_type == 'mixed': + # Get the threshold for starting to use mixed values + start_use_mix_steps = getattr(self._cfg, 'start_use_mix_training_steps', 30000) + + if self._train_iteration < start_use_mix_steps: + # Early training: use bootstrap only + final_target_value = target_value + else: + # Later training: mix bootstrap and search + # Use a simple equal mixing (can be adjusted) + final_target_value = 0.5 * target_value + 0.5 * target_search_value + elif value_target_type == 'search': + # Use search values only + final_target_value = target_search_value + else: + # Default: use bootstrap values only + final_target_value = target_value + + # Use the final mixed target value for all subsequent computations + target_value = final_target_value + + # ``scalar_transform`` to transform the original value to the scaled value, + # i.e. h(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. + transformed_target_value_prefix = scalar_transform(target_value_prefix) + transformed_target_value = scalar_transform(target_value) + # transform a scalar to its categorical_distribution. After this transformation, each scalar is + # represented as the linear combination of its two adjacent supports. + target_value_prefix_categorical = phi_transform(self.reward_support, transformed_target_value_prefix) + target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + # ============================================================== + # the core initial_inference in EfficientZero policy. + # ============================================================== + network_output = self._learn_model.initial_inference(obs_batch) + # value_prefix shape: (batch_size, 10), the ``value_prefix`` at the first step is zero padding. + latent_state, value_prefix, reward_hidden_state, value, policy_logits = ez_network_output_unpack(network_output) + + # transform the scaled value or its categorical representation to its original value, + # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. + original_value = self.value_inverse_scalar_transform_handle(value) + + # Note: The following lines are just for debugging. + predicted_value_prefixs = [] + if self._cfg.monitor_extra_statistics: + predicted_values, predicted_policies = original_value.detach().cpu(), torch.softmax( + policy_logits, dim=1 + ).detach().cpu() + + # calculate the new priorities for each transition. + value_priority = L1Loss(reduction='none')(original_value.squeeze(-1), target_value[:, 0]) + 1e-6 + + prob = torch.softmax(policy_logits, dim=-1) + policy_entropy = -(prob * prob.log()).sum(-1).mean() + + # ============================================================== + # calculate policy and value loss for the first step. + # ============================================================== + policy_loss = cross_entropy_loss(policy_logits, target_policy[:, 0]) + + # Here we take the init hypothetical step k=0. + target_normalized_visit_count_init_step = target_policy[:, 0] + + # ******* NOTE: target_policy_entropy is only for debug. ****** + non_masked_indices = torch.nonzero(mask_batch[:, 0]).squeeze(-1) + # Check if there are any unmasked rows + if len(non_masked_indices) > 0: + target_normalized_visit_count_masked = torch.index_select( + target_normalized_visit_count_init_step, 0, non_masked_indices + ) + target_policy_entropy = -((target_normalized_visit_count_masked+1e-6) * (target_normalized_visit_count_masked+1e-6).log()).sum(-1).mean() + else: + # Set target_policy_entropy to log(|A|) if all rows are masked + target_policy_entropy = torch.log(torch.tensor(target_normalized_visit_count_init_step.shape[-1])) + + value_loss = cross_entropy_loss(value, target_value_categorical[:, 0]) + + value_prefix_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) + consistency_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) + + # ============================================================== + # the core recurrent_inference in EfficientZero policy. + # ============================================================== + for step_k in range(self._cfg.num_unroll_steps): + # unroll with the dynamics function: predict the next ``latent_state``, ``reward_hidden_state``, + # `` value_prefix`` given current ``latent_state`` ``reward_hidden_state`` and ``action``. + # And then predict policy_logits and value with the prediction function. + network_output = self._learn_model.recurrent_inference( + latent_state, reward_hidden_state, action_batch[:, step_k] + ) + latent_state, value_prefix, reward_hidden_state, value, policy_logits = ez_network_output_unpack( + network_output + ) + + # transform the scaled value or its categorical representation to its original value, + # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. + original_value = self.value_inverse_scalar_transform_handle(value) + + # ============================================================== + # calculate consistency loss for the next ``num_unroll_steps`` unroll steps. + # ============================================================== + if self._cfg.ssl_loss_weight > 0: + # obtain the oracle latent states from representation function. + beg_index, end_index = self._get_target_obs_index_in_step_k(step_k) + network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index]) + + latent_state = to_tensor(latent_state) + representation_state = to_tensor(network_output.latent_state) + + # NOTE: no grad for the representation_state branch. + # import pdb; pdb.set_trace() + dynamic_proj = self._learn_model.project(latent_state, with_grad=True) + observation_proj = self._learn_model.project(representation_state, with_grad=False) + temp_loss = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_k] + + consistency_loss += temp_loss + + # NOTE: the target policy, target_value_categorical, target_value_prefix_categorical is calculated in + # game buffer now. + # ============================================================== + # calculate policy loss for the next ``num_unroll_steps`` unroll steps. + # NOTE: the +=. + # ============================================================== + policy_loss += cross_entropy_loss(policy_logits, target_policy[:, step_k + 1]) + + # Here we take the hypothetical step k = step_k + 1 + prob = torch.softmax(policy_logits, dim=-1) + policy_entropy += -(prob * prob.log()).sum(-1).mean() + + target_normalized_visit_count = target_policy[:, step_k + 1] + + # ******* NOTE: target_policy_entropy is only for debug. ****** + non_masked_indices = torch.nonzero(mask_batch[:, step_k + 1]).squeeze(-1) + # Check if there are any unmasked rows + if len(non_masked_indices) > 0: + target_normalized_visit_count_masked = torch.index_select( + target_normalized_visit_count, 0, non_masked_indices + ) + target_policy_entropy += -((target_normalized_visit_count_masked+1e-6) * (target_normalized_visit_count_masked+1e-6).log()).sum(-1).mean() + else: + # Set target_policy_entropy to log(|A|) if all rows are masked + target_policy_entropy += torch.log(torch.tensor(target_normalized_visit_count.shape[-1])) + + value_loss += cross_entropy_loss(value, target_value_categorical[:, step_k + 1]) + value_prefix_loss += cross_entropy_loss(value_prefix, target_value_prefix_categorical[:, step_k]) + + # reset hidden states every ``lstm_horizon_len`` unroll steps. + if (step_k + 1) % self._cfg.lstm_horizon_len == 0: + reward_hidden_state = ( + torch.zeros(1, self._cfg.batch_size, self._cfg.model.lstm_hidden_size).to(self._cfg.device), + torch.zeros(1, self._cfg.batch_size, self._cfg.model.lstm_hidden_size).to(self._cfg.device) + ) + + if self._cfg.monitor_extra_statistics: + original_value_prefixs = self.value_inverse_scalar_transform_handle(value_prefix) + original_value_prefixs_cpu = original_value_prefixs.detach().cpu() + predicted_values = torch.cat( + (predicted_values, self.value_inverse_scalar_transform_handle(value).detach().cpu()) + ) + predicted_value_prefixs.append(original_value_prefixs_cpu) + predicted_policies = torch.cat((predicted_policies, torch.softmax(policy_logits, dim=1).detach().cpu())) + + # ============================================================== + # the core learn model update step. + # ============================================================== + # weighted loss with masks (some invalid states which are out of trajectory.) + loss = ( + self._cfg.ssl_loss_weight * consistency_loss + self._cfg.policy_loss_weight * policy_loss + + self._cfg.value_loss_weight * value_loss + self._cfg.reward_loss_weight * value_prefix_loss + ) + weighted_total_loss = (weights * loss).mean() + # TODO(pu): test the effect of gradient scale. + gradient_scale = 1 / self._cfg.num_unroll_steps + weighted_total_loss.register_hook(lambda grad: grad * gradient_scale) + self._optimizer.zero_grad() + weighted_total_loss.backward() + if self._cfg.multi_gpu: + self.sync_gradients(self._learn_model) + total_grad_norm_before_clip = torch.nn.utils.clip_grad_norm_( + self._learn_model.parameters(), self._cfg.grad_clip_value + ) + self._optimizer.step() + if self._cfg.piecewise_decay_lr_scheduler: + self.lr_scheduler.step() + + # ============================================================== + # the core target model update step. + # ============================================================== + self._target_model.update(self._learn_model.state_dict()) + + if self._cfg.monitor_extra_statistics: + predicted_value_prefixs = torch.stack(predicted_value_prefixs).transpose(1, 0).squeeze(-1) + predicted_value_prefixs = predicted_value_prefixs.reshape(-1).unsqueeze(-1) + + return { + 'collect_mcts_temperature': self._collect_mcts_temperature, + 'collect_epsilon': self.collect_epsilon, + 'cur_lr': self._optimizer.param_groups[0]['lr'], + 'weighted_total_loss': weighted_total_loss.item(), + 'total_loss': loss.mean().item(), + 'policy_loss': policy_loss.mean().item(), + 'policy_entropy': policy_entropy.item() / (self._cfg.num_unroll_steps + 1), + 'target_policy_entropy': target_policy_entropy.item() / (self._cfg.num_unroll_steps + 1), + 'value_prefix_loss': value_prefix_loss.mean().item(), + 'value_loss': value_loss.mean().item(), + 'consistency_loss': consistency_loss.mean().item() / self._cfg.num_unroll_steps, + 'target_value_prefix': target_value_prefix.mean().item(), + 'target_value': target_value.mean().item(), + 'transformed_target_value_prefix': transformed_target_value_prefix.mean().item(), + 'transformed_target_value': transformed_target_value.mean().item(), + 'predicted_value_prefixs': predicted_value_prefixs.mean().item(), + 'predicted_values': predicted_values.mean().item(), + 'total_grad_norm_before_clip': total_grad_norm_before_clip.item(), + # ============================================================== + # priority related + # ============================================================== + 'value_priority': value_priority.mean().item(), + 'value_priority_orig': value_priority, # torch.tensor compatible with ddp settings + } + + def _init_collect(self) -> None: + """ + Overview: + Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. + """ + self._collect_model = self._model + if self._cfg.mcts_ctree: + self._mcts_collect = MCTSCtree(self._cfg) + else: + self._mcts_collect = MCTSPtree(self._cfg) + self._collect_mcts_temperature = 1 + self.collect_epsilon = 0.0 + + def _forward_collect( + self, + data: torch.Tensor, + action_mask: list = None, + temperature: float = 1, + to_play: List = [-1], + epsilon: float = 0.25, + ready_env_id: np.array = None, + **kwargs, + ): + """ + Overview: + The forward function for collecting data in collect mode. Use model to execute MCTS search. + Choosing the action through sampling during the collect mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - temperature (:obj:`float`): The temperature of the policy. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - temperature: :math:`(1, )`. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._collect_model.eval() + self._collect_mcts_temperature = temperature + self.collect_epsilon = epsilon + active_collect_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + output = {i: None for i in ready_env_id} + + + with torch.no_grad(): + # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} + network_output = self._collect_model.initial_inference(data) + latent_state_roots, value_prefix_roots, reward_hidden_state_roots, pred_values, policy_logits = ez_network_output_unpack( + network_output + ) + + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + reward_hidden_state_roots = ( + reward_hidden_state_roots[0].detach().cpu().numpy(), + reward_hidden_state_roots[1].detach().cpu().numpy() + ) + + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] + # if abs(sum(policy_logits[0]))>0.02: + + + if not self._cfg.collect_with_pure_policy: + # collect with MCTS guided with policy. + # Generate Gumbel noises for EfficientZero V2 Sequential Halving + gumbel_noises = [ + np.random.gumbel(0, 1, int(sum(action_mask[j]))).astype(np.float32)*self._cfg.root_noise_weight + for j in range(active_collect_env_num) + ] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_collect_env_num, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(active_collect_env_num, legal_actions) + # roots.prepare_no_noise(self._cfg.root_noise_weight, noises, value_prefix_roots, policy_logits, to_play) + roots.prepare_no_noise(value_prefix_roots, policy_logits, to_play) + # EfficientZero V2: search returns improved_policies and best_actions + improved_policies, best_actions = self._mcts_collect.search( + roots, self._collect_model, latent_state_roots, reward_hidden_state_roots, to_play, gumbel_noises + ) + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + + for i, env_id in enumerate(ready_env_id): + # EfficientZero V2: use improved_policy and best_action + improved_policy = improved_policies[i] + best_action_idx = best_actions[i] + value = roots_values[i] + distributions=roots_visit_count_distributions[i] + # EfficientZero V2: greedy action selection (Sequential Halving best action) + action_index_in_legal_action_set = best_action_idx + + if self._cfg.eps.eps_greedy_exploration_in_collect: + # eps-greedy exploration + if np.random.rand() < self.collect_epsilon: + action_index_in_legal_action_set = np.random.choice(len(legal_actions[i])) + + # Compute improved policy entropy for logging + visit_count_distribution_entropy = -np.sum( + improved_policy * np.log(improved_policy + 1e-9) + ) + + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + # import pudb;pudb.set_trace() + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions,#improved_policy, # EfficientZero V2: improved policy + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + "improved_policy_probs":improved_policy, + "roots_completed_value":value + } + else: + # collect with pure policy. + for i, env_id in enumerate(ready_env_id): + policy_values = torch.softmax(torch.tensor([policy_logits[i][a] for a in legal_actions[i]]), dim=0).tolist() + policy_values = policy_values / np.sum(policy_values) + action_index_in_legal_action_set = np.random.choice(len(legal_actions[i]), p=policy_values) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + output[env_id] = { + 'action': action, + 'searched_value': pred_values[i], + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + return output + + def _init_eval(self) -> None: + """ + Overview: + Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. + """ + self._eval_model = self._model + if self._cfg.mcts_ctree: + self._mcts_eval = MCTSCtree(self._cfg) + else: + self._mcts_eval = MCTSPtree(self._cfg) + + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: Union[int, List] = [-1], ready_env_id: np.array = None, **kwargs): + """ + Overview: + The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. + Choosing the action with the highest value (argmax) rather than sampling during the eval mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._eval_model.eval() + active_eval_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + output = {i: None for i in ready_env_id} + with torch.no_grad(): + # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} + network_output = self._eval_model.initial_inference(data) + latent_state_roots, value_prefix_roots, reward_hidden_state_roots, pred_values, policy_logits = ez_network_output_unpack( + network_output + ) + + if not self._eval_model.training: + # if not in training, obtain the scalars of the value/reward + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + latent_state_roots = latent_state_roots.detach().cpu().numpy() + reward_hidden_state_roots = ( + reward_hidden_state_roots[0].detach().cpu().numpy(), + reward_hidden_state_roots[1].detach().cpu().numpy() + ) + policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_eval_env_num, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(active_eval_env_num, legal_actions) + roots.prepare_no_noise(value_prefix_roots, policy_logits, to_play) + # Generate zero Gumbel noises for evaluation (no exploration) + gumbel_noises = [ + np.zeros(int(sum(action_mask[j])), dtype=np.float32) + for j in range(active_eval_env_num) + ] + # EfficientZero V2: search returns improved_policies and best_actions + improved_policies, best_actions = self._mcts_eval.search( + roots, self._eval_model, latent_state_roots, reward_hidden_state_roots, to_play, gumbel_noises + ) + + roots_values = roots.get_values() # shape: {list: batch_size} + + for i, env_id in enumerate(ready_env_id): + # EfficientZero V2: use improved_policy and best_action + improved_policy = improved_policies[i] + best_action_idx = best_actions[i] + value = roots_values[i] + + # EfficientZero V2: greedy action selection (Sequential Halving best action) + # Eval mode: always use best action (deterministic=True) + action_index_in_legal_action_set = best_action_idx + + # Compute improved policy entropy for logging + visit_count_distribution_entropy = -np.sum( + improved_policy * np.log(improved_policy + 1e-9) + ) + + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + output[env_id] = { + 'action': action, + 'visit_count_distributions': improved_policy, # EfficientZero V2: improved policy + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + + return output + + def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Register the variables to be monitored in learn mode. The registered variables will be logged in + tensorboard according to the return value ``_forward_learn``. + """ + return [ + 'collect_mcts_temperature', + 'cur_lr', + 'weighted_total_loss', + 'total_loss', + 'policy_loss', + 'policy_entropy', + 'target_policy_entropy', + 'value_prefix_loss', + 'value_loss', + 'consistency_loss', + 'value_priority', + 'target_value_prefix', + 'target_value', + 'predicted_value_prefixs', + 'predicted_values', + 'transformed_target_value_prefix', + 'transformed_target_value', + 'total_grad_norm_before_clip', + ] + + def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Return the state_dict of learn mode, usually including model and optimizer. + Returns: + - state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring. + """ + return { + 'model': self._learn_model.state_dict(), + 'target_model': self._target_model.state_dict(), + 'optimizer': self._optimizer.state_dict(), + } + + def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + """ + Overview: + Load the state_dict variable into policy learn mode. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before. + """ + self._learn_model.load_state_dict(state_dict['model']) + self._target_model.load_state_dict(state_dict['target_model']) + self._optimizer.load_state_dict(state_dict['optimizer']) + + def _process_transition(self, obs, policy_output, timestep): + # be compatible with DI-engine Policy class + pass + + def _get_train_sample(self, data): + # be compatible with DI-engine Policy class + pass diff --git a/zoo/atari/config/atari_efficientzero_v2_config.py b/zoo/atari/config/atari_efficientzero_v2_config.py new file mode 100644 index 000000000..0d8c60d3b --- /dev/null +++ b/zoo/atari/config/atari_efficientzero_v2_config.py @@ -0,0 +1,98 @@ +from easydict import EasyDict +from zoo.atari.config.atari_env_action_space_map import atari_env_action_space_map + +env_id = 'PongNoFrameskip-v4' # You can specify any Atari game here +action_space_size = atari_env_action_space_map[env_id] + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +update_per_collect = None +replay_ratio = 0.25 +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +batch_size = 256 +max_env_step = int(5e5) +reanalyze_ratio = 0. +num_unroll_steps = 5 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +atari_efficientzero_config = dict( + exp_name=f'data_efficientzero/{env_id[:-14]}_efficientzero_stack4_H{num_unroll_steps}_seed0', + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=[4, 64, 64], + frame_stack_num=4, + gray_scale=True, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=[4, 64, 64], + image_channel=1, + frame_stack_num=4, + gray_scale=True, + action_space_size=action_space_size, + downsample=True, + self_supervised_learning_loss=True, # default is False + discrete_action_encoding_type='one_hot', + norm_type='BN', + reward_support_range=(-50., 51., 1.), + value_support_range=(-50., 51., 1.), + ), + cuda=True, + env_type='not_board_games', + game_segment_length=400, + use_augmentation=True, + use_priority=False, + replay_ratio=replay_ratio, + update_per_collect=update_per_collect, + batch_size=batch_size, + dormant_threshold=0.025, + optim_type='SGD', + piecewise_decay_lr_scheduler=True, + learning_rate=0.2, + target_update_freq=100, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + ssl_loss_weight=2, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) +atari_efficientzero_config = EasyDict(atari_efficientzero_config) +main_config = atari_efficientzero_config + +atari_efficientzero_create_config = dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='efficientzero_v2', + import_names=['lzero.policy.efficientzero_v2'], + ), +) +atari_efficientzero_create_config = EasyDict(atari_efficientzero_create_config) +create_config = atari_efficientzero_create_config + +if __name__ == "__main__": + # Define a list of seeds for multiple runs + seeds = [0, 1, 2] # You can add more seed values here + for seed in seeds: + # Update exp_name to include the current seed + main_config.exp_name = f'data_efficientzero_v2/{env_id[:-14]}_efficientzero_v2_stack4_H{num_unroll_steps}_seed{seed}' + from lzero.entry import train_muzero + train_muzero([main_config, create_config], seed=seed, max_env_step=max_env_step) diff --git a/zoo/classic_control/cartpole/config/cartpole_efficientzero_v2_config.py b/zoo/classic_control/cartpole/config/cartpole_efficientzero_v2_config.py new file mode 100644 index 000000000..f6bd11ba5 --- /dev/null +++ b/zoo/classic_control/cartpole/config/cartpole_efficientzero_v2_config.py @@ -0,0 +1,89 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 25 +update_per_collect = 100 +batch_size = 256 +max_env_step = int(1e5) +reanalyze_ratio = 0. +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +cartpole_efficientzero_config = dict( + exp_name=f'data_ezv2/cartpole_efficientzero_v2_ns{num_simulations}_upc{update_per_collect}_rer{reanalyze_ratio}_seed0', + env=dict( + env_id='CartPole-v0', + continuous=False, + manually_discretization=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=4, + action_space_size=2, + model_type='mlp', + lstm_hidden_size=128, + latent_state_dim=128, + discrete_action_encoding_type='one_hot', + norm_type='BN', + ), + # (str) The path of the pretrained model. If None, the model will be initialized by the default model. + model_path=None, + cuda=True, + env_type='not_board_games', + game_segment_length=50, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + piecewise_decay_lr_scheduler=False, + learning_rate=0.003, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e2), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) + +cartpole_efficientzero_config = EasyDict(cartpole_efficientzero_config) +main_config = cartpole_efficientzero_config + +cartpole_efficientzero_create_config = dict( + env=dict( + type='cartpole_lightzero', + import_names=['zoo.classic_control.cartpole.envs.cartpole_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='efficientzero_v2', + import_names=['lzero.policy.efficientzero_v2'], + ), +) +cartpole_efficientzero_create_config = EasyDict(cartpole_efficientzero_create_config) +create_config = cartpole_efficientzero_create_config + +if __name__ == "__main__": + # Users can use different train entry by specifying the entry_type. + entry_type = "train_muzero" # options={"train_muzero", "train_muzero_with_gym_env"} + + if entry_type == "train_muzero": + from lzero.entry import train_muzero + elif entry_type == "train_muzero_with_gym_env": + """ + The ``train_muzero_with_gym_env`` entry means that the environment used in the training process is generated by wrapping the original gym environment with LightZeroEnvWrapper. + Users can refer to lzero/envs/wrappers for more details. + """ + from lzero.entry import train_muzero_with_gym_env as train_muzero + + train_muzero([main_config, create_config], seed=0, model_path=main_config.policy.model_path, max_env_step=max_env_step)