Skip to content

Commit a14d16e

Browse files
authored
Merge pull request #26 from Geode-solutions/use_rtree_for_stat
fix(Neighborhood): use rtree search in pairwise interactions.
2 parents 81cf426 + 2768652 commit a14d16e

14 files changed

Lines changed: 101 additions & 90 deletions

File tree

commitlint.config.js

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
export default {
1+
const Configuration = {
22
extends: ["@commitlint/config-angular"],
33
rules: {
44
"scope-empty": [2, "never"],
@@ -12,5 +12,8 @@ export default {
1212
"subject-full-stop": [0],
1313
"type-case": [0],
1414
"type-empty": [0],
15+
"type-enum": [2, "always", ["feat", "fix", "perf"]],
1516
},
1617
}
18+
19+
export default Configuration

include/geode/stochastic/sampling/mcmc/helpers/fracture_simulation_runner.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,7 @@ namespace geode
117117
this->energy_terms_collection_.add_energy_term(
118118
std::make_unique< DensityTerm< OwnerSegment2D > >(
119119
absl::StrCat( set_desc.name, "_density" ),
120-
set_desc.p20,
121-
absl::flat_hash_set< uuid >{ set_id } ) ) );
120+
set_desc.p20, std::vector< uuid >{ set_id } ) ) );
122121
// spacing
123122
if( set_desc.minimal_spacing < GLOBAL_EPSILON )
124123
{
@@ -133,7 +132,7 @@ namespace geode
133132
this->energy_terms_collection_.add_energy_term(
134133
std::make_unique< PairwiseTerm< OwnerSegment2D > >(
135134
absl::StrCat( set_desc.name, "_min_spacing" ), 0.,
136-
absl::flat_hash_set< uuid >{ set_id },
135+
std::vector< uuid >{ set_id },
137136
std::move( interaction ) ) ) );
138137
}
139138

include/geode/stochastic/sampling/mcmc/models/components/density_term.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ namespace geode
3535
public:
3636
explicit DensityTerm( std::string_view name,
3737
double lambda,
38-
absl::flat_hash_set< uuid > targeted_set_ids )
38+
std::vector< uuid > targeted_set_ids )
3939
: EnergyTerm< ObjectType >(
4040
name, lambda, std::move( targeted_set_ids ) )
4141
{

include/geode/stochastic/sampling/mcmc/models/components/energy_term.hpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,11 @@ namespace geode
8888
public:
8989
explicit EnergyTerm( std::string_view name,
9090
double param,
91-
absl::flat_hash_set< uuid >&& targeted_set_ids )
91+
std::vector< uuid >&& targeted_set_ids )
9292
: energy_scale_{ param },
9393
targeted_set_ids_{ std::move( targeted_set_ids ) }
9494
{
95+
std::sort( targeted_set_ids_.begin(), targeted_set_ids_.end() );
9596
IdentifierBuilder builder( *this );
9697
builder.set_name( name );
9798
}
@@ -103,7 +104,7 @@ namespace geode
103104
return energy_scale_.parameter();
104105
}
105106

106-
const absl::flat_hash_set< uuid >& targeted_set_ids() const
107+
const std::vector< uuid >& targeted_set_ids() const
107108
{
108109
return targeted_set_ids_;
109110
}
@@ -147,7 +148,8 @@ namespace geode
147148
protected:
148149
bool is_targeted_set( const uuid& set_id ) const
149150
{
150-
return targeted_set_ids_.find( set_id ) != targeted_set_ids_.end();
151+
return std::binary_search(
152+
targeted_set_ids_.begin(), targeted_set_ids_.end(), set_id );
151153
}
152154

153155
template < typename Func >
@@ -166,6 +168,6 @@ namespace geode
166168

167169
private:
168170
detail::EnergyScale energy_scale_;
169-
absl::flat_hash_set< uuid > targeted_set_ids_;
171+
std::vector< uuid > targeted_set_ids_;
170172
};
171173
} // namespace geode

include/geode/stochastic/sampling/mcmc/models/components/pairwise_term.hpp

