diff --git a/lzero/mcts/ctree/ctree_efficientzero_v2/Makefile b/lzero/mcts/ctree/ctree_efficientzero_v2/Makefile new file mode 100644 index 000000000..0f1774f48 --- /dev/null +++ b/lzero/mcts/ctree/ctree_efficientzero_v2/Makefile @@ -0,0 +1,72 @@ +CXX = g++ +CXXFLAGS = -std=c++11 -Wall -Wextra -O2 +INCLUDES = -I. + +# Source files +SOURCES = lib/cnode.cpp lib/cminimax.cpp +TEST_SOURCE = test_cnode.cpp +TEST_SH_SOURCE = test_sequential_halving.cpp +TEST_BATCH_SH_SOURCE = test_batch_sequential_halving.cpp + +# Object files +OBJECTS = $(SOURCES:.cpp=.o) +TEST_OBJECTS = $(TEST_SOURCE:.cpp=.o) +TEST_SH_OBJECTS = $(TEST_SH_SOURCE:.cpp=.o) +TEST_BATCH_SH_OBJECTS = $(TEST_BATCH_SH_SOURCE:.cpp=.o) + +# Executables +TEST_EXECUTABLE = test_cnode +TEST_SH_EXECUTABLE = test_sequential_halving +TEST_BATCH_SH_EXECUTABLE = test_batch_sequential_halving + +.PHONY: all clean test test_sh test_batch_sh + +all: $(TEST_EXECUTABLE) + +$(TEST_EXECUTABLE): $(OBJECTS) $(TEST_OBJECTS) + @echo "Linking test executable..." + $(CXX) $(CXXFLAGS) -o $(TEST_EXECUTABLE) $(OBJECTS) $(TEST_OBJECTS) + @echo "Build successful!" + +lib/%.o: lib/%.cpp lib/%.h lib/cminimax.h + @echo "Compiling $<..." + $(CXX) $(CXXFLAGS) $(INCLUDES) -c $< -o $@ + +test_cnode.o: test_cnode.cpp + @echo "Compiling test_cnode.cpp..." + $(CXX) $(CXXFLAGS) $(INCLUDES) -c test_cnode.cpp -o test_cnode.o + +test_sequential_halving.o: test_sequential_halving.cpp + @echo "Compiling test_sequential_halving.cpp..." + $(CXX) $(CXXFLAGS) $(INCLUDES) -c test_sequential_halving.cpp -o test_sequential_halving.o + +$(TEST_SH_EXECUTABLE): $(OBJECTS) $(TEST_SH_OBJECTS) + @echo "Linking test_sequential_halving executable..." + $(CXX) $(CXXFLAGS) -o $(TEST_SH_EXECUTABLE) $(OBJECTS) $(TEST_SH_OBJECTS) + @echo "Build successful!" + +test_batch_sequential_halving.o: test_batch_sequential_halving.cpp + @echo "Compiling test_batch_sequential_halving.cpp..." + $(CXX) $(CXXFLAGS) $(INCLUDES) -c test_batch_sequential_halving.cpp -o test_batch_sequential_halving.o + +$(TEST_BATCH_SH_EXECUTABLE): $(OBJECTS) $(TEST_BATCH_SH_OBJECTS) + @echo "Linking test_batch_sequential_halving executable..." + $(CXX) $(CXXFLAGS) -o $(TEST_BATCH_SH_EXECUTABLE) $(OBJECTS) $(TEST_BATCH_SH_OBJECTS) + @echo "Build successful!" + +test: all + @echo "Running tests..." + @./$(TEST_EXECUTABLE) + +test_sh: $(TEST_SH_EXECUTABLE) + @echo "Running Sequential Halving tests..." + @./$(TEST_SH_EXECUTABLE) + +test_batch_sh: $(TEST_BATCH_SH_EXECUTABLE) + @echo "Running batch Sequential Halving tests..." + @./$(TEST_BATCH_SH_EXECUTABLE) + +clean: + @echo "Cleaning up..." + rm -f $(OBJECTS) $(TEST_OBJECTS) $(TEST_SH_OBJECTS) $(TEST_BATCH_SH_OBJECTS) $(TEST_EXECUTABLE) $(TEST_SH_EXECUTABLE) $(TEST_BATCH_SH_EXECUTABLE) + @echo "Clean complete!" diff --git a/lzero/mcts/ctree/ctree_efficientzero_v2/__init__.py b/lzero/mcts/ctree/ctree_efficientzero_v2/__init__.py new file mode 100644 index 000000000..a2ee6590c --- /dev/null +++ b/lzero/mcts/ctree/ctree_efficientzero_v2/__init__.py @@ -0,0 +1,3 @@ +from . import ez_tree + +__all__ = ['ez_tree'] diff --git a/lzero/mcts/ctree/ctree_efficientzero_v2/ez_tree.pxd b/lzero/mcts/ctree/ctree_efficientzero_v2/ez_tree.pxd new file mode 100644 index 000000000..b935d5b4e --- /dev/null +++ b/lzero/mcts/ctree/ctree_efficientzero_v2/ez_tree.pxd @@ -0,0 +1,124 @@ +# distutils:language=c++ +# cython:language_level=3 +from libcpp.vector cimport vector + + +cdef extern from "lib/cminimax.cpp": + pass + + +cdef extern from "lib/cminimax.h" namespace "tools": + cdef cppclass CMinMaxStats: + CMinMaxStats() except + + int c_visit + float c_scale + float maximum, minimum, value_delta_max + + void set_delta(float value_delta_max) + void set_static_val(float value_delta_max, int c_visit, float c_scale) + void update(float value) + void clear() + float normalize(float value) + + cdef cppclass CMinMaxStatsList: + CMinMaxStatsList() except + + CMinMaxStatsList(int num) except + + int num + vector[CMinMaxStats] stats_lst + + void set_delta(float value_delta_max) + void set_static_val(float value_delta_max, int c_visit, float c_scale) + +cdef extern from "lib/cnode.cpp": + pass + + +cdef extern from "lib/cnode.h" namespace "tree": + cdef cppclass CNode: + CNode() except + + CNode(float prior, vector[int] & legal_actions) except + + int visit_count, to_play, current_latent_state_index, batch_index, best_action + float value_prefixs, prior, value_sum, parent_value_prefix + + void expand(int to_play, int current_latent_state_index, int batch_index, float value_prefixs, + vector[float] policy_logits) + void add_exploration_noise(float exploration_fraction, vector[float] noises) + float compute_mean_q(int isRoot, float parent_q, float discount_factor) + + int expanded() + float value() + vector[int] get_trajectory() + vector[int] get_children_distribution() + CNode * get_child(int action) + + cdef cppclass CRoots: + CRoots() except + + CRoots(int root_num, vector[vector[int]] legal_actions_list) except + + int root_num + vector[CNode] roots + + void prepare(float root_noise_weight, const vector[vector[float]] & noises, + const vector[float] & value_prefixs, const vector[vector[float]] & policies, + vector[int] to_play_batch) + void prepare_no_noise(const vector[float] & value_prefixs, const vector[vector[float]] & policies, + vector[int] to_play_batch) + void clear() + vector[vector[int]] get_trajectories() + vector[vector[int]] get_distributions() + vector[float] get_values() + vector[vector[float]] get_root_policies(CMinMaxStatsList *min_max_stats_lst) + vector[int] get_best_actions() + # visualize related code + # CNode* get_root(int index) + + cdef cppclass CSearchResults: + CSearchResults() except + + CSearchResults(int num) except + + int num + vector[int] latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, search_lens + vector[int] virtual_to_play_batchs + vector[CNode *] nodes + + cdef void cbackpropagate(vector[CNode *] & search_path, CMinMaxStats & min_max_stats, + int to_play, float value, float discount_factor) + void cbatch_backpropagate(int current_latent_state_index, float discount_factor, vector[float] value_prefixs, + vector[float] values, vector[vector[float]] policies, + CMinMaxStatsList *min_max_stats_lst, CSearchResults & results, + vector[int] is_reset_list, vector[int] & to_play_batch) + void cbatch_backpropagate_with_reuse(int current_latent_state_index, float discount_factor, vector[float] value_prefixs, + vector[float] values, vector[vector[float]] policies, + CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, + vector[int] is_reset_list, vector[int] &to_play_batch, vector[int] &no_inference_lst, + vector[int] &reuse_lst, vector[float] &reuse_value_lst) + # ========== MuZero/UCB 风格的遍历(备份) ========== + void cbatch_traverse_ucb(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, + CMinMaxStatsList *min_max_stats_lst, CSearchResults & results, + vector[int] & virtual_to_play_batch) + void cbatch_traverse_with_reuse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, + CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, + vector[int] &virtual_to_play_batch, vector[int] &true_action, vector[float] &reuse_value) + + # ========== EfficientZero V2 风格的遍历(Sequential Halving 集成) ========== + void cbatch_traverse(CRoots *roots, CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, + int num_simulations, int simulation_idx, const vector[vector[float]]& gumble_noise, + int current_num_top_actions, vector[int] &virtual_to_play_batch) + + # ========== EfficientZero V2 Sequential Halving ========== + vector[int] c_batch_sequential_halving(CRoots *roots, vector[vector[float]] gumbel_noises, + CMinMaxStatsList *min_max_stats_lst, int current_phase, + int current_num_top_actions) + +cdef class MinMaxStatsList: + cdef CMinMaxStatsList *cmin_max_stats_lst + +cdef class ResultsWrapper: + cdef CSearchResults cresults + +cdef class Roots: + cdef readonly int root_num + cdef CRoots *roots + cdef readonly int num_actions + cdef public object legal_actions_list + +cdef class Node: + cdef CNode cnode diff --git a/lzero/mcts/ctree/ctree_efficientzero_v2/ez_tree.pyx b/lzero/mcts/ctree/ctree_efficientzero_v2/ez_tree.pyx new file mode 100644 index 000000000..1a589db0b --- /dev/null +++ b/lzero/mcts/ctree/ctree_efficientzero_v2/ez_tree.pyx @@ -0,0 +1,158 @@ +# distutils:language=c++ +# cython:language_level=3 +import cython +from libcpp.vector cimport vector + +cdef class MinMaxStatsList: + @cython.binding + def __cinit__(self, int num): + self.cmin_max_stats_lst = new CMinMaxStatsList(num) + + @cython.binding + def set_delta(self, float value_delta_max): + self.cmin_max_stats_lst[0].set_delta(value_delta_max) + + def __dealloc__(self): + del self.cmin_max_stats_lst + +cdef class ResultsWrapper: + @cython.binding + def __cinit__(self, int num): + self.cresults = CSearchResults(num) + + @cython.binding + def get_search_len(self): + return self.cresults.search_lens + +cdef class Roots: + @cython.binding + def __cinit__(self, int root_num, vector[vector[int]] legal_actions_list): + self.root_num = root_num + self.roots = new CRoots(root_num, legal_actions_list) + # Store legal_actions for access from Python + self.legal_actions_list = legal_actions_list + + @cython.binding + def prepare(self, float root_noise_weight, list noises, list value_prefix_pool, + list policy_logits_pool, vector[int] & to_play_batch): + self.roots[0].prepare(root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play_batch) + + @cython.binding + def prepare_no_noise(self, list value_prefix_pool, list policy_logits_pool, vector[int] & to_play_batch): + self.roots[0].prepare_no_noise(value_prefix_pool, policy_logits_pool, to_play_batch) + + @cython.binding + def get_trajectories(self): + return self.roots[0].get_trajectories() + + @cython.binding + def get_distributions(self): + return self.roots[0].get_distributions() + + @cython.binding + def get_values(self): + return self.roots[0].get_values() + + @cython.binding + def get_root_policies(self, MinMaxStatsList min_max_stats_lst): + return self.roots[0].get_root_policies(min_max_stats_lst.cmin_max_stats_lst) + + @cython.binding + def get_best_actions(self): + return self.roots[0].get_best_actions() + + # visualize related code + #def get_root(self, int index): + # return self.roots[index] + + @cython.binding + def clear(self): + self.roots[0].clear() + + @cython.binding + def get_legal_actions(self): + """Get the legal actions list""" + return list(self.legal_actions_list) + + def __dealloc__(self): + del self.roots + + @property + def num(self): + return self.root_num + +cdef class Node: + def __cinit__(self): + pass + + def __cinit__(self, float prior, vector[int] & legal_actions): + pass + + @cython.binding + def expand(self, int to_play, int current_latent_state_index, int batch_index, float value_prefix, + list policy_logits): + cdef vector[float] cpolicy = policy_logits + self.cnode.expand(to_play, current_latent_state_index, batch_index, value_prefix, cpolicy) + +@cython.binding +def batch_backpropagate(int current_latent_state_index, float discount_factor, list value_prefixs, list values, list policies, + MinMaxStatsList min_max_stats_lst, ResultsWrapper results, list is_reset_list, + list to_play_batch): + cdef int i + cdef vector[float] cvalue_prefixs = value_prefixs + cdef vector[float] cvalues = values + cdef vector[vector[float]] cpolicies = policies + + cbatch_backpropagate(current_latent_state_index, discount_factor, cvalue_prefixs, cvalues, cpolicies, + min_max_stats_lst.cmin_max_stats_lst, results.cresults, is_reset_list, to_play_batch) + +@cython.binding +def batch_backpropagate_with_reuse(int current_latent_state_index, float discount_factor, list value_prefixs, list values, list policies, + MinMaxStatsList min_max_stats_lst, ResultsWrapper results, list is_reset_list, + list to_play_batch, list no_inference_lst, list reuse_lst, list reuse_value_lst): + cdef int i + cdef vector[float] cvalue_prefixs = value_prefixs + cdef vector[float] cvalues = values + cdef vector[vector[float]] cpolicies = policies + cdef vector[float] creuse_value_lst = reuse_value_lst + + cbatch_backpropagate_with_reuse(current_latent_state_index, discount_factor, cvalue_prefixs, cvalues, cpolicies, + min_max_stats_lst.cmin_max_stats_lst, results.cresults, is_reset_list, to_play_batch, no_inference_lst, reuse_lst, creuse_value_lst) + +# ========== MuZero/UCB 风格的遍历(备份版本) ========== +@cython.binding +def batch_traverse_ucb(Roots roots, int pb_c_base, float pb_c_init, float discount_factor, MinMaxStatsList min_max_stats_lst, + ResultsWrapper results, list virtual_to_play_batch): + """MuZero/UCB 风格的批量树遍历(备份版本)""" + cbatch_traverse_ucb(roots.roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst.cmin_max_stats_lst, + results.cresults, virtual_to_play_batch) + + return results.cresults.latent_state_index_in_search_path, results.cresults.latent_state_index_in_batch, results.cresults.last_actions, results.cresults.virtual_to_play_batchs + +# ========== EfficientZero V2 风格的遍历(Sequential Halving 集成) ========== +@cython.binding +def batch_traverse(Roots roots, MinMaxStatsList min_max_stats_lst, ResultsWrapper results, + int num_simulations, int simulation_idx, list gumbel_noises, + int current_num_top_actions, list virtual_to_play_batch): + """EfficientZero V2 风格的批量树遍历,集成 Sequential Halving 逻辑""" + cdef vector[vector[float]] c_gumbel_noises = gumbel_noises + cbatch_traverse(roots.roots, min_max_stats_lst.cmin_max_stats_lst, results.cresults, + num_simulations, simulation_idx, c_gumbel_noises, + current_num_top_actions, virtual_to_play_batch) + + return results.cresults.latent_state_index_in_search_path, results.cresults.latent_state_index_in_batch, results.cresults.last_actions, results.cresults.virtual_to_play_batchs + +@cython.binding +def batch_traverse_with_reuse(Roots roots, int pb_c_base, float pb_c_init, float discount_factor, MinMaxStatsList min_max_stats_lst, + ResultsWrapper results, list virtual_to_play_batch, list true_action, list reuse_value): + cbatch_traverse_with_reuse(roots.roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst.cmin_max_stats_lst, results.cresults, + virtual_to_play_batch, true_action, reuse_value) + + return results.cresults.latent_state_index_in_search_path, results.cresults.latent_state_index_in_batch, results.cresults.last_actions, results.cresults.virtual_to_play_batchs + +@cython.binding +def batch_sequential_halving(Roots roots, list gumbel_noises, MinMaxStatsList min_max_stats_lst, + int current_phase, int current_num_top_actions): + cdef vector[vector[float]] c_gumbel_noises = gumbel_noises + return c_batch_sequential_halving(roots.roots, c_gumbel_noises, min_max_stats_lst.cmin_max_stats_lst, + current_phase, current_num_top_actions) diff --git a/lzero/mcts/ctree/ctree_efficientzero_v2/lib/cminimax.cpp b/lzero/mcts/ctree/ctree_efficientzero_v2/lib/cminimax.cpp new file mode 100644 index 000000000..b438ad51e --- /dev/null +++ b/lzero/mcts/ctree/ctree_efficientzero_v2/lib/cminimax.cpp @@ -0,0 +1,71 @@ +// C++11 + +#include "cminimax.h" +#include +#include + +namespace tools { + + // ========== CMinMaxStats Implementation ========== + CMinMaxStats::CMinMaxStats() { + this->maximum = FLOAT_MIN; + this->minimum = FLOAT_MAX; + this->value_delta_max = 0.01; + this->c_visit = 1; + this->c_scale = 1.0; + } + + CMinMaxStats::~CMinMaxStats() {} + + void CMinMaxStats::set_delta(float value_delta_max) { + this->value_delta_max = value_delta_max; + } + + void CMinMaxStats::set_static_val(float value_delta_max, int c_visit, float c_scale) { + this->value_delta_max = value_delta_max; + this->c_visit = c_visit; + this->c_scale = c_scale; + } + + void CMinMaxStats::update(float value) { + this->maximum = std::max(this->maximum, value); + this->minimum = std::min(this->minimum, value); + } + + void CMinMaxStats::clear() { + this->maximum = FLOAT_MIN; + this->minimum = FLOAT_MAX; + } + + float CMinMaxStats::normalize(float value) { + float delta = this->maximum - this->minimum; + delta = std::max(delta, this->value_delta_max); + return (value - this->minimum) / delta; + } + + // ========== CMinMaxStatsList Implementation ========== + CMinMaxStatsList::CMinMaxStatsList() { + this->num = 0; + } + + CMinMaxStatsList::CMinMaxStatsList(int num) { + this->num = num; + for (int i = 0; i < num; ++i) { + this->stats_lst.push_back(CMinMaxStats()); + } + } + + CMinMaxStatsList::~CMinMaxStatsList() {} + + void CMinMaxStatsList::set_delta(float value_delta_max) { + for (int i = 0; i < this->num; ++i) { + this->stats_lst[i].set_delta(value_delta_max); + } + } + + void CMinMaxStatsList::set_static_val(float value_delta_max, int c_visit, float c_scale) { + for (int i = 0; i < this->num; ++i) { + this->stats_lst[i].set_static_val(value_delta_max, c_visit, c_scale); + } + } +} diff --git a/lzero/mcts/ctree/ctree_efficientzero_v2/lib/cminimax.h b/lzero/mcts/ctree/ctree_efficientzero_v2/lib/cminimax.h new file mode 100644 index 000000000..4debf35c9 --- /dev/null +++ b/lzero/mcts/ctree/ctree_efficientzero_v2/lib/cminimax.h @@ -0,0 +1,45 @@ +// C++11 + +#ifndef CMINIMAX_H +#define CMINIMAX_H + +#include +#include + +const float FLOAT_MAX = 1000000.0; +const float FLOAT_MIN = -FLOAT_MAX; +const float EPSILON = 0.000001; + +namespace tools { + + class CMinMaxStats { + public: + int c_visit; + float c_scale; + float maximum, minimum, value_delta_max; + + CMinMaxStats(); + ~CMinMaxStats(); + + void set_delta(float value_delta_max); + void set_static_val(float value_delta_max, int c_visit, float c_scale); + void update(float value); + void clear(); + float normalize(float value); + }; + + class CMinMaxStatsList { + public: + int num; + std::vector stats_lst; + + CMinMaxStatsList(); + CMinMaxStatsList(int num); + ~CMinMaxStatsList(); + + void set_delta(float value_delta_max); + void set_static_val(float value_delta_max, int c_visit, float c_scale); + }; +} + +#endif diff --git a/lzero/mcts/ctree/ctree_efficientzero_v2/lib/cnode.cpp b/lzero/mcts/ctree/ctree_efficientzero_v2/lib/cnode.cpp new file mode 100644 index 000000000..4a2afc645 --- /dev/null +++ b/lzero/mcts/ctree/ctree_efficientzero_v2/lib/cnode.cpp @@ -0,0 +1,1792 @@ +// C++11 + +#include +#include "cnode.h" +#include "cminimax.h" +#include +#include +#include +#include + +#ifdef _WIN32 +#include "..\..\common_lib\utils.cpp" +#else +#include "../../common_lib/utils.cpp" +#endif + + +namespace tree +{ + + CSearchResults::CSearchResults() + { + /* + Overview: + Initialization of CSearchResults, the default result number is set to 0. + */ + this->num = 0; + } + + CSearchResults::CSearchResults(int num) + { + /* + Overview: + Initialization of CSearchResults with result number. + */ + this->num = num; + for (int i = 0; i < num; ++i) + { + this->search_paths.push_back(std::vector()); + } + } + + CSearchResults::~CSearchResults() {} + + //********************************************************* + + CNode::CNode() + { + /* + Overview: + Initialization of CNode. + */ + this->prior = 0; + this->legal_actions = legal_actions; + + this->is_reset = 0; + this->visit_count = 0; + this->value_sum = 0; + this->best_action = -1; + this->to_play = 0; + this->value_prefix = 0.0; + this->parent_value_prefix = 0.0; + } + + CNode::CNode(float prior, std::vector &legal_actions) + { + /* + Overview: + Initialization of CNode with prior value and legal actions. + Arguments: + - prior: the prior value of this node. + - legal_actions: a vector of legal actions of this node. + */ + this->prior = prior; + this->legal_actions = legal_actions; + + this->is_reset = 0; + this->visit_count = 0; + this->value_sum = 0; + this->best_action = -1; + this->to_play = 0; + this->value_prefix = 0.0; + this->parent_value_prefix = 0.0; + this->current_latent_state_index = -1; + this->batch_index = -1; + } + + CNode::~CNode() {} + + void CNode::expand(int to_play, int current_latent_state_index, int batch_index, float value_prefix, const std::vector &policy_logits) + { + /* + Overview: + Expand the child nodes of the current node. + Arguments: + - to_play: which player to play the game in the current node. + - current_latent_state_index: the x/first index of hidden state vector of the current node, i.e. the search depth. + - batch_index: the y/second index of hidden state vector of the current node, i.e. the index of batch root node, its maximum is ``batch_size``/``env_num``. + - value_prefix: the value prefix of the current node. + - policy_logits: the policy logit of the child nodes. + */ + this->to_play = to_play; + this->current_latent_state_index = current_latent_state_index; + this->batch_index = batch_index; + this->value_prefix = value_prefix; + + int action_num = policy_logits.size(); + if (this->legal_actions.size() == 0) + { + for (int i = 0; i < action_num; ++i) + { + this->legal_actions.push_back(i); + } + } + float temp_policy; + float policy_sum = 0.0; + + #ifdef _WIN32 + // 创建动态数组 + float* policy = new float[action_num]; + #else + float policy[action_num]; + #endif + + float policy_max = FLOAT_MIN; + for (auto a : this->legal_actions) + { + if (policy_max < policy_logits[a]) + { + policy_max = policy_logits[a]; + } + } + + for (auto a : this->legal_actions) + { + temp_policy = exp(policy_logits[a] - policy_max); + policy_sum += temp_policy; + policy[a] = temp_policy; + } + + float prior; + for (auto a : this->legal_actions) + { + prior = policy[a] / policy_sum; + std::vector tmp_empty; + this->children[a] = CNode(prior, tmp_empty); // only for muzero/efficient zero, not support alphazero + } + #ifdef _WIN32 + // 释放数组内存 + delete[] policy; + #else + #endif + } + + void CNode::add_exploration_noise(float exploration_fraction, const std::vector &noises) + { + /* + Overview: + Add a noise to the prior of the child nodes. + Arguments: + - exploration_fraction: the fraction to add noise. + - noises: the vector of noises added to each child node. + */ + float noise, prior; + for (int i = 0; i < this->legal_actions.size(); ++i) + { + noise = noises[i]; + CNode *child = this->get_child(this->legal_actions[i]); + + prior = child->prior; + child->prior = prior * (1 - exploration_fraction) + noise * exploration_fraction; + } + } + + float CNode::compute_mean_q(int isRoot, float parent_q, float discount_factor) + { + /* + Overview: + Compute the mean q value of the current node. + Arguments: + - isRoot: whether the current node is a root node. + - parent_q: the q value of the parent node. + - discount_factor: the discount_factor of reward. + */ + float total_unsigned_q = 0.0; + int total_visits = 0; + float parent_value_prefix = this->value_prefix; + for (auto a : this->legal_actions) + { + CNode *child = this->get_child(a); + if (child->visit_count > 0) + { + float true_reward = child->value_prefix - parent_value_prefix; + if (this->is_reset == 1) + { + true_reward = child->value_prefix; + } + float qsa = true_reward + discount_factor * child->value(); + total_unsigned_q += qsa; + total_visits += 1; + } + } + + float mean_q = 0.0; + if (isRoot && total_visits > 0) + { + mean_q = (total_unsigned_q) / (total_visits); + } + else + { + mean_q = (parent_q + total_unsigned_q) / (total_visits + 1); + } + return mean_q; + } + + int CNode::expanded() + { + /* + Overview: + Return whether the current node is expanded. + */ + return this->children.size() > 0; + } + + float CNode::value() + { + /* + Overview: + Return the estimated value of the current tree. + Current implementation: uses value_sum / visit_count (traditional MCTS style). + Returns: + The mean value of all backpropagations through this node. + */ + float true_value = 0.0; + if (this->visit_count == 0) + { + return true_value; + } + else + { + true_value = this->value_sum / this->visit_count; + return true_value; + } + } + + float CNode::value_v2() + { + /* + Overview: + EfficientZero V2 original style value estimation using estimated_value_lst. + This method is NOT USED in current implementation, kept for reference. + + Difference from value(): + - value(): uses value_sum / visit_count (O(1) memory, cumulative sum) + - value_v2(): uses sum(estimated_value_lst) / len(estimated_value_lst) (O(n) memory, keeps history) + + Mathematical equivalence: + Both methods produce the same average value. + + Why keep estimated_value_lst? + 1. Reanalyze: Can re-evaluate nodes with new network + 2. Uncertainty estimation: Can compute std(estimated_value_lst) + 3. Value distribution: Can analyze value distribution + 4. Advanced features: Ensemble learning, distributed training + + Current status: + UNUSED - The estimated_value_lst is initialized but not populated. + To enable: uncomment the push_back lines in cbackpropagate(). + */ + if (this->estimated_value_lst.empty()) + { + return 0.0; + } + + float sum = 0.0; + for (float v : this->estimated_value_lst) + { + sum += v; + } + return sum / this->estimated_value_lst.size(); + } + + std::vector CNode::get_trajectory() + { + /* + Overview: + Find the current best trajectory starts from the current node. + Returns: + - traj: a vector of node index, which is the current best trajectory from this node. + */ + std::vector traj; + + CNode *node = this; + int best_action = node->best_action; + while (best_action >= 0) + { + traj.push_back(best_action); + + node = node->get_child(best_action); + best_action = node->best_action; + } + return traj; + } + + std::vector CNode::get_children_distribution() + { + /* + Overview: + Get the distribution of child nodes in the format of visit_count. + Returns: + - distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]). + */ + std::vector distribution; + if (this->expanded()) + { + for (auto a : this->legal_actions) + { + CNode *child = this->get_child(a); + distribution.push_back(child->visit_count); + } + } + return distribution; + } + + CNode *CNode::get_child(int action) + { + /* + Overview: + Get the child node corresponding to the input action. + Arguments: + - action: the action to get child. + */ + return &(this->children[action]); + } + + //********************************************************* + + CRoots::CRoots() + { + /* + Overview: + The initialization of CRoots. + */ + this->root_num = 0; + } + + CRoots::CRoots(int root_num, std::vector > &legal_actions_list) + { + /* + Overview: + The initialization of CRoots with root num and legal action lists. + Arguments: + - root_num: the number of the current root. + - legal_action_list: the vector of the legal action of this root. + */ + this->root_num = root_num; + this->legal_actions_list = legal_actions_list; + + for (int i = 0; i < root_num; ++i) + { + this->roots.push_back(CNode(0, this->legal_actions_list[i])); + } + } + + CRoots::~CRoots() {} + + void CRoots::prepare(float root_noise_weight, const std::vector > &noises, const std::vector &value_prefixs, const std::vector > &policies, std::vector &to_play_batch) + { + /* + Overview: + Expand the roots and add noises. + Arguments: + - root_noise_weight: the exploration fraction of roots + - noises: the vector of noise add to the roots. + - value_prefixs: the vector of value prefixs of each root. + - policies: the vector of policy logits of each root. + - to_play_batch: the vector of the player side of each root. + */ + for (int i = 0; i < this->root_num; ++i) + { + this->roots[i].expand(to_play_batch[i], 0, i, value_prefixs[i], policies[i]); + this->roots[i].add_exploration_noise(root_noise_weight, noises[i]); + this->roots[i].visit_count += 1; + + // Initialize selected_children_idx with all legal actions for Sequential Halving + this->roots[i].selected_children_idx.clear(); + for (int action : this->roots[i].legal_actions) { + this->roots[i].selected_children_idx.push_back(action); + } + } + } + + void CRoots::prepare_no_noise(const std::vector &value_prefixs, const std::vector > &policies, std::vector &to_play_batch) + { + /* + Overview: + Expand the roots without noise. + Arguments: + - value_prefixs: the vector of value prefixs of each root. + - policies: the vector of policy logits of each root. + - to_play_batch: the vector of the player side of each root. + */ + for (int i = 0; i < this->root_num; ++i) + { + this->roots[i].expand(to_play_batch[i], 0, i, value_prefixs[i], policies[i]); + this->roots[i].visit_count += 1; + + // Initialize selected_children_idx with all legal actions for Sequential Halving + this->roots[i].selected_children_idx.clear(); + for (int action : this->roots[i].legal_actions) { + this->roots[i].selected_children_idx.push_back(action); + } + } + } + + void CRoots::clear() + { + /* + Overview: + Clear the roots vector. + */ + this->roots.clear(); + } + + std::vector > CRoots::get_trajectories() + { + /* + Overview: + Find the current best trajectory starts from each root. + Returns: + - traj: a vector of node index, which is the current best trajectory from each root. + */ + std::vector > trajs; + trajs.reserve(this->root_num); + + for (int i = 0; i < this->root_num; ++i) + { + trajs.push_back(this->roots[i].get_trajectory()); + } + return trajs; + } + + std::vector > CRoots::get_distributions() + { + /* + Overview: + Get the children distribution of each root. + Returns: + - distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]). + */ + std::vector > distributions; + distributions.reserve(this->root_num); + + for (int i = 0; i < this->root_num; ++i) + { + distributions.push_back(this->roots[i].get_children_distribution()); + } + return distributions; + } + + std::vector CRoots::get_values() + { + /* + Overview: + Return the estimated value of each root. + */ + std::vector values; + for (int i = 0; i < this->root_num; ++i) + { + values.push_back(this->roots[i].value()); + } + return values; + } + + std::vector> CRoots::get_root_policies(tools::CMinMaxStatsList *min_max_stats_lst) + { + /* + Overview: + Get improved policies for each root based on MCTS search results. + The improved policy is computed as softmax(prior + transformed_Q). + Arguments: + - min_max_stats_lst: min-max statistics for Q value normalization. + Returns: + - policies: vector of policy distributions for each root. + */ + std::vector> policies; + policies.reserve(this->root_num); + + for (int i = 0; i < this->root_num; ++i) + { + // Get transformed completed Q values (with Sigma transform) + std::vector transformed_completed_Qs = + get_transformed_completed_Qs(&(this->roots[i]), min_max_stats_lst->stats_lst[i], 0); + + // Get improved policy based on Q values: softmax(prior + transformed_Q) + std::vector improved_policy = + this->roots[i].get_improved_policy(transformed_completed_Qs); + + policies.push_back(improved_policy); + } + return policies; + } + + std::vector CRoots::get_best_actions() + { + /* + Overview: + Get the best action for each root after Sequential Halving. + The best action is the first element in selected_children_idx, + which has the highest score (Gumbel noise + prior + transformed Q). + Returns: + - best_actions: vector of best action indices for each root. + */ + std::vector best_actions(this->root_num, -1); + + for (int i = 0; i < this->root_num; ++i) + { + // selected_children_idx[0] is the action with highest score after Sequential Halving + best_actions[i] = this->roots[i].selected_children_idx[0]; + } + return best_actions; + } + + //********************************************************* + // + void update_tree_q(CNode *root, tools::CMinMaxStats &min_max_stats, float discount_factor, int players) + { + /* + Overview: + Update the q value of the root and its child nodes. + Arguments: + - root: the root that update q value from. + - min_max_stats: a tool used to min-max normalize the q value. + - discount_factor: the discount factor of reward. + - players: the number of players. + */ + std::stack node_stack; + node_stack.push(root); + float parent_value_prefix = 0.0; + int is_reset = 0; + while (node_stack.size() > 0) + { + CNode *node = node_stack.top(); + node_stack.pop(); + + if (node != root) + { + // NOTE: in self-play-mode, value_prefix is not calculated according to the perspective of current player of node, + // but treated as 1 player, just for obtaining the true reward in the perspective of current player of node. + // true_reward = node.value_prefix - (- parent_value_prefix) + float true_reward = node->value_prefix - node->parent_value_prefix; + + if (is_reset == 1) + { + true_reward = node->value_prefix; + } + float qsa; + if (players == 1) + { + qsa = true_reward + discount_factor * node->value(); + } + else if (players == 2) + { + // TODO(pu): why only the last reward multiply the discount_factor? + qsa = true_reward + discount_factor * (-1) * node->value(); + } + + min_max_stats.update(qsa); + } + + for (auto a : node->legal_actions) + { + CNode *child = node->get_child(a); + if (child->expanded()) + { + child->parent_value_prefix = node->value_prefix; + node_stack.push(child); + } + } + + is_reset = node->is_reset; + } + } + + void cbackpropagate(std::vector &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount_factor) + { + /* + Overview: + Update the value sum and visit count of nodes along the search path. + Arguments: + - search_path: a vector of nodes on the search path. + - min_max_stats: a tool used to min-max normalize the q value. + - to_play: which player to play the game in the current node. + - value: the value to propagate along the search path. + - discount_factor: the discount factor of reward. + */ + assert(to_play == -1 || to_play == 1 || to_play == 2); + if (to_play == -1) + { + // for play-with-bot-mode + float bootstrap_value = value; + int path_len = search_path.size(); + for (int i = path_len - 1; i >= 0; --i) + { + CNode *node = search_path[i]; + + // ========== Current Implementation: value_sum (Traditional MCTS) ========== + node->value_sum += bootstrap_value; + node->visit_count += 1; + + // ========== Alternative Implementation: estimated_value_lst (EfficientZero V2 Original) ========== + // Uncomment the line below to use EfficientZero V2 original style: + // (node->estimated_value_lst).push_back(bootstrap_value); + // node->visit_count += 1; + // + // Then in value calculation, use: + // node->value_v2() instead of node->value() + // + // Difference: + // - value_sum: O(1) memory, stores cumulative sum + // - estimated_value_lst: O(n) memory, stores all values (enables reanalyze, uncertainty estimation) + // - Result: Both produce the same average value mathematically + // ============================================================================ + + float parent_value_prefix = 0.0; + int is_reset = 0; + if (i >= 1) + { + CNode *parent = search_path[i - 1]; + parent_value_prefix = parent->value_prefix; + is_reset = parent->is_reset; + } + + float true_reward = node->value_prefix - parent_value_prefix; + min_max_stats.update(true_reward + discount_factor * node->value()); + + if (is_reset == 1) + { + // parent is reset + true_reward = node->value_prefix; + } + + bootstrap_value = true_reward + discount_factor * bootstrap_value; + } + } + else + { + // for self-play-mode + float bootstrap_value = value; + int path_len = search_path.size(); + for (int i = path_len - 1; i >= 0; --i) + { + CNode *node = search_path[i]; + + // ========== Current Implementation: value_sum (Traditional MCTS) ========== + if (node->to_play == to_play) + { + node->value_sum += bootstrap_value; + } + else + { + node->value_sum += -bootstrap_value; + } + node->visit_count += 1; + + // ========== Alternative Implementation: estimated_value_lst (EfficientZero V2 Original) ========== + // For self-play mode with estimated_value_lst: + // if (node->to_play == to_play) + // { + // (node->estimated_value_lst).push_back(bootstrap_value); + // } + // else + // { + // (node->estimated_value_lst).push_back(-bootstrap_value); + // } + // node->visit_count += 1; + // ============================================================================ + + float parent_value_prefix = 0.0; + int is_reset = 0; + if (i >= 1) + { + CNode *parent = search_path[i - 1]; + parent_value_prefix = parent->value_prefix; + is_reset = parent->is_reset; + } + + // NOTE: in self-play-mode, value_prefix is not calculated according to the perspective of current player of node, + // but treated as 1 player, just for obtaining the true reward in the perspective of current player of node. + float true_reward = node->value_prefix - parent_value_prefix; + + min_max_stats.update(true_reward + discount_factor * node->value()); + + if (is_reset == 1) + { + // parent is reset + true_reward = node->value_prefix; + } + if (node->to_play == to_play) + { + bootstrap_value = -true_reward + discount_factor * bootstrap_value; + } + else + { + bootstrap_value = true_reward + discount_factor * bootstrap_value; + } + } + } + } + + void cbatch_backpropagate(int current_latent_state_index, float discount_factor, const std::vector &value_prefixs, const std::vector &values, const std::vector > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector is_reset_list, std::vector &to_play_batch) + { + /* + Overview: + Expand the nodes along the search path and update the infos. + Arguments: + - current_latent_state_index: The index of latent state of the leaf node in the search path. + - discount_factor: the discount factor of reward. + - value_prefixs: the value prefixs of nodes along the search path. + - values: the values to propagate along the search path. + - policies: the policy logits of nodes along the search path. + - min_max_stats: a tool used to min-max normalize the q value. + - results: the search results. + - is_reset_list: the vector of is_reset nodes along the search path, where is_reset represents for whether the parent value prefix needs to be reset. + - to_play_batch: the batch of which player is playing on this node. + */ + for (int i = 0; i < results.num; ++i) + { + results.nodes[i]->expand(to_play_batch[i], current_latent_state_index, i, value_prefixs[i], policies[i]); + // reset + results.nodes[i]->is_reset = is_reset_list[i]; + + cbackpropagate(results.search_paths[i], min_max_stats_lst->stats_lst[i], to_play_batch[i], values[i], discount_factor); + } + } + + void cbatch_backpropagate_with_reuse(int current_latent_state_index, float discount_factor, const std::vector &value_prefixs, const std::vector &values, const std::vector > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector is_reset_list, std::vector &to_play_batch, std::vector &no_inference_lst, std::vector &reuse_lst, std::vector &reuse_value_lst) + { + /* + Overview: + Expand the nodes along the search path and update the infos. + Arguments: + - current_latent_state_index: The index of latent state of the leaf node in the search path. + - discount_factor: the discount factor of reward. + - value_prefixs: the value prefixs of nodes along the search path. + - values: the values to propagate along the search path. + - policies: the policy logits of nodes along the search path. + - min_max_stats: a tool used to min-max normalize the q value. + - results: the search results. + - to_play_batch: the batch of which player is playing on this node. + - no_inference_lst: the list of the nodes which does not need to expand. + - reuse_lst: the list of the nodes which should use reuse-value to backpropagate. + - reuse_value_lst: the list of the reuse-value. + */ + int count_a = 0; + int count_b = 0; + int count_c = 0; + float value_propagate = 0; + for (int i = 0; i < results.num; ++i) + { + if (i == no_inference_lst[count_a]) + { + count_a = count_a + 1; + value_propagate = reuse_value_lst[i]; + } + else + { + results.nodes[i]->expand(to_play_batch[i], current_latent_state_index, count_b, value_prefixs[count_b], policies[count_b]); + if (i == reuse_lst[count_c]) + { + value_propagate = reuse_value_lst[i]; + count_c = count_c + 1; + } + else + { + value_propagate = values[count_b]; + } + count_b = count_b + 1; + } + results.nodes[i]->is_reset = is_reset_list[i]; + cbackpropagate(results.search_paths[i], min_max_stats_lst->stats_lst[i], to_play_batch[i], value_propagate, discount_factor); + } + } + + int cselect_child(CNode *root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players) + { + /* + Overview: + Select the child node of the roots according to ucb scores. + Arguments: + - root: the roots to select the child node. + - min_max_stats: a tool used to min-max normalize the score. + - pb_c_base: constants c2 in muzero. + - pb_c_init: constants c1 in muzero. + - disount_factor: the discount factor of reward. + - mean_q: the mean q value of the parent node. + - players: the number of players. + Returns: + - action: the action to select. + */ + float max_score = FLOAT_MIN; + const float epsilon = 0.000001; + std::vector max_index_lst; + for (auto a : root->legal_actions) + { + CNode *child = root->get_child(a); + float temp_score = cucb_score(child, min_max_stats, mean_q, root->is_reset, root->visit_count - 1, root->value_prefix, pb_c_base, pb_c_init, discount_factor, players); + + if (max_score < temp_score) + { + max_score = temp_score; + + max_index_lst.clear(); + max_index_lst.push_back(a); + } + else if (temp_score >= max_score - epsilon) + { + max_index_lst.push_back(a); + } + } + + int action = 0; + if (max_index_lst.size() > 0) + { + int rand_index = rand() % max_index_lst.size(); + action = max_index_lst[rand_index]; + } + return action; + } + + int cselect_root_child(CNode *root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players, int true_action, float reuse_value) + { + /* + Overview: + Select the child node of the roots according to ucb scores. + Arguments: + - root: the roots to select the child node. + - min_max_stats: a tool used to min-max normalize the score. + - pb_c_base: constants c2 in muzero. + - pb_c_init: constants c1 in muzero. + - disount_factor: the discount factor of reward. + - mean_q: the mean q value of the parent node. + - players: the number of players. + - true_action: the action chosen in the trajectory. + - reuse_value: the value obtained from the search of the next state in the trajectory. + Returns: + - action: the action to select. + */ + + float max_score = FLOAT_MIN; + const float epsilon = 0.000001; + std::vector max_index_lst; + for (auto a : root->legal_actions) + { + + CNode *child = root->get_child(a); + float temp_score = 0.0; + if (a == true_action) + { + temp_score = carm_score(child, min_max_stats, mean_q, root->is_reset, reuse_value, root->visit_count - 1, root->value_prefix, pb_c_base, pb_c_init, discount_factor, players); + } + else + { + temp_score = cucb_score(child, min_max_stats, mean_q, root->is_reset, root->visit_count - 1, root->value_prefix, pb_c_base, pb_c_init, discount_factor, players); + } + + if (max_score < temp_score) + { + max_score = temp_score; + + max_index_lst.clear(); + max_index_lst.push_back(a); + } + else if (temp_score >= max_score - epsilon) + { + max_index_lst.push_back(a); + } + } + + int action = 0; + if (max_index_lst.size() > 0) + { + int rand_index = rand() % max_index_lst.size(); + action = max_index_lst[rand_index]; + } + // printf("select root child ends"); + return action; + } + + float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, int is_reset, float total_children_visit_counts, float parent_value_prefix, float pb_c_base, float pb_c_init, float discount_factor, int players) + { + /* + Overview: + Compute the ucb score of the child. + Arguments: + - child: the child node to compute ucb score. + - min_max_stats: a tool used to min-max normalize the score. + - parent_mean_q: the mean q value of the parent node. + - is_reset: whether the value prefix needs to be reset. + - total_children_visit_counts: the total visit counts of the child nodes of the parent node. + - parent_value_prefix: the value prefix of parent node. + - pb_c_base: constants c2 in muzero. + - pb_c_init: constants c1 in muzero. + - disount_factor: the discount factor of reward. + - players: the number of players. + Returns: + - ucb_value: the ucb score of the child. + */ + float pb_c = 0.0, prior_score = 0.0, value_score = 0.0; + pb_c = log((total_children_visit_counts + pb_c_base + 1) / pb_c_base) + pb_c_init; + pb_c *= (sqrt(total_children_visit_counts) / (child->visit_count + 1)); + + prior_score = pb_c * child->prior; + if (child->visit_count == 0) + { + value_score = parent_mean_q; + } + else + { + float true_reward = child->value_prefix - parent_value_prefix; + if (is_reset == 1) + { + true_reward = child->value_prefix; + } + + if (players == 1) + { + value_score = true_reward + discount_factor * child->value(); + } + else if (players == 2) + { + value_score = true_reward + discount_factor * (-child->value()); + } + } + + value_score = min_max_stats.normalize(value_score); + + if (value_score < 0) + { + value_score = 0; + } + else if (value_score > 1) + { + value_score = 1; + } + + return prior_score + value_score; // ucb_value + } + + float carm_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, int is_reset, float reuse_value, float total_children_visit_counts, float parent_value_prefix, float pb_c_base, float pb_c_init, float discount_factor, int players) + { + /* + Overview: + Compute the ucb score of the child. + Arguments: + - child: the child node to compute ucb score. + - min_max_stats: a tool used to min-max normalize the score. + - parent_mean_q: the mean q value of the parent node. + - is_reset: whether the value prefix needs to be reset. + - total_children_visit_counts: the total visit counts of the child nodes of the parent node. + - parent_value_prefix: the value prefix of parent node. + - pb_c_base: constants c2 in muzero. + - pb_c_init: constants c1 in muzero. + - disount_factor: the discount factor of reward. + - players: the number of players. + Returns: + - ucb_value: the ucb score of the child. + */ + float pb_c = 0.0, prior_score = 0.0, value_score = 0.0; + pb_c = log((total_children_visit_counts + pb_c_base + 1) / pb_c_base) + pb_c_init; + pb_c *= (sqrt(total_children_visit_counts) / (child->visit_count + 1)); + + prior_score = pb_c * child->prior; + if (child->visit_count == 0) + { + value_score = parent_mean_q; + } + else + { + float true_reward = child->value_prefix - parent_value_prefix; + if (is_reset == 1) + { + true_reward = child->value_prefix; + } + + if (players == 1) + { + value_score = true_reward + discount_factor * reuse_value; + } + else if (players == 2) + { + value_score = true_reward + discount_factor * (-reuse_value); + } + } + + value_score = min_max_stats.normalize(value_score); + + if (value_score < 0) + { + value_score = 0; + } + else if (value_score > 1) + { + value_score = 1; + } + + float ucb_value = 0.0; + if (child->visit_count == 0) + { + ucb_value = prior_score + value_score; + } + else + { + ucb_value = value_score; + } + // printf("carmscore ends"); + return ucb_value; + } + + /** + * ===================================================================== + * 【EfficientZero V2】动作选择函数 + * ===================================================================== + * 根据当前节点类型和搜索进度采用不同的动作选择策略: + * 1. 根节点:Sequential Halving + 等量访问(Gumbel-Max 采样初始化) + * 2. 非根节点:改进策略 + 访问计数平衡 + * + * 参数: + * - node: 当前树节点(可能是根节点或内部节点) + * - min_max_stats: 最小最大值统计,用于Q值归一化 + * - num_simulations: 总模拟数(序列削减的总步数) + * - simulation_idx: 当前模拟索引(序列削减的当前阶段,从0开始) + * - gumble_noise: 该样本的Gumbel噪声向量(长度=动作数) + * - current_num_top_actions: 当前保留的顶部候选动作数量 + * + * 返回:选中的动作索引 + * ===================================================================== + */ + int select_action(CNode* node, tools::CMinMaxStats &min_max_stats, + int num_simulations, int simulation_idx, + const std::vector& gumble_noise, + int current_num_top_actions){ + + int action = -1; + int num_actions = node->legal_actions.size(); + + // Check if this is a root node by examining if selected_children_idx is populated + // (Root nodes go through Sequential Ha2lving and have selected_children_idx set) + bool is_root = !node->selected_children_idx.empty() || node->visit_count == 0; + + if(is_root){ + // ============================================================ + // 【根节点处理】Sequential Halving + 等量访问策略 + // ============================================================ + + if(simulation_idx == 0){ + // 【第一阶段 (simulation_idx == 0)】 + // 目的:基于Gumbel-Max采样初始化顶部候选动作集合 + + // 1. 获取根节点所有子节点的先验概率(来自网络) + std::vector children_prior = node->get_children_priors(); + + // 2. 计算初始得分 = Gumbel噪声 + log(先验) + // 这实现了Gumbel-Max采样:max_a[g_a + log(p_a)] + // Gumbel噪声用于探索,先验用于指导 + std::vector children_scores; + for(int a = 0; a < num_actions; ++a){ + // g_a: Gumbel噪声(从高斯分布生成) + // p_a: 网络输出的先验策略 + children_scores.push_back(gumble_noise[a] + children_prior[a]); + } + + // 3. 对得分进行降序排序,获取排序后的动作索引 + std::vector idx(children_scores.size()); + std::iota(idx.begin(), idx.end(), 0); // idx初始化为[0, 1, 2, ...] + std::sort(idx.begin(), idx.end(), + [&children_scores](size_t i1, size_t i2) { + return children_scores[i1] > children_scores[i2]; + }); + + // 4. 保存前K个最佳动作到节点中(作为候选集合) + // 【关键】这些候选在后续迭代中会被等量访问 + // 实现了序列削减的"筛选"阶段 + node->selected_children_idx.clear(); + for(int a = 0; a < current_num_top_actions; ++a){ + node->selected_children_idx.push_back(idx[a]); + } + } + + // 【后续阶段 (simulation_idx > 0)】 + // 使用等量访问策略:从候选集合中轮流选择动作 + // 目的:平衡评估各个候选动作,同时让高价值候选更早收敛 + action = node->do_equal_visit(num_simulations); + } + else{ + // ============================================================ + // 【非根节点处理】改进策略 + 访问计数平衡 + // ============================================================ + + // 1. 计算所有子节点的【变换后完整Q值】 + // 完整Q = R + γ*V (当前节点的累积奖励 + 折扣未来值) + // 通过min-max统计进行归一化到[0,1]范围内 + std::vector transformed_completed_Qs = + get_transformed_completed_Qs(node, min_max_stats, 0); + + // 2. 计算【改进策略】= softmax(transformed_Q) + // 这是MCTS搜索中学到的改进策略,优于网络初始策略 + std::vector improved_policy = + node->get_improved_policy(transformed_completed_Qs); + + // 3. 获取原始网络先验策略(用于对比或调试) + std::vector ori_policy = node->get_children_priors(); + + // 4. 获取每个子节点的访问次数(反映已探索程度) + std::vector children_visits = node->get_children_visits(); + + // 5. 计算每个动作的【选择得分】 + // 得分 = + // 其中:访问惩罚 = 子节点访问次数 / (1 + 父节点访问次数) + // + // 设计原理: + // - 改进策略项:偏向高Q值的动作(利用) + // - 访问惩罚项:偏向访问较少的动作(探索) + // - 除以父节点访问次数:动态平衡,搜索越深惩罚越大 + std::vector children_scores(num_actions, 0.0); + for(int a = 0; a < num_actions; ++a){ + float visit_penalty = children_visits[a] / (1.0f + float(node->visit_count)); + float score = improved_policy[a] - visit_penalty; + children_scores[a] = score; + } + + // 6. 选择得分最高的动作(贪心选择) + action = argmax(children_scores); + } + + return action; + } + + // ========== MuZero/UCB 风格的批量遍历(备份版本) ========== + void cbatch_traverse_ucb(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &virtual_to_play_batch) + { + /* + Overview: + Search node path from the roots using UCB/PUCT selection (MuZero style). + Arguments: + - roots: the roots that search from. + - pb_c_base: constants c2 in muzero. + - pb_c_init: constants c1 in muzero. + - disount_factor: the discount factor of reward. + - min_max_stats: a tool used to min-max normalize the score. + - results: the search results. + - virtual_to_play_batch: the batch of which player is playing on this node. + */ + // set seed + get_time_and_set_rand_seed(); + + int last_action = -1; + float parent_q = 0.0; + results.search_lens = std::vector(); + + int players = 0; + int largest_element = *max_element(virtual_to_play_batch.begin(), virtual_to_play_batch.end()); // 0 or 2 + if (largest_element == -1) + { + players = 1; + } + else + { + players = 2; + } + + for (int i = 0; i < results.num; ++i) + { + CNode *node = &(roots->roots[i]); + int is_root = 1; + int search_len = 0; + results.search_paths[i].push_back(node); + + while (node->expanded()) + { + float mean_q = node->compute_mean_q(is_root, parent_q, discount_factor); + is_root = 0; + parent_q = mean_q; + + int action = cselect_child(node, min_max_stats_lst->stats_lst[i], pb_c_base, pb_c_init, discount_factor, mean_q, players); + if (players > 1) + { + assert(virtual_to_play_batch[i] == 1 || virtual_to_play_batch[i] == 2); + if (virtual_to_play_batch[i] == 1) + { + virtual_to_play_batch[i] = 2; + } + else + { + virtual_to_play_batch[i] = 1; + } + } + + node->best_action = action; + // next + node = node->get_child(action); + last_action = action; + results.search_paths[i].push_back(node); + search_len += 1; + } + + CNode *parent = results.search_paths[i][results.search_paths[i].size() - 2]; + + results.latent_state_index_in_search_path.push_back(parent->current_latent_state_index); + results.latent_state_index_in_batch.push_back(parent->batch_index); + + results.last_actions.push_back(last_action); + results.search_lens.push_back(search_len); + results.nodes.push_back(node); + results.virtual_to_play_batchs.push_back(virtual_to_play_batch[i]); + } + } + + /** + * ===================================================================== + * 【EfficientZero V2 核心】批量树遍历函数 - cbatch_traverse + * ===================================================================== + * 该函数是 EfficientZero V2 MCTS 搜索的核心遍历算法,集成了 Sequential Halving: + * 1. 对批次中的每个样本进行树遍历 + * 2. 使用 select_action() 进行动作选择(集成 Sequential Halving) + * 3. 遍历直至到达叶节点(未扩展的节点) + * 4. 记录搜索路径、动作序列、叶节点等信息供后续回传使用 + * + * 参数说明: + * roots - 批量样本的根节点集合(CRoots对象数组) + * min_max_stats_lst - 批量min-max统计对象集合(用于Q值归一化) + * results - 搜索结果缓冲区(保存搜索路径、动作、叶节点等) + * num_simulations - 总模拟次数(序列削减的总步数) + * simulation_idx - 当前模拟索引(当前处于第几次迭代) + * gumble_noise - 批量Gumbel噪声矩阵(大小:batch_size × num_actions) + * current_num_top_actions - 当前保留的顶部候选动作数(随削减递减) + * virtual_to_play_batch - 虚拟玩家批次(用于两人游戏,单人游戏时为[-1,-1,...]) + * ===================================================================== + */ + void cbatch_traverse(CRoots *roots, + tools::CMinMaxStatsList *min_max_stats_lst, + CSearchResults &results, + int num_simulations, + int simulation_idx, + const std::vector>& gumble_noise, + int current_num_top_actions, + std::vector &virtual_to_play_batch) { + + // 初始化结果容器 + int last_action = -1; + results.search_lens = std::vector(); + + // 判断游戏类型(单人/双人) + int players = 0; + int largest_element = *max_element(virtual_to_play_batch.begin(), virtual_to_play_batch.end()); + if (largest_element == -1) { + players = 1; // 单人游戏 + } else { + players = 2; // 双人游戏 + } + + // 对批次中每个样本独立遍历 + for (int i = 0; i < results.num; ++i) { + CNode *node = &(roots->roots[i]); + int search_len = 0; + results.search_paths[i].push_back(node); + + // 从根节点遍历到叶节点 + while (node->expanded()) { + // 【关键】使用 select_action 替代 cselect_child + // 该函数内部集成了 Sequential Halving 逻辑: + // - 根节点:在 simulation_idx==0 时初始化候选集,后续等量访问 + // - 非根节点:使用改进策略 + 访问计数平衡 + int action = select_action( + node, + min_max_stats_lst->stats_lst[i], + num_simulations, // Sequential Halving 参数 + simulation_idx, // Sequential Halving 参数 + gumble_noise[i], // Sequential Halving 参数 + current_num_top_actions // Sequential Halving 参数 + ); + + // 【两人游戏兼容】切换玩家 + if (players > 1) { + assert(virtual_to_play_batch[i] == 1 || virtual_to_play_batch[i] == 2); + if (virtual_to_play_batch[i] == 1) { + virtual_to_play_batch[i] = 2; + } else { + virtual_to_play_batch[i] = 1; + } + } + + node->best_action = action; + node = node->get_child(action); + last_action = action; + results.search_paths[i].push_back(node); + search_len += 1; + } + + // 记录叶节点父节点的信息 + CNode *parent = results.search_paths[i][results.search_paths[i].size() - 2]; + + // 记录父节点的隐状态索引位置(后续用于神经网络推理定位隐状态) + results.latent_state_index_in_search_path.push_back(parent->current_latent_state_index); + results.latent_state_index_in_batch.push_back(parent->batch_index); + + // 记录此次搜索的其他信息 + results.last_actions.push_back(last_action); + results.search_lens.push_back(search_len); + results.nodes.push_back(node); + results.virtual_to_play_batchs.push_back(virtual_to_play_batch[i]); + } + } + + void cbatch_traverse_with_reuse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &virtual_to_play_batch, std::vector &true_action, std::vector &reuse_value) + { + /* + Overview: + Search node path from the roots. + Arguments: + - roots: the roots that search from. + - pb_c_base: constants c2 in muzero. + - pb_c_init: constants c1 in muzero. + - disount_factor: the discount factor of reward. + - min_max_stats: a tool used to min-max normalize the score. + - results: the search results. + - virtual_to_play_batch: the batch of which player is playing on this node. + - true_action: the action chosen in the trajectory. + - reuse_value: the value obtained from the search of the next state in the trajectory. + */ + // set seed + get_time_and_set_rand_seed(); + + int last_action = -1; + float parent_q = 0.0; + results.search_lens = std::vector(); + + int players = 0; + int largest_element = *max_element(virtual_to_play_batch.begin(), virtual_to_play_batch.end()); // 0 or 2 + if (largest_element == -1) + { + players = 1; + } + else + { + players = 2; + } + + for (int i = 0; i < results.num; ++i) + { + CNode *node = &(roots->roots[i]); + int is_root = 1; + int search_len = 0; + results.search_paths[i].push_back(node); + + while (node->expanded()) + { + float mean_q = node->compute_mean_q(is_root, parent_q, discount_factor); + parent_q = mean_q; + + int action = 0; + if (is_root) + { + action = cselect_root_child(node, min_max_stats_lst->stats_lst[i], pb_c_base, pb_c_init, discount_factor, mean_q, players, true_action[i], reuse_value[i]); + } + else + { + action = cselect_child(node, min_max_stats_lst->stats_lst[i], pb_c_base, pb_c_init, discount_factor, mean_q, players); + } + + if (players > 1) + { + assert(virtual_to_play_batch[i] == 1 || virtual_to_play_batch[i] == 2); + if (virtual_to_play_batch[i] == 1) + { + virtual_to_play_batch[i] = 2; + } + else + { + virtual_to_play_batch[i] = 1; + } + } + + node->best_action = action; + // next + node = node->get_child(action); + last_action = action; + results.search_paths[i].push_back(node); + search_len += 1; + + if(is_root && action == true_action[i]) + { + break; + } + + is_root = 0; + } + + if (node->expanded()) + { + results.latent_state_index_in_search_path.push_back(-1); + results.latent_state_index_in_batch.push_back(i); + + results.last_actions.push_back(last_action); + results.search_lens.push_back(search_len); + results.nodes.push_back(node); + results.virtual_to_play_batchs.push_back(virtual_to_play_batch[i]); + } + else + { + CNode *parent = results.search_paths[i][results.search_paths[i].size() - 2]; + + results.latent_state_index_in_search_path.push_back(parent->current_latent_state_index); + results.latent_state_index_in_batch.push_back(parent->batch_index); + + results.last_actions.push_back(last_action); + results.search_lens.push_back(search_len); + results.nodes.push_back(node); + results.virtual_to_play_batchs.push_back(virtual_to_play_batch[i]); + } + } + } + + // ==================== EfficientZero V2 Sequential Halving 实现 ==================== + + // Softmax 函数:将 logits 转换为概率分布 + std::vector softmax(const std::vector &logits){ + std::vector policy(logits.size(), 0.0); + + // 数值稳定:减去最大值防止溢出 + float max_logit = -1e9; + for(size_t a = 0; a < logits.size(); ++a){ + if(logits[a] > max_logit) max_logit = logits[a]; + } + + for(size_t a = 0; a < logits.size(); ++a){ + policy[a] = exp(logits[a] - max_logit); + } + + // 归一化 + float policy_sum = 0.0; + for(size_t a = 0; a < policy.size(); ++a){ + policy_sum += policy[a]; + } + for(size_t a = 0; a < policy.size(); ++a){ + policy[a] = policy[a] / policy_sum; + } + return policy; + } + + // 获取子节点的先验概率 + std::vector CNode::get_children_priors(){ + std::vector priors(this->legal_actions.size(), 0.0); + for(size_t i = 0; i < this->legal_actions.size(); ++i){ + int action = this->legal_actions[i]; + priors[i] = this->children[action].prior; + } + return priors; + } + + /** + * ===================================================================== + * 获取子节点访问次数向量 + * ===================================================================== + * 将该节点的所有子节点的访问次数合并为一个向量 + * 向量大小与合法动作数量相同 + * + * 返回:visit_count 向量,索引对应合法动作顺序 + * ===================================================================== + */ + std::vector CNode::get_children_visits(){ + // 创建与合法动作数量相同大小的向量,初始化为0 + std::vector visits(this->legal_actions.size(), 0); + + // 遍历所有合法动作 + for(size_t i = 0; i < this->legal_actions.size(); ++i){ + // 获取该动作的索引 + int action = this->legal_actions[i]; + + // 从子节点中获取访问次数 + visits[i] = this->children[action].visit_count; + } + + return visits; + } + + // 获取该节点的奖励 + float CNode::get_reward(){ + if(this->is_reset){ + // 如果是重置点,直接返回 value_prefix + return this->value_prefix; + } else { + // 否则返回与父节点的差值 + return this->value_prefix - this->parent_value_prefix; + } + } + + // 获取单个动作的 Q 值:Q(s,a) = R(a) + γ*V(s') + float CNode::get_qsa(int action){ + CNode* child = &(this->children[action]); + // Q(s,a) = R(a) + γ*V(s') + float qsa = child->get_reward() + this->discount * child->value(); + return qsa; + } + + // 计算混合值估计(用于未访问子节点的乐观估计) + float CNode::get_v_mix(){ + // 1. 获取当前节点的策略分布 + std::vector priors = this->get_children_priors(); + std::vector pi_lst = softmax(priors); + + float pi_sum = 0.0; // 已访问动作的策略概率和 + float pi_qsa_sum = 0.0; // π(a) * Q(a) 的加权和 + + // 2. 遍历所有合法动作,只统计已扩展的子节点 + for(size_t i = 0; i < this->legal_actions.size(); ++i){ + int action = this->legal_actions[i]; + if(this->children[action].expanded()){ + // 子节点已访问:累加策略概率 + pi_sum += pi_lst[i]; + // 累加加权 Q 值 + pi_qsa_sum += pi_lst[i] * this->get_qsa(action); + } + } + + // 3. 计算 v_mix + float v_mix = 0.0; + const float EPSILON = 0.000001; + + if(pi_sum < EPSILON) { + // 没有子节点被访问,直接用节点值 + v_mix = this->value(); + } + else{ + // v_mix = (1/(1+N)) * [V + N * Σ(π*Q) / Σ(π)] + v_mix = (1.0 / (1.0 + this->visit_count)) * (this->value() + this->visit_count * pi_qsa_sum / pi_sum); + } + + return v_mix; + } + + // 获取完整的Q值(包括已访问和未访问节点的估计) + std::vector CNode::get_completed_Q(tools::CMinMaxStats &min_max_stats, int to_normalize){ + // 1. 创建返回值向量 + std::vector completed_Qs(this->legal_actions.size(), 0.0); + + // 2. 计算混合值(用于未访问的子节点) + float v_mix = this->get_v_mix(); + + // 3. 对每个动作计算完整的 Q 值 + for(size_t i = 0; i < this->legal_actions.size(); ++i){ + int action = this->legal_actions[i]; + float Q = 0.0; + + // 判断子节点是否已扩展(访问过) + if(this->children[action].expanded()){ + // 子节点已访问:Q = R + γ*V + Q = this->get_qsa(action); + } + else { + // 子节点未访问:用 v_mix 乐观估计 + Q = v_mix; + } + + // 4. 根据 to_normalize 标志选择归一化方式 + if (to_normalize == 1) { + // 模式1:使用 min-max 统计进行归一化 + completed_Qs[i] = min_max_stats.normalize(Q); + if (completed_Qs[i] < 0.0) completed_Qs[i] = 0.0; + if (completed_Qs[i] > 1.0) completed_Qs[i] = 1.0; + } + else { + completed_Qs[i] = Q; + } + } + + // 5. 如果选择最终归一化模式(to_normalize == 2) + if (to_normalize == 2){ + float v_max = -1e9; + float v_min = 1e9; + for(float q : completed_Qs){ + if(q > v_max) v_max = q; + if(q < v_min) v_min = q; + } + + if(v_max > v_min){ + for(size_t i = 0; i < completed_Qs.size(); ++i){ + completed_Qs[i] = (completed_Qs[i] - v_min) / (v_max - v_min); + } + } + } + + return completed_Qs; + } + + // 获取改进后的策略(基于 Q 值) + std::vector CNode::get_improved_policy(const std::vector &transformed_completed_Qs){ + // 获取子节点的先验 + std::vector logits = this->get_children_priors(); + + // EfficientZero V2 改进:logits = prior + transformed_Q + for(size_t i = 0; i < logits.size(); ++i){ + logits[i] = logits[i] + transformed_completed_Qs[i]; + } + + // 转换为概率分布 + std::vector policy = softmax(logits); + return policy; + } + + /** + * ===================================================================== + * 【EfficientZero V2 根节点策略】等量访问函数 - do_equal_visit + * ===================================================================== + * 在 Sequential Halving 中,根节点采用等量访问策略来评估候选动作。 + * 该函数从已筛选的候选集合中选择访问次数最少的动作。 + * + * 算法原理: + * 1. 遍历 selected_children_idx 中所有候选动作 + * 2. 找到访问次数最少的动作 + * 3. 返回该动作的索引 + * + * 目的:确保在同一阶段中各个候选动作被公平地评估 + * + * 参数: + * num_simulations - 总模拟数(用于初始化最小值) + * + * 返回:访问次数最少的候选动作索引 + * ===================================================================== + */ + int CNode::do_equal_visit(int num_simulations){ + // 初始化最小访问次数为一个很大的值(num_simulations + 1) + // 这样确保任何实际的访问次数都会更小 + int min_visit_count = num_simulations + 1; + int action = -1; // 初始化为-1,表示未选择 + + // 遍历所有候选动作(在 selected_children_idx 中) + for(int selected_child_idx : this->selected_children_idx){ + // 获取该候选动作的访问次数 + int visit_count = (this->get_child(selected_child_idx))->visit_count; + + // 如果该动作的访问次数更少,更新最小访问次数和选中的动作 + if(visit_count < min_visit_count){ + action = selected_child_idx; + min_visit_count = visit_count; + } + } + + // 返回选中的动作(访问次数最少的候选) + return action; + } + + // ==================== 辅助函数 ==================== + + // 获取向量中的最大浮点数 + float max_float(const std::vector &arr){ + if(arr.empty()) return -1e9; + float max_val = arr[0]; + for(size_t i = 1; i < arr.size(); ++i){ + if(arr[i] > max_val) max_val = arr[i]; + } + return max_val; + } + + // 获取向量中的最小浮点数 + float min_float(const std::vector &arr){ + if(arr.empty()) return 1e9; + float min_val = arr[0]; + for(size_t i = 1; i < arr.size(); ++i){ + if(arr[i] < min_val) min_val = arr[i]; + } + return min_val; + } + + // 获取向量中的最大整数 + int max_int(const std::vector &arr){ + if(arr.empty()) return -1e9; + int max_val = arr[0]; + for(size_t i = 1; i < arr.size(); ++i){ + if(arr[i] > max_val) max_val = arr[i]; + } + return max_val; + } + + // 求浮点数向量的和 + float sum_float(const std::vector &arr){ + float res = 0.0; + for(float a : arr) res += a; + return res; + } + + // 求整数向量的和 + int sum_int(const std::vector &arr){ + int res = 0; + for(int a : arr) res += a; + return res; + } + + // 获取最大值的索引 + int argmax(const std::vector &arr){ + if(arr.empty()) return -1; + int index = 0; + float max_val = arr[0]; + for(size_t i = 1; i < arr.size(); ++i){ + if(arr[i] > max_val){ + max_val = arr[i]; + index = i; + } + } + return index; + } + + // 获取转换后的完整 Q 值(带 Sigma 变换) + std::vector get_transformed_completed_Qs(CNode* node, tools::CMinMaxStats &min_max_stats, int final){ + // 1. 获取完整Q值(根据 final 参数选择归一化模式) + int to_normalize = (final == 0) ? 1 : 2; + std::vector completed_Qs = node->get_completed_Q(min_max_stats, to_normalize); + + // 2. 计算最大访问数 + int max_child_visit_count = 0; + for(int action : node->legal_actions){ + if(node->children.count(action) > 0){ + int visit_count = node->children[action].visit_count; + if(visit_count > max_child_visit_count){ + max_child_visit_count = visit_count; + } + } + } + + // 3. Sigma 变换(缩放):Q' = (c_visit + max_visit) * c_scale * Q + for(size_t i = 0; i < completed_Qs.size(); ++i){ + completed_Qs[i] = (min_max_stats.c_visit + max_child_visit_count) * min_max_stats.c_scale * completed_Qs[i]; + } + + return completed_Qs; + } + + // Sequential Halving:逐步淘汰差的动作 + int sequential_halving(CNode* root, const std::vector& gumbel_noise, + tools::CMinMaxStats &min_max_stats, int current_phase, int current_num_top_actions){ + // 1. 获取子节点的先验概率和转换后的Q值 + std::vector children_prior = root->get_children_priors(); + std::vector transformed_completed_Qs = get_transformed_completed_Qs(root, min_max_stats, 0); + + // 获取当前已选的动作列表 + std::vector selected_children_idx = root->selected_children_idx; + std::vector children_scores; + + // 2. 计算分数:gumbel噪声 + 先验 + 转换后的Q值 + // 直接用 action 作为索引(action ∈ [0, num_actions-1]) + for(int action : selected_children_idx){ + float score = gumbel_noise[action] + children_prior[action] + transformed_completed_Qs[action]; + children_scores.push_back(score); + } + + // 3. 创建索引数组并初始化为 [0, 1, 2, ...] + std::vector idx(children_scores.size()); + std::iota(idx.begin(), idx.end(), 0); + + // 4. 按分数从高到低排序索引 + std::sort(idx.begin(), idx.end(), + [&children_scores](size_t index_1, size_t index_2) { + return children_scores[index_1] > children_scores[index_2]; + }); + + // 5. 清空已选动作,只保留分数最高的 top-m 个 + root->selected_children_idx.clear(); + int keep_count = std::min(current_num_top_actions, (int)selected_children_idx.size()); + for(int i = 0; i < keep_count; ++i){ + root->selected_children_idx.push_back(selected_children_idx[idx[i]]); + } + + // 6. 返回分数最高的动作 + int best_action = root->selected_children_idx[0]; + return best_action; + } + + // 批量 Sequential Halving:对多个搜索树进行动作选择 + std::vector c_batch_sequential_halving(CRoots *roots, const std::vector>& gumbel_noises, + tools::CMinMaxStatsList *min_max_stats_lst, + int current_phase, int current_num_top_actions){ + std::vector best_actions(roots->root_num, -1); + + for(int i = 0; i < roots->root_num; ++i){ + int action = sequential_halving(&(roots->roots[i]), gumbel_noises[i], + min_max_stats_lst->stats_lst[i], current_phase, current_num_top_actions); + best_actions[i] = action; + } + + return best_actions; + } +} \ No newline at end of file diff --git a/lzero/mcts/ctree/ctree_efficientzero_v2/lib/cnode.h b/lzero/mcts/ctree/ctree_efficientzero_v2/lib/cnode.h new file mode 100644 index 000000000..491e93001 --- /dev/null +++ b/lzero/mcts/ctree/ctree_efficientzero_v2/lib/cnode.h @@ -0,0 +1,134 @@ +// C++11 + +#ifndef CNODE_H +#define CNODE_H + +#include "cminimax.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +const int DEBUG_MODE = 0; + +namespace tree { + class CNode { + public: + int visit_count, to_play, current_latent_state_index, batch_index, best_action, is_reset; + float value_prefix, prior, value_sum; + float parent_value_prefix; + std::vector children_index; + std::map children; + + std::vector legal_actions; + + // ========== V2 新增字段 ========== + std::vector selected_children_idx; // Sequential Halving 选中的动作 + std::vector estimated_value_lst; // 多值估计列表(EfficientZero V2 原版风格,未启用) + float discount; // 折扣因子(V2 用) + + CNode(); + CNode(float prior, std::vector &legal_actions); + ~CNode(); + + void expand(int to_play, int current_latent_state_index, int batch_index, float value_prefix, const std::vector &policy_logits); + void add_exploration_noise(float exploration_fraction, const std::vector &noises); + float compute_mean_q(int isRoot, float parent_q, float discount_factor); + void print_out(); + + int expanded(); + + float value(); + float value_v2(); // EfficientZero V2 原版风格的值估计(基于 estimated_value_lst,未启用) + + std::vector get_trajectory(); + std::vector get_children_distribution(); + CNode* get_child(int action); + + // ========== EfficientZero V2 新增方法 ========== + std::vector get_children_priors(); + std::vector get_children_visits(); // 获取子节点访问次数 + float get_reward(); + float get_qsa(int action); + float get_v_mix(); + std::vector get_completed_Q(tools::CMinMaxStats &min_max_stats, int to_normalize); + std::vector get_improved_policy(const std::vector &transformed_completed_Qs); + int do_equal_visit(int num_simulations); // Sequential Halving 等量访问策略 + }; + + class CRoots{ + public: + int root_num; + std::vector roots; + std::vector > legal_actions_list; + + CRoots(); + CRoots(int root_num, std::vector > &legal_actions_list); + ~CRoots(); + + void prepare(float root_noise_weight, const std::vector > &noises, const std::vector &value_prefixs, const std::vector > &policies, std::vector &to_play_batch); + void prepare_no_noise(const std::vector &value_prefixs, const std::vector > &policies, std::vector &to_play_batch); + void clear(); + std::vector > get_trajectories(); + std::vector > get_distributions(); + std::vector get_values(); + std::vector > get_root_policies(tools::CMinMaxStatsList *min_max_stats_lst); + std::vector get_best_actions(); + CNode* get_root(int index); + }; + + class CSearchResults{ + public: + int num; + std::vector latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, search_lens; + std::vector virtual_to_play_batchs; + std::vector nodes; + std::vector > search_paths; + + CSearchResults(); + CSearchResults(int num); + ~CSearchResults(); + + }; + + + //********************************************************* + void update_tree_q(CNode* root, tools::CMinMaxStats &min_max_stats, float discount_factor, int players); + void cbackpropagate(std::vector &search_path, tools::CMinMaxStats &min_max_stats, int to_play, float value, float discount_factor); + void cbatch_backpropagate(int current_latent_state_index, float discount_factor, const std::vector &value_prefixs, const std::vector &values, const std::vector > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector is_reset_list, std::vector &to_play_batch); + void cbatch_backpropagate_with_reuse(int current_latent_state_index, float discount_factor, const std::vector &value_prefixs, const std::vector &values, const std::vector > &policies, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector is_reset_list, std::vector &to_play_batch, std::vector &no_inference_lst, std::vector &reuse_lst, std::vector &reuse_value_lst); + int cselect_child(CNode* root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players); + int cselect_root_child(CNode *root, tools::CMinMaxStats &min_max_stats, int pb_c_base, float pb_c_init, float discount_factor, float mean_q, int players, int true_action, float reuse_value); + float cucb_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, int is_reset, float total_children_visit_counts, float parent_value_prefix, float pb_c_base, float pb_c_init, float discount_factor, int players); + float carm_score(CNode *child, tools::CMinMaxStats &min_max_stats, float parent_mean_q, int is_reset, float reuse_value, float total_children_visit_counts, float parent_value_prefix, float pb_c_base, float pb_c_init, float discount_factor, int players); + + // ========== MuZero/UCB 风格的遍历(备份) ========== + void cbatch_traverse_ucb(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &virtual_to_play_batch); + void cbatch_traverse_with_reuse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, std::vector &virtual_to_play_batch, std::vector &true_action, std::vector &reuse_value); + + // ========== EfficientZero V2 风格的遍历(Sequential Halving 集成) ========== + void cbatch_traverse(CRoots *roots, tools::CMinMaxStatsList *min_max_stats_lst, CSearchResults &results, + int num_simulations, int simulation_idx, const std::vector>& gumble_noise, + int current_num_top_actions, std::vector &virtual_to_play_batch); + + // ========== EfficientZero V2 Sequential Halving 相关 ========== + std::vector softmax(const std::vector &logits); + std::vector get_transformed_completed_Qs(CNode* node, tools::CMinMaxStats &min_max_stats, int final); + int sequential_halving(CNode* root, const std::vector& gumbel_noise, tools::CMinMaxStats &min_max_stats, int current_phase, int current_num_top_actions); + std::vector c_batch_sequential_halving(CRoots *roots, const std::vector>& gumbel_noises, tools::CMinMaxStatsList *min_max_stats_lst, int current_phase, int current_num_top_actions); + + // 辅助函数 + float max_float(const std::vector &arr); + float min_float(const std::vector &arr); + int max_int(const std::vector &arr); + float sum_float(const std::vector &arr); + int sum_int(const std::vector &arr); + int argmax(const std::vector &arr); +} + +#endif \ No newline at end of file diff --git a/lzero/mcts/ctree/ctree_efficientzero_v2/test_batch_sequential_halving b/lzero/mcts/ctree/ctree_efficientzero_v2/test_batch_sequential_halving new file mode 100755 index 000000000..b308acee9 Binary files /dev/null and b/lzero/mcts/ctree/ctree_efficientzero_v2/test_batch_sequential_halving differ diff --git a/lzero/mcts/ctree/ctree_efficientzero_v2/test_batch_sequential_halving.py b/lzero/mcts/ctree/ctree_efficientzero_v2/test_batch_sequential_halving.py new file mode 100644 index 000000000..3c0d6045a --- /dev/null +++ b/lzero/mcts/ctree/ctree_efficientzero_v2/test_batch_sequential_halving.py @@ -0,0 +1,259 @@ +""" +Test script for c_batch_sequential_halving +Tests the Sequential Halving algorithm with batch processing +""" + +import sys +sys.path.insert(0, '/mnt/shared-storage-user/tangjia/eff/eff_orign/LightZero') + +import numpy as np +from lzero.mcts.ctree.ctree_efficientzero_v2 import ez_tree as tree_efficientzero_v2 + + +def test_sequential_halving_phase_reduction(): + """Test that Sequential Halving correctly reduces action space across phases""" + print("\n" + "="*70) + print("[Test 1] Sequential Halving Phase Reduction") + print("="*70) + + try: + # Setup + batch_size = 2 + num_actions = 8 + num_simulations = 8 + + print(f"\nSetup:") + print(f" - batch_size: {batch_size}") + print(f" - num_actions: {num_actions}") + print(f" - num_simulations: {num_simulations}") + + # Create roots + legal_actions = [[0, 1, 2, 3, 4, 5, 6, 7] for _ in range(batch_size)] + roots = tree_efficientzero_v2.Roots(batch_size, legal_actions) + print(f" ✓ Roots created: root_num={roots.root_num}") + + # Prepare roots with initial policy and values + value_prefix_roots = [0.0 for _ in range(batch_size)] + policy_logits = [np.random.randn(num_actions).astype(np.float32) for _ in range(batch_size)] + noises = [np.random.randn(num_actions).astype(np.float32) for _ in range(batch_size)] + to_play_batch = [0, 0] + + roots.prepare(0.25, noises, value_prefix_roots, policy_logits, to_play_batch) + print(f" ✓ Roots prepared") + + # Create MinMaxStats + min_max_stats_lst = tree_efficientzero_v2.MinMaxStatsList(batch_size) + min_max_stats_lst.set_delta(0.01) + print(f" ✓ MinMaxStats created") + + # Generate Gumbel noises + gumbel_noises = [np.random.randn(num_actions).astype(np.float32) for _ in range(batch_size)] + print(f" ✓ Gumbel noises generated") + + # Test Sequential Halving across phases + print(f"\nSequential Halving Phase Progression:") + num_phases = 4 + expected_actions = [8, 4, 2, 1] + + for phase in range(num_phases): + current_num_top_actions = max(1, num_actions // (2 ** phase)) + + print(f"\n Phase {phase}:") + print(f" - current_num_top_actions: {current_num_top_actions}") + print(f" - Expected actions to keep: {expected_actions[phase]}") + + # Call Sequential Halving + best_actions = tree_efficientzero_v2.batch_sequential_halving( + roots, gumbel_noises, min_max_stats_lst, phase, current_num_top_actions + ) + + print(f" - Best actions selected: {best_actions}") + print(f" ✓ Phase {phase} completed") + + # Verify that the number of selected actions is correct + assert current_num_top_actions == expected_actions[phase], \ + f"Phase {phase}: Expected {expected_actions[phase]}, got {current_num_top_actions}" + + print(f"\n ✓ All phases completed successfully") + print(f" ✓ Actions correctly reduced: {expected_actions}") + + return True + + except Exception as e: + print(f"\n✗ Test failed: {e}") + import traceback + traceback.print_exc() + return False + + +def test_sequential_halving_action_scores(): + """Test that Sequential Halving ranks actions by score (gumbel + prior + Q)""" + print("\n" + "="*70) + print("[Test 2] Sequential Halving Action Scoring") + print("="*70) + + try: + batch_size = 1 + num_actions = 4 + + print(f"\nSetup:") + print(f" - batch_size: {batch_size}") + print(f" - num_actions: {num_actions} (small for easy inspection)") + + # Create roots + legal_actions = [[0, 1, 2, 3]] + roots = tree_efficientzero_v2.Roots(batch_size, legal_actions) + print(f" ✓ Roots created") + + # Create deterministic policy for easy prediction + value_prefix_roots = [0.0] + # Manually set policy logits so action 0 has highest prior + policy_logits = [np.array([2.0, 1.0, 0.5, 0.2], dtype=np.float32)] + noises = [np.zeros(num_actions, dtype=np.float32)] # Zero noise for deterministic test + to_play_batch = [0] + + roots.prepare(0.25, noises, value_prefix_roots, policy_logits, to_play_batch) + print(f" ✓ Roots prepared with deterministic policy") + print(f" - Policy logits: {policy_logits[0]}") + + # Create MinMaxStats + min_max_stats_lst = tree_efficientzero_v2.MinMaxStatsList(batch_size) + min_max_stats_lst.set_delta(0.01) + print(f" ✓ MinMaxStats created") + + # Zero Gumbel noise for deterministic test + gumbel_noises = [np.zeros(num_actions, dtype=np.float32)] + print(f" ✓ Gumbel noises set to zero (deterministic)") + + # Run Sequential Halving + print(f"\nRunning Sequential Halving (Phase 0, keep all 4):") + best_actions = tree_efficientzero_v2.batch_sequential_halving( + roots, gumbel_noises, min_max_stats_lst, 0, 4 + ) + + print(f" - Best action selected: {best_actions[0]}") + print(f" ✓ Sequential Halving completed") + + # With zero gumbel noise and no Q-values yet, the best action should be the one with highest prior + # which is action 0 (logit = 2.0) + expected_best = 0 + assert best_actions[0] == expected_best, \ + f"Expected best action {expected_best}, got {best_actions[0]}" + + print(f" ✓ Best action is action {best_actions[0]} (as expected with highest prior)") + + return True + + except Exception as e: + print(f"\n✗ Test failed: {e}") + import traceback + traceback.print_exc() + return False + + +def test_sequential_halving_batch_consistency(): + """Test that Sequential Halving works consistently across batch""" + print("\n" + "="*70) + print("[Test 3] Sequential Halving Batch Consistency") + print("="*70) + + try: + batch_size = 4 + num_actions = 6 + + print(f"\nSetup:") + print(f" - batch_size: {batch_size}") + print(f" - num_actions: {num_actions}") + + # Create roots + legal_actions = [[0, 1, 2, 3, 4, 5] for _ in range(batch_size)] + roots = tree_efficientzero_v2.Roots(batch_size, legal_actions) + print(f" ✓ Roots created for batch of {batch_size}") + + # Prepare roots with different policies for each batch + value_prefix_roots = [0.0 for _ in range(batch_size)] + policy_logits = [np.random.randn(num_actions).astype(np.float32) for _ in range(batch_size)] + noises = [np.random.randn(num_actions).astype(np.float32) for _ in range(batch_size)] + to_play_batch = list(range(batch_size)) + + roots.prepare(0.25, noises, value_prefix_roots, policy_logits, to_play_batch) + print(f" ✓ Roots prepared for all batch items") + + # Create MinMaxStats + min_max_stats_lst = tree_efficientzero_v2.MinMaxStatsList(batch_size) + min_max_stats_lst.set_delta(0.01) + + # Generate Gumbel noises + gumbel_noises = [np.random.randn(num_actions).astype(np.float32) for _ in range(batch_size)] + + # Test multiple phases + print(f"\nTesting Sequential Halving across phases:") + phases = [0, 1, 2] + expected_top_actions = [6, 3, 1] + + for phase in phases: + current_num_top_actions = expected_top_actions[phase] + + best_actions = tree_efficientzero_v2.batch_sequential_halving( + roots, gumbel_noises, min_max_stats_lst, phase, current_num_top_actions + ) + + print(f" Phase {phase} (keep {current_num_top_actions}):") + print(f" - Best actions per batch: {best_actions}") + assert len(best_actions) == batch_size, \ + f"Expected {batch_size} results, got {len(best_actions)}" + print(f" ✓ Returned {len(best_actions)} best actions (one per batch item)") + + print(f"\n ✓ Sequential Halving works consistently across batch") + + return True + + except Exception as e: + print(f"\n✗ Test failed: {e}") + import traceback + traceback.print_exc() + return False + + +def main(): + """Run all tests""" + print("\n" + "="*70) + print("c_batch_sequential_halving Test Suite") + print("="*70) + + results = [] + + # Test 1: Phase reduction + results.append(("Phase Reduction", test_sequential_halving_phase_reduction())) + + # Test 2: Action scoring + results.append(("Action Scoring", test_sequential_halving_action_scores())) + + # Test 3: Batch consistency + results.append(("Batch Consistency", test_sequential_halving_batch_consistency())) + + # Summary + print("\n" + "="*70) + print("Test Summary") + print("="*70) + + passed = sum(1 for _, result in results if result) + total = len(results) + + for test_name, result in results: + status = "✓ PASSED" if result else "✗ FAILED" + print(f" {test_name}: {status}") + + print(f"\nTotal: {passed}/{total} tests passed") + + if passed == total: + print("\n✓ All tests passed!") + return 0 + else: + print(f"\n✗ {total - passed} test(s) failed") + return 1 + + +if __name__ == '__main__': + exit_code = main() + sys.exit(exit_code) diff --git a/lzero/mcts/tree_search/__init__.py b/lzero/mcts/tree_search/__init__.py index 514daf813..3f50f8733 100644 --- a/lzero/mcts/tree_search/__init__.py +++ b/lzero/mcts/tree_search/__init__.py @@ -1,4 +1,4 @@ -from .mcts_ctree import MuZeroMCTSCtree, EfficientZeroMCTSCtree, GumbelMuZeroMCTSCtree, UniZeroMCTSCtree, MuZeroRNNFullObsMCTSCtree +from .mcts_ctree import MuZeroMCTSCtree, EfficientZeroMCTSCtree, GumbelMuZeroMCTSCtree, UniZeroMCTSCtree, MuZeroRNNFullObsMCTSCtree,EZV2MCTSCtree from .mcts_ctree_sampled import SampledEfficientZeroMCTSCtree, SampledMuZeroMCTSCtree, SampledUniZeroMCTSCtree from .mcts_ctree_stochastic import StochasticMuZeroMCTSCtree from .mcts_ptree import MuZeroMCTSPtree, EfficientZeroMCTSPtree diff --git a/lzero/mcts/tree_search/mcts_ctree.py b/lzero/mcts/tree_search/mcts_ctree.py index 4e238a6b3..b9549fa31 100644 --- a/lzero/mcts/tree_search/mcts_ctree.py +++ b/lzero/mcts/tree_search/mcts_ctree.py @@ -1,11 +1,12 @@ import copy -from typing import TYPE_CHECKING, List, Any, Union +from typing import TYPE_CHECKING, List, Any, Union, Tuple import numpy as np import torch from easydict import EasyDict from lzero.mcts.ctree.ctree_efficientzero import ez_tree as tree_efficientzero +from lzero.mcts.ctree.ctree_efficientzero_v2 import ez_tree as tree_efficientzero_v2 from lzero.mcts.ctree.ctree_gumbel_muzero import gmz_tree as tree_gumbel_muzero from lzero.mcts.ctree.ctree_muzero import mz_tree as tree_muzero from lzero.policy import DiscreteSupport, InverseScalarTransform, to_detach_cpu_numpy @@ -15,6 +16,7 @@ from lzero.mcts.ctree.ctree_muzero import mz_tree as mz_ctree from lzero.mcts.ctree.ctree_gumbel_muzero import gmz_tree as gmz_ctree +from line_profiler import line_profiler class UniZeroMCTSCtree(object): """ @@ -72,11 +74,11 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "m from lzero.mcts.ctree.ctree_muzero import mz_tree as ctree return ctree.Roots(active_collect_env_num, legal_actions) - # @profile + #@profile def search( self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int, - List[Any]], timestep: Union[int, List[Any]] - ) -> dict: + List[Any]], timestep: Union[int, List[Any]]=None, task_id=None + ) -> None: """ Overview: Perform Monte Carlo Tree Search (MCTS) for a batch of root nodes in parallel. @@ -137,7 +139,15 @@ def search( for ix, iy in zip(latent_state_index_in_search_path, latent_state_index_in_batch): latent_states.append(latent_state_batch_in_search_path[ix][iy]) - latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device) + try: + latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device) + except Exception as e: + print("="*20) + print(e) + print("roots:", roots, "latent_state_roots:", latent_state_roots) + print ("latent_state_roots.shape:", latent_state_roots.shape) + + # TODO: .long() is only for discrete action last_actions = torch.from_numpy(np.asarray(last_actions)).to(self._cfg.device).long() @@ -154,7 +164,23 @@ def search( # search_depth is used for rope in UniZero search_depth = results.get_search_len() # print(f'simulation_index:{simulation_index}, search_depth:{search_depth}, latent_state_index_in_search_path:{latent_state_index_in_search_path}') - network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth, timestep) + if timestep is None: + # for UniZero + if task_id is not None: + # multi task setting + network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth, task_id=task_id) + else: + # single task setting + network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth) + else: + # for UniZero + if task_id is not None: + # multi task setting + # network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth, timestep, task_id=task_id) + network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth, task_id=task_id) + else: + # single task setting + network_output = model.recurrent_inference(state_action_history, simulation_index, search_depth, timestep) network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) @@ -245,10 +271,10 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "m from lzero.mcts.ctree.ctree_muzero import mz_tree as ctree return ctree.Roots(active_collect_env_num, legal_actions) - # @profile + # #@profile def search( self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int, - List[Any]] + List[Any]], task_id=None ) -> None: """ Overview: @@ -318,6 +344,13 @@ def search( """ network_output = model.recurrent_inference(latent_states, last_actions) # for classic muzero + if task_id is not None: + # multi task setting + network_output = model.recurrent_inference(latent_states, last_actions, task_id=task_id) + else: + # single task setting + network_output = model.recurrent_inference(latent_states, last_actions) + network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) network_output.value = to_detach_cpu_numpy(self.value_inverse_scalar_transform_handle(network_output.value)) @@ -516,7 +549,7 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "e """ return tree_muzero.Roots(active_collect_env_num, legal_actions) - # @profile + # #@profile def search( self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], world_model_latent_history_roots: List[Any], to_play_batch: Union[int, List[Any]], ready_env_id=None, @@ -715,6 +748,8 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "e ..note:: The initialization is achieved by the ``Roots`` class from the ``ctree_efficientzero`` module. """ + # import pudb;pudb.set_trace() + return tree_efficientzero.Roots(active_collect_env_num, legal_actions) def search( @@ -735,6 +770,7 @@ def search( .. note:: The core functions ``batch_traverse`` and ``batch_backpropagate`` are implemented in C++. """ + with torch.no_grad(): model.eval() @@ -977,6 +1013,284 @@ def search_with_reuse( return length, average_infer +class EZV2MCTSCtree(object): + """ + Overview: + The C++ implementation of MCTS (batch format) for EfficientZero V2. \ + It completes the ``roots``and ``search`` methods by calling functions in module ``ctree_efficientzero_v2``, \ + which are implemented in C++. + Interfaces: + ``__init__``, ``roots``, ``search`` + + ..note:: + The benefit of searching for a batch of nodes at the same time is that \ + it can be parallelized during model inference, thus saving time. + + EfficientZero V2 uses Sequential Halving at the root node to progressively eliminate low-scoring actions. + """ + + config = dict( + # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of the search tree. + root_dirichlet_alpha=0.3, + # (float) The noise weight at the root node of the search tree. + root_noise_weight=0.25, + # (int) The base constant used in the PUCT formula for balancing exploration and exploitation during tree search. + pb_c_base=19652, + # (float) The initialization constant used in the PUCT formula for balancing exploration and exploitation during tree search. + pb_c_init=1.25, + # (float) The maximum change in value allowed during the backup step of the search tree update. + value_delta_max=0.01, + # (int) The initial number of top actions to consider in Sequential Halving. If None, equals num_actions. + num_top_actions=None, + ) + + @classmethod + def default_config(cls: type) -> EasyDict: + """ + Overview: + A class method that returns a default configuration in the form of an EasyDict object. + Returns: + - cfg (:obj:`EasyDict`): The dict of the default configuration. + """ + # Create a deep copy of the `config` attribute of the class. + cfg = EasyDict(copy.deepcopy(cls.config)) + # Add a new attribute `cfg_type` to the `cfg` object. + cfg.cfg_type = cls.__name__ + 'Dict' + return cfg + + def __init__(self, cfg: EasyDict = None) -> None: + """ + Overview: + Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key + in the default configuration, the user-provided value will override the default configuration. Otherwise, + the default configuration will be used. + Arguments: + - cfg (:obj:`EasyDict`): The configuration passed in by the user. + """ + # Get the default configuration. + default_config = self.default_config() + # Update the default configuration with the values provided by the user in ``cfg``. + default_config.update(cfg) + self._cfg = default_config + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) + self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) + self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) + + # EfficientZero V2: Sequential Halving phase management + self.num_top_actions = self._cfg.num_top_actions + + @classmethod + def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "ez_ctree.Roots": + """ + Overview: + Initializes a batch of roots to search parallelly later. + Arguments: + - root_num (:obj:`int`): the number of the roots in a batch. + - legal_action_list (:obj:`List[Any]`): the vector of the legal actions for the roots. + + ..note:: + The initialization is achieved by the ``Roots`` class from the ``ctree_efficientzero_v2`` module. + """ + + return tree_efficientzero_v2.Roots(active_collect_env_num, legal_actions) + + def search( + self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], + reward_hidden_state_roots: List[Any], to_play_batch: Union[int, List[Any]], + gumbel_noises + ) -> Tuple[List[List[float]], List[int]]: + """ + Overview: + Do MCTS for a batch of roots. Parallel in model inference. \ + Use C++ to implement the tree search. + Arguments: + - roots (:obj:`Any`): a batch of expanded root nodes. + - latent_state_roots (:obj:`list`): the hidden states of the roots. + - reward_hidden_state_roots (:obj:`list`): the value prefix hidden states in LSTM of the roots. + - model (:obj:`torch.nn.Module`): The model used for inference. + - to_play (:obj:`list`): the to_play list used in in self-play-mode board games. + + .. note:: + The core functions ``batch_traverse`` and ``batch_backpropagate`` are implemented in C++. + """ + with torch.no_grad(): + model.eval() + + + + # preparation some constant + batch_size = roots.num + pb_c_base, pb_c_init, discount_factor = self._cfg.pb_c_base, self._cfg.pb_c_init, self._cfg.discount_factor + + # the data storage of latent states: storing the latent state of all the nodes in one search. + latent_state_batch_in_search_path = [latent_state_roots] + # the data storage of value prefix hidden states in LSTM + # print(f"reward_hidden_state_roots[0]={reward_hidden_state_roots[0]}") + # print(f"reward_hidden_state_roots[1]={reward_hidden_state_roots[1]}") + reward_hidden_state_c_batch = [reward_hidden_state_roots[0]] + reward_hidden_state_h_batch = [reward_hidden_state_roots[1]] + + # minimax value storage + min_max_stats_lst = tree_efficientzero_v2.MinMaxStatsList(batch_size) + min_max_stats_lst.set_delta(self._cfg.value_delta_max) + + # ========== EfficientZero V2: Calculate Sequential Halving parameters ========== + legal_actions_list = roots.get_legal_actions() + num_actions = len(legal_actions_list[0]) if legal_actions_list and len(legal_actions_list) > 0 else 0 + num_top_actions = self.num_top_actions if self.num_top_actions is not None else num_actions + if num_top_actions > 0: + num_phases = int(np.ceil(np.log2(num_top_actions))) + sims_per_phase = max(1, self._cfg.num_simulations // num_phases) + else: + num_phases = 1 + sims_per_phase = self._cfg.num_simulations + + + + for simulation_index in range(self._cfg.num_simulations): + # In each simulation, we expanded a new node, so in one search, we have ``num_simulations`` num of nodes at most. + + latent_states = [] + hidden_states_c_reward = [] + hidden_states_h_reward = [] + + # prepare a result wrapper to transport results between python and c++ parts + results = tree_efficientzero_v2.ResultsWrapper(num=batch_size) + + # ========== EfficientZero V2: Sequential Halving at the root ========== + # Calculate current phase and number of top actions to keep based on simulation index + current_phase = min(simulation_index // sims_per_phase, num_phases - 1) + current_num_top_actions = max(1, num_top_actions // (2 ** current_phase)) + + # Call Sequential Halving to prune low-scoring root actions + tree_efficientzero_v2.batch_sequential_halving( + roots, gumbel_noises, min_max_stats_lst, current_phase, current_num_top_actions + ) + + # latent_state_index_in_search_path: the first index of leaf node states in latent_state_batch_in_search_path, i.e. is current_latent_state_index in one the search. + # latent_state_index_in_batch: the second index of leaf node states in latent_state_batch_in_search_path, i.e. the index in the batch, whose maximum is ``batch_size``. + # e.g. the latent state of the leaf node in (x, y) is latent_state_batch_in_search_path[x, y], where x is current_latent_state_index, y is batch_index. + # The index of value prefix hidden state of the leaf node is in the same manner. + """ + MCTS stage 1: Selection + Each simulation starts from the internal root state s0, and finishes when the simulation reaches a leaf node s_l. + Sequential Halving has already pruned low-scoring root actions before selection begins. + """ + if self._cfg.env_type == 'not_board_games': + latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, virtual_to_play_batch = tree_efficientzero_v2.batch_traverse( + roots, + min_max_stats_lst, + results, + self._cfg.num_simulations, + simulation_index, + gumbel_noises, + current_num_top_actions, + to_play_batch + ) + else: + # the ``to_play_batch`` is only used in board games, here we need to deepcopy it to avoid changing the original data. + latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, virtual_to_play_batch = tree_efficientzero_v2.batch_traverse( + roots, + min_max_stats_lst, + results, + self._cfg.num_simulations, + simulation_index, + gumbel_noises, + current_num_top_actions, + copy.deepcopy(to_play_batch) + ) + # obtain the search horizon for leaf nodes + search_lens = results.get_search_len() + + # Debug: print batch_traverse return values + print(f"[DEBUG sim {simulation_index}] batch_traverse returned:") + print(f" - latent_state_index_in_search_path: {latent_state_index_in_search_path}") + print(f" - latent_state_index_in_batch: {latent_state_index_in_batch}") + print(f" - last_actions: {last_actions}") + print(f" - search_lens: {search_lens}") + print(f" - latent_state_batch_in_search_path length: {len(latent_state_batch_in_search_path)}") + if len(latent_state_batch_in_search_path) > 0: + print(f" - latent_state_batch_in_search_path[0] type: {type(latent_state_batch_in_search_path[0])}") + if hasattr(latent_state_batch_in_search_path[0], 'shape'): + print(f" - latent_state_batch_in_search_path[0] shape: {latent_state_batch_in_search_path[0].shape}") + + # obtain the latent state for leaf node + for ix, iy in zip(latent_state_index_in_search_path, latent_state_index_in_batch): + print(f" - Accessing latent_state_batch_in_search_path[{ix}][{iy}]") + if ix >= len(latent_state_batch_in_search_path): + print(f" ERROR: ix={ix} out of range, len={len(latent_state_batch_in_search_path)}") + latent_states.append(latent_state_batch_in_search_path[ix][iy]) + hidden_states_c_reward.append(reward_hidden_state_c_batch[ix][0][iy]) + hidden_states_h_reward.append(reward_hidden_state_h_batch[ix][0][iy]) + + latent_states = torch.from_numpy(np.asarray(latent_states)).to(self._cfg.device) + hidden_states_c_reward = torch.from_numpy(np.asarray(hidden_states_c_reward)).to(self._cfg.device + ).unsqueeze(0) + hidden_states_h_reward = torch.from_numpy(np.asarray(hidden_states_h_reward)).to(self._cfg.device + ).unsqueeze(0) + # .long() is only for discrete action + last_actions = torch.from_numpy(np.asarray(last_actions)).to(self._cfg.device).long() + """ + MCTS stage 2: Expansion + At the final time-step l of the simulation, the next_latent_state and reward/value_prefix are computed by the dynamics function. + Then we calculate the policy_logits and value for the leaf node (next_latent_state) by the prediction function. (aka. evaluation) + MCTS stage 3: Backup + At the end of the simulation, the statistics along the trajectory are updated. + """ + network_output = model.recurrent_inference( + latent_states, (hidden_states_c_reward, hidden_states_h_reward), last_actions + ) + + network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) + network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) + network_output.value = to_detach_cpu_numpy(self.value_inverse_scalar_transform_handle(network_output.value)) + network_output.value_prefix = to_detach_cpu_numpy( + self.value_inverse_scalar_transform_handle(network_output.value_prefix)) + + network_output.reward_hidden_state = ( + network_output.reward_hidden_state[0].detach().cpu().numpy(), + network_output.reward_hidden_state[1].detach().cpu().numpy() + ) + + latent_state_batch_in_search_path.append(network_output.latent_state) + # tolist() is to be compatible with cpp datatype. + value_prefix_batch = network_output.value_prefix.reshape(-1).tolist() + value_batch = network_output.value.reshape(-1).tolist() + policy_logits_batch = network_output.policy_logits.tolist() + + reward_latent_state_batch = network_output.reward_hidden_state + # reset the hidden states in LSTM every ``lstm_horizon_len`` steps in one search. + # which enable the model only need to predict the value prefix in a range (e.g.: [s0,...,s5]) + assert self._cfg.lstm_horizon_len > 0 + reset_idx = (np.array(search_lens) % self._cfg.lstm_horizon_len == 0) + assert len(reset_idx) == batch_size + reward_latent_state_batch[0][:, reset_idx, :] = 0 + reward_latent_state_batch[1][:, reset_idx, :] = 0 + is_reset_list = reset_idx.astype(np.int32).tolist() + reward_hidden_state_c_batch.append(reward_latent_state_batch[0]) + reward_hidden_state_h_batch.append(reward_latent_state_batch[1]) + + # In ``batch_backpropagate()``, we first expand the leaf node using ``the policy_logits`` and + # ``reward`` predicted by the model, then perform backpropagation along the search path to update the + # statistics. + + # NOTE: simulation_index + 1 is very important, which is the depth of the current leaf node. + current_latent_state_index = simulation_index + 1 + tree_efficientzero_v2.batch_backpropagate( + current_latent_state_index, discount_factor, value_prefix_batch, value_batch, policy_logits_batch, + min_max_stats_lst, results, is_reset_list, virtual_to_play_batch + ) + + # ========== EfficientZero V2: Return improved policies and best actions ========== + improved_policies = roots.get_root_policies(min_max_stats_lst) + best_actions = roots.get_best_actions() + # Convert to numpy arrays for easier manipulation + improved_policies = [np.array(p) for p in improved_policies] + return improved_policies, best_actions + + + class GumbelMuZeroMCTSCtree(object): """ Overview: diff --git a/lzero/policy/efficientzero_v2.py b/lzero/policy/efficientzero_v2.py new file mode 100644 index 000000000..fb8af3701 --- /dev/null +++ b/lzero/policy/efficientzero_v2.py @@ -0,0 +1,861 @@ +import copy +from typing import List, Dict, Any, Tuple, Union + +import numpy as np +import torch +import torch.optim as optim +from ding.model import model_wrap +from ding.torch_utils import to_tensor +from ding.utils import POLICY_REGISTRY +from torch.distributions import Categorical +from torch.nn import L1Loss + +from lzero.mcts import EZV2MCTSCtree as MCTSCtree +from lzero.mcts import EfficientZeroMCTSPtree as MCTSPtree +from lzero.model import ImageTransforms +from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ + DiscreteSupport, select_action, to_torch_float_tensor, ez_network_output_unpack, negative_cosine_similarity, \ + prepare_obs, \ + configure_optimizers +from lzero.policy.muzero import MuZeroPolicy + + +@POLICY_REGISTRY.register('efficientzero_v2') +class EfficientZeroV2Policy(MuZeroPolicy): + """ + Overview: + The policy class for EfficientZero proposed in the paper https://arxiv.org/abs/2111.00210. + """ + + # The default_config for EfficientZero policy. + config = dict( + model=dict( + # (str) The model type. For 1-dimensional vector obs, we use mlp model. For 3-dimensional image obs, we use conv model. + model_type='conv', # options={'mlp', 'conv'} + # (bool) If True, the action space of the environment is continuous, otherwise discrete. + continuous_action_space=False, + # (tuple) The stacked obs shape. + # observation_shape=(1, 96, 96), # if frame_stack_num=1 + observation_shape=(4, 96, 96), # if frame_stack_num=4 + # (bool) Whether to use the self-supervised learning loss. + self_supervised_learning_loss=True, + # (bool) Whether to use discrete support to represent categorical distribution for value/reward/value_prefix. + categorical_distribution=True, + # (int) The image channel in image observation. + image_channel=1, + # (int) The number of frames to stack together. + frame_stack_num=1, + # (tuple) The range of supports used in categorical distribution. + # These variables are only effective when ``model.categorical_distribution=True``. + reward_support_range=(-300., 301., 1.), + value_support_range=(-300., 301., 1.), + # (int) The hidden size in LSTM. + lstm_hidden_size=512, + # (bool) whether to learn bias in the last linear layer in value and policy head. + bias=True, + # (str) The type of action encoding. Options are ['one_hot', 'not_one_hot']. Default to 'one_hot'. + discrete_action_encoding_type='one_hot', + # (bool) whether to use res connection in dynamics. + res_connection_in_dynamics=True, + # (str) The type of normalization in MuZero model. Options are ['BN', 'LN']. Default to 'LN'. + norm_type='BN', + ), + # ****** common ****** + # (bool) Whether to use multi-gpu training. + multi_gpu=False, + # (bool) Whether to enable the sampled-based algorithm (e.g. Sampled EfficientZero) + # this variable is used in ``collector``. + sampled_algo=False, + # (bool) Whether to enable the gumbel-based algorithm (e.g. Gumbel Muzero) + gumbel_algo=True, + # (bool) Whether to use C++ MCTS in policy. If False, use Python implementation. + mcts_ctree=True, + # (bool) Whether to use cuda for network. + cuda=True, + # (int) The number of environments used in collecting data. + collector_env_num=8, + # (int) The number of environments used in evaluating policy. + evaluator_env_num=3, + # (str) The type of environment. The options are ['not_board_games', 'board_games']. + env_type='not_board_games', + # (str) The type of action space. Options are ['fixed_action_space', 'varied_action_space']. + action_type='fixed_action_space', + # (str) The type of battle mode. The options are ['play_with_bot_mode', 'self_play_mode']. + battle_mode='play_with_bot_mode', + # (bool) Whether to monitor extra statistics in tensorboard. + monitor_extra_statistics=True, + # (int) The transition number of one ``GameSegment``. + game_segment_length=200, + # (bool): Indicates whether to perform an offline evaluation of the checkpoint (ckpt). + # If set to True, the checkpoint will be evaluated after the training process is complete. + # IMPORTANT: Setting eval_offline to True requires configuring the saving of checkpoints to align with the evaluation frequency. + # This is done by setting the parameter learn.learner.hook.save_ckpt_after_iter to the same value as eval_freq in the train_muzero.py automatically. + eval_offline=False, + + # ****** observation ****** + # (bool) Whether to transform image to string to save memory. + transform2string=False, + # (bool) Whether to use gray scale image. + gray_scale=False, + # (bool) Whether to use data augmentation. + use_augmentation=False, + # (list) The style of augmentation. + augmentation=['shift', 'intensity'], + + # ******* learn ****** + # (bool) Whether to ignore the done flag in the training data. Typically, this value is set to False. + # However, for some environments with a fixed episode length, to ensure the accuracy of Q-value calculations, + # we should set it to True to avoid the influence of the done flag. + ignore_done=False, + # (int) How many updates(iterations) to train after collector's one collection. + # Bigger "update_per_collect" means bigger off-policy. + # collect data -> update policy-> collect data -> ... + # For different env, we have different episode_length, + # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor + # if we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.replay_ratio automatically. + update_per_collect=None, + # (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None. + replay_ratio=0.25, + # (int) Minibatch size for one gradient descent. + batch_size=256, + # (str) Optimizer for training policy network. ['SGD', 'Adam', 'AdamW'] + optim_type='SGD', + # (float) Learning rate for training policy network. Initial lr for manually decay schedule. + learning_rate=0.2, + # (int) Frequency of target network update. + target_update_freq=100, + # (float) Weight decay for training policy network. + weight_decay=1e-4, + # (float) One-order Momentum in optimizer, which stabilizes the training process (gradient direction). + momentum=0.9, + # (float) The maximum constraint value of gradient norm clipping. + grad_clip_value=10, + # (int) The number of episodes in each collecting stage. + n_episode=8, + # (float) the number of simulations in MCTS. + num_simulations=50, + # (float) Discount factor (gamma) for returns. + discount_factor=0.997, + # (int) The number of steps for calculating target q_value. + td_steps=5, + # (int) The number of unroll steps in dynamics network. + num_unroll_steps=5, + # (int) reset the hidden states in LSTM every ``lstm_horizon_len`` horizon steps. + lstm_horizon_len=5, + # (float) The weight of reward loss. + reward_loss_weight=1, + # (float) The weight of value loss. + value_loss_weight=0.25, + # (float) The weight of policy loss. + policy_loss_weight=1, + # (float) The weight of ssl (self-supervised learning) loss. + ssl_loss_weight=2, + # (bool) Whether to use piecewise constant learning rate decay. + # i.e. lr: 0.2 -> 0.02 -> 0.002 + piecewise_decay_lr_scheduler=True, + # (int) The number of final training iterations to control lr decay, which is only used for manually decay. + threshold_training_steps_for_final_lr=int(5e4), + # (int) The number of final training iterations to control temperature, which is only used for manually decay. + threshold_training_steps_for_final_temperature=int(1e5), + # (bool) Whether to use manually decayed temperature. + # i.e. temperature: 1 -> 0.5 -> 0.25 + manual_temperature_decay=False, + # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. + # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. + fixed_temperature_value=0.25, + # (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048. + use_ture_chance_label_in_chance_encoder=False, + # (bool) Whether to add noise to roots during reanalyze process. + reanalyze_noise=True, + # (bool) Whether to reuse the root value between batch searches. + reuse_search=False, + # (bool) whether to use the pure policy to collect data. If False, use the MCTS guided with policy. + collect_with_pure_policy=False, + + # ****** Priority ****** + # (bool) Whether to use priority when sampling training data from the buffer. + use_priority=False, + # (float) The degree of prioritization to use. A value of 0 means no prioritization, + # while a value of 1 means full prioritization. + priority_prob_alpha=0.6, + # (float) The degree of correction to use. A value of 0 means no correction, + # while a value of 1 means full correction. + priority_prob_beta=0.4, + + # ****** UCB ****** + # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of the search tree. + root_dirichlet_alpha=0.3, + # (float) The noise weight at the root node of the search tree. + root_noise_weight=0.25, + + # ****** Explore by random collect ****** + # (int) The number of episodes to collect data randomly before training. + random_collect_episode_num=0, + + # ****** Explore by eps greedy ****** + eps=dict( + # (bool) Whether to use eps greedy exploration in collecting data. + eps_greedy_exploration_in_collect=False, + # (str) The type of decaying epsilon. Options are 'linear', 'exp'. + type='linear', + # (float) The start value of eps. + start=1., + # (float) The end value of eps. + end=0.05, + # (int) The decay steps from start to end eps. + decay=int(1e5), + ), + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default model setting. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): model name and model import_names. + - model_type (:obj:`str`): The model type used in this algorithm, which is registered in ModelRegistry. + - import_names (:obj:`List[str]`): The model class path list used in this algorithm. + .. note:: + The user can define and use customized network model but must obey the same interface definition indicated \ + by import_names path. For EfficientZero, ``lzero.model.efficientzero_model.EfficientZeroModel`` + """ + if self._cfg.model.model_type == "conv": + return 'EfficientZeroModel', ['lzero.model.efficientzero_model'] + elif self._cfg.model.model_type == "mlp": + return 'EfficientZeroModelMLP', ['lzero.model.efficientzero_model_mlp'] + else: + raise ValueError("model type {} is not supported".format(self._cfg.model.model_type)) + + def _init_learn(self) -> None: + """ + Overview: + Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. + """ + assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW'], self._cfg.optim_type + if self._cfg.optim_type == 'SGD': + self._optimizer = optim.SGD( + self._model.parameters(), + lr=self._cfg.learning_rate, + momentum=self._cfg.momentum, + weight_decay=self._cfg.weight_decay, + ) + + elif self._cfg.optim_type == 'Adam': + self._optimizer = optim.Adam( + self._model.parameters(), + lr=self._cfg.learning_rate, + weight_decay=self._cfg.weight_decay, + ) + elif self._cfg.optim_type == 'AdamW': + self._optimizer = configure_optimizers( + model=self._model, + weight_decay=self._cfg.weight_decay, + learning_rate=self._cfg.learning_rate, + device_type=self._cfg.device + ) + + if self._cfg.piecewise_decay_lr_scheduler: + from torch.optim.lr_scheduler import LambdaLR + max_step = self._cfg.threshold_training_steps_for_final_lr + # NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr. + lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01) # noqa + self.lr_scheduler = LambdaLR(self._optimizer, lr_lambda=lr_lambda) + + # use model_wrapper for specialized demands of different modes + self._target_model = copy.deepcopy(self._model) + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='assign', + update_kwargs={'freq': self._cfg.target_update_freq} + ) + self._learn_model = self._model + + if self._cfg.use_augmentation: + self.image_transforms = ImageTransforms( + self._cfg.augmentation, + image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) + ) + self.value_support = DiscreteSupport(*self._cfg.model.value_support_range, self._cfg.device) + self.reward_support = DiscreteSupport(*self._cfg.model.reward_support_range, self._cfg.device) + assert self.value_support.size == self._learn_model.value_support_size # if these assertions fails, somebody introduced... + assert self.reward_support.size == self._learn_model.reward_support_size # ...incoherence between policy and model + self.value_inverse_scalar_transform_handle = InverseScalarTransform(self.value_support, self._cfg.model.categorical_distribution) + self.reward_inverse_scalar_transform_handle = InverseScalarTransform(self.reward_support, self._cfg.model.categorical_distribution) + + def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: + """ + Overview: + The forward function for learning policy in learn mode, which is the core of the learning process. \ + The data is sampled from replay buffer. \ + The loss is calculated by the loss function and the loss is backpropagated to update the model. + Arguments: + - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. \ + The first tensor is the current_batch, the second tensor is the target_batch. + Returns: + - info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ + current learning loss and learning statistics. + """ + # import pudb;pudb.set_trace() + self._learn_model.train() + self._target_model.train() + + current_batch, target_batch = data + obs_batch_ori, action_batch, mask_batch, indices, weights, make_time = current_batch + # EfficientZero V2: unpack search_values for value mixing + target_value_prefix, target_value, target_policy, target_search_value = target_batch + + obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) + + # do augmentations + if self._cfg.use_augmentation: + obs_batch = self.image_transforms.transform(obs_batch) + obs_target_batch = self.image_transforms.transform(obs_target_batch) + + # shape: (batch_size, num_unroll_steps, action_dim) + # NOTE: .long(), in discrete action space. + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze(-1).long() + data_list = [ + mask_batch, + target_value_prefix.astype('float32'), + target_value.astype('float32'), target_policy, weights, + target_search_value.astype('float32') # EfficientZero V2: search values + ] + [mask_batch, target_value_prefix, target_value, target_policy, + weights, target_search_value] = to_torch_float_tensor(data_list, self._cfg.device) + + target_value_prefix = target_value_prefix.view(self._cfg.batch_size, -1) + target_value = target_value.view(self._cfg.batch_size, -1) + target_search_value = target_search_value.view(self._cfg.batch_size, -1) # EfficientZero V2 + assert obs_batch.size(0) == self._cfg.batch_size == target_value_prefix.size(0) + + # ============================================================== + # EfficientZero V2: Value Target Mixing + # ============================================================== + # Get value target type from config + value_target_type = getattr(self._cfg, 'value_target', 'bootstrap') # 'bootstrap', 'search', 'mixed' + + if value_target_type == 'mixed': + # Get the threshold for starting to use mixed values + start_use_mix_steps = getattr(self._cfg, 'start_use_mix_training_steps', 30000) + + if self._train_iteration < start_use_mix_steps: + # Early training: use bootstrap only + final_target_value = target_value + else: + # Later training: mix bootstrap and search + # Use a simple equal mixing (can be adjusted) + final_target_value = 0.5 * target_value + 0.5 * target_search_value + elif value_target_type == 'search': + # Use search values only + final_target_value = target_search_value + else: + # Default: use bootstrap values only + final_target_value = target_value + + # Use the final mixed target value for all subsequent computations + target_value = final_target_value + + # ``scalar_transform`` to transform the original value to the scaled value, + # i.e. h(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. + transformed_target_value_prefix = scalar_transform(target_value_prefix) + transformed_target_value = scalar_transform(target_value) + # transform a scalar to its categorical_distribution. After this transformation, each scalar is + # represented as the linear combination of its two adjacent supports. + target_value_prefix_categorical = phi_transform(self.reward_support, transformed_target_value_prefix) + target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + # ============================================================== + # the core initial_inference in EfficientZero policy. + # ============================================================== + network_output = self._learn_model.initial_inference(obs_batch) + # value_prefix shape: (batch_size, 10), the ``value_prefix`` at the first step is zero padding. + latent_state, value_prefix, reward_hidden_state, value, policy_logits = ez_network_output_unpack(network_output) + + # transform the scaled value or its categorical representation to its original value, + # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. + original_value = self.value_inverse_scalar_transform_handle(value) + + # Note: The following lines are just for debugging. + predicted_value_prefixs = [] + if self._cfg.monitor_extra_statistics: + predicted_values, predicted_policies = original_value.detach().cpu(), torch.softmax( + policy_logits, dim=1 + ).detach().cpu() + + # calculate the new priorities for each transition. + value_priority = L1Loss(reduction='none')(original_value.squeeze(-1), target_value[:, 0]) + 1e-6 + + prob = torch.softmax(policy_logits, dim=-1) + policy_entropy = -(prob * prob.log()).sum(-1).mean() + + # ============================================================== + # calculate policy and value loss for the first step. + # ============================================================== + policy_loss = cross_entropy_loss(policy_logits, target_policy[:, 0]) + + # Here we take the init hypothetical step k=0. + target_normalized_visit_count_init_step = target_policy[:, 0] + + # ******* NOTE: target_policy_entropy is only for debug. ****** + non_masked_indices = torch.nonzero(mask_batch[:, 0]).squeeze(-1) + # Check if there are any unmasked rows + if len(non_masked_indices) > 0: + target_normalized_visit_count_masked = torch.index_select( + target_normalized_visit_count_init_step, 0, non_masked_indices + ) + target_policy_entropy = -((target_normalized_visit_count_masked+1e-6) * (target_normalized_visit_count_masked+1e-6).log()).sum(-1).mean() + else: + # Set target_policy_entropy to log(|A|) if all rows are masked + target_policy_entropy = torch.log(torch.tensor(target_normalized_visit_count_init_step.shape[-1])) + + value_loss = cross_entropy_loss(value, target_value_categorical[:, 0]) + + value_prefix_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) + consistency_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device) + + # ============================================================== + # the core recurrent_inference in EfficientZero policy. + # ============================================================== + for step_k in range(self._cfg.num_unroll_steps): + # unroll with the dynamics function: predict the next ``latent_state``, ``reward_hidden_state``, + # `` value_prefix`` given current ``latent_state`` ``reward_hidden_state`` and ``action``. + # And then predict policy_logits and value with the prediction function. + network_output = self._learn_model.recurrent_inference( + latent_state, reward_hidden_state, action_batch[:, step_k] + ) + latent_state, value_prefix, reward_hidden_state, value, policy_logits = ez_network_output_unpack( + network_output + ) + + # transform the scaled value or its categorical representation to its original value, + # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. + original_value = self.value_inverse_scalar_transform_handle(value) + + # ============================================================== + # calculate consistency loss for the next ``num_unroll_steps`` unroll steps. + # ============================================================== + if self._cfg.ssl_loss_weight > 0: + # obtain the oracle latent states from representation function. + beg_index, end_index = self._get_target_obs_index_in_step_k(step_k) + network_output = self._learn_model.initial_inference(obs_target_batch[:, beg_index:end_index]) + + latent_state = to_tensor(latent_state) + representation_state = to_tensor(network_output.latent_state) + + # NOTE: no grad for the representation_state branch. + # import pdb; pdb.set_trace() + dynamic_proj = self._learn_model.project(latent_state, with_grad=True) + observation_proj = self._learn_model.project(representation_state, with_grad=False) + temp_loss = negative_cosine_similarity(dynamic_proj, observation_proj) * mask_batch[:, step_k] + + consistency_loss += temp_loss + + # NOTE: the target policy, target_value_categorical, target_value_prefix_categorical is calculated in + # game buffer now. + # ============================================================== + # calculate policy loss for the next ``num_unroll_steps`` unroll steps. + # NOTE: the +=. + # ============================================================== + policy_loss += cross_entropy_loss(policy_logits, target_policy[:, step_k + 1]) + + # Here we take the hypothetical step k = step_k + 1 + prob = torch.softmax(policy_logits, dim=-1) + policy_entropy += -(prob * prob.log()).sum(-1).mean() + + target_normalized_visit_count = target_policy[:, step_k + 1] + + # ******* NOTE: target_policy_entropy is only for debug. ****** + non_masked_indices = torch.nonzero(mask_batch[:, step_k + 1]).squeeze(-1) + # Check if there are any unmasked rows + if len(non_masked_indices) > 0: + target_normalized_visit_count_masked = torch.index_select( + target_normalized_visit_count, 0, non_masked_indices + ) + target_policy_entropy += -((target_normalized_visit_count_masked+1e-6) * (target_normalized_visit_count_masked+1e-6).log()).sum(-1).mean() + else: + # Set target_policy_entropy to log(|A|) if all rows are masked + target_policy_entropy += torch.log(torch.tensor(target_normalized_visit_count.shape[-1])) + + value_loss += cross_entropy_loss(value, target_value_categorical[:, step_k + 1]) + value_prefix_loss += cross_entropy_loss(value_prefix, target_value_prefix_categorical[:, step_k]) + + # reset hidden states every ``lstm_horizon_len`` unroll steps. + if (step_k + 1) % self._cfg.lstm_horizon_len == 0: + reward_hidden_state = ( + torch.zeros(1, self._cfg.batch_size, self._cfg.model.lstm_hidden_size).to(self._cfg.device), + torch.zeros(1, self._cfg.batch_size, self._cfg.model.lstm_hidden_size).to(self._cfg.device) + ) + + if self._cfg.monitor_extra_statistics: + original_value_prefixs = self.value_inverse_scalar_transform_handle(value_prefix) + original_value_prefixs_cpu = original_value_prefixs.detach().cpu() + predicted_values = torch.cat( + (predicted_values, self.value_inverse_scalar_transform_handle(value).detach().cpu()) + ) + predicted_value_prefixs.append(original_value_prefixs_cpu) + predicted_policies = torch.cat((predicted_policies, torch.softmax(policy_logits, dim=1).detach().cpu())) + + # ============================================================== + # the core learn model update step. + # ============================================================== + # weighted loss with masks (some invalid states which are out of trajectory.) + loss = ( + self._cfg.ssl_loss_weight * consistency_loss + self._cfg.policy_loss_weight * policy_loss + + self._cfg.value_loss_weight * value_loss + self._cfg.reward_loss_weight * value_prefix_loss + ) + weighted_total_loss = (weights * loss).mean() + # TODO(pu): test the effect of gradient scale. + gradient_scale = 1 / self._cfg.num_unroll_steps + weighted_total_loss.register_hook(lambda grad: grad * gradient_scale) + self._optimizer.zero_grad() + weighted_total_loss.backward() + if self._cfg.multi_gpu: + self.sync_gradients(self._learn_model) + total_grad_norm_before_clip = torch.nn.utils.clip_grad_norm_( + self._learn_model.parameters(), self._cfg.grad_clip_value + ) + self._optimizer.step() + if self._cfg.piecewise_decay_lr_scheduler: + self.lr_scheduler.step() + + # ============================================================== + # the core target model update step. + # ============================================================== + self._target_model.update(self._learn_model.state_dict()) + + if self._cfg.monitor_extra_statistics: + predicted_value_prefixs = torch.stack(predicted_value_prefixs).transpose(1, 0).squeeze(-1) + predicted_value_prefixs = predicted_value_prefixs.reshape(-1).unsqueeze(-1) + + return { + 'collect_mcts_temperature': self._collect_mcts_temperature, + 'collect_epsilon': self.collect_epsilon, + 'cur_lr': self._optimizer.param_groups[0]['lr'], + 'weighted_total_loss': weighted_total_loss.item(), + 'total_loss': loss.mean().item(), + 'policy_loss': policy_loss.mean().item(), + 'policy_entropy': policy_entropy.item() / (self._cfg.num_unroll_steps + 1), + 'target_policy_entropy': target_policy_entropy.item() / (self._cfg.num_unroll_steps + 1), + 'value_prefix_loss': value_prefix_loss.mean().item(), + 'value_loss': value_loss.mean().item(), + 'consistency_loss': consistency_loss.mean().item() / self._cfg.num_unroll_steps, + 'target_value_prefix': target_value_prefix.mean().item(), + 'target_value': target_value.mean().item(), + 'transformed_target_value_prefix': transformed_target_value_prefix.mean().item(), + 'transformed_target_value': transformed_target_value.mean().item(), + 'predicted_value_prefixs': predicted_value_prefixs.mean().item(), + 'predicted_values': predicted_values.mean().item(), + 'total_grad_norm_before_clip': total_grad_norm_before_clip.item(), + # ============================================================== + # priority related + # ============================================================== + 'value_priority': value_priority.mean().item(), + 'value_priority_orig': value_priority, # torch.tensor compatible with ddp settings + } + + def _init_collect(self) -> None: + """ + Overview: + Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. + """ + self._collect_model = self._model + if self._cfg.mcts_ctree: + self._mcts_collect = MCTSCtree(self._cfg) + else: + self._mcts_collect = MCTSPtree(self._cfg) + self._collect_mcts_temperature = 1 + self.collect_epsilon = 0.0 + + def _forward_collect( + self, + data: torch.Tensor, + action_mask: list = None, + temperature: float = 1, + to_play: List = [-1], + epsilon: float = 0.25, + ready_env_id: np.array = None, + **kwargs, + ): + """ + Overview: + The forward function for collecting data in collect mode. Use model to execute MCTS search. + Choosing the action through sampling during the collect mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - temperature (:obj:`float`): The temperature of the policy. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - temperature: :math:`(1, )`. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._collect_model.eval() + self._collect_mcts_temperature = temperature + self.collect_epsilon = epsilon + active_collect_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + output = {i: None for i in ready_env_id} + + + with torch.no_grad(): + # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} + network_output = self._collect_model.initial_inference(data) + latent_state_roots, value_prefix_roots, reward_hidden_state_roots, pred_values, policy_logits = ez_network_output_unpack( + network_output + ) + + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + reward_hidden_state_roots = ( + reward_hidden_state_roots[0].detach().cpu().numpy(), + reward_hidden_state_roots[1].detach().cpu().numpy() + ) + + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] + # if abs(sum(policy_logits[0]))>0.02: + + + if not self._cfg.collect_with_pure_policy: + # collect with MCTS guided with policy. + # Generate Gumbel noises for EfficientZero V2 Sequential Halving + gumbel_noises = [ + np.random.gumbel(0, 1, int(sum(action_mask[j]))).astype(np.float32)*self._cfg.root_noise_weight + for j in range(active_collect_env_num) + ] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_collect_env_num, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(active_collect_env_num, legal_actions) + # roots.prepare_no_noise(self._cfg.root_noise_weight, noises, value_prefix_roots, policy_logits, to_play) + roots.prepare_no_noise(value_prefix_roots, policy_logits, to_play) + # EfficientZero V2: search returns improved_policies and best_actions + improved_policies, best_actions = self._mcts_collect.search( + roots, self._collect_model, latent_state_roots, reward_hidden_state_roots, to_play, gumbel_noises + ) + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + + for i, env_id in enumerate(ready_env_id): + # EfficientZero V2: use improved_policy and best_action + improved_policy = improved_policies[i] + best_action_idx = best_actions[i] + value = roots_values[i] + distributions=roots_visit_count_distributions[i] + # EfficientZero V2: greedy action selection (Sequential Halving best action) + action_index_in_legal_action_set = best_action_idx + + if self._cfg.eps.eps_greedy_exploration_in_collect: + # eps-greedy exploration + if np.random.rand() < self.collect_epsilon: + action_index_in_legal_action_set = np.random.choice(len(legal_actions[i])) + + # Compute improved policy entropy for logging + visit_count_distribution_entropy = -np.sum( + improved_policy * np.log(improved_policy + 1e-9) + ) + + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + # import pudb;pudb.set_trace() + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions,#improved_policy, # EfficientZero V2: improved policy + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + "improved_policy_probs":improved_policy, + "roots_completed_value":value + } + else: + # collect with pure policy. + for i, env_id in enumerate(ready_env_id): + policy_values = torch.softmax(torch.tensor([policy_logits[i][a] for a in legal_actions[i]]), dim=0).tolist() + policy_values = policy_values / np.sum(policy_values) + action_index_in_legal_action_set = np.random.choice(len(legal_actions[i]), p=policy_values) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + output[env_id] = { + 'action': action, + 'searched_value': pred_values[i], + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + return output + + def _init_eval(self) -> None: + """ + Overview: + Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. + """ + self._eval_model = self._model + if self._cfg.mcts_ctree: + self._mcts_eval = MCTSCtree(self._cfg) + else: + self._mcts_eval = MCTSPtree(self._cfg) + + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: Union[int, List] = [-1], ready_env_id: np.array = None, **kwargs): + """ + Overview: + The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. + Choosing the action with the highest value (argmax) rather than sampling during the eval mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._eval_model.eval() + active_eval_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + output = {i: None for i in ready_env_id} + with torch.no_grad(): + # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} + network_output = self._eval_model.initial_inference(data) + latent_state_roots, value_prefix_roots, reward_hidden_state_roots, pred_values, policy_logits = ez_network_output_unpack( + network_output + ) + + if not self._eval_model.training: + # if not in training, obtain the scalars of the value/reward + pred_values = self.value_inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + latent_state_roots = latent_state_roots.detach().cpu().numpy() + reward_hidden_state_roots = ( + reward_hidden_state_roots[0].detach().cpu().numpy(), + reward_hidden_state_roots[1].detach().cpu().numpy() + ) + policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_eval_env_num, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(active_eval_env_num, legal_actions) + roots.prepare_no_noise(value_prefix_roots, policy_logits, to_play) + # Generate zero Gumbel noises for evaluation (no exploration) + gumbel_noises = [ + np.zeros(int(sum(action_mask[j])), dtype=np.float32) + for j in range(active_eval_env_num) + ] + # EfficientZero V2: search returns improved_policies and best_actions + improved_policies, best_actions = self._mcts_eval.search( + roots, self._eval_model, latent_state_roots, reward_hidden_state_roots, to_play, gumbel_noises + ) + + roots_values = roots.get_values() # shape: {list: batch_size} + + for i, env_id in enumerate(ready_env_id): + # EfficientZero V2: use improved_policy and best_action + improved_policy = improved_policies[i] + best_action_idx = best_actions[i] + value = roots_values[i] + + # EfficientZero V2: greedy action selection (Sequential Halving best action) + # Eval mode: always use best action (deterministic=True) + action_index_in_legal_action_set = best_action_idx + + # Compute improved policy entropy for logging + visit_count_distribution_entropy = -np.sum( + improved_policy * np.log(improved_policy + 1e-9) + ) + + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + output[env_id] = { + 'action': action, + 'visit_count_distributions': improved_policy, # EfficientZero V2: improved policy + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + + return output + + def _monitor_vars_learn(self) -> List[str]: + """ + Overview: + Register the variables to be monitored in learn mode. The registered variables will be logged in + tensorboard according to the return value ``_forward_learn``. + """ + return [ + 'collect_mcts_temperature', + 'cur_lr', + 'weighted_total_loss', + 'total_loss', + 'policy_loss', + 'policy_entropy', + 'target_policy_entropy', + 'value_prefix_loss', + 'value_loss', + 'consistency_loss', + 'value_priority', + 'target_value_prefix', + 'target_value', + 'predicted_value_prefixs', + 'predicted_values', + 'transformed_target_value_prefix', + 'transformed_target_value', + 'total_grad_norm_before_clip', + ] + + def _state_dict_learn(self) -> Dict[str, Any]: + """ + Overview: + Return the state_dict of learn mode, usually including model and optimizer. + Returns: + - state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring. + """ + return { + 'model': self._learn_model.state_dict(), + 'target_model': self._target_model.state_dict(), + 'optimizer': self._optimizer.state_dict(), + } + + def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None: + """ + Overview: + Load the state_dict variable into policy learn mode. + Arguments: + - state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before. + """ + self._learn_model.load_state_dict(state_dict['model']) + self._target_model.load_state_dict(state_dict['target_model']) + self._optimizer.load_state_dict(state_dict['optimizer']) + + def _process_transition(self, obs, policy_output, timestep): + # be compatible with DI-engine Policy class + pass + + def _get_train_sample(self, data): + # be compatible with DI-engine Policy class + pass diff --git a/zoo/atari/config/atari_efficientzero_v2_config.py b/zoo/atari/config/atari_efficientzero_v2_config.py new file mode 100644 index 000000000..0d8c60d3b --- /dev/null +++ b/zoo/atari/config/atari_efficientzero_v2_config.py @@ -0,0 +1,98 @@ +from easydict import EasyDict +from zoo.atari.config.atari_env_action_space_map import atari_env_action_space_map + +env_id = 'PongNoFrameskip-v4' # You can specify any Atari game here +action_space_size = atari_env_action_space_map[env_id] + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +update_per_collect = None +replay_ratio = 0.25 +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +batch_size = 256 +max_env_step = int(5e5) +reanalyze_ratio = 0. +num_unroll_steps = 5 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +atari_efficientzero_config = dict( + exp_name=f'data_efficientzero/{env_id[:-14]}_efficientzero_stack4_H{num_unroll_steps}_seed0', + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=[4, 64, 64], + frame_stack_num=4, + gray_scale=True, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=[4, 64, 64], + image_channel=1, + frame_stack_num=4, + gray_scale=True, + action_space_size=action_space_size, + downsample=True, + self_supervised_learning_loss=True, # default is False + discrete_action_encoding_type='one_hot', + norm_type='BN', + reward_support_range=(-50., 51., 1.), + value_support_range=(-50., 51., 1.), + ), + cuda=True, + env_type='not_board_games', + game_segment_length=400, + use_augmentation=True, + use_priority=False, + replay_ratio=replay_ratio, + update_per_collect=update_per_collect, + batch_size=batch_size, + dormant_threshold=0.025, + optim_type='SGD', + piecewise_decay_lr_scheduler=True, + learning_rate=0.2, + target_update_freq=100, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + ssl_loss_weight=2, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) +atari_efficientzero_config = EasyDict(atari_efficientzero_config) +main_config = atari_efficientzero_config + +atari_efficientzero_create_config = dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='efficientzero_v2', + import_names=['lzero.policy.efficientzero_v2'], + ), +) +atari_efficientzero_create_config = EasyDict(atari_efficientzero_create_config) +create_config = atari_efficientzero_create_config + +if __name__ == "__main__": + # Define a list of seeds for multiple runs + seeds = [0, 1, 2] # You can add more seed values here + for seed in seeds: + # Update exp_name to include the current seed + main_config.exp_name = f'data_efficientzero_v2/{env_id[:-14]}_efficientzero_v2_stack4_H{num_unroll_steps}_seed{seed}' + from lzero.entry import train_muzero + train_muzero([main_config, create_config], seed=seed, max_env_step=max_env_step) diff --git a/zoo/classic_control/cartpole/config/cartpole_efficientzero_v2_config.py b/zoo/classic_control/cartpole/config/cartpole_efficientzero_v2_config.py new file mode 100644 index 000000000..f6bd11ba5 --- /dev/null +++ b/zoo/classic_control/cartpole/config/cartpole_efficientzero_v2_config.py @@ -0,0 +1,89 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 25 +update_per_collect = 100 +batch_size = 256 +max_env_step = int(1e5) +reanalyze_ratio = 0. +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +cartpole_efficientzero_config = dict( + exp_name=f'data_ezv2/cartpole_efficientzero_v2_ns{num_simulations}_upc{update_per_collect}_rer{reanalyze_ratio}_seed0', + env=dict( + env_id='CartPole-v0', + continuous=False, + manually_discretization=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=4, + action_space_size=2, + model_type='mlp', + lstm_hidden_size=128, + latent_state_dim=128, + discrete_action_encoding_type='one_hot', + norm_type='BN', + ), + # (str) The path of the pretrained model. If None, the model will be initialized by the default model. + model_path=None, + cuda=True, + env_type='not_board_games', + game_segment_length=50, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + piecewise_decay_lr_scheduler=False, + learning_rate=0.003, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e2), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) + +cartpole_efficientzero_config = EasyDict(cartpole_efficientzero_config) +main_config = cartpole_efficientzero_config + +cartpole_efficientzero_create_config = dict( + env=dict( + type='cartpole_lightzero', + import_names=['zoo.classic_control.cartpole.envs.cartpole_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='efficientzero_v2', + import_names=['lzero.policy.efficientzero_v2'], + ), +) +cartpole_efficientzero_create_config = EasyDict(cartpole_efficientzero_create_config) +create_config = cartpole_efficientzero_create_config + +if __name__ == "__main__": + # Users can use different train entry by specifying the entry_type. + entry_type = "train_muzero" # options={"train_muzero", "train_muzero_with_gym_env"} + + if entry_type == "train_muzero": + from lzero.entry import train_muzero + elif entry_type == "train_muzero_with_gym_env": + """ + The ``train_muzero_with_gym_env`` entry means that the environment used in the training process is generated by wrapping the original gym environment with LightZeroEnvWrapper. + Users can refer to lzero/envs/wrappers for more details. + """ + from lzero.entry import train_muzero_with_gym_env as train_muzero + + train_muzero([main_config, create_config], seed=0, model_path=main_config.policy.model_path, max_env_step=max_env_step)