Skip to content

Commit 6fd96ad

Browse files
committed
fix problem data collection for v0.1
1 parent 9719bc8 commit 6fd96ad

2 files changed

Lines changed: 54 additions & 30 deletions

File tree

pina/problem/abstract_problem.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
"""Module for the AbstractProblem class."""
22

33
from abc import ABCMeta, abstractmethod
4+
import warnings
45
from copy import deepcopy
56
from ..utils import check_consistency
67
from ..domain import DomainInterface, CartesianDomain
78
from ..condition.domain_equation_condition import DomainEquationCondition
89
from ..label_tensor import LabelTensor
9-
from ..utils import merge_tensors
10+
from ..utils import merge_tensors, custom_warning_format
1011

1112

1213
class AbstractProblem(metaclass=ABCMeta):
@@ -43,16 +44,35 @@ def __init__(self):
4344
@property
4445
def collected_data(self):
4546
"""
46-
Return the collected data from the problem's conditions.
47+
Return the collected data from the problem's conditions. If some domains
48+
are not sampled, they will not be returned by collected data.
4749
4850
:return: The collected data. Keys are condition names, and values are
4951
dictionaries containing the input points and the corresponding
5052
equations or target points.
5153
:rtype: dict
5254
"""
53-
if not self._collected_data:
54-
raise RuntimeError(
55-
"You have to call collect_data() before accessing the data."
55+
# collect data so far
56+
self.collect_data()
57+
# raise warning if some sample data are missing
58+
if not self.are_all_domains_discretised:
59+
warnings.formatwarning = custom_warning_format
60+
warnings.filterwarnings("always", category=RuntimeWarning)
61+
warning_message = "\n".join(
62+
[
63+
f"""{" " * 13} ---> Domain {key} {
64+
"sampled" if key in self.discretised_domains
65+
else
66+
"not sampled"}"""
67+
for key in self.domains
68+
]
69+
)
70+
warnings.warn(
71+
"Some of the domains are still not sampled. Consider calling "
72+
"problem.discretise_domain function for all domains before "
73+
"accessing the collected data:\n"
74+
f"{warning_message}",
75+
RuntimeWarning,
5676
)
5777
return self._collected_data
5878

@@ -61,17 +81,14 @@ def collected_data(self):
6181
def input_pts(self):
6282
"""
6383
Return a dictionary mapping condition names to their corresponding
64-
input points.
84+
input points. If some domains are not sampled, they will not be returned
85+
and the corresponding condition will be empty.
6586
6687
:return: The input points of the problem.
6788
:rtype: dict
6889
"""
6990
to_return = {}
70-
if self._collected_data is None:
71-
raise RuntimeError(
72-
"You have to call collect_data() before accessing the data."
73-
)
74-
for cond_name, data in self._collected_data.items():
91+
for cond_name, data in self.collected_data.items():
7592
to_return[cond_name] = data["input"]
7693
return to_return
7794

@@ -303,22 +320,19 @@ def collect_data(self):
303320
Aggregate data from the problem's conditions into a single dictionary.
304321
"""
305322
data = {}
306-
# check if all domains are discretised
307-
if not self.are_all_domains_discretised:
308-
raise RuntimeError(
309-
"All domains must be discretised before aggregating data."
310-
)
311323
# Iterate over the conditions and collect data
312324
for condition_name in self.conditions:
313325
condition = self.conditions[condition_name]
314326
# Check if the condition has an domain attribute
315327
if hasattr(condition, "domain"):
316-
# Store the discretisation points
317-
samples = self.discretised_domains[condition.domain]
318-
data[condition_name] = {
319-
"input": samples,
320-
"equation": condition.equation,
321-
}
328+
# Only store the discretisation points if the domain is
329+
# in the dictionary
330+
if condition.domain in self.discretised_domains:
331+
samples = self.discretised_domains[condition.domain]
332+
data[condition_name] = {
333+
"input": samples,
334+
"equation": condition.equation,
335+
}
322336
else:
323337
# If the condition does not have a domain attribute, store
324338
# the input and target points

tests/test_problem.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,24 @@ def test_variables_correct_order_sampling():
5050
)
5151

5252

53+
def test_input_pts():
54+
n = 10
55+
poisson_problem = Poisson()
56+
poisson_problem.discretise_domain(n, "grid")
57+
assert sorted(list(poisson_problem.input_pts.keys())) == sorted(
58+
list(poisson_problem.conditions.keys())
59+
)
60+
61+
62+
def test_collected_data():
63+
n = 10
64+
poisson_problem = Poisson()
65+
poisson_problem.discretise_domain(n, "grid")
66+
assert sorted(list(poisson_problem.collected_data.keys())) == sorted(
67+
list(poisson_problem.conditions.keys())
68+
)
69+
70+
5371
def test_add_points():
5472
poisson_problem = Poisson()
5573
poisson_problem.discretise_domain(0, "random", domains=["D"])
@@ -109,11 +127,3 @@ def test_aggregate_data():
109127
assert "input" in poisson_problem.collected_data[name].keys()
110128
assert "target" not in poisson_problem.collected_data[name].keys()
111129
assert "equation" in poisson_problem.collected_data[name].keys()
112-
113-
114-
def test_wrong_aggregate_data():
115-
poisson_problem = Poisson()
116-
poisson_problem.discretise_domain(0, "random", domains=["D"])
117-
assert not poisson_problem._collected_data
118-
with pytest.raises(RuntimeError):
119-
poisson_problem.collect_data()

0 commit comments

Comments
 (0)