Skip to content

Commit 27c16a6

Browse files
authored
Merge pull request #17 from Geode-solutions/refactor_objectset
feat(ObjectSet): refactor and rename
2 parents 9b28696 + 2a26e70 commit 27c16a6

30 files changed

Lines changed: 1004 additions & 618 deletions

include/geode/stochastic/sampling/direct/object_set_sampler/object_set_sampler.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
#pragma once
2525

2626
#include <geode/stochastic/sampling/random_engine.hpp>
27-
#include <geode/stochastic/spatial/object_set.hpp>
27+
#include <geode/stochastic/spatial/object_sets.hpp>
2828

2929
namespace geode
3030
{

include/geode/stochastic/sampling/mcmc/metropolis_hasting_sampler.hpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ namespace geode
5959
}
6060

6161
StepResult< ObjectType > step(
62-
ObjectSet< ObjectType >& state, RandomEngine& engine ) const
62+
ObjectSets< ObjectType >& state, RandomEngine& engine ) const
6363
{
6464
Proposal< ObjectType > proposal =
6565
proposal_kernel_->propose( state, engine );
@@ -79,7 +79,7 @@ namespace geode
7979
return StepResult< ObjectType >{};
8080
}
8181

82-
void walk( ObjectSet< ObjectType >& state,
82+
void walk( ObjectSets< ObjectType >& state,
8383
RandomEngine& engine,
8484
index_t nb_steps ) const
8585
{
@@ -90,7 +90,7 @@ namespace geode
9090
}
9191
}
9292

93-
ObjectSet< ObjectType > walk_copy( ObjectSet< ObjectType > initial,
93+
ObjectSets< ObjectType > walk_copy( ObjectSets< ObjectType > initial,
9494
RandomEngine& engine,
9595
index_t nb_steps ) const
9696
{
@@ -160,7 +160,7 @@ namespace geode
160160
template < typename ApplyMove >
161161
StepResult< ObjectType > accept_or_reject(
162162
Proposal< ObjectType >& proposal,
163-
ObjectSet< ObjectType >& state,
163+
ObjectSets< ObjectType >& state,
164164
RandomEngine& engine,
165165
const double delta_log_energy,
166166
ApplyMove&& apply_move ) const
@@ -183,7 +183,7 @@ namespace geode
183183
}
184184

185185
StepResult< ObjectType > birth_step( Proposal< ObjectType >& proposal,
186-
ObjectSet< ObjectType >& state,
186+
ObjectSets< ObjectType >& state,
187187
RandomEngine& engine ) const
188188
{
189189
const auto new_object = proposal.new_object();
@@ -193,12 +193,12 @@ namespace geode
193193
[]( auto& state, auto& proposal ) {
194194
state.add_object(
195195
std::move( proposal.proposed_move.new_object.value() ),
196-
proposal.subset_id );
196+
proposal.set_id );
197197
} );
198198
};
199199

200200
StepResult< ObjectType > death_step( Proposal< ObjectType >& proposal,
201-
ObjectSet< ObjectType >& state,
201+
ObjectSets< ObjectType >& state,
202202
RandomEngine& engine ) const
203203
{
204204
const auto old_object_id = proposal.old_object_id();
@@ -211,7 +211,7 @@ namespace geode
211211
};
212212

213213
StepResult< ObjectType > change_step( Proposal< ObjectType >& proposal,
214-
ObjectSet< ObjectType >& state,
214+
ObjectSets< ObjectType >& state,
215215
RandomEngine& engine ) const
216216
{
217217
const auto new_object = proposal.new_object();

include/geode/stochastic/sampling/mcmc/models/components/density_term.hpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
*/
2323
#pragma once
2424

25-
#include <geode/stochastic/spatial/object_set.hpp>
25+
#include <geode/stochastic/spatial/object_sets.hpp>
2626

2727
#include <geode/stochastic/sampling/mcmc/models/components/energy_term.hpp>
2828

@@ -35,51 +35,51 @@ namespace geode
3535
public:
3636
explicit DensityTerm( std::string_view name,
3737
double lambda,
38-
absl::flat_hash_set< uuid > targeted_subset_ids )
38+
absl::flat_hash_set< uuid > targeted_set_ids )
3939
: EnergyTerm< ObjectType >(
40-
name, lambda, std::move( targeted_subset_ids ) )
40+
name, lambda, std::move( targeted_set_ids ) )
4141
{
4242
}
4343

44-
double total_log( const ObjectSet< ObjectType >& state ) const override
44+
double total_log( const ObjectSets< ObjectType >& state ) const override
4545
{
4646
const auto n = this->statistic( state );
4747
return this->contribution( n );
4848
}
4949

