1+ /*
2+ * Copyright (c) 2019 - 2025 Geode-solutions
3+ *
4+ * Permission is hereby granted, free of charge, to any person obtaining a copy
5+ * of this software and associated documentation files (the "Software"), to deal
6+ * in the Software without restriction, including without limitation the rights
7+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8+ * copies of the Software, and to permit persons to whom the Software is
9+ * furnished to do so, subject to the following conditions:
10+ *
11+ * The above copyright notice and this permission notice shall be included in
12+ * all copies or substantial portions of the Software.
13+ *
14+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
20+ * SOFTWARE.
21+ *
22+ */
23+
24+ #pragma once
25+ #include < geode/stochastic/common.hpp>
26+ #include < geode/stochastic/sampling/mcmc/metropolis_hasting_sampler.hpp>
27+ #include < geode/stochastic/sampling/mcmc/models/energy_term_collection.hpp>
28+
29+ #include < absl/strings/str_join.h>
30+ #include < fstream>
31+
32+ namespace geode
33+ {
34+ class MonitoringStatistics
35+ {
36+ public:
37+ MonitoringStatistics ( MonitoringStatistics&& ) = default ;
38+ MonitoringStatistics ( const MonitoringStatistics& ) = default ;
39+ MonitoringStatistics& operator =(
40+ MonitoringStatistics&& ) noexcept = default ;
41+ MonitoringStatistics& operator =(
42+ const MonitoringStatistics& ) noexcept = default ;
43+
44+ MonitoringStatistics ( const index_t nb_energy_terms )
45+ {
46+ sum.resize ( nb_energy_terms, 0.0 );
47+ sum_squares.resize ( nb_energy_terms, 0.0 );
48+ means.resize ( nb_energy_terms, 0.0 );
49+ variances.resize ( nb_energy_terms, 0.0 );
50+ }
51+
52+ void add_realization ( const std::vector< double >& values )
53+ {
54+ for ( const auto stat_id : Range{ values.size () } )
55+ {
56+ sum[stat_id] += values[stat_id];
57+ sum_squares[stat_id] += values[stat_id] * values[stat_id];
58+ }
59+ }
60+
61+ void finalize ( const index_t nb_realizations )
62+ {
63+ for ( const auto stat_id : Range{ sum.size () } )
64+ {
65+ means[stat_id] = sum[stat_id] / nb_realizations;
66+ double variance =
67+ ( sum_squares[stat_id]
68+ - ( sum[stat_id] * sum[stat_id] ) / nb_realizations )
69+ / ( nb_realizations - 1 );
70+ variances[stat_id] = variance;
71+ // stddevs[stat_id] =std::sqrt( std::max( variance, 0.0 ) );
72+ }
73+ }
74+
75+ public:
76+ std::vector< double > sum;
77+ std::vector< double > sum_squares;
78+ std::vector< double > means;
79+ std::vector< double > variances;
80+ };
81+
82+ template < typename ObjectType >
83+ class SimulationRunner
84+ {
85+ public:
86+ SimulationRunner () = default ;
87+ virtual ~SimulationRunner () = default ;
88+
89+ virtual void initialize () = 0;
90+
91+ const ObjectSets< ObjectType >& run (
92+ RandomEngine& engine, const index_t steps )
93+ {
94+ mh_sampler_->walk ( object_sets_, engine, steps );
95+ return object_sets_;
96+ }
97+
98+ void run_and_print ( std::string_view filename,
99+ RandomEngine& engine,
100+ const index_t steps,
101+ const index_t nb_realizations )
102+ {
103+ const auto file_exist =
104+ static_cast < bool >( std::ifstream ( filename.data () ) );
105+ if ( !file_exist )
106+ {
107+ const auto header = statistics_header_file ();
108+ print_to_file ( filename, header );
109+ }
110+
111+ for ( const auto realization : Range{ nb_realizations } )
112+ {
113+ run ( engine, steps );
114+ const auto statistics = statistics_string ();
115+ print_to_file ( filename, statistics );
116+ }
117+ }
118+
119+ MonitoringStatistics run_print_and_monitor ( std::string_view filename,
120+ RandomEngine& engine,
121+ const index_t steps,
122+ const index_t nb_realizations )
123+ {
124+ const auto file_exist =
125+ static_cast < bool >( std::ifstream ( filename.data () ) );
126+ if ( !file_exist )
127+ {
128+ const auto header = statistics_header_file ();
129+ print_to_file ( filename, header );
130+ }
131+ MonitoringStatistics stat_monitoring (
132+ energy_terms_collection_.size () );
133+
134+ for ( const auto realization : Range{ nb_realizations } )
135+ {
136+ run ( engine, steps );
137+ const auto stats = statistics ();
138+ print_to_file ( filename,
139+ absl::StrCat ( absl::StrJoin ( stats, " ; " ), " \n " ) );
140+ stat_monitoring.add_realization ( stats );
141+ }
142+ stat_monitoring.finalize ( nb_realizations );
143+ return stat_monitoring;
144+ }
145+
146+ const ObjectSets< ObjectType >& current_pattern_realization () const
147+ {
148+ return object_sets_;
149+ }
150+
151+ std::vector< double > statistics () const
152+ {
153+ std::vector< double > statistic_values;
154+ statistic_values.reserve ( ordered_energy_terms_.size () );
155+
156+ for ( const auto & energy_term_uuid : ordered_energy_terms_ )
157+ {
158+ const auto & term =
159+ energy_terms_collection_.get ( energy_term_uuid );
160+ statistic_values.push_back ( term.statistic ( object_sets_ ) );
161+ }
162+
163+ return statistic_values;
164+ }
165+
166+ std::string statistics_log_info () const
167+ {
168+ std::string message ( " Pattern statistics: " );
169+ for ( const auto term_id :
170+ geode::Range{ ordered_energy_terms_.size () } )
171+ {
172+ const auto & energy_term = energy_terms_collection_.get (
173+ ordered_energy_terms_[term_id] );
174+ const double value = energy_term.statistic ( object_sets_ );
175+ absl::StrAppend ( &message, " \t Term(" , energy_term.name (),
176+ " ) --> value/traget: " , value, " / " ,
177+ ordered_target_statistics_[term_id] );
178+ }
179+ return message;
180+ }
181+
182+ protected:
183+ std::string energy_term_names () const
184+ {
185+ std::vector< std::string > term_names;
186+ term_names.reserve ( ordered_energy_terms_.size () );
187+
188+ for ( const auto & energy_term_uuid : ordered_energy_terms_ )
189+ {
190+ const auto & term =
191+ energy_terms_collection_.get ( energy_term_uuid );
192+ term_names.push_back ( term.name ().data () );
193+ }
194+
195+ return absl::StrCat ( absl::StrJoin ( term_names, " ; " ), " \n " );
196+ }
197+
198+ std::string statistics_string () const
199+ {
200+ return absl::StrCat ( absl::StrJoin ( statistics (), " ; " ), " \n " );
201+ }
202+
203+ std::string statistics_header_file ()
204+ {
205+ std::string message ( " Sufficient statistics mcmc iterations:\n " );
206+ absl::StrAppend ( &message, energy_term_names () );
207+ return message;
208+ }
209+
210+ void print_to_file (
211+ absl::string_view filename, absl::string_view message )
212+ {
213+ std::ofstream file (
214+ filename.data (), std::ofstream::out | std::ofstream::app );
215+ file << message;
216+ file.close ();
217+ return ;
218+ }
219+
220+ protected:
221+ std::vector< std::unique_ptr< geode::ObjectSetSampler< ObjectType > > >
222+ set_samplers_;
223+
224+ std::vector< geode::uuid > ordered_energy_terms_;
225+ std::vector< double > ordered_target_statistics_;
226+
227+ EnergyTermCollection< ObjectType > energy_terms_collection_;
228+ std::unique_ptr< geode::MetropolisHastings< ObjectType > > mh_sampler_;
229+
230+ ObjectSets< ObjectType > object_sets_;
231+ };
232+ } // namespace geode
0 commit comments