From e21715ab26175b1abee7a9a3ba8c89004f955f79 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Wed, 14 May 2025 11:57:34 +0200 Subject: [PATCH 1/4] Remove collector --- pina/collector.py | 129 ----------------------------- pina/data/data_module.py | 24 +++--- pina/problem/abstract_problem.py | 45 ++++++----- tests/test_collector.py | 135 ------------------------------- tests/test_problem.py | 32 ++++++++ 5 files changed, 69 insertions(+), 296 deletions(-) delete mode 100644 pina/collector.py delete mode 100644 tests/test_collector.py diff --git a/pina/collector.py b/pina/collector.py deleted file mode 100644 index e0d04a040..000000000 --- a/pina/collector.py +++ /dev/null @@ -1,129 +0,0 @@ -"""Module for the Collector class.""" - -from .graph import Graph -from .utils import check_consistency - - -class Collector: - """ - Collector class for retrieving data from different conditions in the - problem. - """ - - def __init__(self, problem): - """ - Initialize the Collector class, by creating a hook between the collector - and the problem and initializing the data collections (dictionary where - data will be stored). - - :param pina.problem.abstract_problem.AbstractProblem problem: The - problem to collect data from. - """ - # creating a hook between collector and problem - self.problem = problem - - # those variables are used for the dataloading - self._data_collections = {name: {} for name in self.problem.conditions} - self.conditions_name = dict(enumerate(self.problem.conditions)) - - # variables used to check that all conditions are sampled - self._is_conditions_ready = { - name: False for name in self.problem.conditions - } - self.full = False - - @property - def full(self): - """ - Returns ``True`` if the collector is full. The collector is considered - full if all conditions have entries in the ``data_collection`` - dictionary. - - :return: ``True`` if all conditions are ready, ``False`` otherwise. - :rtype: bool - """ - - return all(self._is_conditions_ready.values()) - - @full.setter - def full(self, value): - """ - Set the ``_full`` variable. - - :param bool value: The value to set the ``_full`` variable. - """ - - check_consistency(value, bool) - self._full = value - - @property - def data_collections(self): - """ - Return the data collections (dictionary where data is stored). - - :return: The data collections where the data is stored. - :rtype: dict - """ - - return self._data_collections - - @property - def problem(self): - """ - Problem connected to the collector. - - :return: The problem from which the data is collected. - :rtype: pina.problem.abstract_problem.AbstractProblem - """ - return self._problem - - @problem.setter - def problem(self, value): - """ - Set the problem connected to the collector. - - :param pina.problem.abstract_problem.AbstractProblem value: The problem - to connect to the collector. - """ - - self._problem = value - - def store_fixed_data(self): - """ - Store inside data collections the fixed data of the problem. These comes - from the conditions that do not require sampling. - """ - - # loop over all conditions - for condition_name, condition in self.problem.conditions.items(): - # if the condition is not ready and domain is not attribute - # of condition, we get and store the data - if (not self._is_conditions_ready[condition_name]) and ( - not hasattr(condition, "domain") - ): - # get data - keys = condition.__slots__ - values = [getattr(condition, name) for name in keys] - self.data_collections[condition_name] = dict(zip(keys, values)) - # condition now is ready - self._is_conditions_ready[condition_name] = True - - def store_sample_domains(self): - """ - Store inside data collections the sampled data of the problem. These - comes from the conditions that require sampling (e.g. - :class:`~pina.condition.domain_equation_condition.\ - DomainEquationCondition`). - """ - - for condition_name in self.problem.conditions: - condition = self.problem.conditions[condition_name] - if not hasattr(condition, "domain"): - continue - - samples = self.problem.discretised_domains[condition.domain] - - self.data_collections[condition_name] = { - "input": samples, - "equation": condition.equation, - } diff --git a/pina/data/data_module.py b/pina/data/data_module.py index 566f64625..0e6538c31 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -12,7 +12,8 @@ from torch.utils.data.distributed import DistributedSampler from ..label_tensor import LabelTensor from .dataset import PinaDatasetFactory, PinaTensorDataset -from ..collector import Collector + +# from ..collector import Collector class DummyDataloader: @@ -330,9 +331,10 @@ def __init__( self.pin_memory = pin_memory # Collect data - collector = Collector(problem) - collector.store_fixed_data() - collector.store_sample_domains() + # collector = Collector(problem) + # collector.store_fixed_data() + # collector.store_sample_domains() + problem.aggregate_data() # Check if the splits are correct self._check_slit_sizes(train_size, test_size, val_size) @@ -361,7 +363,7 @@ def __init__( # raises NotImplementedError self.val_dataloader = super().val_dataloader - self.collector_splits = self._create_splits(collector, splits_dict) + self.data_splits = self._create_splits(problem.data, splits_dict) self.transfer_batch_to_device = self._transfer_batch_to_device def setup(self, stage=None): @@ -376,15 +378,15 @@ def setup(self, stage=None): """ if stage == "fit" or stage is None: self.train_dataset = PinaDatasetFactory( - self.collector_splits["train"], + self.data_splits["train"], max_conditions_lengths=self.find_max_conditions_lengths( "train" ), automatic_batching=self.automatic_batching, ) - if "val" in self.collector_splits.keys(): + if "val" in self.data_splits.keys(): self.val_dataset = PinaDatasetFactory( - self.collector_splits["val"], + self.data_splits["val"], max_conditions_lengths=self.find_max_conditions_lengths( "val" ), @@ -392,7 +394,7 @@ def setup(self, stage=None): ) elif stage == "test": self.test_dataset = PinaDatasetFactory( - self.collector_splits["test"], + self.data_splits["test"], max_conditions_lengths=self.find_max_conditions_lengths("test"), automatic_batching=self.automatic_batching, ) @@ -473,7 +475,7 @@ def _apply_shuffle(condition_dict, len_data): for ( condition_name, condition_dict, - ) in collector.data_collections.items(): + ) in collector.items(): len_data = len(condition_dict["input"]) if self.shuffle: _apply_shuffle(condition_dict, len_data) @@ -540,7 +542,7 @@ def find_max_conditions_lengths(self, split): """ max_conditions_lengths = {} - for k, v in self.collector_splits[split].items(): + for k, v in self.data_splits[split].items(): if self.batch_size is None: max_conditions_lengths[k] = len(v["input"]) elif self.repeat: diff --git a/pina/problem/abstract_problem.py b/pina/problem/abstract_problem.py index 5f601acff..cb1a0a1b5 100644 --- a/pina/problem/abstract_problem.py +++ b/pina/problem/abstract_problem.py @@ -23,14 +23,11 @@ def __init__(self): Initialization of the :class:`AbstractProblem` class. """ self._discretised_domains = {} - # create collector to manage problem data # create hook conditions <-> problems for condition_name in self.conditions: self.conditions[condition_name].problem = self - self._batching_dimension = 0 - # Store in domains dict all the domains object directly passed to # ConditionInterface. Done for back compatibility with PINA <0.2 if not hasattr(self, "domains"): @@ -41,24 +38,7 @@ def __init__(self): self.domains[cond_name] = cond.domain cond.domain = cond_name - @property - def batching_dimension(self): - """ - Get batching dimension. - - :return: The batching dimension. - :rtype: int - """ - return self._batching_dimension - - @batching_dimension.setter - def batching_dimension(self, value): - """ - Set the batching dimension. - - :param int value: The batching dimension. - """ - self._batching_dimension = value + self.data = None # back compatibility 0.1 @property @@ -300,3 +280,26 @@ def add_points(self, new_points_dict): self.discretised_domains[k] = LabelTensor.vstack( [self.discretised_domains[k], v] ) + + def aggregate_data(self): + """ + Aggregate data from the problem's conditions into a single dictionary. + """ + self.data = {} + if not self.are_all_domains_discretised: + raise RuntimeError( + "All domains must be discretised before aggregating data." + ) + for condition_name in self.conditions: + condition = self.conditions[condition_name] + if hasattr(condition, "domain"): + samples = self.discretised_domains[condition.domain] + + self.data[condition_name] = { + "input": samples, + "equation": condition.equation, + } + else: + keys = condition.__slots__ + values = [getattr(condition, name) for name in keys] + self.data[condition_name] = dict(zip(keys, values)) diff --git a/tests/test_collector.py b/tests/test_collector.py deleted file mode 100644 index 3119f9db0..000000000 --- a/tests/test_collector.py +++ /dev/null @@ -1,135 +0,0 @@ -import torch -import pytest -from pina import Condition, LabelTensor, Graph -from pina.condition import InputTargetCondition, DomainEquationCondition -from pina.graph import RadiusGraph -from pina.problem import AbstractProblem, SpatialProblem -from pina.domain import CartesianDomain -from pina.equation.equation import Equation -from pina.equation.equation_factory import FixedValue -from pina.operator import laplacian -from pina.collector import Collector - - -def test_supervised_tensor_collector(): - class SupervisedProblem(AbstractProblem): - output_variables = None - conditions = { - "data1": Condition( - input=torch.rand((10, 2)), - target=torch.rand((10, 2)), - ), - "data2": Condition( - input=torch.rand((20, 2)), - target=torch.rand((20, 2)), - ), - "data3": Condition( - input=torch.rand((30, 2)), - target=torch.rand((30, 2)), - ), - } - - problem = SupervisedProblem() - collector = Collector(problem) - for v in collector.conditions_name.values(): - assert v in problem.conditions.keys() - - -def test_pinn_collector(): - def laplace_equation(input_, output_): - force_term = torch.sin(input_.extract(["x"]) * torch.pi) * torch.sin( - input_.extract(["y"]) * torch.pi - ) - delta_u = laplacian(output_.extract(["u"]), input_) - return delta_u - force_term - - my_laplace = Equation(laplace_equation) - in_ = LabelTensor( - torch.tensor([[0.0, 1.0]], requires_grad=True), ["x", "y"] - ) - out_ = LabelTensor(torch.tensor([[0.0]], requires_grad=True), ["u"]) - - class Poisson(SpatialProblem): - output_variables = ["u"] - spatial_domain = CartesianDomain({"x": [0, 1], "y": [0, 1]}) - - conditions = { - "gamma1": Condition( - domain=CartesianDomain({"x": [0, 1], "y": 1}), - equation=FixedValue(0.0), - ), - "gamma2": Condition( - domain=CartesianDomain({"x": [0, 1], "y": 0}), - equation=FixedValue(0.0), - ), - "gamma3": Condition( - domain=CartesianDomain({"x": 1, "y": [0, 1]}), - equation=FixedValue(0.0), - ), - "gamma4": Condition( - domain=CartesianDomain({"x": 0, "y": [0, 1]}), - equation=FixedValue(0.0), - ), - "D": Condition( - domain=CartesianDomain({"x": [0, 1], "y": [0, 1]}), - equation=my_laplace, - ), - "data": Condition(input=in_, target=out_), - } - - def poisson_sol(self, pts): - return -( - torch.sin(pts.extract(["x"]) * torch.pi) - * torch.sin(pts.extract(["y"]) * torch.pi) - ) / (2 * torch.pi**2) - - truth_solution = poisson_sol - - problem = Poisson() - boundaries = ["gamma1", "gamma2", "gamma3", "gamma4"] - problem.discretise_domain(10, "grid", domains=boundaries) - problem.discretise_domain(10, "grid", domains="D") - - collector = Collector(problem) - collector.store_fixed_data() - collector.store_sample_domains() - - for k, v in problem.conditions.items(): - if isinstance(v, InputTargetCondition): - assert list(collector.data_collections[k].keys()) == [ - "input", - "target", - ] - - for k, v in problem.conditions.items(): - if isinstance(v, DomainEquationCondition): - assert list(collector.data_collections[k].keys()) == [ - "input", - "equation", - ] - - -def test_supervised_graph_collector(): - pos = torch.rand((100, 3)) - x = [torch.rand((100, 3)) for _ in range(10)] - graph_list_1 = [RadiusGraph(pos=pos, radius=0.4, x=x_) for x_ in x] - out_1 = torch.rand((10, 100, 3)) - - pos = torch.rand((50, 3)) - x = [torch.rand((50, 3)) for _ in range(10)] - graph_list_2 = [RadiusGraph(pos=pos, radius=0.4, x=x_) for x_ in x] - out_2 = torch.rand((10, 50, 3)) - - class SupervisedProblem(AbstractProblem): - output_variables = None - conditions = { - "data1": Condition(input=graph_list_1, target=out_1), - "data2": Condition(input=graph_list_2, target=out_2), - } - - problem = SupervisedProblem() - collector = Collector(problem) - collector.store_fixed_data() - # assert all(collector._is_conditions_ready.values()) - for v in collector.conditions_name.values(): - assert v in problem.conditions.keys() diff --git a/tests/test_problem.py b/tests/test_problem.py index 069dc0620..235736eda 100644 --- a/tests/test_problem.py +++ b/tests/test_problem.py @@ -4,6 +4,11 @@ from pina import LabelTensor from pina.domain import Union from pina.domain import CartesianDomain +from pina.condition import ( + Condition, + InputTargetCondition, + DomainEquationCondition, +) def test_discretise_domain(): @@ -84,3 +89,30 @@ def test_wrong_custom_sampling_logic(mode): } with pytest.raises(RuntimeError): poisson_problem.discretise_domain(sample_rules=sampling_rules) + + +def test_aggregate_data(): + poisson_problem = Poisson() + poisson_problem.conditions["data"] = Condition( + input=LabelTensor(torch.tensor([[0.0, 1.0]]), labels=["x", "y"]), + target=LabelTensor(torch.tensor([[0.0]]), labels=["u"]), + ) + poisson_problem.discretise_domain(0, "random", domains="all") + poisson_problem.aggregate_data() + assert isinstance(poisson_problem.data, dict) + for name, conditions in poisson_problem.conditions.items(): + assert name in poisson_problem.data.keys() + if isinstance(conditions, InputTargetCondition): + assert "input" in poisson_problem.data[name].keys() + assert "target" in poisson_problem.data[name].keys() + elif isinstance(conditions, DomainEquationCondition): + assert "input" in poisson_problem.data[name].keys() + assert "target" not in poisson_problem.data[name].keys() + assert "equation" in poisson_problem.data[name].keys() + + +def test_wrong_aggregate_data(): + poisson_problem = Poisson() + poisson_problem.discretise_domain(0, "random", domains=["D"]) + with pytest.raises(RuntimeError): + poisson_problem.aggregate_data() From 31f8a8badedad3e9234cc7e5b393feb40c76bbc8 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Wed, 14 May 2025 21:32:28 +0200 Subject: [PATCH 2/4] Fixes --- pina/data/data_module.py | 6 +++-- pina/problem/abstract_problem.py | 42 +++++++++++++++++++++++++++----- tests/test_problem.py | 20 ++++++++------- 3 files changed, 51 insertions(+), 17 deletions(-) diff --git a/pina/data/data_module.py b/pina/data/data_module.py index 0e6538c31..f5dbf163f 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -334,7 +334,7 @@ def __init__( # collector = Collector(problem) # collector.store_fixed_data() # collector.store_sample_domains() - problem.aggregate_data() + problem.collect_data() # Check if the splits are correct self._check_slit_sizes(train_size, test_size, val_size) @@ -363,7 +363,9 @@ def __init__( # raises NotImplementedError self.val_dataloader = super().val_dataloader - self.data_splits = self._create_splits(problem.data, splits_dict) + self.data_splits = self._create_splits( + problem.collected_data, splits_dict + ) self.transfer_batch_to_device = self._transfer_batch_to_device def setup(self, stage=None): diff --git a/pina/problem/abstract_problem.py b/pina/problem/abstract_problem.py index cb1a0a1b5..6a40134b2 100644 --- a/pina/problem/abstract_problem.py +++ b/pina/problem/abstract_problem.py @@ -38,9 +38,33 @@ def __init__(self): self.domains[cond_name] = cond.domain cond.domain = cond_name - self.data = None + self._collect_data = {} + + @property + def collected_data(self): + """ + Return the collected data from the problem's conditions. + + :return: The collected data. + :rtype: dict + """ + if not self._collect_data: + raise RuntimeError( + "You have to call collect_data() before accessing the data." + ) + return self._collect_data + + @collected_data.setter + def collected_data(self, data): + """ + Set the collected data from the problem's conditions. + + :param dict data: The collected data. + """ + self._collect_data = data # back compatibility 0.1 + @property def input_pts(self): """ @@ -281,25 +305,31 @@ def add_points(self, new_points_dict): [self.discretised_domains[k], v] ) - def aggregate_data(self): + def collect_data(self): """ Aggregate data from the problem's conditions into a single dictionary. """ - self.data = {} + data = {} + # check if all domains are discretised if not self.are_all_domains_discretised: raise RuntimeError( "All domains must be discretised before aggregating data." ) + # Iterate over the conditions and collect data for condition_name in self.conditions: condition = self.conditions[condition_name] + # Check if the condition has an domain attribute if hasattr(condition, "domain"): + # Store the discretisation points samples = self.discretised_domains[condition.domain] - - self.data[condition_name] = { + data[condition_name] = { "input": samples, "equation": condition.equation, } else: + # If the condition does not have a domain attribute, store + # the input and target points keys = condition.__slots__ values = [getattr(condition, name) for name in keys] - self.data[condition_name] = dict(zip(keys, values)) + data[condition_name] = dict(zip(keys, values)) + self.collected_data = data diff --git a/tests/test_problem.py b/tests/test_problem.py index 235736eda..f33370b32 100644 --- a/tests/test_problem.py +++ b/tests/test_problem.py @@ -98,21 +98,23 @@ def test_aggregate_data(): target=LabelTensor(torch.tensor([[0.0]]), labels=["u"]), ) poisson_problem.discretise_domain(0, "random", domains="all") - poisson_problem.aggregate_data() - assert isinstance(poisson_problem.data, dict) + poisson_problem.collect_data() + assert isinstance(poisson_problem.collected_data, dict) for name, conditions in poisson_problem.conditions.items(): - assert name in poisson_problem.data.keys() + assert name in poisson_problem.collected_data.keys() if isinstance(conditions, InputTargetCondition): - assert "input" in poisson_problem.data[name].keys() - assert "target" in poisson_problem.data[name].keys() + assert "input" in poisson_problem.collected_data[name].keys() + assert "target" in poisson_problem.collected_data[name].keys() elif isinstance(conditions, DomainEquationCondition): - assert "input" in poisson_problem.data[name].keys() - assert "target" not in poisson_problem.data[name].keys() - assert "equation" in poisson_problem.data[name].keys() + assert "input" in poisson_problem.collected_data[name].keys() + assert "target" not in poisson_problem.collected_data[name].keys() + assert "equation" in poisson_problem.collected_data[name].keys() def test_wrong_aggregate_data(): poisson_problem = Poisson() poisson_problem.discretise_domain(0, "random", domains=["D"]) with pytest.raises(RuntimeError): - poisson_problem.aggregate_data() + poisson_problem.collected_data() + with pytest.raises(RuntimeError): + poisson_problem.collect_data() From e30b7b302c2d3f2dfc57155c900fa6fbd7c990e3 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Thu, 15 May 2025 14:20:52 +0200 Subject: [PATCH 3/4] Fixes --- pina/data/data_module.py | 3 --- pina/problem/abstract_problem.py | 33 +++++++++++++------------------- tests/test_problem.py | 3 +-- 3 files changed, 14 insertions(+), 25 deletions(-) diff --git a/pina/data/data_module.py b/pina/data/data_module.py index f5dbf163f..1775deef8 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -331,9 +331,6 @@ def __init__( self.pin_memory = pin_memory # Collect data - # collector = Collector(problem) - # collector.store_fixed_data() - # collector.store_sample_domains() problem.collect_data() # Check if the splits are correct diff --git a/pina/problem/abstract_problem.py b/pina/problem/abstract_problem.py index 6a40134b2..5da2cbf74 100644 --- a/pina/problem/abstract_problem.py +++ b/pina/problem/abstract_problem.py @@ -38,33 +38,25 @@ def __init__(self): self.domains[cond_name] = cond.domain cond.domain = cond_name - self._collect_data = {} + self._collected_data = {} @property def collected_data(self): """ Return the collected data from the problem's conditions. - :return: The collected data. + :return: The collected data. Keys are condition names, and values are + dictionaries containing the input points and the corresponding + equations or target points. :rtype: dict """ - if not self._collect_data: + if not self._collected_data: raise RuntimeError( "You have to call collect_data() before accessing the data." ) - return self._collect_data - - @collected_data.setter - def collected_data(self, data): - """ - Set the collected data from the problem's conditions. - - :param dict data: The collected data. - """ - self._collect_data = data + return self._collected_data # back compatibility 0.1 - @property def input_pts(self): """ @@ -75,11 +67,12 @@ def input_pts(self): :rtype: dict """ to_return = {} - for cond_name, cond in self.conditions.items(): - if hasattr(cond, "input"): - to_return[cond_name] = cond.input - elif hasattr(cond, "domain"): - to_return[cond_name] = self._discretised_domains[cond.domain] + if self._collected_data is None: + raise RuntimeError( + "You have to call collect_data() before accessing the data." + ) + for cond_name, data in self._collected_data.items(): + to_return[cond_name] = data["input"] return to_return @property @@ -332,4 +325,4 @@ def collect_data(self): keys = condition.__slots__ values = [getattr(condition, name) for name in keys] data[condition_name] = dict(zip(keys, values)) - self.collected_data = data + self._collected_data = data diff --git a/tests/test_problem.py b/tests/test_problem.py index f33370b32..04869d5e6 100644 --- a/tests/test_problem.py +++ b/tests/test_problem.py @@ -114,7 +114,6 @@ def test_aggregate_data(): def test_wrong_aggregate_data(): poisson_problem = Poisson() poisson_problem.discretise_domain(0, "random", domains=["D"]) - with pytest.raises(RuntimeError): - poisson_problem.collected_data() + assert not poisson_problem._collected_data with pytest.raises(RuntimeError): poisson_problem.collect_data() From 764dd1e362f9e89c40baabae8d7ea0b54959a803 Mon Sep 17 00:00:00 2001 From: Dario Coscia <93731561+dario-coscia@users.noreply.github.com> Date: Fri, 16 May 2025 09:49:29 +0200 Subject: [PATCH 4/4] rm unnecessary comment --- pina/data/data_module.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pina/data/data_module.py b/pina/data/data_module.py index 1775deef8..9ed5c6437 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -13,8 +13,6 @@ from ..label_tensor import LabelTensor from .dataset import PinaDatasetFactory, PinaTensorDataset -# from ..collector import Collector - class DummyDataloader: