Skip to content

Commit 6da78cf

Browse files
feat(Helpers): add poisson and strauss helpers
1 parent 700f135 commit 6da78cf

20 files changed

Lines changed: 980 additions & 677 deletions

File tree

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#pragma once
2+
3+
#include <geode/stochastic/inference/target_statistics.hpp>
4+
#include <geode/stochastic/sampling/mcmc/helpers/simulation_context.hpp>
5+
6+
namespace geode
7+
{
8+
template < typename ObjectType >
9+
struct PoissonSetDescription
10+
{
11+
std::string set_name;
12+
13+
ObjectSamplerConfig< ObjectType > sampler;
14+
15+
std::string density_name;
16+
double lambda;
17+
std::optional< double > expected_nb_objects;
18+
19+
double birth_ratio{ 1.0 };
20+
double death_ratio{ 1.0 };
21+
double change_ratio{ 1.0 };
22+
};
23+
24+
template < typename ObjectType >
25+
struct PoissonProcessDescription
26+
{
27+
SpatialDomainConfig< ObjectType::dim > domain;
28+
29+
std::vector< PoissonSetDescription< ObjectType > > sets;
30+
31+
PoissonSetDescription< ObjectType >& add_set(
32+
absl::string_view set_name, absl::string_view density_name )
33+
{
34+
auto& set = sets.emplace_back();
35+
set.set_name = set_name;
36+
set.density_name = density_name;
37+
return set;
38+
}
39+
};
40+
41+
template < typename ObjectType >
42+
SimulationContext< ObjectType > build_poisson_process(
43+
const PoissonProcessDescription< ObjectType >& description );
44+
45+
template < typename ObjectType >
46+
std::vector< geode::TargetStatisticConfig > build_poisson_targeted_stat(
47+
const PoissonProcessDescription< ObjectType >& description );
48+
49+
} // namespace geode
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#pragma once
2+
3+
#include <geode/stochastic/applications/poisson_process.hpp>
4+
#include <geode/stochastic/inference/target_statistics.hpp>
5+
#include <geode/stochastic/sampling/mcmc/helpers/simulation_context.hpp>
6+
7+
namespace geode
8+
{
9+
template < typename ObjectType >
10+
struct StraussInteractionDescription
11+
{
12+
std::string interaction_name;
13+
14+
std::vector< std::string > set_names;
15+
16+
double gamma;
17+
double distance;
18+
19+
bool include_intra_set{ true };
20+
bool include_inter_set{ false };
21+
22+
std::optional< double > expected_nb_interactions;
23+
};
24+
25+
template < typename ObjectType >
26+
struct StraussProcessDescription
27+
{
28+
SpatialDomainConfig< ObjectType::dim > domain;
29+
30+
std::vector< PoissonSetDescription< ObjectType > > sets;
31+
32+
std::vector< StraussInteractionDescription< ObjectType > > interactions;
33+
34+
PoissonSetDescription< ObjectType >& add_set(
35+
absl::string_view set_name, absl::string_view density_name )
36+
{
37+
auto& set = sets.emplace_back();
38+
set.set_name = set_name;
39+
set.density_name = density_name;
40+
return set;
41+
}
42+
43+
StraussInteractionDescription< ObjectType >& add_interaction(
44+
absl::string_view interaction_name )
45+
{
46+
auto& interaction = interactions.emplace_back();
47+
interaction.interaction_name = interaction_name;
48+
return interaction;
49+
}
50+
};
51+
52+
template < typename ObjectType >
53+
SimulationContext< ObjectType > build_strauss_process(
54+
const StraussProcessDescription< ObjectType >& description );
55+
56+
template < typename ObjectType >
57+
std::vector< geode::TargetStatisticConfig > build_strauss_targeted_stat(
58+
const StraussProcessDescription< ObjectType >& description );
59+
60+
} // namespace geode

