Skip to content

Commit dbf2fdf

Browse files
feat(SimulationContext): remove inheritance in simulation runner
1 parent 3331379 commit dbf2fdf

8 files changed

Lines changed: 224 additions & 193 deletions

File tree

include/geode/stochastic/models/model.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,11 @@ namespace geode
118118
return names;
119119
}
120120

121+
[[nodiscard]] std::string string() const
122+
{
123+
return terms_collection_.string();
124+
}
125+
121126
private:
122127
EnergyTermCollection< ObjectType > terms_collection_;
123128
GibbsEnergy< ObjectType > energy_;
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/*
2+
* Copyright (c) 2019 - 2026 Geode-solutions
3+
*
4+
* Permission is hereby granted, free of charge, to any person obtaining a copy
5+
* of this software and associated documentation files (the "Software"), to deal
6+
* in the Software without restriction, including without limitation the rights
7+
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8+
* copies of the Software, and to permit persons to whom the Software is
9+
* furnished to do so, subject to the following conditions:
10+
*
11+
* The above copyright notice and this permission notice shall be included in
12+
* all copies or substantial portions of the Software.
13+
*
14+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
20+
* SOFTWARE.
21+
*
22+
*/
23+
24+
#pragma once
25+
26+
#include <geode/stochastic/common.hpp>
27+
28+
#include <geode/stochastic/spatial/object_sets.hpp>
29+
30+
#include <geode/stochastic/inference/target_statistics.hpp>
31+
32+
#include <geode/stochastic/models/model.hpp>
33+
34+
#include <geode/stochastic/sampling/direct/object_set_sampler/object_set_sampler.hpp>
35+
#include <geode/stochastic/sampling/mcmc/metropolis_hasting_sampler.hpp>
36+
37+
#include <geode/stochastic/spatial/spatial_domain.hpp>
38+
39+
namespace geode
40+
{
41+
template < typename ObjectType >
42+
struct SimulationContext
43+
{
44+
explicit SimulationContext( SpatialDomain< ObjectType::dim > domain_in )
45+
: domain( std::move( domain_in ) )
46+
{
47+
}
48+
49+
std::string string() const
50+
{
51+
auto message = std::string{ "SimulationContext: " };
52+
absl::StrAppend( &message, "\n\t --> ", domain.string() );
53+
absl::StrAppend( &message, "\n\t --> ", object_sets.string() );
54+
absl::StrAppend(
55+
&message, "\n\t --> ", set_samplers.size(), " Sets samplers " );
56+
absl::StrAppend( &message, "\n\t --> ", model->string() );
57+
// absl::StrAppend( &message, "\n\t --> ", mh_sampler_ > string() );
58+
59+
return message;
60+
}
61+
62+
SpatialDomain< ObjectType::dim > domain;
63+
64+
ObjectSets< ObjectType > object_sets;
65+
std::vector< std::unique_ptr< geode::ObjectSetSampler< ObjectType > > >
66+
set_samplers;
67+
std::unique_ptr< Model< ObjectType > > model;
68+
std::unique_ptr< geode::MetropolisHastings< ObjectType > > mh_sampler;
69+
70+
std::optional< TargetStatistics< ObjectType > > target_statistics;
71+
};
72+
73+
} // namespace geode

include/geode/stochastic/sampling/mcmc/simulation_runner.hpp

Lines changed: 21 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,9 @@
2323

2424
#pragma once
2525
#include <geode/stochastic/common.hpp>
26-
#include <geode/stochastic/inference/statistics_tracker.hpp>
27-
#include <geode/stochastic/inference/target_statistics.hpp>
28-
29-
#include <geode/stochastic/models/energy_term_collection.hpp>
3026
#include <geode/stochastic/sampling/mcmc/helpers/simulation_printer.hpp>
31-
#include <geode/stochastic/sampling/mcmc/metropolis_hasting_sampler.hpp>
32-
#include <geode/stochastic/spatial/spatial_domain.hpp>
27+
28+
#include <geode/stochastic/sampling/mcmc/helpers/simulation_context.hpp>
3329

3430
#include <absl/strings/str_join.h>
3531