Lines changed: 38 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ namespace geode
3737
public:
3838
explicit PairwiseTerm( std::string_view name,
3939
double gamma,
40-
absl::flat_hash_set< uuid > targeted_set_ids,
40+
std::vector< uuid > targeted_set_ids,
4141
std::unique_ptr< PairwiseInteraction< ObjectType > > interaction )
4242
: EnergyTerm< ObjectType >(
4343
name, gamma, std::move( targeted_set_ids ) ),
@@ -59,8 +59,9 @@ namespace geode
5959
return 0.0;
6060
}
6161
double delta = 0.0;
62-
const auto neighbors = state.neighbors( new_object.object,
63-
interaction_->neighborhood_searching_distance() );
62+
const auto neighbors =
63+
state.neighbors( new_object.object, this->targeted_set_ids(),
64+
interaction_->neighborhood_searching_distance() );
6465
for( const auto& neigh_id : neighbors )
6566
{
6667
geode::ObjectRef< ObjectType > neigh_object{
@@ -83,8 +84,9 @@ namespace geode
8384
ObjectRef< ObjectType > object_to_remove{
8485
state.get_object( object_id ), object_id.set_id
8586
};
86-
const auto neighbors = state.neighbors(
87-
object_id, interaction_->neighborhood_searching_distance() );
87+
const auto neighbors =
88+
state.neighbors( object_id, this->targeted_set_ids(),
89+
interaction_->neighborhood_searching_distance() );
8890
for( auto neigh_id : neighbors )
8991
{
9092
ObjectRef< ObjectType > neigh_object{
@@ -111,8 +113,9 @@ namespace geode
111113
ObjectRef< ObjectType > object_to_remove{
112114
state.get_object( old_object_id ), old_object_id.set_id
113115
};
114-
const auto old_neighbors = state.neighbors( old_object_id,
115-
interaction_->neighborhood_searching_distance() );
116+
const auto old_neighbors =
117+
state.neighbors( old_object_id, this->targeted_set_ids(),
118+
interaction_->neighborhood_searching_distance() );
116119
for( auto neigh_id : old_neighbors )
117120
{
118121
ObjectRef< ObjectType > neigh_object{
@@ -123,8 +126,9 @@ namespace geode
123126
}
124127

125128
// Add new object's interactions
126-
const auto new_neighbors = state.neighbors( new_object.object,
127-
interaction_->neighborhood_searching_distance() );
129+
const auto new_neighbors =
130+
state.neighbors( new_object.object, this->targeted_set_ids(),
131+
interaction_->neighborhood_searching_distance() );
128132
for( auto neigh_id : new_neighbors )
129133
{
130134
if( old_object_id == neigh_id )
@@ -143,36 +147,33 @@ namespace geode
143147
double statistic( const ObjectSets< ObjectType >& state ) const override
144148
{
145149
double sum = 0.0;
146-
this->for_each_targeted_object( state, [&]( const ObjectId&
147-
obj_id ) {
148-
const auto& cur_obj = state.get_object( obj_id );
149-
const auto neighbors =
150-
state.get_all_object(); // state.neighbors( obj_id, 1.1 );
151-
for( const auto& neigh_obj_id : neighbors )
152-
{
153-
// if( neigh_obj_id.set_id < obj_id.set_id )
154-
//{
155-
// continue;
156-
// }
157-
// if( neigh_obj_id.set_id == obj_id.set_id
158-
// && neigh_obj_id.index <= obj_id.index )
159-
//{
160-
// continue;
161-
// }
162-
163-
if( neigh_obj_id == obj_id )
164-
{
165-
continue;
166-
}
150+
this->for_each_targeted_object(
151+
state, [&]( const ObjectId& obj_id ) {
152+
const auto& cur_obj = state.get_object( obj_id );
167153
ObjectRef< ObjectType > object{ cur_obj, obj_id.set_id };
168-
ObjectRef< ObjectType > neigh_object{
169-
state.get_object( neigh_obj_id ), neigh_obj_id.set_id
170-
};
154+
const auto neighbors =
155+
state.neighbors( obj_id, this->targeted_set_ids(),
156+
interaction_->neighborhood_searching_distance() );
157+
for( const auto& neigh_obj_id : neighbors )
158+
{
159+
if( neigh_obj_id.set_id < obj_id.set_id )
160+
{
161+
continue;
162+
}
163+
if( neigh_obj_id.set_id == obj_id.set_id
164+
&& neigh_obj_id.index <= obj_id.index )
165+
{
166+
continue;
167+
}
168+
ObjectRef< ObjectType > neigh_object{
169+
state.get_object( neigh_obj_id ),
170+
neigh_obj_id.set_id
171+
};
171172

172-
sum += interaction_->evaluate( object, neigh_object );
173-
}
174-
} );
175-
return sum / 2.;
173+
sum += interaction_->evaluate( object, neigh_object );
174+
}
175+
} );
176+
return sum;
176177
}
177178

178179
private:

include/geode/stochastic/spatial/object_neighborhood.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ namespace geode
6464

6565
std::vector< ObjectId > get_all_neighbor_ids(
6666
const BoundingBox< dimension >& box,
67+
const std::vector< uuid >& targeted_set_ids,
6768
std::optional< ObjectId > exclude_self_id ) const;
6869

6970
private:

include/geode/stochastic/spatial/object_sets.hpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,14 @@ namespace geode
7070
void remove_object( const ObjectId& object_id );
7171

7272
// Object neighbor search by ObjectId (always excludes self)
73-
std::vector< ObjectId > neighbors(
74-
const ObjectId& object_id, double searching_distance ) const;
73+
std::vector< ObjectId > neighbors( const ObjectId& object_id,
74+
const std::vector< uuid >& targeted_set_ids,
75+
double searching_distance ) const;
7576
// Object neighbor search by arbitrary object (return self if in the
7677
// object_set)
77-
std::vector< ObjectId > neighbors(
78-
const Type& object, double searching_distance ) const;
78+
std::vector< ObjectId > neighbors( const Type& object,
79+
const std::vector< uuid >& targeted_set_ids,
80+
double searching_distance ) const;
7981

8082
std::string string() const;
8183

src/geode/stochastic/spatial/object_neighborhood.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,19 @@ namespace geode
9090
std::vector< ObjectId >
9191
ObjectNeighborhood< dimension >::get_all_neighbor_ids(
9292
const BoundingBox< dimension >& box,
93+
const std::vector< uuid >& targeted_set_ids,
9394
std::optional< ObjectId > exclude_self_id ) const
9495
{
9596
auto [min_bound, max_bound] = bounding_box_bounds< dimension >( box );
9697
std::vector< ObjectId > res;
9798
tree_.Search( min_bound.data(), max_bound.data(),
98-
[&res, &exclude_self_id]( const ObjectId& cur_id ) -> bool {
99+
[&res, &exclude_self_id, &targeted_set_ids](
100+
const ObjectId& cur_id ) -> bool {
101+
if( !std::binary_search( targeted_set_ids.begin(),
102+
targeted_set_ids.end(), cur_id.set_id ) )
103+
{
104+
return true;
105+
}
99106
if( exclude_self_id && exclude_self_id.value() == cur_id )
100107
{
101108
return true;

src/geode/stochastic/spatial/object_sets.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,20 +123,25 @@ namespace geode
123123

124124
template < typename Type >
125125
std::vector< ObjectId > ObjectSets< Type >::neighbors(
126-
const ObjectId& object_id, double searching_distance ) const
126+
const ObjectId& object_id,
127+
const std::vector< uuid >& targeted_set_ids,
128+
double searching_distance ) const
127129
{
128130
auto box = object_bounding_box( get_object( object_id ) );
129131
box.extends( searching_distance * 2. );
130-
return neighborhood_.get_all_neighbor_ids( box, object_id );
132+
return neighborhood_.get_all_neighbor_ids(
133+
box, targeted_set_ids, object_id );
131134
}
132135

133136
template < typename Type >
134-
std::vector< ObjectId > ObjectSets< Type >::neighbors(
135-
const Type& object, double searching_distance ) const
137+
std::vector< ObjectId > ObjectSets< Type >::neighbors( const Type& object,
138+
const std::vector< uuid >& targeted_set_ids,
139+
double searching_distance ) const
136140
{
137141
auto box = object_bounding_box( object );
138142
box.extends( searching_distance * 2. );
139-
return neighborhood_.get_all_neighbor_ids( box, std::nullopt );
143+
return neighborhood_.get_all_neighbor_ids(
144+
box, targeted_set_ids, std::nullopt );
140145
}
141146

142147
template < typename Type >

tests/stochastic/sampling/mcmc/models/test-gibbs-energy.cpp

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -31,30 +31,21 @@
3131
#include <geode/stochastic/sampling/mcmc/models/gibbs_energy.hpp>
3232
#include <geode/stochastic/spatial/object_sets.hpp>
3333

34-
namespace
34+
void test_gibbs_energy()
3535
{
36-
geode::ObjectSets< geode::Point2D > create_object_set(
37-
const geode::uuid& set_id )
38-
{
39-
geode::Point2D p1{ { 0., 0. } };
40-
geode::Point2D p2{ { 1., 1. } };
41-
42-
geode::ObjectSets< geode::Point2D > pattern;
43-
pattern.add_object( std::move( p1 ), set_id );
44-
pattern.add_object( std::move( p2 ), set_id );
45-
46-
return pattern;
47-
}
48-
} // namespace
36+
geode::ObjectSets< geode::Point2D > pattern;
37+
const auto set_id = pattern.add_set( "default_name" );
38+
geode::Point2D p1{ { 0., 0. } };
39+
geode::Point2D p2{ { 1., 1. } };
40+
pattern.add_object( std::move( p1 ), set_id );
41+
pattern.add_object( std::move( p2 ), set_id );
4942

50-
void test_gibbs_energy( const geode::uuid& set_id )
51-
{
5243
geode::EnergyTermCollection< geode::Point2D > energy_terms;
5344

5445
// Add intensity term
5546
energy_terms.add_energy_term(
5647
std::make_unique< geode::DensityTerm< geode::Point2D > >(
57-
"intensity", 0.5, absl::flat_hash_set< geode::uuid >{ set_id } ) );
48+
"intensity", 0.5, std::vector< geode::uuid >{ set_id } ) );
5849

5950
// Add pairwise term with trivial interaction: always counts 1 for each pair
6051
auto interaction =
@@ -63,16 +54,14 @@ void test_gibbs_energy( const geode::uuid& set_id )
6354

6455
energy_terms.add_energy_term(
6556
std::make_unique< geode::PairwiseTerm< geode::Point2D > >(
66-
"interaction", 0.8, absl::flat_hash_set< geode::uuid >{ set_id },
57+
"interaction", 0.8, std::vector< geode::uuid >{ set_id },
6758
std::move( interaction ) ) );
6859

6960
OPENGEODE_EXCEPTION( energy_terms.size() == 2,
7061
"[test gibbs] Wrong number of components after adding terms." );
7162

7263
geode::GibbsEnergy< geode::Point2D > gibbs_energy( energy_terms );
7364

74-
auto pattern = create_object_set( set_id );
75-
7665
// Check total log-energy is finite
7766
double total_energy = gibbs_energy.total_log_energy( pattern );
7867
OPENGEODE_EXCEPTION( std::isfinite( total_energy ),
@@ -99,7 +88,7 @@ void test_gibbs_energy( const geode::uuid& set_id )
9988

10089
// Clear components and verify
10190
energy_terms.clear();
102-
OPENGEODE_EXCEPTION( energy_terms.size() == 2,
91+
OPENGEODE_EXCEPTION( energy_terms.size() == 0,
10392
"[test gibbs] Components not cleared properly." );
10493
}
10594

@@ -108,9 +97,9 @@ int main()
10897
try
10998
{
11099
geode::StochasticLibrary::initialize();
111-
geode::uuid set_id;
112-
113-
// test_gibbs_energy( set_id );
100+
test_gibbs_energy();
101+
geode::Logger::info( "MH TEST GIBBS ENERGY SUCCESS" );
102+
return 0;
114103
}
115104
catch( ... )
116105
{

0 commit comments

Comments
 (0)