include/geode/stochastic/sampling/direct/object_set_sampler/point_set_sampler.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ namespace geode
7070
constexpr index_t max_try{ 100 };
7171
for( const auto try_id : geode::Range{ max_try } )
7272
{
73+
geode_unused( try_id );
7374
if( domain_.extended_contains( new_point ) )
7475
{
7576
return new_point;
@@ -107,11 +108,11 @@ namespace geode
107108
};
108109

109110
template < index_t dimension >
110-
std::unique_ptr< ObjectSetSampler< Point2D > > build_objectset_sampler(
111-
const SpatialDomain< 2 >& domain,
112-
const ObjectSamplerConfig< Point< dimension > >& config )
111+
std::unique_ptr< ObjectSetSampler< Point< dimension > > >
112+
build_objectset_sampler( const SpatialDomain< dimension >& domain,
113+
const ObjectSamplerConfig< Point< dimension > >& config )
113114
{
114-
return std::make_unique< UniformPointSetSampler< 2 > >(
115+
return std::make_unique< UniformPointSetSampler< dimension > >(
115116
domain, config );
116117
}
117118

include/geode/stochastic/sampling/mcmc/helpers/simulation_context.hpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,13 @@ namespace geode
8282
template < typename ObjectType >
8383
struct SimulationContextConfig
8484
{
85+
ObjectSetDefinition< ObjectType >& add_set( absl::string_view name )
86+
{
87+
auto& set = sets.emplace_back();
88+
set.name = name;
89+
return set;
90+
}
91+
8592
SpatialDomainConfig< ObjectType::dim > domain;
8693

8794
std::vector< ObjectSetDefinition< ObjectType > > sets;
@@ -127,7 +134,7 @@ namespace geode
127134
// MH sampler
128135
// -------------------------
129136
context.mh_sampler =
130-
std::make_unique< geode::MetropolisHastings< geode::Point2D > >(
137+
std::make_unique< geode::MetropolisHastings< ObjectType > >(
131138
*context.model, std::move( proposal_kernel ) );
132139

133140
return context;

include/geode/stochastic/sampling/mcmc/metropolis_hasting_sampler.hpp

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,8 @@ namespace geode
150150
}
151151

152152
private:
153-
const double compute_log_accept( const double deltaU,
154-
const ProposalProbabilities& proposal_probas ) const
153+
double compute_log_accept(
154+
double deltaU, const ProposalProbabilities& proposal_probas ) const
155155
{
156156
return -beta_ * deltaU + proposal_probas.transition_probability();
157157
}
@@ -189,10 +189,11 @@ namespace geode
189189
const auto delta_log_energy =
190190
model_.energy().delta_log_add( state, new_object );
191191
return accept_or_reject( proposal, state, engine, delta_log_energy,
192-
[]( auto& state, auto& proposal ) {
193-
state.add_object(
194-
std::move( proposal.proposed_move.new_object.value() ),
195-
proposal.set_id, false );
192+
[]( auto& cur_state, auto& accepted_proposal ) {
193+
cur_state.add_object(
194+
std::move( accepted_proposal.proposed_move.new_object
195+
.value() ),
196+
accepted_proposal.set_id, false );
196197
} );
197198
};
198199

@@ -204,8 +205,9 @@ namespace geode
204205
const auto delta_log_energy =
205206
model_.energy().delta_log_remove( state, old_object_id );
206207
return accept_or_reject( proposal, state, engine, delta_log_energy,
207-
[]( auto& state, auto& proposal ) {
208-
state.remove_free_object( proposal.old_object_id() );
208+
[]( auto& cur_state, auto& accepted_proposal ) {
209+
cur_state.remove_free_object(
210+
accepted_proposal.old_object_id() );
209211
} );
210212
};
211213

@@ -218,10 +220,11 @@ namespace geode
218220
const auto delta_log_energy = model_.energy().delta_log_change(
219221
state, old_object_id, new_object );
220222
return accept_or_reject( proposal, state, engine, delta_log_energy,
221-
[]( auto& state, auto& proposal ) {
222-
state.update_free_object( proposal.old_object_id(),
223-
std::move(
224-
proposal.proposed_move.new_object.value() ) );
223+
[]( auto& cur_state, auto& accepted_proposal ) {
224+
cur_state.update_free_object(
225+
accepted_proposal.old_object_id(),
226+
std::move( accepted_proposal.proposed_move.new_object
227+
.value() ) );
225228
} );
226229
};
227230

include/geode/stochastic/spatial/single_object_features/segment_length_feature.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
* USE OR OTHER DEALINGS IN THE SOFTWARE.
2222
*
2323
*/
24+
#pragma once
2425

2526
#include <geode/stochastic/spatial/single_object_features/single_object_feature.hpp>
2627

src/geode/stochastic/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ add_geode_library(
2222
NAME stochastic
2323
FOLDER "geode/stochastic"
2424
SOURCES
25+
"applications/poisson_process.cpp"
26+
"applications/strauss_process.cpp"
2527
"spatial/object_set.cpp"
2628
"spatial/object_sets.cpp"
2729
"spatial/object_neighborhood.cpp"
@@ -37,6 +39,8 @@ add_geode_library(
3739
"sampling/random_engine.cpp"
3840
"common.cpp"
3941
PUBLIC_HEADERS
42+
"applications/poisson_process.hpp"
43+
"applications/strauss_process.hpp"
4044
#"inference/abc_shadow.hpp"
4145
"inference/statistics_tools.hpp"
4246
"inference/statistics_tracker.hpp"
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
#pragma once
2+
3+
#include <geode/stochastic/applications/poisson_process.hpp>
4+
5+
namespace geode
6+
{
7+
using PoissonDensityDescription = geode::SingleObjectTermConfig;
8+
9+
template < typename ObjectType >
10+
SimulationContext< ObjectType > build_poisson_process(
11+
const PoissonProcessDescription< ObjectType >& desc )
12+
{
13+
SimulationContextConfig< ObjectType > config;
14+
15+
config.domain = desc.domain;
16+
17+
for( const auto& set_desc : desc.sets )
18+
{
19+
auto& set = config.add_set( set_desc.set_name );
20+
21+
set.sampler = set_desc.sampler;
22+
23+
set.dynamics.birth_ratio = set_desc.birth_ratio;
24+
set.dynamics.death_ratio = set_desc.death_ratio;
25+
set.dynamics.change_ratio = set_desc.change_ratio;
26+
27+
PoissonDensityDescription density;
28+
29+
density.term_name = set_desc.density_name;
30+
density.object_set_names = { set_desc.set_name };
31+
density.lambda = set_desc.lambda;
32+
density.object_feature = ObjectInDomainFeatureConfig{};
33+
34+
config.model.terms.emplace_back( std::move( density ) );
35+
}
36+
37+
return build_simulation_context( config );
38+
}
39+
40+
template opengeode_stochastic_stochastic_api SimulationContext< Point2D >
41+
build_poisson_process< Point2D >(
42+
const PoissonProcessDescription< Point2D >& );
43+
template opengeode_stochastic_stochastic_api SimulationContext< Point3D >
44+
build_poisson_process< Point3D >(
45+
const PoissonProcessDescription< Point3D >& );
46+
47+
template < typename ObjectType >
48+
std::vector< geode::TargetStatisticConfig > build_poisson_targeted_stat(
49+
const PoissonProcessDescription< ObjectType >& description )
50+
{
51+
std::vector< geode::TargetStatisticConfig > targets;
52+
53+
for( const auto& set_desc : description.sets )
54+
{
55+
if( !set_desc.expected_nb_objects )
56+
{
57+
continue;
58+
}
59+
60+
targets.push_back( geode::TargetStatisticConfig{
61+
set_desc.density_name, *set_desc.expected_nb_objects, 0.1 } );
62+
}
63+
64+
return targets;
65+
}
66+
67+
template opengeode_stochastic_stochastic_api
68+
std::vector< geode::TargetStatisticConfig >
69+
build_poisson_targeted_stat< Point2D >(
70+
const PoissonProcessDescription< Point2D >& );
71+
template opengeode_stochastic_stochastic_api
72+
std::vector< geode::TargetStatisticConfig >
73+
build_poisson_targeted_stat< Point3D >(
74+
const PoissonProcessDescription< Point3D >& );
75+
} // namespace geode

0 commit comments

Comments
 (0)