@@ -65,49 +61,51 @@ namespace geode
6561
class SimulationRunner
6662
{
6763
public:
68-
SimulationRunner( const SpatialDomain< ObjectType::dim >& domain )
69-
: domain_( domain ) {};
64+
SimulationRunner( SimulationContext< ObjectType >&& context )
65+
: context_( std::move( context ) ){};
7066
virtual ~SimulationRunner() = default;
7167

72-
virtual void initialize() = 0;
73-
7468
const ObjectSets< ObjectType >& run(
7569
RandomEngine& engine, const index_t steps )
7670
{
77-
mh_sampler_->walk( object_sets_, engine, steps );
78-
return object_sets_;
71+
context_.mh_sampler->walk( context_.object_sets, engine, steps );
72+
return context_.object_sets;
7973
}
8074

8175
StatisticsTracker< ObjectType > run(
8276
RandomEngine& engine, const SimulationConfigurator& config )
8377
{
78+
Logger::info( context_.string() );
8479
if( config.burn_in_steps > 0 )
8580
{
86-
mh_sampler_->walk( object_sets_, engine, config.burn_in_steps );
81+
context_.mh_sampler->walk(
82+
context_.object_sets, engine, config.burn_in_steps );
8783
}
8884

8985
// Initialize monitoring
90-
StatisticsTracker< ObjectType > stats_monitor( *model_ );
86+
StatisticsTracker< ObjectType > stats_monitor( *context_.model );
9187
std::unique_ptr< SimulationPrinter< ObjectType > > printer;
9288

9389
if( config.printer.has_value() )
9490
{
9591
printer = std::make_unique< SimulationPrinter< ObjectType > >(
96-
*model_, config.printer.value() );
92+
*context_.model, config.printer.value() );
9793
}
9894

9995
for( const auto realization : Range{ config.realizations } )
10096
{
101-
mh_sampler_->walk(
102-
object_sets_, engine, config.metropolis_hasting_steps );
97+
context_.mh_sampler->walk( context_.object_sets, engine,
98+
config.metropolis_hasting_steps );
10399

104-
const auto stats = model_->compute_statistics( object_sets_ );
100+
const auto stats =
101+
context_.model->compute_statistics( context_.object_sets );
105102
stats_monitor.add_realization( stats );
106103

107104
if( printer )
108105
{
109106
printer->print_statistics( stats );
110-
printer->print_object_sets( object_sets_, realization );
107+
printer->print_object_sets(
108+
context_.object_sets, realization );
111109
}
112110
}
113111

@@ -123,27 +121,19 @@ namespace geode
123121
target_statistics() const
124122
{
125123
OpenGeodeStochasticStochasticException::check_exception(
126-
target_statistics_.has_value(), nullptr,
124+
context_.target_statistics.has_value(), nullptr,
127125
OpenGeodeException::TYPE::data,
128126
"[SimulationRunner] Target statistics not initialized" );
129127

130-
return *target_statistics_;
128+
return *context_.target_statistics;
131129
}
132130

133131
[[nodiscard]] const ObjectSets< ObjectType >& state_realization() const
134132
{
135-
return object_sets_;
133+
return context_.object_sets;
136134
}
137135

138136
protected:
139-
SpatialDomain< ObjectType::dim > domain_;
140-
141-
ObjectSets< ObjectType > object_sets_;
142-
std::vector< std::unique_ptr< geode::ObjectSetSampler< ObjectType > > >
143-
set_samplers_;
144-
std::unique_ptr< Model< ObjectType > > model_;
145-
std::unique_ptr< geode::MetropolisHastings< ObjectType > > mh_sampler_;
146-
147-
std::optional< TargetStatistics< ObjectType > > target_statistics_;
137+
SimulationContext< ObjectType > context_;
148138
};
149139
} // namespace geode

