Skip to content

Commit 2cd0a74

Browse files
authored
Merge pull request #18 from Geode-solutions/simulation_helpers
Simulation helpers
2 parents 27c16a6 + 8fdc892 commit 2cd0a74

21 files changed

Lines changed: 712 additions & 603 deletions
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
/*
2+
* Copyright (c) 2019 - 2025 Geode-solutions
3+
*
4+
* Permission is hereby granted, free of charge, to any person obtaining a copy
5+
* of this software and associated documentation files (the "Software"), to deal
6+
* in the Software without restriction, including without limitation the rights
7+
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8+
* copies of the Software, and to permit persons to whom the Software is
9+
* furnished to do so, subject to the following conditions:
10+
*
11+
* The above copyright notice and this permission notice shall be included in
12+
* all copies or substantial portions of the Software.
13+
*
14+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
20+
* SOFTWARE.
21+
*
22+
*/
23+
24+
#pragma once
25+
#include <geode/stochastic/common.hpp>
26+
#include <geode/stochastic/sampling/mcmc/metropolis_hasting_sampler.hpp>
27+
#include <geode/stochastic/sampling/mcmc/models/energy_term_collection.hpp>
28+
29+
#include <absl/strings/str_join.h>
30+
#include <fstream>
31+
32+
namespace geode
33+
{
34+
class MonitoringStatistics
35+
{
36+
public:
37+
MonitoringStatistics( MonitoringStatistics&& ) = default;
38+
MonitoringStatistics( const MonitoringStatistics& ) = default;
39+
MonitoringStatistics& operator=(
40+
MonitoringStatistics&& ) noexcept = default;
41+
MonitoringStatistics& operator=(
42+
const MonitoringStatistics& ) noexcept = default;
43+
44+
MonitoringStatistics( const index_t nb_energy_terms )
45+
{
46+
sum.resize( nb_energy_terms, 0.0 );
47+
sum_squares.resize( nb_energy_terms, 0.0 );
48+
means.resize( nb_energy_terms, 0.0 );
49+
variances.resize( nb_energy_terms, 0.0 );
50+
}
51+
52+
void add_realization( const std::vector< double >& values )
53+
{
54+
for( const auto stat_id : Range{ values.size() } )
55+
{
56+
sum[stat_id] += values[stat_id];
57+
sum_squares[stat_id] += values[stat_id] * values[stat_id];
58+
}
59+
}
60+
61+
void finalize( const index_t nb_realizations )
62+
{
63+
for( const auto stat_id : Range{ sum.size() } )
64+
{
65+
means[stat_id] = sum[stat_id] / nb_realizations;
66+
double variance =
67+
( sum_squares[stat_id]
68+
- ( sum[stat_id] * sum[stat_id] ) / nb_realizations )
69+
/ ( nb_realizations - 1 );
70+
variances[stat_id] = variance;
71+
// stddevs[stat_id] =std::sqrt( std::max( variance, 0.0 ) );
72+
}
73+
}
74+
75+
public:
76+
std::vector< double > sum;
77+
std::vector< double > sum_squares;
78+
std::vector< double > means;
79+
std::vector< double > variances;
80+
};
81+
82+
template < typename ObjectType >
83+
class SimulationRunner
84+
{
85+
public:
86+
SimulationRunner() = default;
87+
virtual ~SimulationRunner() = default;
88+
89+
virtual void initialize() = 0;
90+
91+
const ObjectSets< ObjectType >& run(
92+
RandomEngine& engine, const index_t steps )
93+
{
94+
mh_sampler_->walk( object_sets_, engine, steps );
95+
return object_sets_;
96+
}
97+
98+
void run_and_print( std::string_view filename,
99+
RandomEngine& engine,
100+
const index_t steps,
101+
const index_t nb_realizations )
102+
{
103+
const auto file_exist =
104+
static_cast< bool >( std::ifstream( filename.data() ) );
105+
if( !file_exist )
106+
{
107+
const auto header = statistics_header_file();
108+
print_to_file( filename, header );
109+
}
110+
111+
for( const auto realization : Range{ nb_realizations } )
112+
{
113+
run( engine, steps );
114+
const auto statistics = statistics_string();
115+
print_to_file( filename, statistics );
116+
}
117+
}
118+
119+
MonitoringStatistics run_print_and_monitor( std::string_view filename,
120+
RandomEngine& engine,
121+
const index_t steps,
122+
const index_t nb_realizations )
123+
{
124+
const auto file_exist =
125+
static_cast< bool >( std::ifstream( filename.data() ) );
126+
if( !file_exist )
127+
{
128+
const auto header = statistics_header_file();
129+
print_to_file( filename, header );
130+
}
131+
MonitoringStatistics stat_monitoring(
132+
energy_terms_collection_.size() );
133+
134+
for( const auto realization : Range{ nb_realizations } )
135+
{
136+
run( engine, steps );
137+
const auto stats = statistics();
138+
print_to_file( filename,
139+
absl::StrCat( absl::StrJoin( stats, " ; " ), "\n" ) );
140+
stat_monitoring.add_realization( stats );
141+
}
142+
stat_monitoring.finalize( nb_realizations );
143+
return stat_monitoring;
144+
}
145+
146+
const ObjectSets< ObjectType >& current_pattern_realization() const
147+
{
148+
return object_sets_;
149+
}
150+
151+
std::vector< double > statistics() const
152+
{
153+
std::vector< double > statistic_values;
154+
statistic_values.reserve( ordered_energy_terms_.size() );
155+
156+
for( const auto& energy_term_uuid : ordered_energy_terms_ )
157+
{
158+
const auto& term =
159+
energy_terms_collection_.get( energy_term_uuid );
160+
statistic_values.push_back( term.statistic( object_sets_ ) );
161+
}
162+
163+
return statistic_values;
164+
}
165+
166+
std::string statistics_log_info() const
167+
{
168+
std::string message( "Pattern statistics: " );
169+
for( const auto term_id :
170+
geode::Range{ ordered_energy_terms_.size() } )
171+
{
172+
const auto& energy_term = energy_terms_collection_.get(
173+
ordered_energy_terms_[term_id] );
174+
const double value = energy_term.statistic( object_sets_ );
175+
absl::StrAppend( &message, " \t Term(", energy_term.name(),
176+
") --> value/traget: ", value, " / ",
177+
ordered_target_statistics_[term_id] );
178+
}
179+
return message;
180+
}
181+
182+
protected:
183+
std::string energy_term_names() const
184+
{
185+
std::vector< std::string > term_names;
186+
term_names.reserve( ordered_energy_terms_.size() );
187+
188+
for( const auto& energy_term_uuid : ordered_energy_terms_ )
189+
{
190+
const auto& term =
191+
energy_terms_collection_.get( energy_term_uuid );
192+
term_names.push_back( term.name().data() );
193+
}
194+
195+
return absl::StrCat( absl::StrJoin( term_names, " ; " ), "\n" );
196+
}
197+
198+
std::string statistics_string() const
199+
{
200+
return absl::StrCat( absl::StrJoin( statistics(), " ; " ), "\n" );
201+
}
202+
203+
std::string statistics_header_file()
204+
{
205+
std::string message( "Sufficient statistics mcmc iterations:\n" );
206+
absl::StrAppend( &message, energy_term_names() );
207+
return message;
208+
}
209+
210+
void print_to_file(
211+
absl::string_view filename, absl::string_view message )
212+
{
213+
std::ofstream file(
214+
filename.data(), std::ofstream::out | std::ofstream::app );
215+
file << message;
216+
file.close();
217+
return;
218+
}
219+
220+
protected:
221+
std::vector< std::unique_ptr< geode::ObjectSetSampler< ObjectType > > >
222+
set_samplers_;
223+
224+
std::vector< geode::uuid > ordered_energy_terms_;
225+
std::vector< double > ordered_target_statistics_;
226+
227+
EnergyTermCollection< ObjectType > energy_terms_collection_;
228+
std::unique_ptr< geode::MetropolisHastings< ObjectType > > mh_sampler_;
229+
230+
ObjectSets< ObjectType > object_sets_;
231+
};
232+
} // namespace geode

