Skip to content

Commit 8e86a7b

Browse files
Merge pull request #41 from Geode-solutions/refacto_simulation_runner
feat(SimulationContext): remove inheritance in simulation runner
2 parents 3331379 + 2ae78a9 commit 8e86a7b

11 files changed

Lines changed: 427 additions & 333 deletions

File tree

include/geode/stochastic/inference/target_statistics.hpp

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
*/
2323
#pragma once
2424
#include <geode/basic/uuid.hpp>
25+
#include <geode/stochastic/models/model.hpp>
2526

2627
namespace geode
2728
{
@@ -36,60 +37,66 @@ namespace geode
3637
class TargetStatistics
3738
{
3839
public:
39-
explicit TargetStatistics( const Model< ObjectType >& model )
40+
explicit TargetStatistics( const Model< ObjectType >& model,
41+
const std::vector< TargetStatisticConfig >& statistic_targets )
4042
: model_( model )
4143
{
4244
values_.resize( model.nb_terms(), 0.0 );
4345
tolerances_.resize( model.nb_terms(), 0.0 );
4446
active_.resize( model.nb_terms(), false );
47+
for( const auto& target : statistic_targets )
48+
{
49+
set_target( target );
50+
}
4551
}
4652

47-
const Model< ObjectType >& model() const
53+
[[nodiscard]] const Model< ObjectType >& model() const
4854
{
4955
return model_;
5056
}
5157

52-
void set_target( const TargetStatisticConfig& statistic )
53-
{
54-
const auto term_uuid =
55-
model_.terms().get_term_uuid( statistic.term_name );
56-
const auto idx = model_.term_index( term_uuid );
57-
58-
values_[idx] = statistic.value;
59-
tolerances_[idx] = statistic.tolerance;
60-
active_[idx] = true;
61-
}
62-
63-
bool has_target( const uuid& term_uuid ) const
58+
[[nodiscard]] bool has_target( const uuid& term_uuid ) const
6459
{
6560
return active_[model_.term_index( term_uuid )];
6661
}
6762

68-
double target( const uuid& term_uuid ) const
63+
[[nodiscard]] double target( const uuid& term_uuid ) const
6964
{
7065
return values_[model_.term_index( term_uuid )];
7166
}
7267

73-
double tolerance( const uuid& term_uuid ) const
68+
[[nodiscard]] double tolerance( const uuid& term_uuid ) const
7469
{
7570
return tolerances_[model_.term_index( term_uuid )];
7671
}
7772

78-
std::vector< uuid > active_terms() const
73+
[[nodiscard]] std::vector< uuid > active_terms() const
7974
{
8075
std::vector< uuid > active_terms_uuid;
8176

8277
for( const auto& term : model_.terms().energy_terms() )
8378
{
84-
const auto& id = term->id();
85-
if( active_[model_.term_index( id )] )
79+
const auto& term_id = term->id();
80+
if( active_[model_.term_index( term_id )] )
8681
{
87-
active_terms_uuid.push_back( id );
82+
active_terms_uuid.push_back( term_id );
8883
}
8984
}
9085
return active_terms_uuid;
9186
}
9287

88+
private:
89+
void set_target( const TargetStatisticConfig& statistic )
90+
{
91+
const auto term_uuid =
92+
model_.terms().get_term_uuid( statistic.term_name );
93+
const auto idx = model_.term_index( term_uuid );
94+
95+
values_[idx] = statistic.value;
96+
tolerances_[idx] = statistic.tolerance;
97+
active_[idx] = true;
98+
}
99+
93100
private:
94101
const Model< ObjectType >& model_;
95102

include/geode/stochastic/models/model.hpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,13 @@ namespace geode
5050
template < typename ObjectType >
5151
class Model
5252
{
53-
OPENGEODE_DISABLE_COPY( Model );
53+
OPENGEODE_DISABLE_COPY_AND_MOVE( Model );
5454

5555
public:
5656
Model() = delete;
57-
Model( EnergyTermCollection< ObjectType >&& energy_terms )
57+
~Model() = default;
58+
59+
explicit Model( EnergyTermCollection< ObjectType >&& energy_terms )
5860
: terms_collection_( std::move( energy_terms ) ),
5961
energy_{ terms_collection_ }
6062
{
@@ -118,6 +120,11 @@ namespace geode
118120
return names;
119121
}
120122

123+
[[nodiscard]] std::string string() const
124+
{
125+
return terms_collection_.string();
126+
}
127+
121128
private:
122129
EnergyTermCollection< ObjectType > terms_collection_;
123130
GibbsEnergy< ObjectType > energy_;
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
/*
2+
* Copyright (c) 2019 - 2026 Geode-solutions
3+
*
4+
* Permission is hereby granted, free of charge, to any person obtaining a copy
5+
* of this software and associated documentation files (the "Software"), to deal
6+
* in the Software without restriction, including without limitation the rights
7+
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8+
* copies of the Software, and to permit persons to whom the Software is
9+
* furnished to do so, subject to the following conditions:
10+
*
11+
* The above copyright notice and this permission notice shall be included in
12+
* all copies or substantial portions of the Software.
13+
*
14+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
20+
* SOFTWARE.
21+
*
22+
*/
23+
24+
#pragma once
25+
26+
#include <geode/stochastic/common.hpp>
27+
28+
#include <geode/stochastic/spatial/object_sets.hpp>
29+
#include <geode/stochastic/spatial/spatial_domain.hpp>
30+
31+
#include <geode/stochastic/inference/target_statistics.hpp>
32+
33+
#include <geode/stochastic/models/model.hpp>
34+
35+
#include <geode/stochastic/sampling/direct/object_set_sampler/object_set_sampler.hpp>
36+
#include <geode/stochastic/sampling/direct/object_set_sampler/point_set_sampler.hpp>
37+
38+
#include <geode/stochastic/sampling/mcmc/proposal/classical_proposals.hpp>
39+
40+
#include <geode/stochastic/sampling/mcmc/metropolis_hasting_sampler.hpp>
41+
#include <geode/stochastic/sampling/mcmc/proposal/object_set_dynamic_config.hpp>
42+
43+
namespace geode
44+
{
45+
template < typename ObjectType >
46+
struct SimulationContext
47+
{
48+
[[nodiscard]] std::string string() const
49+
{
50+
auto message = std::string{ "SimulationContext: " };
51+
absl::StrAppend( &message, "\n\t --> ", domain->string() );
52+
absl::StrAppend( &message, "\n\t --> ", object_sets->string() );
53+
absl::StrAppend(
54+
&message, "\n\t --> ", set_samplers.size(), " Sets samplers " );
55+
absl::StrAppend( &message, "\n\t --> ", model->string() );
56+
// absl::StrAppend( &message, "\n\t --> ", mh_sampler_ > string() );
57+
58+
return message;
59+
}
60+
61+
std::unique_ptr< SpatialDomain< ObjectType::dim > > domain;
62+
63+
std::unique_ptr< ObjectSets< ObjectType > > object_sets{
64+
std::make_unique< ObjectSets< ObjectType > >()
65+
};
66+
std::vector< std::unique_ptr< geode::ObjectSetSampler< ObjectType > > >
67+
set_samplers;
68+
std::unique_ptr< Model< ObjectType > > model;
69+
std::unique_ptr< geode::MetropolisHastings< ObjectType > > mh_sampler;
70+
};
71+
72+
template < typename ObjectType >
73+
struct SimulationContextConfig
74+
{
75+
SpatialDomainConfig< ObjectType::dim > domain;
76+
77+
std::vector< ObjectSetConfig > sets;
78+
std::vector< ObjectSetDynamicsConfig > proposals;
79+
80+
geode::ModelConfig model;
81+
};
82+
83+
template < typename ObjectType >
84+
[[nodiscard]] geode::SimulationContext< ObjectType >
85+
build_simulation_context(
86+
const SimulationContextConfig< ObjectType >& config )
87+
{
88+
geode::SimulationContext< ObjectType > context;
89+
90+
// -------------------------
91+
// Domain
92+
// -------------------------
93+
context.domain = geode::build_spatial_domain( config.domain );
94+
95+
// -------------------------
96+
// Sets
97+
// -------------------------
98+
99+
// auto proposal_kernel =
100+
// std::make_unique< geode::ProposalKernel< geode::Point2D >
101+
// >();
102+
// for( const auto& set_desc : set_descriptors_ )
103+
// {
104+
// const auto set_id = context.object_sets->add_set(
105+
// set_desc.name ); context.set_samplers.push_back(
106+
// std::make_unique< geode::UniformPointSetSampler< 2 >
107+
// >(
108+
// *context.domain ) );
109+
// geode::add_birth_death_change_moves(
110+
// context.set_samplers.back(),
111+
// *proposal_kernel, set_id, set_desc.birth_ratio,
112+
// set_desc.death_ratio, set_desc.change_ratio );
113+
// }
114+
// return proposal_kernel;
115+
116+
for( const auto& set_cfg : config.sets )
117+
{
118+
const auto set_id = context.object_sets->add_set( set_cfg.name );
119+
geode_unused( set_id );
120+
}
121+
122+
// -------------------------
123+
// Model
124+
// -------------------------
125+
context.model = geode::build_model< ObjectType >(
126+
config.model, *context.object_sets, *context.domain );
127+
128+
// -------------------------
129+
// Proposal
130+
// -------------------------
131+
auto proposal_kernel =
132+
std::make_unique< geode::ProposalKernel< ObjectType > >();
133+
for( const auto& set_proposal : config.proposals )
134+
{
135+
const auto set_id =
136+
context.object_sets->get_set_uuid( set_proposal.name );
137+
context.set_samplers.push_back(
138+
std::make_unique< geode::UniformPointSetSampler< 2 > >(
139+
*context.domain ) );
140+
141+
geode::add_birth_death_change_moves( context.set_samplers.back(),
142+
*proposal_kernel, set_id, set_proposal.birth_ratio,
143+
set_proposal.death_ratio, set_proposal.change_ratio );
144+
}
145+
146+
// -------------------------
147+
// MH sampler
148+
// -------------------------
149+
context.mh_sampler =
150+
std::make_unique< geode::MetropolisHastings< geode::Point2D > >(
151+
*context.model, std::move( proposal_kernel ) );
152+
153+
return context;
154+
}
155+
156+
} // namespace geode
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*
2+
* Copyright (c) 2019 - 2026 Geode-solutions
3+
*
4+
* Permission is hereby granted, free of charge, to any person obtaining a copy
5+
* of this software and associated documentation files (the "Software"), to deal
6+
* in the Software without restriction, including without limitation the rights
7+
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8+
* copies of the Software, and to permit persons to whom the Software is
9+
* furnished to do so, subject to the following conditions:
10+
*
11+
* The above copyright notice and this permission notice shall be included in
12+
* all copies or substantial portions of the Software.
13+
*
14+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
20+
* SOFTWARE.
21+
*
22+
*/
23+
24+
#pragma once
25+
26+
#include <string>
27+
28+
namespace geode
29+
{
30+
struct ObjectSetDynamicsConfig
31+
{
32+
std::string name;
33+
34+
double birth_ratio = 1.0;
35+
double death_ratio = 1.0;
36+
double change_ratio = 1.0;
37+
};
38+
} // namespace geode

0 commit comments

Comments
 (0)