50-
double delta_log_add( const ObjectSet< ObjectType >& /*state*/,
50+
double delta_log_add( const ObjectSets< ObjectType >& /*state*/,
5151
const ObjectRef< ObjectType >& new_object ) const override
5252
{
53-
if( !this->is_targeted_subset( new_object.subset ) )
53+
if( !this->is_targeted_set( new_object.set_id ) )
5454
{
5555
return 0.0;
5656
}
5757
return this->contribution( 1.0 );
5858
}
5959

60-
double delta_log_remove( const ObjectSet< ObjectType >& /*state*/,
60+
double delta_log_remove( const ObjectSets< ObjectType >& /*state*/,
6161
const ObjectId& object_id ) const override
6262
{
63-
if( !this->is_targeted_subset( object_id.subset ) )
63+
if( !this->is_targeted_set( object_id.set_id ) )
6464
{
6565
return 0.0;
6666
}
6767
return this->contribution( -1.0 );
6868
}
6969

70-
double delta_log_change( const ObjectSet< ObjectType >& /*state*/,
70+
double delta_log_change( const ObjectSets< ObjectType >& /*state*/,
7171
const ObjectId& /*old_object_id*/,
7272
const ObjectRef< ObjectType >& /*new_object*/ ) const override
7373
{
7474
return 0.0;
7575
}
7676

77-
double statistic( const ObjectSet< ObjectType >& state ) const override
77+
double statistic( const ObjectSets< ObjectType >& state ) const override
7878
{
7979
index_t count{ 0 };
80-
for( const auto& subset_uuid : this->targeted_subset_ids() )
80+
for( const auto& set_id : this->targeted_set_ids() )
8181
{
82-
count += state.nb_objects_in_subset( subset_uuid );
82+
count += state.nb_objects_in_set( set_id );
8383
}
8484
return static_cast< double >( count );
8585
}

include/geode/stochastic/sampling/mcmc/models/components/energy_term.hpp

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
#include <geode/basic/identifier_builder.hpp>
2929

3030
#include <geode/stochastic/common.hpp>
31-
#include <geode/stochastic/spatial/object_set.hpp>
31+
#include <geode/stochastic/spatial/object_sets.hpp>
3232

3333
#include <optional>
3434

@@ -88,14 +88,14 @@ namespace geode
8888
// std::string name;
8989
// std::string type;
9090
// double parameter_value;
91-
// std::optional< uuid > targeted_subset_id{};
91+
// std::optional< uuid > targeted_set_id{};
9292
// }
9393
//
9494
// struct StatisticalDescription
9595
// {
9696
// std::string label;
9797
// double value;
98-
// std::optional< uuid > targeted_subset_id{};
98+
// std::optional< uuid > targeted_set_id{};
9999
// };
100100

101101
template < typename ObjectType >
@@ -104,10 +104,9 @@ namespace geode
104104
public:
105105
explicit EnergyTerm( std::string_view name,
106106
double param,
107-
absl::flat_hash_set< uuid >&& targeted_subset_ids )
108-
: Identifier{},
109-
energy_scale_{ param },
110-
targeted_subset_ids_{ std::move( targeted_subset_ids ) }
107+
absl::flat_hash_set< uuid >&& targeted_set_ids )
108+
: energy_scale_{ param },
109+
targeted_set_ids_{ std::move( targeted_set_ids ) }
111110
{
112111
IdentifierBuilder builder( *this );
113112
builder.set_name( name );
@@ -120,9 +119,9 @@ namespace geode
120119
return energy_scale_.parameter();
121120
}
122121

123-
const absl::flat_hash_set< uuid >& targeted_subset_ids() const
122+
const absl::flat_hash_set< uuid >& targeted_set_ids() const
124123
{
125-
return targeted_subset_ids_;
124+
return targeted_set_ids_;
126125
}
127126

128127
/// Energy contribution for a given statistic multiplier
@@ -132,58 +131,57 @@ namespace geode
132131
}
133132

134133
virtual double total_log(
135-
const ObjectSet< ObjectType >& state ) const = 0;
134+
const ObjectSets< ObjectType >& state ) const = 0;
136135

