Skip to content

Commit a1b1ea5

Browse files
feat(Helpers): create inference helpers.
1 parent 964ff1f commit a1b1ea5

26 files changed

Lines changed: 419 additions & 333 deletions

bindings/python/src/stochastic/sampling/mcmc/helpers/fracture_simulation_runner.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ namespace geode
8888
// pybind11::arg( "engine" ), pybind11::arg( "steps"
8989
// ), "Run simulation for a fixed number of steps." )
9090
.def( "run",
91-
static_cast< StatisticsMonitor ( FractureSimulationRunner::* )(
91+
static_cast< StatisticsTracker ( FractureSimulationRunner::* )(
9292
RandomEngine&, const SimulationConfigurator& ) >(
9393
&FractureSimulationRunner::run ),
9494
pybind11::arg( "engine" ), pybind11::arg( "config" ),

bindings/python/src/stochastic/sampling/mcmc/helpers/simulation_monitor.hpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,30 +21,30 @@
2121
*
2222
*/
2323

24-
#include <geode/stochastic/sampling/mcmc/helpers/simulation_monitor.hpp>
24+
#include <geode/stochastic/inference/statistics_tracker.hpp>
2525

2626
namespace geode
2727
{
2828
void define_simulation_monitor( pybind11::module &module )
2929
{
30-
pybind11::class_< geode::StatisticsMonitor >(
31-
module, "StatisticsMonitor" )
30+
pybind11::class_< geode::StatisticsTracker >(
31+
module, "StatisticsTracker" )
3232
.def( pybind11::init< geode::index_t >(),
3333
pybind11::arg( "nb_energy_terms" ),
34-
"Create a StatisticsMonitor for a given number of energy "
34+
"Create a StatisticsTracker for a given number of energy "
3535
"terms" )
36-
.def( "add_realization", &geode::StatisticsMonitor::add_realization,
36+
.def( "add_realization", &geode::StatisticsTracker::add_realization,
3737
pybind11::arg( "values" ),
3838
"Add a realization (vector of doubles) to update statistics" )
39-
.def( "statiscal_count", &geode::StatisticsMonitor::statiscal_count,
39+
.def( "statiscal_count", &geode::StatisticsTracker::statiscal_count,
4040
"Return the number of realizations added" )
41-
.def_property_readonly( "means", &geode::StatisticsMonitor::means,
41+
.def_property_readonly( "means", &geode::StatisticsTracker::means,
4242
"Return the computed mean values for each energy term" )
4343
.def_property_readonly( "variances",
44-
&geode::StatisticsMonitor::variances,
44+
&geode::StatisticsTracker::variances,
4545
"Return the computed variances for each energy term" )
46-
.def( "__repr__", []( const geode::StatisticsMonitor &self ) {
47-
return "<StatisticsMonitor count="
46+
.def( "__repr__", []( const geode::StatisticsTracker &self ) {
47+
return "<StatisticsTracker count="
4848
+ std::to_string( self.statiscal_count() ) + ">";
4949
} );
5050
}

bindings/python/src/stochastic/sampling/mcmc/helpers/simulation_printer.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ namespace geode
6464
// &SimulationPrinter::print_statistics_summary,
6565
// pybind11::arg( "monitor" ),
6666
// pybind11::arg( "energy_term_names" ) = "",
67-
// "Print statistics summary from a StatisticsMonitor."
67+
// "Print statistics summary from a StatisticsTracker."
6868
// );
6969
}
7070
} // namespace geode

bindings/python/src/stochastic/stochastic.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
#include "sampling/direct/double_sampler.hpp"
2828

2929
// #include "sampling/mcmc/helpers/fracture_simulation_runner.hpp"
30-
#include "sampling/mcmc/helpers/simulation_monitor.hpp"
30+
// #include "sampling/mcmc/helpers/simulation_monitor.hpp"
3131
#include "sampling/mcmc/helpers/simulation_printer.hpp"
3232
#include "sampling/mcmc/simulation_runner.hpp"
3333

@@ -47,7 +47,7 @@ PYBIND11_MODULE( opengeode_stochastic_py_stochastic, module )
4747
geode::define_random_engine( module );
4848
geode::define_double_sampler( module );
4949

50-
geode::define_simulation_monitor( module );
50+
// geode::define_simulation_monitor( module );
5151
geode::define_simulation_printer( module );
5252
geode::define_simulation_runner( module );
5353
// geode::define_fracture_simulation( module );

include/geode/stochastic/inference/statistic_monitor.hpp

Lines changed: 0 additions & 51 deletions
This file was deleted.

include/geode/stochastic/inference/statistic_objective.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
#pragma once
2424

2525
#include <absl/container/flat_hash_map.h>
26-
#include <geode/stochastic/inference/statistic_monitor.hpp>
26+
#include <geode/stochastic/inference/statistics_tracker.hpp>
2727
#include <geode/stochastic/inference/target_statistic.hpp>
2828

2929
namespace geode
@@ -33,7 +33,7 @@ namespace geode
3333
class StatisticObjective
3434
{
3535
public:
36-
double compute_loss( const StatMonitor& monitor,
36+
double compute_loss( const StatisticsTracker< ObjectType >& monitor,
3737
const std::vector< TargetStatistic >& targets ) const
3838
{
3939
double loss = 0.0;

include/geode/stochastic/inference/statistic_validator.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
#include <absl/container/flat_hash_map.h>
2626

27-
#include <geode/stochastic/inference/statistic_monitor.hpp>
27+
#include <geode/stochastic/inference/statistics_tracker.hpp>
2828
#include <geode/stochastic/inference/target_statistic.hpp>
2929

3030
namespace geode
@@ -33,7 +33,7 @@ namespace geode
3333
class StatisticsValidator
3434
{
3535
public:
36-
void check( const StatisticsMonitor& monitor,
36+
void check( const StatisticsTracker& monitor,
3737
const std::vector< TargetStatistic >& targets ) const
3838
{
3939
for( const auto& target : targets )
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#pragma once
2+
3+
// #include <absl/container/flat_hash_map.h>
4+
#include <geode/basic/common.hpp>
5+
#include <geode/basic/uuid.hpp>
6+
7+
#include <geode/stochastic/common.hpp>
8+
#include <geode/stochastic/models/model.hpp>
9+
10+
namespace geode
11+
{
12+
template < typename ObjectType >
13+
class StatisticsTracker
14+
{
15+
public:
16+
StatisticsTracker( const Model< ObjectType >& model ) : model_{ model }
17+
{
18+
means_.resize( model.nb_terms(), 0.0 );
19+
m2_.resize( model.nb_terms(), 0.0 );
20+
}
21+
22+
[[nodiscard]] index_t statiscal_count() const
23+
{
24+
return count_;
25+
}
26+
27+
void add_realization( const std::vector< double >& values )
28+
{
29+
++count_;
30+
for( const auto value_id : geode::Range{ values.size() } )
31+
{
32+
auto& value = values[value_id];
33+
auto& mean = means_[value_id];
34+
auto& m2 = m2_[value_id]; // somme des carrés
35+
36+
double delta = value - mean;
37+
mean += delta / count_;
38+
double delta2 = value - mean;
39+
m2 += delta * delta2;
40+
}
41+
}
42+
43+
[[nodiscard]] double mean( const uuid& term_uuid ) const
44+
{
45+
return means_[model_.term_index( term_uuid )];
46+
}
47+
48+
[[nodiscard]] const std::vector< double >& means() const
49+
{
50+
return means_;
51+
}
52+
53+
[[nodiscard]] double variance( const uuid& term_uuid ) const
54+
{
55+
return variance( model_.term_index( term_uuid ) );
56+
}
57+
58+
[[nodiscard]] std::vector< double > variances() const
59+
{
60+
std::vector< double > variances;
61+
variances.reserve( model_.nb_terms() );
62+
for( const auto variance_id : geode::Range{ model_.nb_terms() } )
63+
{
64+
variances.emplace_back( this->variance( variance_id ) );
65+
}
66+
return variances;
67+
}
68+
69+
private:
70+
[[nodiscard]] double variance( index_t term_index ) const
71+
{
72+
if( count_ < 2 )
73+
{
74+
return 0.0;
75+
}
76+
return m2_[term_index] / ( count_ - 1 );
77+
}
78+
79+
private:
80+
const Model< ObjectType >& model_;
81+
82+
std::vector< double > means_{};
83+
std::vector< double > m2_{};
84+
index_t count_{ 0 };
85+
};
86+
} // namespace geode

include/geode/stochastic/inference/target_statistic.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
*
2222
*/
2323
#pragma once
24+
#include <geode/basic/uuid.hpp>
2425

2526
namespace geode
2627
{

include/geode/stochastic/models/energy_term_collection.hpp

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,26 @@
1+
/*
2+
* Copyright (c) 2019 - 2026 Geode-solutions
3+
*
4+
* Permission is hereby granted, free of charge, to any person obtaining a copy
5+
* of this software and associated documentation files (the "Software"), to deal
6+
* in the Software without restriction, including without limitation the rights
7+
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8+
* copies of the Software, and to permit persons to whom the Software is
9+
* furnished to do so, subject to the following conditions:
10+
*
11+
* The above copyright notice and this permission notice shall be included in
12+
* all copies or substantial portions of the Software.
13+
*
14+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
20+
* SOFTWARE.
21+
*
22+
*/
23+
124
#pragma once
225

326
#include <absl/container/flat_hash_map.h>
@@ -49,14 +72,19 @@ namespace geode
4972
return energy_terms_.size();
5073
}
5174

52-
[[nodiscard]] const EnergyTerm< ObjectType >& get(
53-
const uuid& term_id ) const
75+
[[nodiscard]] index_t get_term_index( const uuid& term_uuid ) const
5476
{
55-
auto term_it = uuid_to_index_.find( term_id );
77+
auto term_it = uuid_to_index_.find( term_uuid );
5678
OPENGEODE_EXCEPTION( term_it != uuid_to_index_.end(),
5779
absl::StrCat( "[EnergyTermCollection] Unknown energy term: ",
58-
term_id.string() ) );
59-
return *energy_terms_[term_it->second];
80+
term_uuid.string() ) );
81+
return term_it->second;
82+
}
83+
84+
[[nodiscard]] const EnergyTerm< ObjectType >& get(
85+
const uuid& term_uuid ) const
86+
{
87+
return *energy_terms_[get_term_index( term_uuid )];
6088
}
6189

6290
[[nodiscard]] uuid get_term_uuid( std::string_view name ) const

0 commit comments

Comments
 (0)