Skip to content

Commit 7c55238

Browse files
authored
Merge pull request #633 from astrorama/feature/model_fitting_param_optimization
Model fitting parameter creation optimization
2 parents 7444e2c + 82663c3 commit 7c55238

3 files changed

Lines changed: 48 additions & 61 deletions

File tree

ModelFitting/ModelFitting/Parameters/DependentParameter.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,11 @@ class DependentParameter: public BasicParameter {
6565
//m_get_value_hook = std::bind(&DependentParameter::getValueHook, this);
6666
}
6767

68-
virtual ~DependentParameter() = default;
68+
virtual ~DependentParameter() {
69+
for (auto& parameter_observer : m_parameter_observers) {
70+
std::get<0>(parameter_observer)->removeObserver(std::get<1>(parameter_observer));
71+
}
72+
}
6973

7074
double getValue() const override {
7175
if (!this->isObserved()) {
@@ -118,13 +122,17 @@ class DependentParameter: public BasicParameter {
118122

119123
template<typename Param>
120124
void addParameterObserver(int, Param& param) {
121-
param->addObserver([this](double){
125+
auto id = param->addObserver([this](double){
122126
// Do not bother updating live if there are no observers
123127
if (this->isObserved()) {
124128
this->update((*m_params)[0]->getValue());
125129
}
126130
});
131+
132+
m_parameter_observers.emplace_back(param, id);
127133
}
134+
135+
std::vector<std::tuple<std::shared_ptr<BasicParameter>, size_t>> m_parameter_observers;
128136
};
129137

130138
template<typename ... Parameters>
@@ -133,6 +141,7 @@ std::shared_ptr<DependentParameter<Parameters...>> createDependentParameter(
133141
return std::make_shared<DependentParameter<Parameters...>>(value_calculator, parameters...);
134142
}
135143

144+
136145
}
137146

138147
#endif /* MODELFITTING_DEPENDENTPARAMETER_H */

SEImplementation/SEImplementation/Plugin/FlexibleModelFitting/FlexibleModelFittingIterativeTask.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "ModelFitting/Models/FrameModel.h"
2222
#include "ModelFitting/Engine/ResidualEstimator.h"
2323
#include "ModelFitting/Engine/LeastSquareEngineManager.h"
24+
#include "ModelFitting/Engine/EngineParameterManager.h"
2425

2526
#include "SEUtils/PixelRectangle.h"
2627

@@ -31,6 +32,7 @@
3132
#include "SEImplementation/Plugin/FlexibleModelFitting/FlexibleModelFittingParameter.h"
3233
#include "SEImplementation/Plugin/FlexibleModelFitting/FlexibleModelFittingFrame.h"
3334
#include "SEImplementation/Plugin/FlexibleModelFitting/FlexibleModelFittingPrior.h"
35+
#include "SEImplementation/Plugin/FlexibleModelFitting/FlexibleModelFittingParameterManager.h"
3436

3537
#include "SEImplementation/Image/DownSampledImagePsf.h"
3638

@@ -85,6 +87,10 @@ class FlexibleModelFittingIterativeTask : public GroupTask {
8587
std::vector<int> iterations_per_meta;
8688
std::vector<SeFloat> fitting_areas_x;
8789
std::vector<SeFloat> fitting_areas_y;
90+
91+
FlexibleModelFittingParameterManager parameter_manager;
92+
ModelFitting::EngineParameterManager engine_parameter_manager {};
93+
int n_free_parameters = 0;
8894
};
8995

9096
struct FittingState {
@@ -115,7 +121,7 @@ class FlexibleModelFittingIterativeTask : public GroupTask {
115121
std::shared_ptr<const Image<SeFloat>> model, std::shared_ptr<const Image<SeFloat>> weights, int& data_points) const;
116122
int fitSourcePrepareParameters(FlexibleModelFittingParameterManager& parameter_manager,
117123
ModelFitting::EngineParameterManager& engine_parameter_manager,
118-
SourceInterface& source, int index, FittingState& state) const;
124+
SourceInterface& source, SourceState& state) const;
119125
int fitSourcePrepareModels(FlexibleModelFittingParameterManager& parameter_manager,
120126
ModelFitting::ResidualEstimator& res_estimator, int& good_pixels,
121127
SourceGroupInterface& group, SourceInterface& source, int index, FittingState& state, double downscaling) const;

SEImplementation/src/lib/Plugin/FlexibleModelFitting/FlexibleModelFittingIterativeTask.cpp

Lines changed: 30 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,8 @@ void FlexibleModelFittingIterativeTask::computeProperties(SourceGroupInterface&
426426
}
427427
}
428428

429+
initial_state.n_free_parameters = fitSourcePrepareParameters(initial_state.parameter_manager, initial_state.engine_parameter_manager, source, initial_state);
430+
429431
fitting_state.source_states.emplace_back(std::move(initial_state));
430432
}
431433

@@ -501,38 +503,20 @@ std::shared_ptr<VectorImage<SeFloat>> FlexibleModelFittingIterativeTask::createD
501503
auto rect = getFittingRect(source, frame_index);
502504

503505
double pixel_scale = 1.0;
504-
FlexibleModelFittingParameterManager parameter_manager;
505-
ModelFitting::EngineParameterManager engine_parameter_manager {};
506-
int n_free_parameters = 0;
507506

507+
auto deblend_image = VectorImage<SeFloat>::create(rect.getWidth(), rect.getHeight());
508508
int index = 0;
509509
for (auto& src : group) {
510-
if (index != source_index) {
510+
if (index != source_index && isFrameValid(src, frame->getFrameNb())) {
511+
// reset parameters to final values after fitting
511512
for (auto parameter : m_parameters) {
512-
auto free_parameter = std::dynamic_pointer_cast<FlexibleModelFittingFreeParameter>(parameter);
513-
514-
if (free_parameter != nullptr) {
515-
++n_free_parameters;
516-
517-
// Initial with the values from the current iteration run
518-
parameter_manager.addParameter(src, parameter,
519-
free_parameter->create(parameter_manager, engine_parameter_manager, src,
520-
state.source_states[index].parameters_initial_values.at(free_parameter->getId()),
521-
state.source_states[index].parameters_values.at(free_parameter->getId())));
522-
} else {
523-
parameter_manager.addParameter(src, parameter,
524-
parameter->create(parameter_manager, engine_parameter_manager, src));
513+
auto engine_parameter = std::dynamic_pointer_cast<ModelFitting::EngineParameter>(state.source_states[index].parameter_manager.getParameter(src, parameter));
514+
if (engine_parameter != nullptr) {
515+
engine_parameter->setValue(state.source_states[index].parameters_values.at(parameter->getId()));
525516
}
526517
}
527-
}
528-
index++;
529-
}
530518

531-
auto deblend_image = VectorImage<SeFloat>::create(rect.getWidth(), rect.getHeight());
532-
index = 0;
533-
for (auto& src : group) {
534-
if (index != source_index && isFrameValid(src, frame->getFrameNb())) {
535-
auto frame_model = createFrameModel(src, pixel_scale, parameter_manager, frame, rect);
519+
auto frame_model = createFrameModel(src, pixel_scale, state.source_states[index].parameter_manager, frame, rect);
536520
auto final_stamp = frame_model.getImage();
537521

538522
for (int y = 0; y < final_stamp->getHeight(); ++y) {
@@ -550,7 +534,7 @@ std::shared_ptr<VectorImage<SeFloat>> FlexibleModelFittingIterativeTask::createD
550534
int FlexibleModelFittingIterativeTask::fitSourcePrepareParameters(
551535
FlexibleModelFittingParameterManager& parameter_manager,
552536
ModelFitting::EngineParameterManager& engine_parameter_manager,
553-
SourceInterface& source, int index, FittingState& state) const {
537+
SourceInterface& source, SourceState& state) const {
554538
int free_parameters_nb = 0;
555539
for (auto parameter : m_parameters) {
556540
auto free_parameter = std::dynamic_pointer_cast<FlexibleModelFittingFreeParameter>(parameter);
@@ -561,8 +545,8 @@ int FlexibleModelFittingIterativeTask::fitSourcePrepareParameters(
561545
// Initial with the values from the current iteration run
562546
parameter_manager.addParameter(source, parameter,
563547
free_parameter->create(parameter_manager, engine_parameter_manager, source,
564-
state.source_states[index].parameters_initial_values.at(free_parameter->getId()),
565-
state.source_states[index].parameters_values.at(free_parameter->getId())));
548+
state.parameters_initial_values.at(free_parameter->getId()),
549+
state.parameters_values.at(free_parameter->getId())));
566550
} else {
567551
parameter_manager.addParameter(source, parameter,
568552
parameter->create(parameter_manager, engine_parameter_manager, source));
@@ -717,17 +701,19 @@ void FlexibleModelFittingIterativeTask::fitSource(SourceGroupInterface& group, S
717701
//////////////////////////////////////////////
718702
// Prepare parameters
719703

720-
FlexibleModelFittingParameterManager parameter_manager;
721-
ModelFitting::EngineParameterManager engine_parameter_manager{};
722-
int n_free_parameters = fitSourcePrepareParameters(
723-
parameter_manager, engine_parameter_manager, source, index, state);
704+
for (auto parameter : m_parameters) {
705+
auto engine_parameter = std::dynamic_pointer_cast<ModelFitting::EngineParameter>(state.source_states[index].parameter_manager.getParameter(source, parameter));
706+
if (engine_parameter != nullptr) {
707+
engine_parameter->setValue(state.source_states[index].parameters_values.at(parameter->getId()));
708+
}
709+
}
724710

725711
///////////////////////////////////////////////////////////////////////////////////
726712
// Add models for all frames
727713
ResidualEstimator res_estimator {};
728714
int n_good_pixels = 0;
729715
int valid_frames = fitSourcePrepareModels(
730-
parameter_manager, res_estimator, n_good_pixels, group, source, index, state, down_scaling);
716+
state.source_states[index].parameter_manager, res_estimator, n_good_pixels, group, source, index, state, down_scaling);
731717

732718
///////////////////////////////////////////////////////////////////////////////
733719
// Check that we had enough data for the fit
@@ -737,7 +723,7 @@ void FlexibleModelFittingIterativeTask::fitSource(SourceGroupInterface& group, S
737723
if (valid_frames == 0) {
738724
flags = Flags::OUTSIDE;
739725
}
740-
else if (n_good_pixels < n_free_parameters) {
726+
else if (n_good_pixels < state.source_states[index].n_free_parameters) {
741727
flags = Flags::INSUFFICIENT_DATA;
742728
}
743729

@@ -754,14 +740,14 @@ void FlexibleModelFittingIterativeTask::fitSource(SourceGroupInterface& group, S
754740
////////////////////////////////////////////////////////////////////////////////
755741
// Add priors
756742
for (auto prior : m_priors) {
757-
prior->setupPrior(parameter_manager, source, res_estimator);
743+
prior->setupPrior(state.source_states[index].parameter_manager, source, res_estimator);
758744
}
759745

760746
/////////////////////////////////////////////////////////////////////////////////
761747
// Model fitting
762748

763749
auto engine = LeastSquareEngineManager::create(m_least_squares_engine, m_max_iterations);
764-
auto solution = engine->solveProblem(engine_parameter_manager, res_estimator);
750+
auto solution = engine->solveProblem(state.source_states[index].engine_parameter_manager, res_estimator);
765751

766752
auto iterations = solution.iteration_no;
767753
auto stop_reason = solution.engine_stop_reason;
@@ -773,49 +759,34 @@ void FlexibleModelFittingIterativeTask::fitSource(SourceGroupInterface& group, S
773759
////////////////////////////////////////////////////////////////////////////////////
774760
// compute chi squared
775761

776-
SeFloat avg_reduced_chi_squared = fitSourceComputeChiSquared(parameter_manager, group, source, index, state);
762+
SeFloat avg_reduced_chi_squared = fitSourceComputeChiSquared(state.source_states[index].parameter_manager, group, source, index, state);
777763

778764
////////////////////////////////////////////////////////////////////////////////////
779765
// update state with results
780-
fitSourceUpdateState(parameter_manager, source, avg_reduced_chi_squared, duration, iterations, stop_reason, flags, solution,
766+
fitSourceUpdateState(state.source_states[index].parameter_manager, source, avg_reduced_chi_squared, duration, iterations, stop_reason, flags, solution,
781767
index, state);
782768
}
783769

784770
void FlexibleModelFittingIterativeTask::updateCheckImages(SourceGroupInterface& group,
785771
double pixel_scale, FittingState& state) const {
786772

787-
// recreate parameters
788-
789-
FlexibleModelFittingParameterManager parameter_manager;
790-
ModelFitting::EngineParameterManager engine_parameter_manager {};
791-
792773
int index = 0;
793774
for (auto& src : group) {
794775
for (auto parameter : m_parameters) {
795-
auto free_parameter = std::dynamic_pointer_cast<FlexibleModelFittingFreeParameter>(parameter);
796-
797-
if (free_parameter != nullptr) {
798-
// Initialize with the values from the current iteration run
799-
parameter_manager.addParameter(src, parameter,
800-
free_parameter->create(parameter_manager, engine_parameter_manager, src,
801-
state.source_states[index].parameters_initial_values.at(free_parameter->getId()),
802-
state.source_states[index].parameters_values.at(free_parameter->getId())));
803-
} else {
804-
parameter_manager.addParameter(src, parameter,
805-
parameter->create(parameter_manager, engine_parameter_manager, src));
776+
// reset parameters to final values after fitting
777+
auto engine_parameter = std::dynamic_pointer_cast<ModelFitting::EngineParameter>(state.source_states[index].parameter_manager.getParameter(src, parameter));
778+
if (engine_parameter != nullptr) {
779+
engine_parameter->setValue(state.source_states[index].parameters_values.at(parameter->getId()));
806780
}
807781
}
808-
index++;
809-
}
810782

811-
for (auto& src : group) {
812783
for (auto frame : m_frames) {
813784
int frame_index = frame->getFrameNb();
814785

815786
if (isFrameValid(src, frame_index)) {
816787
auto stamp_rect = getFittingRect(src, frame_index);
817788

818-
auto frame_model = createFrameModel(src, pixel_scale, parameter_manager, frame, stamp_rect);
789+
auto frame_model = createFrameModel(src, pixel_scale, state.source_states[index].parameter_manager, frame, stamp_rect);
819790
auto final_stamp = frame_model.getImage();
820791

821792
auto weight_image = createWeightImage(src, frame_index);
@@ -855,6 +826,7 @@ void FlexibleModelFittingIterativeTask::updateCheckImages(SourceGroupInterface&
855826

856827
}
857828
}
829+
index++;
858830
}
859831
}
860832

0 commit comments

Comments
 (0)