include/geode/stochastic/spatial/spatial_domain.hpp

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,16 @@ namespace geode
3333
class SpatialDomain
3434
{
3535
public:
36-
SpatialDomain(
37-
const BoundingBox< dimension >& domain, double buffer_size )
36+
SpatialDomain( const SpatialDomain& ) = default;
37+
SpatialDomain( SpatialDomain&& ) noexcept = default;
38+
39+
SpatialDomain& operator=( const SpatialDomain& ) = default;
40+
SpatialDomain& operator=( SpatialDomain&& ) noexcept = default;
41+
42+
SpatialDomain( BoundingBox< dimension > domain, double buffer_size )
3843
: domain_{ domain },
3944
buffer_size_{ buffer_size },
40-
extended_domain_{ domain }
45+
extended_domain_{ domain_ }
4146
{
4247
auto volume = domain_.n_volume();
4348
OpenGeodeStochasticStochasticException::check_exception(
@@ -53,6 +58,7 @@ namespace geode
5358
{
5459
extended_domain_.extends( buffer_size_ );
5560
}
61+
Logger::info( domain_.string(), " ", extended_domain_.string() );
5662
}
5763

5864
const BoundingBox< dimension > box() const
@@ -90,6 +96,13 @@ namespace geode
9096
return extended_domain_;
9197
}
9298

99+
std::string string() const
100+
{
101+
return absl::StrCat( "Spatial Domain --> center ", domain_.string(),
102+
" extended: ", extended_domain_.string(),
103+
" buffer: ", buffer_size_ );
104+
}
105+
93106
private:
94107
BoundingBox< dimension > domain_;
95108

@@ -132,4 +145,21 @@ namespace geode
132145
return domain.box().intersects( seg );
133146
}
134147
};
148+
149+
template < index_t dimension >
150+
struct SpatialDomainConfig
151+
{
152+
Point< dimension > min_point;
153+
Point< dimension > max_point;
154+
double buffer_size{ 0.0 };
155+
};
156+
157+
template < index_t dimension >
158+
SpatialDomain< dimension > build_spatial_domain(
159+
const SpatialDomainConfig< dimension >& config )
160+
{
161+
BoundingBox< dimension > box{ config.min_point, config.max_point };
162+
Logger::info( box.string() );
163+
return SpatialDomain{ std::move( box ), config.buffer_size };
164+
}
135165
} // namespace geode

src/geode/stochastic/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ add_geode_library(
6464
"sampling/direct/point_uniform_sampler.hpp"
6565
"sampling/direct/segment_uniform_sampler.hpp"
6666
#"sampling/mcmc/helpers/fracture_simulation_runner.hpp"
67+
"sampling/mcmc/helpers/simulation_context.hpp"
6768
"sampling/mcmc/helpers/simulation_printer.hpp"
6869
"models/energy_terms/energy_term.hpp"
6970
"models/energy_terms/pairwise_term.hpp"

tests/stochastic/CMakeLists.txt

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -161,28 +161,21 @@ add_geode_test(
161161
# ${PROJECT_NAME}::stochastic
162162
#)
163163

164-
add_geode_test(
165-
SOURCE "sampling/mcmc/test-mh-poisson.cpp"
166-
DEPENDENCIES
167-
OpenGeode::basic
168-
OpenGeode::geometry
169-
${PROJECT_NAME}::stochastic
170-
)
171-
add_geode_test(
172-
SOURCE "sampling/mcmc/test-mh-poisson-new.cpp"
173-
DEPENDENCIES
174-
OpenGeode::basic
175-
OpenGeode::geometry
176-
${PROJECT_NAME}::stochastic
177-
)
178-
179-
add_geode_test(
180-
SOURCE "sampling/mcmc/test-mh-strauss.cpp"
181-
DEPENDENCIES
182-
OpenGeode::basic
183-
OpenGeode::geometry
184-
${PROJECT_NAME}::stochastic
185-
)
164+
add_geode_test(
165+
SOURCE "sampling/mcmc/test-mh-poisson.cpp"
166+
DEPENDENCIES
167+
OpenGeode::basic
168+
OpenGeode::geometry
169+
${PROJECT_NAME}::stochastic
170+
)
171+
172+
# add_geode_test(
173+
# SOURCE "sampling/mcmc/test-mh-strauss.cpp"
174+
# DEPENDENCIES
175+
# OpenGeode::basic
176+
# OpenGeode::geometry
177+
# ${PROJECT_NAME}::stochastic
178+
# )
186179

187180

188181
add_geode_test(

0 commit comments

Comments
 (0)