include/geode/stochastic/sampling/mcmc/metropolis_hasting_sampler.hpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,10 @@ namespace geode
4949
class MetropolisHastings
5050
{
5151
public:
52-
MetropolisHastings( GibbsEnergy< ObjectType >& energy,
52+
MetropolisHastings(
53+
const EnergyTermCollection< ObjectType >& energy_term_collection,
5354
std::unique_ptr< ProposalKernel< ObjectType > > proposal_kernel )
54-
: energy_( energy ),
55+
: gibbs_energy_{ energy_term_collection },
5556
proposal_kernel_( std::move( proposal_kernel ) )
5657
{
5758
OPENGEODE_ASSERT(
@@ -188,7 +189,7 @@ namespace geode
188189
{
189190
const auto new_object = proposal.new_object();
190191
const auto delta_log_energy =
191-
energy_.delta_log_add( state, new_object );
192+
gibbs_energy_.delta_log_add( state, new_object );
192193
return accept_or_reject( proposal, state, engine, delta_log_energy,
193194
[]( auto& state, auto& proposal ) {
194195
state.add_object(
@@ -203,7 +204,7 @@ namespace geode
203204
{
204205
const auto old_object_id = proposal.old_object_id();
205206
const auto delta_log_energy =
206-
energy_.delta_log_remove( state, old_object_id );
207+
gibbs_energy_.delta_log_remove( state, old_object_id );
207208
return accept_or_reject( proposal, state, engine, delta_log_energy,
208209
[]( auto& state, auto& proposal ) {
209210
state.remove_object( proposal.old_object_id() );
@@ -216,8 +217,8 @@ namespace geode
216217
{
217218
const auto new_object = proposal.new_object();
218219
const auto old_object_id = proposal.old_object_id();
219-
const auto delta_log_energy =
220-
energy_.delta_log_change( state, old_object_id, new_object );
220+
const auto delta_log_energy = gibbs_energy_.delta_log_change(
221+
state, old_object_id, new_object );
221222
// should we test that objects are in the same group?
222223
// should be ensured by the dynamic
223224
return accept_or_reject( proposal, state, engine, delta_log_energy,
@@ -229,7 +230,7 @@ namespace geode
229230
};
230231

231232
private:
232-
const GibbsEnergy< ObjectType >& energy_;
233+
GibbsEnergy< ObjectType > gibbs_energy_;
233234
std::unique_ptr< ProposalKernel< ObjectType > > proposal_kernel_;
234235
double beta_{ 1.0 };
235236
};

include/geode/stochastic/sampling/mcmc/models/components/energy_term.hpp

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -82,22 +82,6 @@ namespace geode
8282

8383
namespace geode
8484
{
85-
// struct EnergyTermDescription
86-
// {
87-
// geode::uuid id;
88-
// std::string name;
89-
// std::string type;
90-
// double parameter_value;
91-
// std::optional< uuid > targeted_set_id{};
92-
// }
93-
//
94-
// struct StatisticalDescription
95-
// {
96-
// std::string label;
97-
// double value;
98-
// std::optional< uuid > targeted_set_id{};
99-
// };
100-
10185
template < typename ObjectType >
10286
class EnergyTerm : public Identifier
10387
{

include/geode/stochastic/sampling/mcmc/models/components/pairwise_term.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ namespace geode
155155
// continue;
156156
// }
157157
// if( neigh_obj_id.set_id == obj_id.set_id
158-
// && neigh_obj_id.id <= obj_id.id )
158+
// && neigh_obj_id.index <= obj_id.index )
159159
//{
160160
// continue;
161161
// }

include/geode/stochastic/sampling/mcmc/models/components/energy_term_collection.hpp renamed to include/geode/stochastic/sampling/mcmc/models/energy_term_collection.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,14 @@ namespace geode
6363
return energy_terms_.size();
6464
}
6565

66-
[[nodiscard]] std::shared_ptr< const EnergyTerm< ObjectType > > get(
66+
[[nodiscard]] const EnergyTerm< ObjectType >& get(
6767
const uuid& id ) const
6868
{
6969
auto it = energy_terms_.find( id );
7070
OPENGEODE_EXCEPTION( it != energy_terms_.end(),
7171
absl::StrCat( "[EnergyTermCollection] Unknown energy term: ",
7272
id.string() ) );
73-
return it->second;
73+
return *it->second;
7474
}
7575

7676
[[nodiscard]] const absl::flat_hash_map< uuid,

include/geode/stochastic/sampling/mcmc/models/gibbs_energy.hpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
*/
2323
#pragma once
2424

25-
#include <geode/stochastic/sampling/mcmc/models/components/energy_term_collection.hpp>
25+
#include <geode/stochastic/sampling/mcmc/models/energy_term_collection.hpp>
2626
#include <geode/stochastic/spatial/object_sets.hpp>
2727

2828
namespace geode
@@ -65,8 +65,7 @@ namespace geode
6565
{
6666
double log_energy = 0.0;
6767
for( const auto& term :
68-
energy_terms_collection_.terms_for_set( new_object
69-
.set_id ) ) // energy_terms_collection_.all_terms() )
68+
energy_terms_collection_.terms_for_set( new_object.set_id ) )
7069
{
7170
log_energy += term->delta_log_add( state, new_object );
7271
}
@@ -77,8 +76,8 @@ namespace geode
7776
const ObjectSets< ObjectType >& state, const ObjectId& id ) const
7877
{
7978
double log_energy = 0.0;
80-
for( const auto& term : energy_terms_collection_.terms_for_set(
81-
id.set_id ) ) // energy_terms_collection_.all_terms() )
79+
for( const auto& term :
80+
energy_terms_collection_.terms_for_set( id.set_id ) )
8281
{
8382
log_energy += term->delta_log_remove( state, id );
8483
}
@@ -90,8 +89,8 @@ namespace geode
9089
const ObjectRef< ObjectType >& new_object ) const
9190
{
9291
double log_energy = 0.0;
93-
for( const auto& term : energy_terms_collection_.terms_for_set(
94-
old_id.set_id ) ) // energy_terms_collection_.all_terms() )
92+
for( const auto& term :
93+
energy_terms_collection_.terms_for_set( old_id.set_id ) )
9594
{
9695
log_energy +=
9796
term->delta_log_change( state, old_id, new_object );

0 commit comments

Comments
 (0)