137-
virtual double delta_log_add( const ObjectSet< ObjectType >& state,
136+
virtual double delta_log_add( const ObjectSets< ObjectType >& state,
138137
const ObjectRef< ObjectType >& new_object ) const = 0;
139138

140-
virtual double delta_log_remove( const ObjectSet< ObjectType >& state,
139+
virtual double delta_log_remove( const ObjectSets< ObjectType >& state,
141140
const ObjectId& object_id ) const = 0;
142141

143-
virtual double delta_log_change( const ObjectSet< ObjectType >& state,
142+
virtual double delta_log_change( const ObjectSets< ObjectType >& state,
144143
const ObjectId& old_object_id,
145144
const ObjectRef< ObjectType >& new_object ) const = 0;
146145

147146
virtual double statistic(
148-
const ObjectSet< ObjectType >& state ) const = 0;
147+
const ObjectSets< ObjectType >& state ) const = 0;
149148

150149
std::string string() const
151150
{
152151
auto message =
153152
absl::StrCat( "Term : ", name(), "; uuid: ", id().string(),
154153
" parameter value: ", energy_scale_.parameter(),
155-
" applyied on ", targeted_subset_ids_.size(),
154+
" applyied on ", targeted_set_ids_.size(),
156155
" object subsets -->" );
157-
for( const auto& subset_uuid : targeted_subset_ids_ )
156+
for( const auto& set_id : targeted_set_ids_ )
158157
{
159-
absl::StrAppend( &message, "\t", subset_uuid.string() );
158+
absl::StrAppend( &message, "\t", set_id.string() );
160159
}
161160
return message;
162161
}
163162

164163
protected:
165-
bool is_targeted_subset( const uuid& subset_id ) const
164+
bool is_targeted_set( const uuid& set_id ) const
166165
{
167-
return targeted_subset_ids_.find( subset_id )
168-
!= targeted_subset_ids_.end();
166+
return targeted_set_ids_.find( set_id ) != targeted_set_ids_.end();
169167
}
170168

171169
template < typename Func >
172170
void for_each_targeted_object(
173-
const ObjectSet< ObjectType >& state, Func&& do_apply ) const
171+
const ObjectSets< ObjectType >& state, Func&& do_apply ) const
174172
{
175-
for( const auto& targeted_subset_id : targeted_subset_ids_ )
173+
for( const auto& targeted_set_id : targeted_set_ids_ )
176174
{
177-
for( const auto id : geode::Range{
178-
state.nb_objects_in_subset( targeted_subset_id ) } )
175+
for( const auto id :
176+
geode::Range{ state.nb_objects_in_set( targeted_set_id ) } )
179177
{
180-
do_apply( ObjectId{ id, targeted_subset_id } );
178+
do_apply( ObjectId{ id, targeted_set_id } );
181179
}
182180
}
183181
}
184182

185183
private:
186184
detail::EnergyScale energy_scale_;
187-
absl::flat_hash_set< uuid > targeted_subset_ids_;
185+
absl::flat_hash_set< uuid > targeted_set_ids_;
188186
};
189187
} // namespace geode

include/geode/stochastic/sampling/mcmc/models/components/energy_term_collection.hpp

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ namespace geode
1919
{
2020
const uuid id = term->id();
2121
energy_terms_.emplace( id, term );
22-
for( const uuid& subset_id : term->targeted_subset_ids() )
22+
for( const uuid& set_id : term->targeted_set_ids() )
2323
{
24-
subset_to_terms_[subset_id].push_back( term );
24+
set_to_terms_[set_id].push_back( term );
2525
}
2626
return id;
2727
}
@@ -34,18 +34,18 @@ namespace geode
3434

3535
auto term = it->second;
3636

37-
for( const uuid& subset_id : term->targeted_subset_ids() )
37+
for( const uuid& set_id : term->targeted_set_ids() )
3838
{
39-
auto vec_it = subset_to_terms_.find( subset_id );
40-
if( vec_it == subset_to_terms_.end() )
39+
auto vec_it = set_to_terms_.find( set_id );
40+
if( vec_it == set_to_terms_.end() )
4141
continue;
4242

4343
auto& vec = vec_it->second;
4444
vec.erase(
4545
std::remove( vec.begin(), vec.end(), term ), vec.end() );
4646

4747
if( vec.empty() )
48-
subset_to_terms_.erase( vec_it );
48+
set_to_terms_.erase( vec_it );
4949
}
5050

5151
energy_terms_.erase( it );
@@ -55,7 +55,7 @@ namespace geode
5555
void clear()
5656
{
5757
energy_terms_.clear();
58-
subset_to_terms_.clear();
58+
set_to_terms_.clear();
5959
}
6060

6161
[[nodiscard]] index_t size() const
@@ -82,11 +82,11 @@ namespace geode
8282

8383
[[nodiscard]] const std::vector<
8484
std::shared_ptr< EnergyTerm< ObjectType > > >&
85-
terms_for_subset( const uuid& subset_id ) const
85+
terms_for_set( const uuid& set_id ) const
8686
{
87-
const auto it = subset_to_terms_.find( subset_id );
88-
OPENGEODE_EXCEPTION( it != subset_to_terms_.end(),
89-
"[EnergyTermCollection] - Object Subset (", subset_id.string(),
87+
const auto it = set_to_terms_.find( set_id );
88+
OPENGEODE_EXCEPTION( it != set_to_terms_.end(),
89+
"[EnergyTermCollection] - Object Subset (", set_id.string(),
9090
") does not have any energy term." );
9191
return it->second;
9292
}
@@ -103,14 +103,12 @@ namespace geode
103103
}
104104

105105
private:
106-
// strong ownership
107106
absl::flat_hash_map< uuid, std::shared_ptr< EnergyTerm< ObjectType > > >
108107
energy_terms_;
109108

110-
// subset index (shared ownership)
111109
absl::flat_hash_map< uuid,
112110
std::vector< std::shared_ptr< EnergyTerm< ObjectType > > > >
113-
subset_to_terms_;
111+
set_to_terms_;
114112
};
115113

116114
} // namespace geode

0 commit comments

Comments
 (0)