Skip to content

Commit 3c30a67

Browse files
tidy
1 parent 449ca08 commit 3c30a67

6 files changed

Lines changed: 76 additions & 80 deletions

File tree

include/geode/stochastic/inference/target_statistics.hpp

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,30 +36,24 @@ namespace geode
3636
class TargetStatistics
3737
{
3838
public:
39-
explicit TargetStatistics( const Model< ObjectType >& model )
39+
explicit TargetStatistics( const Model< ObjectType >& model,
40+
const std::vector< TargetStatisticConfig >& statistic_targets )
4041
: model_( model )
4142
{
4243
values_.resize( model.nb_terms(), 0.0 );
4344
tolerances_.resize( model.nb_terms(), 0.0 );
4445
active_.resize( model.nb_terms(), false );
46+
for( const auto& target : statistic_targets )
47+
{
48+
set_target( target );
49+
}
4550
}
4651

4752
const Model< ObjectType >& model() const
4853
{
4954
return model_;
5055
}
5156

52-
void set_target( const TargetStatisticConfig& statistic )
53-
{
54-
const auto term_uuid =
55-
model_.terms().get_term_uuid( statistic.term_name );
56-
const auto idx = model_.term_index( term_uuid );
57-
58-
values_[idx] = statistic.value;
59-
tolerances_[idx] = statistic.tolerance;
60-
active_[idx] = true;
61-
}
62-
6357
bool has_target( const uuid& term_uuid ) const
6458
{
6559
return active_[model_.term_index( term_uuid )];
@@ -90,6 +84,18 @@ namespace geode
9084
return active_terms_uuid;
9185
}
9286

87+
private:
88+
void set_target( const TargetStatisticConfig& statistic )
89+
{
90+
const auto term_uuid =
91+
model_.terms().get_term_uuid( statistic.term_name );
92+
const auto idx = model_.term_index( term_uuid );
93+
94+
values_[idx] = statistic.value;
95+
tolerances_[idx] = statistic.tolerance;
96+
active_[idx] = true;
97+
}
98+
9399
private:
94100
const Model< ObjectType >& model_;
95101

include/geode/stochastic/models/model.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ namespace geode
5050
template < typename ObjectType >
5151
class Model
5252
{
53-
OPENGEODE_DISABLE_COPY( Model );
53+
OPENGEODE_DISABLE_COPY_AND_MOVE( Model );
5454

5555
public:
5656
Model() = delete;

include/geode/stochastic/sampling/mcmc/helpers/simulation_context.hpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ namespace geode
4141
template < typename ObjectType >
4242
struct SimulationContext
4343
{
44-
std::string string() const
44+
[[nodiscard]] std::string string() const
4545
{
4646
auto message = std::string{ "SimulationContext: " };
4747
absl::StrAppend( &message, "\n\t --> ", domain->string() );
@@ -63,8 +63,6 @@ namespace geode
6363
set_samplers;
6464
std::unique_ptr< Model< ObjectType > > model;
6565
std::unique_ptr< geode::MetropolisHastings< ObjectType > > mh_sampler;
66-
67-
std::optional< TargetStatistics< ObjectType > > target_statistics;
6866
};
6967

7068
} // namespace geode

include/geode/stochastic/sampling/mcmc/simulation_runner.hpp

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,26 @@
2828
#include <geode/stochastic/sampling/mcmc/helpers/simulation_context.hpp>
2929

3030
#include <absl/strings/str_join.h>
31-
3231
namespace geode
3332
{
33+
namespace detail
34+
{
35+
// NOLINTBEGIN(*-magic-numbers)
36+
constexpr index_t default_realizations{ 1000 };
37+
constexpr index_t default_simulation_steps{ 1000 };
38+
constexpr index_t default_burn_in_steps{ 1000 };
39+
// NOLINTEND(*-magic-numbers)
40+
} // namespace detail
41+
3442
struct SimulationConfigurator
3543
{
36-
index_t realizations{ 1000 };
37-
index_t metropolis_hasting_steps{ 1000 };
38-
index_t burn_in_steps{ 1000 };
44+
index_t realizations{ detail::default_realizations };
45+
index_t metropolis_hasting_steps{ detail::default_simulation_steps };
46+
index_t burn_in_steps{ detail::default_burn_in_steps };
3947

4048
std::optional< SimulationPrinterConfigurator > printer{ std::nullopt };
4149

42-
std::string string() const
50+
[[nodiscard]] std::string string() const
4351
{
4452
auto message = absl::StrCat( "SimulationConfigurator: " );
4553
absl::StrAppend( &message, "\n\t --> ", realizations,
@@ -60,19 +68,22 @@ namespace geode
6068
template < typename ObjectType >
6169
class SimulationRunner
6270
{
71+
OPENGEODE_DISABLE_COPY_AND_MOVE( SimulationRunner );
72+
6373
public:
64-
SimulationRunner( SimulationContext< ObjectType >&& context )
65-
: context_( std::move( context ) ) {};
74+
SimulationRunner() = delete;
75+
explicit SimulationRunner( SimulationContext< ObjectType >&& context )
76+
: context_( std::move( context ) ){};
6677
virtual ~SimulationRunner() = default;
6778

68-
const ObjectSets< ObjectType >& run(
79+
[[nodiscard]] const ObjectSets< ObjectType >& run(
6980
RandomEngine& engine, const index_t steps )
7081
{
7182
context_.mh_sampler->walk( *context_.object_sets, engine, steps );
7283
return *context_.object_sets;
7384
}
7485

75-
StatisticsTracker< ObjectType > run(
86+
[[nodiscard]] StatisticsTracker< ObjectType > run(
7687
RandomEngine& engine, const SimulationConfigurator& config )
7788
{
7889
if( config.burn_in_steps > 0 )
@@ -116,23 +127,17 @@ namespace geode
116127
return stats_monitor;
117128
}
118129

119-
[[nodiscard]] const TargetStatistics< ObjectType >&
120-
target_statistics() const
130+
[[nodiscard]] const ObjectSets< ObjectType >& state_realization() const
121131
{
122-
OpenGeodeStochasticStochasticException::check_exception(
123-
context_.target_statistics.has_value(), nullptr,
124-
OpenGeodeException::TYPE::data,
125-
"[SimulationRunner] Target statistics not initialized" );
126-
127-
return *context_.target_statistics;
132+
return *context_.object_sets;
128133
}
129134

130-
[[nodiscard]] const ObjectSets< ObjectType >& state_realization() const
135+
[[nodiscard]] const Model< ObjectType >& model() const
131136
{
132-
return *context_.object_sets;
137+
return *context_.model;
133138
}
134139

135-
protected:
140+
private:
136141
SimulationContext< ObjectType > context_;
137142
};
138143
} // namespace geode

include/geode/stochastic/spatial/spatial_domain.hpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323

2424
#pragma once
2525

26+
#include <geode/stochastic/common.hpp>
27+
2628
#include <geode/geometry/basic_objects/segment.hpp>
2729
#include <geode/geometry/bounding_box.hpp>
2830
#include <geode/geometry/point.hpp>
@@ -32,12 +34,10 @@ namespace geode
3234
template < index_t dimension >
3335
class SpatialDomain
3436
{
35-
public:
36-
SpatialDomain( const SpatialDomain& ) = default;
37-
SpatialDomain( SpatialDomain&& ) noexcept = default;
37+
OPENGEODE_DISABLE_COPY_AND_MOVE( SpatialDomain );
3838

39-
SpatialDomain& operator=( const SpatialDomain& ) = default;
40-
SpatialDomain& operator=( SpatialDomain&& ) noexcept = default;
39+
public:
40+
~SpatialDomain() = default;
4141

4242
SpatialDomain( BoundingBox< dimension > domain, double buffer_size )
4343
: domain_{ domain },
@@ -60,42 +60,43 @@ namespace geode
6060
}
6161
}
6262

63-
const BoundingBox< dimension > box() const
63+
[[nodiscard]] const BoundingBox< dimension >& box() const
6464
{
6565
return domain_;
6666
}
6767

68-
bool contains( const Point< dimension >& point ) const
68+
[[nodiscard]] bool contains( const Point< dimension >& point ) const
6969
{
7070
return domain_.contains( point );
7171
}
7272

73-
double n_volume() const
73+
[[nodiscard]] double n_volume() const
7474
{
7575
return domain_.n_volume();
7676
}
7777

78-
double smallest_length() const
78+
[[nodiscard]] double smallest_length() const
7979
{
8080
return std::get< 1 >( domain_.smallest_length() );
8181
}
8282

83-
bool extended_contains( const Point< dimension >& point ) const
83+
[[nodiscard]] bool extended_contains(
84+
const Point< dimension >& point ) const
8485
{
8586
return extended_domain_.contains( point );
8687
}
8788

88-
double extended_n_volume() const
89+
[[nodiscard]] double extended_n_volume() const
8990
{
9091
return extended_domain_.n_volume();
9192
}
9293

93-
const BoundingBox< dimension > extended_box() const
94+
[[nodiscard]] const BoundingBox< dimension >& extended_box() const
9495
{
9596
return extended_domain_;
9697
}
9798

98-
std::string string() const
99+
[[nodiscard]] std::string string() const
99100
{
100101
return absl::StrCat( "Spatial Domain --> center ", domain_.string(),
101102
" extended: ", extended_domain_.string(),

tests/stochastic/sampling/mcmc/test-mh-poisson.cpp

Lines changed: 17 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,7 @@ namespace
7575
density_descriptors_.push_back( descriptor );
7676
}
7777

78-
void add_target_statistics(
79-
const geode::TargetStatisticConfig& statistic_descriptor )
80-
{
81-
targeted_statistics_descriptors_.push_back( statistic_descriptor );
82-
}
83-
84-
geode::SimulationContext< geode::Point2D > build() const
78+
[[nodiscard]] geode::SimulationContext< geode::Point2D > build() const
8579
{
8680
geode::SimulationContext< geode::Point2D > context;
8781

@@ -92,7 +86,6 @@ namespace
9286
context.mh_sampler =
9387
std::make_unique< geode::MetropolisHastings< geode::Point2D > >(
9488
*context.model, std::move( proposal_kernel ) );
95-
create_target_statistics( context );
9689
return context;
9790
}
9891

@@ -131,22 +124,10 @@ namespace
131124
config, *context.object_sets, *context.domain ) );
132125
}
133126

134-
void create_target_statistics(
135-
geode::SimulationContext< geode::Point2D >& context ) const
136-
{
137-
context.target_statistics.emplace( *context.model );
138-
for( const auto& target_stat : targeted_statistics_descriptors_ )
139-
{
140-
context.target_statistics->set_target( target_stat );
141-
}
142-
}
143-
144127
private:
145128
geode::SpatialDomainConfig< 2 > domain_config_;
146129
std::vector< SetDescription > set_descriptors_;
147130
std::vector< PoissonDensityDescription > density_descriptors_;
148-
std::vector< geode::TargetStatisticConfig >
149-
targeted_statistics_descriptors_;
150131
};
151132

152133
void test_single_type_poisson()
@@ -155,14 +136,17 @@ namespace
155136

156137
geode::RandomEngine engine;
157138
engine.set_seed( "@mh-test-single-POISSON@" );
139+
140+
// NOLINTBEGIN(*-magic-numbers)
158141
std::array< double, 4 > birth_ratio{ 0.1, 0.5, 2., 4. };
159142
std::array< double, 4 > change_ratio{ 0., 1., 1., 0. };
160143

161144
for( const auto config : geode::Range{ birth_ratio.size() } )
162145
{
163146
PoissonConfig poisson_config;
147+
std::vector< geode::TargetStatisticConfig >
148+
targeted_statistics_descriptors;
164149

165-
// NOLINTBEGIN(*-magic-numbers)
166150
poisson_config.add_domain_config( geode::Point2D{ { 0.0, 0.0 } },
167151
geode::Point2D{ { 10.0, 10.0 } }, 0. );
168152
// --- Set description
@@ -180,10 +164,10 @@ namespace
180164
density_a.object_feature = geode::ObjectInDomainFeatureConfig{};
181165

182166
geode::TargetStatisticConfig stat_a{ "density", 30.0, 0.15 };
167+
targeted_statistics_descriptors.push_back( stat_a );
183168

184169
poisson_config.add_set_descriptor( set_a );
185170
poisson_config.add_density_descriptor( density_a );
186-
poisson_config.add_target_statistics( stat_a );
187171

188172
auto context = poisson_config.build();
189173

@@ -203,8 +187,9 @@ namespace
203187
sim_config.printer = printer_config;
204188

205189
auto statistic_tracker = runner.run( engine, sim_config );
206-
geode::statistics::validate(
207-
statistic_tracker, runner.target_statistics() );
190+
geode::TargetStatistics target_stats{ runner.model(),
191+
targeted_statistics_descriptors };
192+
geode::statistics::validate( statistic_tracker, target_stats );
208193
}
209194
// NOLINTEND(*-magic-numbers)
210195

@@ -218,7 +203,8 @@ namespace
218203
geode::RandomEngine engine;
219204
engine.set_seed( "@mh-test-POISSON-multi@" );
220205
PoissonConfig poisson_config;
221-
206+
std::vector< geode::TargetStatisticConfig >
207+
targeted_statistics_descriptors;
222208
// NOLINTBEGIN(*-magic-numbers)
223209
poisson_config.add_domain_config( geode::Point2D{ { 0.0, 0.0 } },
224210
geode::Point2D{ { 10.0, 10.0 } }, 0. );
@@ -236,6 +222,7 @@ namespace
236222
density01.object_feature = geode::ObjectInDomainFeatureConfig{};
237223

238224
geode::TargetStatisticConfig stat01{ "density01", 10.0, 0.15 };
225+
targeted_statistics_descriptors.push_back( stat01 );
239226

240227
PoissonDensityDescription density02;
241228
density02.term_name = "density02";
@@ -244,6 +231,7 @@ namespace
244231
density02.object_feature = geode::ObjectInDomainFeatureConfig{};
245232

246233
geode::TargetStatisticConfig stat02{ "density02", 40.0, 0.15 };
234+
targeted_statistics_descriptors.push_back( stat02 );
247235

248236
PoissonDensityDescription density03;
249237
density03.term_name = "density03";
@@ -252,6 +240,7 @@ namespace
252240
density03.object_feature = geode::ObjectInDomainFeatureConfig{};
253241

254242
geode::TargetStatisticConfig stat03{ "density03", 30.0, 0.15 };
243+
targeted_statistics_descriptors.push_back( stat03 );
255244

256245
poisson_config.add_set_descriptor( set01 );
257246
poisson_config.add_set_descriptor( set02 );
@@ -261,10 +250,6 @@ namespace
261250
poisson_config.add_density_descriptor( density02 );
262251
poisson_config.add_density_descriptor( density03 );
263252

264-
poisson_config.add_target_statistics( stat01 );
265-
poisson_config.add_target_statistics( stat02 );
266-
poisson_config.add_target_statistics( stat03 );
267-
268253
geode::SimulationRunner< geode::Point2D > runner(
269254
poisson_config.build() );
270255
// run simulation
@@ -280,8 +265,9 @@ namespace
280265
// NOLINTEND(*-magic-numbers)
281266

282267
auto statistic_tracker = runner.run( engine, sim_config );
283-
geode::statistics::validate(
284-
statistic_tracker, runner.target_statistics() );
268+
geode::TargetStatistics target_stats{ runner.model(),
269+
targeted_statistics_descriptors };
270+
geode::statistics::validate( statistic_tracker, target_stats );
285271

286272
geode::Logger::info( "--> SUCCESS!" );
287273
}

0 commit comments

Comments
 (0)