Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 36 additions & 22 deletions pina/problem/abstract_problem.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Module for the AbstractProblem class."""

from abc import ABCMeta, abstractmethod
import warnings
from copy import deepcopy
from ..utils import check_consistency
from ..domain import DomainInterface, CartesianDomain
from ..condition.domain_equation_condition import DomainEquationCondition
from ..label_tensor import LabelTensor
from ..utils import merge_tensors
from ..utils import merge_tensors, custom_warning_format


class AbstractProblem(metaclass=ABCMeta):
Expand Down Expand Up @@ -43,16 +44,35 @@ def __init__(self):
@property
def collected_data(self):
"""
Return the collected data from the problem's conditions.
Return the collected data from the problem's conditions. If some domains
are not sampled, they will not be returned by collected data.

:return: The collected data. Keys are condition names, and values are
dictionaries containing the input points and the corresponding
equations or target points.
:rtype: dict
"""
if not self._collected_data:
raise RuntimeError(
"You have to call collect_data() before accessing the data."
# collect data so far
self.collect_data()
# raise warning if some sample data are missing
if not self.are_all_domains_discretised:
warnings.formatwarning = custom_warning_format
warnings.filterwarnings("always", category=RuntimeWarning)
warning_message = "\n".join(
[
f"""{" " * 13} ---> Domain {key} {
"sampled" if key in self.discretised_domains
else
"not sampled"}"""
for key in self.domains
]
)
warnings.warn(
"Some of the domains are still not sampled. Consider calling "
"problem.discretise_domain function for all domains before "
"accessing the collected data:\n"
f"{warning_message}",
RuntimeWarning,
)
return self._collected_data

Expand All @@ -61,17 +81,14 @@ def collected_data(self):
def input_pts(self):
"""
Return a dictionary mapping condition names to their corresponding
input points.
input points. If some domains are not sampled, they will not be returned
and the corresponding condition will be empty.

:return: The input points of the problem.
:rtype: dict
"""
to_return = {}
if self._collected_data is None:
raise RuntimeError(
"You have to call collect_data() before accessing the data."
)
for cond_name, data in self._collected_data.items():
for cond_name, data in self.collected_data.items():
to_return[cond_name] = data["input"]
return to_return

Expand Down Expand Up @@ -303,22 +320,19 @@ def collect_data(self):
Aggregate data from the problem's conditions into a single dictionary.
"""
data = {}
# check if all domains are discretised
if not self.are_all_domains_discretised:
raise RuntimeError(
"All domains must be discretised before aggregating data."
)
# Iterate over the conditions and collect data
for condition_name in self.conditions:
condition = self.conditions[condition_name]
# Check if the condition has an domain attribute
if hasattr(condition, "domain"):
# Store the discretisation points
samples = self.discretised_domains[condition.domain]
data[condition_name] = {
"input": samples,
"equation": condition.equation,
}
# Only store the discretisation points if the domain is
# in the dictionary
if condition.domain in self.discretised_domains:
samples = self.discretised_domains[condition.domain]
data[condition_name] = {
"input": samples,
"equation": condition.equation,
}
else:
# If the condition does not have a domain attribute, store
# the input and target points
Expand Down
26 changes: 18 additions & 8 deletions tests/test_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,24 @@ def test_variables_correct_order_sampling():
)


def test_input_pts():
n = 10
poisson_problem = Poisson()
poisson_problem.discretise_domain(n, "grid")
assert sorted(list(poisson_problem.input_pts.keys())) == sorted(
list(poisson_problem.conditions.keys())
)


def test_collected_data():
n = 10
poisson_problem = Poisson()
poisson_problem.discretise_domain(n, "grid")
assert sorted(list(poisson_problem.collected_data.keys())) == sorted(
list(poisson_problem.conditions.keys())
)


def test_add_points():
poisson_problem = Poisson()
poisson_problem.discretise_domain(0, "random", domains=["D"])
Expand Down Expand Up @@ -109,11 +127,3 @@ def test_aggregate_data():
assert "input" in poisson_problem.collected_data[name].keys()
assert "target" not in poisson_problem.collected_data[name].keys()
assert "equation" in poisson_problem.collected_data[name].keys()


def test_wrong_aggregate_data():
poisson_problem = Poisson()
poisson_problem.discretise_domain(0, "random", domains=["D"])
assert not poisson_problem._collected_data
with pytest.raises(RuntimeError):
poisson_problem.collect_data()