Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions ModelFitting/ModelFitting/Parameters/DependentParameter.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ class DependentParameter: public BasicParameter {
//m_get_value_hook = std::bind(&DependentParameter::getValueHook, this);
}

virtual ~DependentParameter() = default;
virtual ~DependentParameter() {
for (auto& parameter_observer : m_parameter_observers) {
std::get<0>(parameter_observer)->removeObserver(std::get<1>(parameter_observer));
}
}

double getValue() const override {
if (!this->isObserved()) {
Expand Down Expand Up @@ -118,13 +122,17 @@ class DependentParameter: public BasicParameter {

template<typename Param>
void addParameterObserver(int, Param& param) {
param->addObserver([this](double){
auto id = param->addObserver([this](double){
// Do not bother updating live if there are no observers
if (this->isObserved()) {
this->update((*m_params)[0]->getValue());
}
});

m_parameter_observers.emplace_back(param, id);
}

std::vector<std::tuple<std::shared_ptr<BasicParameter>, size_t>> m_parameter_observers;
};

template<typename ... Parameters>
Expand All @@ -133,6 +141,7 @@ std::shared_ptr<DependentParameter<Parameters...>> createDependentParameter(
return std::make_shared<DependentParameter<Parameters...>>(value_calculator, parameters...);
}


}

#endif /* MODELFITTING_DEPENDENTPARAMETER_H */
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "ModelFitting/Models/FrameModel.h"
#include "ModelFitting/Engine/ResidualEstimator.h"
#include "ModelFitting/Engine/LeastSquareEngineManager.h"
#include "ModelFitting/Engine/EngineParameterManager.h"

#include "SEUtils/PixelRectangle.h"

Expand All @@ -31,6 +32,7 @@
#include "SEImplementation/Plugin/FlexibleModelFitting/FlexibleModelFittingParameter.h"
#include "SEImplementation/Plugin/FlexibleModelFitting/FlexibleModelFittingFrame.h"
#include "SEImplementation/Plugin/FlexibleModelFitting/FlexibleModelFittingPrior.h"
#include "SEImplementation/Plugin/FlexibleModelFitting/FlexibleModelFittingParameterManager.h"

#include "SEImplementation/Image/DownSampledImagePsf.h"

Expand Down Expand Up @@ -85,6 +87,10 @@ class FlexibleModelFittingIterativeTask : public GroupTask {
std::vector<int> iterations_per_meta;
std::vector<SeFloat> fitting_areas_x;
std::vector<SeFloat> fitting_areas_y;

FlexibleModelFittingParameterManager parameter_manager;
ModelFitting::EngineParameterManager engine_parameter_manager {};
int n_free_parameters = 0;
};

struct FittingState {
Expand Down Expand Up @@ -115,7 +121,7 @@ class FlexibleModelFittingIterativeTask : public GroupTask {
std::shared_ptr<const Image<SeFloat>> model, std::shared_ptr<const Image<SeFloat>> weights, int& data_points) const;
int fitSourcePrepareParameters(FlexibleModelFittingParameterManager& parameter_manager,
ModelFitting::EngineParameterManager& engine_parameter_manager,
SourceInterface& source, int index, FittingState& state) const;
SourceInterface& source, SourceState& state) const;
int fitSourcePrepareModels(FlexibleModelFittingParameterManager& parameter_manager,
ModelFitting::ResidualEstimator& res_estimator, int& good_pixels,
SourceGroupInterface& group, SourceInterface& source, int index, FittingState& state, double downscaling) const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,71 +73,71 @@
FlexibleModelFittingIterativeTask::~FlexibleModelFittingIterativeTask() {
}

PixelRectangle FlexibleModelFittingIterativeTask::getUnclippedFittingRect(SourceInterface& source, int frame_index) const {
auto& measurement_frame_rectangle = source.getProperty<MeasurementFrameRectangle>(frame_index);
if (!measurement_frame_rectangle.isValid() || measurement_frame_rectangle.isEmpty()) {
return PixelRectangle();
}

ImageCoordinate min_coord = measurement_frame_rectangle.getTopLeft();
ImageCoordinate max_coord = measurement_frame_rectangle.getBottomRight();

if (m_window_type == WindowType::ROTATED_ELLIPSE) {
auto ellipse = getFittingEllipse(source, frame_index);

ellipse.m_a *= m_ellipse_scale;
ellipse.m_b *= m_ellipse_scale;
return getEllipseRect(ellipse);
}

if ((max_coord.m_x - min_coord.m_x <= 0.0) || (max_coord.m_y - min_coord.m_y <= 0.0)) {
return PixelRectangle();
} else {
auto min = min_coord;
auto max = max_coord;

// FIXME temporary, for now just enlarge the area by a fixed amount of pixels
ImageCoordinate border = (max - min) * .8 + ImageCoordinate(2.0, 2.0);

min -= border;
max += border;

if (m_window_type == WindowType::DISK_MIN || m_window_type == WindowType::SQUARE_MIN) {
if (max.m_x - min.m_x > max.m_y - min.m_y) {
min.m_x += ((max.m_x - min.m_x) - (max.m_y - min.m_y)) / 2;
max.m_x = min.m_x + (max.m_y - min.m_y);
} else {
min.m_y += ((max.m_y - min.m_y) - (max.m_x - min.m_x)) / 2;
max.m_y = min.m_y + (max.m_x - min.m_x);
}
}

if (m_window_type == WindowType::DISK_MAX || m_window_type == WindowType::SQUARE_MAX) {
if (max.m_x - min.m_x < max.m_y - min.m_y) {
min.m_x += ((max.m_x - min.m_x) - (max.m_y - min.m_y)) / 2;
max.m_x = min.m_x + (max.m_y - min.m_y);
} else {
min.m_y += ((max.m_y - min.m_y) - (max.m_x - min.m_x)) / 2;
max.m_y = min.m_y + (max.m_x - min.m_x);
}
}

if (m_window_type == WindowType::DISK_AREA || m_window_type == WindowType::SQUARE_AREA) {
int area = (max.m_x - min.m_x) * (max.m_y - min.m_y);
int size = int(sqrt(area));
min.m_x += ((max.m_x - min.m_x) - size) / 2;
max.m_x = min.m_x + size;
min.m_y += ((max.m_y - min.m_y) - size) / 2;
max.m_y = min.m_y + size;
}

auto min_pc = PixelCoordinate(static_cast<int>(min.m_x), static_cast<int>(min.m_y));
auto max_pc = PixelCoordinate(static_cast<int>(max.m_x), static_cast<int>(max.m_y));

return PixelRectangle(min_pc, max_pc);
}
}

Check notice on line 140 in SEImplementation/src/lib/Plugin/FlexibleModelFitting/FlexibleModelFittingIterativeTask.cpp

View check run for this annotation

codefactor.io / CodeFactor

SEImplementation/src/lib/Plugin/FlexibleModelFitting/FlexibleModelFittingIterativeTask.cpp#L76-L140

Complex Method
PixelRectangle FlexibleModelFittingIterativeTask::clipFittingRect(PixelRectangle fitting_rect,
SourceInterface& source, int frame_index) const {
if (fitting_rect.getWidth() <= 0 || fitting_rect.getHeight() <= 0) {
Expand Down Expand Up @@ -426,6 +426,8 @@
}
}

initial_state.n_free_parameters = fitSourcePrepareParameters(initial_state.parameter_manager, initial_state.engine_parameter_manager, source, initial_state);

fitting_state.source_states.emplace_back(std::move(initial_state));
}

Expand Down Expand Up @@ -501,38 +503,20 @@
auto rect = getFittingRect(source, frame_index);

double pixel_scale = 1.0;
FlexibleModelFittingParameterManager parameter_manager;
ModelFitting::EngineParameterManager engine_parameter_manager {};
int n_free_parameters = 0;

auto deblend_image = VectorImage<SeFloat>::create(rect.getWidth(), rect.getHeight());
int index = 0;
for (auto& src : group) {
if (index != source_index) {
if (index != source_index && isFrameValid(src, frame->getFrameNb())) {
// reset parameters to final values after fitting
for (auto parameter : m_parameters) {
auto free_parameter = std::dynamic_pointer_cast<FlexibleModelFittingFreeParameter>(parameter);

if (free_parameter != nullptr) {
++n_free_parameters;

// Initial with the values from the current iteration run
parameter_manager.addParameter(src, parameter,
free_parameter->create(parameter_manager, engine_parameter_manager, src,
state.source_states[index].parameters_initial_values.at(free_parameter->getId()),
state.source_states[index].parameters_values.at(free_parameter->getId())));
} else {
parameter_manager.addParameter(src, parameter,
parameter->create(parameter_manager, engine_parameter_manager, src));
auto engine_parameter = std::dynamic_pointer_cast<ModelFitting::EngineParameter>(state.source_states[index].parameter_manager.getParameter(src, parameter));
if (engine_parameter != nullptr) {
engine_parameter->setValue(state.source_states[index].parameters_values.at(parameter->getId()));
}
}
}
index++;
}

auto deblend_image = VectorImage<SeFloat>::create(rect.getWidth(), rect.getHeight());
index = 0;
for (auto& src : group) {
if (index != source_index && isFrameValid(src, frame->getFrameNb())) {
auto frame_model = createFrameModel(src, pixel_scale, parameter_manager, frame, rect);
auto frame_model = createFrameModel(src, pixel_scale, state.source_states[index].parameter_manager, frame, rect);
auto final_stamp = frame_model.getImage();

for (int y = 0; y < final_stamp->getHeight(); ++y) {
Expand All @@ -550,7 +534,7 @@
int FlexibleModelFittingIterativeTask::fitSourcePrepareParameters(
FlexibleModelFittingParameterManager& parameter_manager,
ModelFitting::EngineParameterManager& engine_parameter_manager,
SourceInterface& source, int index, FittingState& state) const {
SourceInterface& source, SourceState& state) const {
int free_parameters_nb = 0;
for (auto parameter : m_parameters) {
auto free_parameter = std::dynamic_pointer_cast<FlexibleModelFittingFreeParameter>(parameter);
Expand All @@ -561,8 +545,8 @@
// Initial with the values from the current iteration run
parameter_manager.addParameter(source, parameter,
free_parameter->create(parameter_manager, engine_parameter_manager, source,
state.source_states[index].parameters_initial_values.at(free_parameter->getId()),
state.source_states[index].parameters_values.at(free_parameter->getId())));
state.parameters_initial_values.at(free_parameter->getId()),
state.parameters_values.at(free_parameter->getId())));
} else {
parameter_manager.addParameter(source, parameter,
parameter->create(parameter_manager, engine_parameter_manager, source));
Expand Down Expand Up @@ -717,17 +701,19 @@
//////////////////////////////////////////////
// Prepare parameters

FlexibleModelFittingParameterManager parameter_manager;
ModelFitting::EngineParameterManager engine_parameter_manager{};
int n_free_parameters = fitSourcePrepareParameters(
parameter_manager, engine_parameter_manager, source, index, state);
for (auto parameter : m_parameters) {
auto engine_parameter = std::dynamic_pointer_cast<ModelFitting::EngineParameter>(state.source_states[index].parameter_manager.getParameter(source, parameter));
if (engine_parameter != nullptr) {
engine_parameter->setValue(state.source_states[index].parameters_values.at(parameter->getId()));
}
}

///////////////////////////////////////////////////////////////////////////////////
// Add models for all frames
ResidualEstimator res_estimator {};
int n_good_pixels = 0;
int valid_frames = fitSourcePrepareModels(
parameter_manager, res_estimator, n_good_pixels, group, source, index, state, down_scaling);
state.source_states[index].parameter_manager, res_estimator, n_good_pixels, group, source, index, state, down_scaling);

///////////////////////////////////////////////////////////////////////////////
// Check that we had enough data for the fit
Expand All @@ -737,7 +723,7 @@
if (valid_frames == 0) {
flags = Flags::OUTSIDE;
}
else if (n_good_pixels < n_free_parameters) {
else if (n_good_pixels < state.source_states[index].n_free_parameters) {
flags = Flags::INSUFFICIENT_DATA;
}

Expand All @@ -754,14 +740,14 @@
////////////////////////////////////////////////////////////////////////////////
// Add priors
for (auto prior : m_priors) {
prior->setupPrior(parameter_manager, source, res_estimator);
prior->setupPrior(state.source_states[index].parameter_manager, source, res_estimator);
}

/////////////////////////////////////////////////////////////////////////////////
// Model fitting

auto engine = LeastSquareEngineManager::create(m_least_squares_engine, m_max_iterations);
auto solution = engine->solveProblem(engine_parameter_manager, res_estimator);
auto solution = engine->solveProblem(state.source_states[index].engine_parameter_manager, res_estimator);

auto iterations = solution.iteration_no;
auto stop_reason = solution.engine_stop_reason;
Expand All @@ -773,49 +759,34 @@
////////////////////////////////////////////////////////////////////////////////////
// compute chi squared

SeFloat avg_reduced_chi_squared = fitSourceComputeChiSquared(parameter_manager, group, source, index, state);
SeFloat avg_reduced_chi_squared = fitSourceComputeChiSquared(state.source_states[index].parameter_manager, group, source, index, state);

////////////////////////////////////////////////////////////////////////////////////
// update state with results
fitSourceUpdateState(parameter_manager, source, avg_reduced_chi_squared, duration, iterations, stop_reason, flags, solution,
fitSourceUpdateState(state.source_states[index].parameter_manager, source, avg_reduced_chi_squared, duration, iterations, stop_reason, flags, solution,
index, state);
}

void FlexibleModelFittingIterativeTask::updateCheckImages(SourceGroupInterface& group,
double pixel_scale, FittingState& state) const {

// recreate parameters

FlexibleModelFittingParameterManager parameter_manager;
ModelFitting::EngineParameterManager engine_parameter_manager {};

int index = 0;
for (auto& src : group) {
for (auto parameter : m_parameters) {
auto free_parameter = std::dynamic_pointer_cast<FlexibleModelFittingFreeParameter>(parameter);

if (free_parameter != nullptr) {
// Initialize with the values from the current iteration run
parameter_manager.addParameter(src, parameter,
free_parameter->create(parameter_manager, engine_parameter_manager, src,
state.source_states[index].parameters_initial_values.at(free_parameter->getId()),
state.source_states[index].parameters_values.at(free_parameter->getId())));
} else {
parameter_manager.addParameter(src, parameter,
parameter->create(parameter_manager, engine_parameter_manager, src));
// reset parameters to final values after fitting
auto engine_parameter = std::dynamic_pointer_cast<ModelFitting::EngineParameter>(state.source_states[index].parameter_manager.getParameter(src, parameter));
if (engine_parameter != nullptr) {
engine_parameter->setValue(state.source_states[index].parameters_values.at(parameter->getId()));
}
}
index++;
}

for (auto& src : group) {
for (auto frame : m_frames) {
int frame_index = frame->getFrameNb();

if (isFrameValid(src, frame_index)) {
auto stamp_rect = getFittingRect(src, frame_index);

auto frame_model = createFrameModel(src, pixel_scale, parameter_manager, frame, stamp_rect);
auto frame_model = createFrameModel(src, pixel_scale, state.source_states[index].parameter_manager, frame, stamp_rect);
auto final_stamp = frame_model.getImage();

auto weight_image = createWeightImage(src, frame_index);
Expand Down Expand Up @@ -855,6 +826,7 @@

}
}
index++;
}
}

Expand Down
Loading