Skip to content

Commit 9b28696

Browse files
authored
Merge pull request #16 from Geode-solutions/feat_refactor_gibbs_energy
feat(GibbsEnergy): refactor energy terms
2 parents 00d4de7 + 27fb033 commit 9b28696

12 files changed

Lines changed: 234 additions & 218 deletions

File tree

include/geode/stochastic/sampling/mcmc/metropolis_hasting_sampler.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ namespace geode
188188
{
189189
const auto new_object = proposal.new_object();
190190
const auto delta_log_energy =
191-
energy_.delta_log_energy_add( state, new_object );
191+
energy_.delta_log_add( state, new_object );
192192
return accept_or_reject( proposal, state, engine, delta_log_energy,
193193
[]( auto& state, auto& proposal ) {
194194
state.add_object(
@@ -203,7 +203,7 @@ namespace geode
203203
{
204204
const auto old_object_id = proposal.old_object_id();
205205
const auto delta_log_energy =
206-
energy_.delta_log_energy_remove( state, old_object_id );
206+
energy_.delta_log_remove( state, old_object_id );
207207
return accept_or_reject( proposal, state, engine, delta_log_energy,
208208
[]( auto& state, auto& proposal ) {
209209
state.remove_object( proposal.old_object_id() );
@@ -216,8 +216,8 @@ namespace geode
216216
{
217217
const auto new_object = proposal.new_object();
218218
const auto old_object_id = proposal.old_object_id();
219-
const auto delta_log_energy = energy_.delta_log_energy_change(
220-
state, old_object_id, new_object );
219+
const auto delta_log_energy =
220+
energy_.delta_log_change( state, old_object_id, new_object );
221221
// should we test that objects are in the same group?
222222
// should be ensured by the dynamic
223223
return accept_or_reject( proposal, state, engine, delta_log_energy,

include/geode/stochastic/sampling/mcmc/models/components/density_term.hpp

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,11 @@ namespace geode
3333
class DensityTerm : public EnergyTerm< ObjectType >
3434
{
3535
public:
36-
explicit DensityTerm( std::string_view name, double lambda )
37-
: EnergyTerm< ObjectType >( name, lambda )
38-
{
39-
}
40-
41-
explicit DensityTerm(
42-
std::string_view name, double lambda, const uuid& subset_id )
43-
: EnergyTerm< ObjectType >( name, lambda, subset_id )
36+
explicit DensityTerm( std::string_view name,
37+
double lambda,
38+
absl::flat_hash_set< uuid > targeted_subset_ids )
39+
: EnergyTerm< ObjectType >(
40+
name, lambda, std::move( targeted_subset_ids ) )
4441
{
4542
}
4643

@@ -79,12 +76,12 @@ namespace geode
7976

8077
double statistic( const ObjectSet< ObjectType >& state ) const override
8178
{
82-
if( this->targeted_subset_id() )
79+
index_t count{ 0 };
80+
for( const auto& subset_uuid : this->targeted_subset_ids() )
8381
{
84-
return static_cast< double >( state.nb_objects_in_subset(
85-
this->targeted_subset_id().value() ) );
82+
count += state.nb_objects_in_subset( subset_uuid );
8683
}
87-
return static_cast< double >( state.nb_objects() );
84+
return static_cast< double >( count );
8885
}
8986
};
9087
} // namespace geode

include/geode/stochastic/sampling/mcmc/models/components/energy_term.hpp

Lines changed: 25 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,14 @@
2222
*/
2323
#pragma once
2424

25+
#include <absl/container/flat_hash_set.h>
26+
27+
#include <geode/basic/identifier.hpp>
28+
#include <geode/basic/identifier_builder.hpp>
29+
2530
#include <geode/stochastic/common.hpp>
2631
#include <geode/stochastic/spatial/object_set.hpp>
32+
2733
#include <optional>
2834

2935
namespace geode
@@ -93,43 +99,30 @@ namespace geode
9399
// };
94100

95101
template < typename ObjectType >
96-
class EnergyTerm
102+
class EnergyTerm : public Identifier
97103
{
98104
public:
99-
explicit EnergyTerm( std::string_view name, double param )
100-
: name_{ name }, energy_scale_{ param }
101-
{
102-
}
103-
104105
explicit EnergyTerm( std::string_view name,
105106
double param,
106-
const uuid& targeted_subset_id )
107-
: name_{ name },
107+
absl::flat_hash_set< uuid >&& targeted_subset_ids )
108+
: Identifier{},
108109
energy_scale_{ param },
109-
targeted_subset_id_{ targeted_subset_id }
110+
targeted_subset_ids_{ std::move( targeted_subset_ids ) }
110111
{
112+
IdentifierBuilder builder( *this );
113+
builder.set_name( name );
111114
}
112115

113116
virtual ~EnergyTerm() = default;
114117

115-
const uuid& id() const
116-
{
117-
return energy_term_id_;
118-
}
119-
120-
std::string_view name() const
121-
{
122-
return name_;
123-
}
124-
125118
double parameter() const
126119
{
127120
return energy_scale_.parameter();
128121
}
129122

130-
std::optional< uuid > targeted_subset_id() const
123+
const absl::flat_hash_set< uuid >& targeted_subset_ids() const
131124
{
132-
return targeted_subset_id_;
125+
return targeted_subset_ids_;
133126
}
134127

135128
/// Energy contribution for a given statistic multiplier
@@ -158,45 +151,39 @@ namespace geode
158151
{
159152
auto message =
160153
absl::StrCat( "Term : ", name(), "; uuid: ", id().string(),
161-
" parameter value: ", energy_scale_.parameter() );
162-
if( targeted_subset_id_ )
154+
" parameter value: ", energy_scale_.parameter(),
155+
" applyied on ", targeted_subset_ids_.size(),
156+
" object subsets -->" );
157+
for( const auto& subset_uuid : targeted_subset_ids_ )
163158
{
164-
absl::StrAppend( &message, " targetted subset: ",
165-
targeted_subset_id_.value().string() );
159+
absl::StrAppend( &message, "\t", subset_uuid.string() );
166160
}
167161
return message;
168162
}
169163

170164
protected:
171165
bool is_targeted_subset( const uuid& subset_id ) const
172166
{
173-
return !targeted_subset_id_ || subset_id == *targeted_subset_id_;
167+
return targeted_subset_ids_.find( subset_id )
168+
!= targeted_subset_ids_.end();
174169
}
175170

176171
template < typename Func >
177172
void for_each_targeted_object(
178173
const ObjectSet< ObjectType >& state, Func&& do_apply ) const
179174
{
180-
if( targeted_subset_id_ )
175+
for( const auto& targeted_subset_id : targeted_subset_ids_ )
181176
{
182177
for( const auto id : geode::Range{
183-
state.nb_objects_in_subset( *targeted_subset_id_ ) } )
178+
state.nb_objects_in_subset( targeted_subset_id ) } )
184179
{
185-
do_apply( { id, *targeted_subset_id_ } );
180+
do_apply( ObjectId{ id, targeted_subset_id } );
186181
}
187-
return;
188-
}
189-
for( const auto& id : state.get_all_object() )
190-
{
191-
do_apply( id );
192182
}
193183
}
194184

195185
private:
196-
std::string name_;
197186
detail::EnergyScale energy_scale_;
198-
199-
std::optional< uuid > targeted_subset_id_{};
200-
uuid energy_term_id_{};
187+
absl::flat_hash_set< uuid > targeted_subset_ids_;
201188
};
202189
} // namespace geode
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
#pragma once
2+
3+
#include <absl/container/flat_hash_map.h>
4+
#include <geode/stochastic/sampling/mcmc/models/components/energy_term.hpp>
5+
6+
namespace geode
7+
{
8+
template < typename ObjectType >
9+
class EnergyTermCollection
10+
{
11+
OPENGEODE_DISABLE_COPY( EnergyTermCollection );
12+
13+
public:
14+
EnergyTermCollection() = default;
15+
EnergyTermCollection( EnergyTermCollection&& ) noexcept = default;
16+
~EnergyTermCollection() = default;
17+
18+
uuid add_energy_term( std::shared_ptr< EnergyTerm< ObjectType > > term )
19+
{
20+
const uuid id = term->id();
21+
energy_terms_.emplace( id, term );
22+
for( const uuid& subset_id : term->targeted_subset_ids() )
23+
{
24+
subset_to_terms_[subset_id].push_back( term );
25+
}
26+
return id;
27+
}
28+
29+
bool remove_energy_term( const uuid& id )
30+
{
31+
auto it = energy_terms_.find( id );
32+
if( it == energy_terms_.end() )
33+
return false;
34+
35+
auto term = it->second;
36+
37+
for( const uuid& subset_id : term->targeted_subset_ids() )
38+
{
39+
auto vec_it = subset_to_terms_.find( subset_id );
40+
if( vec_it == subset_to_terms_.end() )
41+
continue;
42+
43+
auto& vec = vec_it->second;
44+
vec.erase(
45+
std::remove( vec.begin(), vec.end(), term ), vec.end() );
46+
47+
if( vec.empty() )
48+
subset_to_terms_.erase( vec_it );
49+
}
50+
51+
energy_terms_.erase( it );
52+
return true;
53+
}
54+
55+
void clear()
56+
{
57+
energy_terms_.clear();
58+
subset_to_terms_.clear();
59+
}
60+
61+
[[nodiscard]] index_t size() const
62+
{
63+
return energy_terms_.size();
64+
}
65+
66+
[[nodiscard]] std::shared_ptr< const EnergyTerm< ObjectType > > get(
67+
const uuid& id ) const
68+
{
69+
auto it = energy_terms_.find( id );
70+
OPENGEODE_EXCEPTION( it != energy_terms_.end(),
71+
absl::StrCat( "[EnergyTermCollection] Unknown energy term: ",
72+
id.string() ) );
73+
return it->second;
74+
}
75+
76+
[[nodiscard]] const absl::flat_hash_map< uuid,
77+
std::shared_ptr< EnergyTerm< ObjectType > > >&
78+
all_terms() const
79+
{
80+
return energy_terms_;
81+
}
82+
83+
[[nodiscard]] const std::vector<
84+
std::shared_ptr< EnergyTerm< ObjectType > > >&
85+
terms_for_subset( const uuid& subset_id ) const
86+
{
87+
const auto it = subset_to_terms_.find( subset_id );
88+
OPENGEODE_EXCEPTION( it != subset_to_terms_.end(),
89+
"[EnergyTermCollection] - Object Subset (", subset_id.string(),
90+
") does not have any energy term." );
91+
return it->second;
92+
}
93+
94+
std::string string() const
95+
{
96+
auto message = absl::StrCat(
97+
"EnergyTermCollection: ", energy_terms_.size(), " terms:" );
98+
for( const auto& [id, term] : energy_terms_ )
99+
{
100+
absl::StrAppend( &message, "\n\t --> ", term->string() );
101+
}
102+
return message;
103+
}
104+
105+
private:
106+
// strong ownership
107+
absl::flat_hash_map< uuid, std::shared_ptr< EnergyTerm< ObjectType > > >
108+
energy_terms_;
109+
110+
// subset index (shared ownership)
111+
absl::flat_hash_map< uuid,
112+
std::vector< std::shared_ptr< EnergyTerm< ObjectType > > > >
113+
subset_to_terms_;
114+
};
115+
116+
} // namespace geode

include/geode/stochastic/sampling/mcmc/models/components/pairwise_term.hpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,10 @@ namespace geode
3737
public:
3838
explicit PairwiseTerm( std::string_view name,
3939
double gamma,
40-
std::unique_ptr< PairwiseInteraction< ObjectType > >&& interaction )
41-
: EnergyTerm< ObjectType >( name, gamma ),
42-
interaction_( std::move( interaction ) )
43-
{
44-
}
45-
46-
explicit PairwiseTerm( std::string_view name,
47-
double gamma,
48-
std::unique_ptr< PairwiseInteraction< ObjectType > >&& interaction,
49-
const uuid& subset_id )
50-
: EnergyTerm< ObjectType >( name, gamma, subset_id ),
40+
absl::flat_hash_set< uuid > targeted_subset_ids,
41+
std::unique_ptr< PairwiseInteraction< ObjectType > > interaction )
42+
: EnergyTerm< ObjectType >(
43+
name, gamma, std::move( targeted_subset_ids ) ),
5144
interaction_( std::move( interaction ) )
5245
{
5346
}
@@ -157,6 +150,16 @@ namespace geode
157150
state.get_all_object(); // state.neighbors( obj_id, 1.1 );
158151
for( const auto& neigh_obj_id : neighbors )
159152
{
153+
// if( neigh_obj_id.subset_id < obj_id.subset_id )
154+
//{
155+
// continue;
156+
// }
157+
// if( neigh_obj_id.subset_id == obj_id.subset_id
158+
// && neigh_obj_id.id <= obj_id.id )
159+
//{
160+
// continue;
161+
// }
162+
160163
if( neigh_obj_id == obj_id )
161164
{
162165
continue;

include/geode/stochastic/sampling/mcmc/models/components/single_object_term.hpp

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,9 @@ namespace geode
3434
public:
3535
explicit SingleObjectTerm( std::string_view name,
3636
double lambda,
37+
absl::flat_hash_set< uuid > targeted_subset_ids,
3738
ObjectContributionFunc contribution_func )
38-
: EnergyTerm< ObjectType >( name, lambda ),
39-
contribution_func_( std::move( contribution_func ) )
40-
{
41-
}
42-
43-
explicit SingleObjectTerm( std::string_view name,
44-
double lambda,
45-
ObjectContributionFunc contribution_func,
46-
std::optional< uuid > subset_id )
47-
: EnergyTerm< ObjectType >( name, lambda, subset_id ),
39+
: EnergyTerm< ObjectType >( name, lambda, targeted_subset_ids ),
4840
contribution_func_( std::move( contribution_func ) )
4941
{
5042
}

0 commit comments

Comments
 (0)