diff --git a/pina/collector.py b/pina/collector.py deleted file mode 100644 index e0d04a040..000000000 --- a/pina/collector.py +++ /dev/null @@ -1,129 +0,0 @@ -"""Module for the Collector class.""" - -from .graph import Graph -from .utils import check_consistency - - -class Collector: - """ - Collector class for retrieving data from different conditions in the - problem. - """ - - def __init__(self, problem): - """ - Initialize the Collector class, by creating a hook between the collector - and the problem and initializing the data collections (dictionary where - data will be stored). - - :param pina.problem.abstract_problem.AbstractProblem problem: The - problem to collect data from. - """ - # creating a hook between collector and problem - self.problem = problem - - # those variables are used for the dataloading - self._data_collections = {name: {} for name in self.problem.conditions} - self.conditions_name = dict(enumerate(self.problem.conditions)) - - # variables used to check that all conditions are sampled - self._is_conditions_ready = { - name: False for name in self.problem.conditions - } - self.full = False - - @property - def full(self): - """ - Returns ``True`` if the collector is full. The collector is considered - full if all conditions have entries in the ``data_collection`` - dictionary. - - :return: ``True`` if all conditions are ready, ``False`` otherwise. - :rtype: bool - """ - - return all(self._is_conditions_ready.values()) - - @full.setter - def full(self, value): - """ - Set the ``_full`` variable. - - :param bool value: The value to set the ``_full`` variable. - """ - - check_consistency(value, bool) - self._full = value - - @property - def data_collections(self): - """ - Return the data collections (dictionary where data is stored). - - :return: The data collections where the data is stored. - :rtype: dict - """ - - return self._data_collections - - @property - def problem(self): - """ - Problem connected to the collector. - - :return: The problem from which the data is collected. - :rtype: pina.problem.abstract_problem.AbstractProblem - """ - return self._problem - - @problem.setter - def problem(self, value): - """ - Set the problem connected to the collector. - - :param pina.problem.abstract_problem.AbstractProblem value: The problem - to connect to the collector. - """ - - self._problem = value - - def store_fixed_data(self): - """ - Store inside data collections the fixed data of the problem. These comes - from the conditions that do not require sampling. - """ - - # loop over all conditions - for condition_name, condition in self.problem.conditions.items(): - # if the condition is not ready and domain is not attribute - # of condition, we get and store the data - if (not self._is_conditions_ready[condition_name]) and ( - not hasattr(condition, "domain") - ): - # get data - keys = condition.__slots__ - values = [getattr(condition, name) for name in keys] - self.data_collections[condition_name] = dict(zip(keys, values)) - # condition now is ready - self._is_conditions_ready[condition_name] = True - - def store_sample_domains(self): - """ - Store inside data collections the sampled data of the problem. These - comes from the conditions that require sampling (e.g. - :class:`~pina.condition.domain_equation_condition.\ - DomainEquationCondition`). - """ - - for condition_name in self.problem.conditions: - condition = self.problem.conditions[condition_name] - if not hasattr(condition, "domain"): - continue - - samples = self.problem.discretised_domains[condition.domain] - - self.data_collections[condition_name] = { - "input": samples, - "equation": condition.equation, - } diff --git a/pina/data/data_module.py b/pina/data/data_module.py index 566f64625..9ed5c6437 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -12,7 +12,6 @@ from torch.utils.data.distributed import DistributedSampler from ..label_tensor import LabelTensor from .dataset import PinaDatasetFactory, PinaTensorDataset -from ..collector import Collector class DummyDataloader: @@ -330,9 +329,7 @@ def __init__( self.pin_memory = pin_memory # Collect data - collector = Collector(problem) - collector.store_fixed_data() - collector.store_sample_domains() + problem.collect_data() # Check if the splits are correct self._check_slit_sizes(train_size, test_size, val_size) @@ -361,7 +358,9 @@ def __init__( # raises NotImplementedError self.val_dataloader = super().val_dataloader - self.collector_splits = self._create_splits(collector, splits_dict) + self.data_splits = self._create_splits( + problem.collected_data, splits_dict + ) self.transfer_batch_to_device = self._transfer_batch_to_device def setup(self, stage=None): @@ -376,15 +375,15 @@ def setup(self, stage=None): """ if stage == "fit" or stage is None: self.train_dataset = PinaDatasetFactory( - self.collector_splits["train"], + self.data_splits["train"], max_conditions_lengths=self.find_max_conditions_lengths( "train" ), automatic_batching=self.automatic_batching, ) - if "val" in self.collector_splits.keys(): + if "val" in self.data_splits.keys(): self.val_dataset = PinaDatasetFactory( - self.collector_splits["val"], + self.data_splits["val"], max_conditions_lengths=self.find_max_conditions_lengths( "val" ), @@ -392,7 +391,7 @@ def setup(self, stage=None): ) elif stage == "test": self.test_dataset = PinaDatasetFactory( - self.collector_splits["test"], + self.data_splits["test"], max_conditions_lengths=self.find_max_conditions_lengths("test"), automatic_batching=self.automatic_batching, ) @@ -473,7 +472,7 @@ def _apply_shuffle(condition_dict, len_data): for ( condition_name, condition_dict, - ) in collector.data_collections.items(): + ) in collector.items(): len_data = len(condition_dict["input"]) if self.shuffle: _apply_shuffle(condition_dict, len_data) @@ -540,7 +539,7 @@ def find_max_conditions_lengths(self, split): """ max_conditions_lengths = {} - for k, v in self.collector_splits[split].items(): + for k, v in self.data_splits[split].items(): if self.batch_size is None: max_conditions_lengths[k] = len(v["input"]) elif self.repeat: diff --git a/pina/problem/abstract_problem.py b/pina/problem/abstract_problem.py index 5f601acff..5da2cbf74 100644 --- a/pina/problem/abstract_problem.py +++ b/pina/problem/abstract_problem.py @@ -23,14 +23,11 @@ def __init__(self): Initialization of the :class:`AbstractProblem` class. """ self._discretised_domains = {} - # create collector to manage problem data # create hook conditions <-> problems for condition_name in self.conditions: self.conditions[condition_name].problem = self - self._batching_dimension = 0 - # Store in domains dict all the domains object directly passed to # ConditionInterface. Done for back compatibility with PINA <0.2 if not hasattr(self, "domains"): @@ -41,24 +38,23 @@ def __init__(self): self.domains[cond_name] = cond.domain cond.domain = cond_name - @property - def batching_dimension(self): - """ - Get batching dimension. - - :return: The batching dimension. - :rtype: int - """ - return self._batching_dimension + self._collected_data = {} - @batching_dimension.setter - def batching_dimension(self, value): + @property + def collected_data(self): """ - Set the batching dimension. + Return the collected data from the problem's conditions. - :param int value: The batching dimension. + :return: The collected data. Keys are condition names, and values are + dictionaries containing the input points and the corresponding + equations or target points. + :rtype: dict """ - self._batching_dimension = value + if not self._collected_data: + raise RuntimeError( + "You have to call collect_data() before accessing the data." + ) + return self._collected_data # back compatibility 0.1 @property @@ -71,11 +67,12 @@ def input_pts(self): :rtype: dict """ to_return = {} - for cond_name, cond in self.conditions.items(): - if hasattr(cond, "input"): - to_return[cond_name] = cond.input - elif hasattr(cond, "domain"): - to_return[cond_name] = self._discretised_domains[cond.domain] + if self._collected_data is None: + raise RuntimeError( + "You have to call collect_data() before accessing the data." + ) + for cond_name, data in self._collected_data.items(): + to_return[cond_name] = data["input"] return to_return @property @@ -300,3 +297,32 @@ def add_points(self, new_points_dict): self.discretised_domains[k] = LabelTensor.vstack( [self.discretised_domains[k], v] ) + + def collect_data(self): + """ + Aggregate data from the problem's conditions into a single dictionary. + """ + data = {} + # check if all domains are discretised + if not self.are_all_domains_discretised: + raise RuntimeError( + "All domains must be discretised before aggregating data." + ) + # Iterate over the conditions and collect data + for condition_name in self.conditions: + condition = self.conditions[condition_name] + # Check if the condition has an domain attribute + if hasattr(condition, "domain"): + # Store the discretisation points + samples = self.discretised_domains[condition.domain] + data[condition_name] = { + "input": samples, + "equation": condition.equation, + } + else: + # If the condition does not have a domain attribute, store + # the input and target points + keys = condition.__slots__ + values = [getattr(condition, name) for name in keys] + data[condition_name] = dict(zip(keys, values)) + self._collected_data = data diff --git a/tests/test_collector.py b/tests/test_collector.py deleted file mode 100644 index 3119f9db0..000000000 --- a/tests/test_collector.py +++ /dev/null @@ -1,135 +0,0 @@ -import torch -import pytest -from pina import Condition, LabelTensor, Graph -from pina.condition import InputTargetCondition, DomainEquationCondition -from pina.graph import RadiusGraph -from pina.problem import AbstractProblem, SpatialProblem -from pina.domain import CartesianDomain -from pina.equation.equation import Equation -from pina.equation.equation_factory import FixedValue -from pina.operator import laplacian -from pina.collector import Collector - - -def test_supervised_tensor_collector(): - class SupervisedProblem(AbstractProblem): - output_variables = None - conditions = { - "data1": Condition( - input=torch.rand((10, 2)), - target=torch.rand((10, 2)), - ), - "data2": Condition( - input=torch.rand((20, 2)), - target=torch.rand((20, 2)), - ), - "data3": Condition( - input=torch.rand((30, 2)), - target=torch.rand((30, 2)), - ), - } - - problem = SupervisedProblem() - collector = Collector(problem) - for v in collector.conditions_name.values(): - assert v in problem.conditions.keys() - - -def test_pinn_collector(): - def laplace_equation(input_, output_): - force_term = torch.sin(input_.extract(["x"]) * torch.pi) * torch.sin( - input_.extract(["y"]) * torch.pi - ) - delta_u = laplacian(output_.extract(["u"]), input_) - return delta_u - force_term - - my_laplace = Equation(laplace_equation) - in_ = LabelTensor( - torch.tensor([[0.0, 1.0]], requires_grad=True), ["x", "y"] - ) - out_ = LabelTensor(torch.tensor([[0.0]], requires_grad=True), ["u"]) - - class Poisson(SpatialProblem): - output_variables = ["u"] - spatial_domain = CartesianDomain({"x": [0, 1], "y": [0, 1]}) - - conditions = { - "gamma1": Condition( - domain=CartesianDomain({"x": [0, 1], "y": 1}), - equation=FixedValue(0.0), - ), - "gamma2": Condition( - domain=CartesianDomain({"x": [0, 1], "y": 0}), - equation=FixedValue(0.0), - ), - "gamma3": Condition( - domain=CartesianDomain({"x": 1, "y": [0, 1]}), - equation=FixedValue(0.0), - ), - "gamma4": Condition( - domain=CartesianDomain({"x": 0, "y": [0, 1]}), - equation=FixedValue(0.0), - ), - "D": Condition( - domain=CartesianDomain({"x": [0, 1], "y": [0, 1]}), - equation=my_laplace, - ), - "data": Condition(input=in_, target=out_), - } - - def poisson_sol(self, pts): - return -( - torch.sin(pts.extract(["x"]) * torch.pi) - * torch.sin(pts.extract(["y"]) * torch.pi) - ) / (2 * torch.pi**2) - - truth_solution = poisson_sol - - problem = Poisson() - boundaries = ["gamma1", "gamma2", "gamma3", "gamma4"] - problem.discretise_domain(10, "grid", domains=boundaries) - problem.discretise_domain(10, "grid", domains="D") - - collector = Collector(problem) - collector.store_fixed_data() - collector.store_sample_domains() - - for k, v in problem.conditions.items(): - if isinstance(v, InputTargetCondition): - assert list(collector.data_collections[k].keys()) == [ - "input", - "target", - ] - - for k, v in problem.conditions.items(): - if isinstance(v, DomainEquationCondition): - assert list(collector.data_collections[k].keys()) == [ - "input", - "equation", - ] - - -def test_supervised_graph_collector(): - pos = torch.rand((100, 3)) - x = [torch.rand((100, 3)) for _ in range(10)] - graph_list_1 = [RadiusGraph(pos=pos, radius=0.4, x=x_) for x_ in x] - out_1 = torch.rand((10, 100, 3)) - - pos = torch.rand((50, 3)) - x = [torch.rand((50, 3)) for _ in range(10)] - graph_list_2 = [RadiusGraph(pos=pos, radius=0.4, x=x_) for x_ in x] - out_2 = torch.rand((10, 50, 3)) - - class SupervisedProblem(AbstractProblem): - output_variables = None - conditions = { - "data1": Condition(input=graph_list_1, target=out_1), - "data2": Condition(input=graph_list_2, target=out_2), - } - - problem = SupervisedProblem() - collector = Collector(problem) - collector.store_fixed_data() - # assert all(collector._is_conditions_ready.values()) - for v in collector.conditions_name.values(): - assert v in problem.conditions.keys() diff --git a/tests/test_problem.py b/tests/test_problem.py index 069dc0620..04869d5e6 100644 --- a/tests/test_problem.py +++ b/tests/test_problem.py @@ -4,6 +4,11 @@ from pina import LabelTensor from pina.domain import Union from pina.domain import CartesianDomain +from pina.condition import ( + Condition, + InputTargetCondition, + DomainEquationCondition, +) def test_discretise_domain(): @@ -84,3 +89,31 @@ def test_wrong_custom_sampling_logic(mode): } with pytest.raises(RuntimeError): poisson_problem.discretise_domain(sample_rules=sampling_rules) + + +def test_aggregate_data(): + poisson_problem = Poisson() + poisson_problem.conditions["data"] = Condition( + input=LabelTensor(torch.tensor([[0.0, 1.0]]), labels=["x", "y"]), + target=LabelTensor(torch.tensor([[0.0]]), labels=["u"]), + ) + poisson_problem.discretise_domain(0, "random", domains="all") + poisson_problem.collect_data() + assert isinstance(poisson_problem.collected_data, dict) + for name, conditions in poisson_problem.conditions.items(): + assert name in poisson_problem.collected_data.keys() + if isinstance(conditions, InputTargetCondition): + assert "input" in poisson_problem.collected_data[name].keys() + assert "target" in poisson_problem.collected_data[name].keys() + elif isinstance(conditions, DomainEquationCondition): + assert "input" in poisson_problem.collected_data[name].keys() + assert "target" not in poisson_problem.collected_data[name].keys() + assert "equation" in poisson_problem.collected_data[name].keys() + + +def test_wrong_aggregate_data(): + poisson_problem = Poisson() + poisson_problem.discretise_domain(0, "random", domains=["D"]) + assert not poisson_problem._collected_data + with pytest.raises(RuntimeError): + poisson_problem.collect_data()