11"""Module for the AbstractProblem class."""
22
33from abc import ABCMeta , abstractmethod
4+ import warnings
45from copy import deepcopy
56from ..utils import check_consistency
67from ..domain import DomainInterface , CartesianDomain
78from ..condition .domain_equation_condition import DomainEquationCondition
89from ..label_tensor import LabelTensor
9- from ..utils import merge_tensors
10+ from ..utils import merge_tensors , custom_warning_format
1011
1112
1213class AbstractProblem (metaclass = ABCMeta ):
@@ -43,16 +44,35 @@ def __init__(self):
4344 @property
4445 def collected_data (self ):
4546 """
46- Return the collected data from the problem's conditions.
47+ Return the collected data from the problem's conditions. If some domains
48+ are not sampled, they will not be returned by collected data.
4749
4850 :return: The collected data. Keys are condition names, and values are
4951 dictionaries containing the input points and the corresponding
5052 equations or target points.
5153 :rtype: dict
5254 """
53- if not self ._collected_data :
54- raise RuntimeError (
55- "You have to call collect_data() before accessing the data."
55+ # collect data so far
56+ self .collect_data ()
57+ # raise warning if some sample data are missing
58+ if not self .are_all_domains_discretised :
59+ warnings .formatwarning = custom_warning_format
60+ warnings .filterwarnings ("always" , category = RuntimeWarning )
61+ warning_message = "\n " .join (
62+ [
63+ f"""{ " " * 13 } ---> Domain { key } {
64+ "sampled" if key in self .discretised_domains
65+ else
66+ "not sampled" } """
67+ for key in self .domains .keys ()
68+ ]
69+ )
70+ warnings .warn (
71+ "Some of the domains are still not sampled. Consider calling "
72+ "problem.discretise_domain function for all domains before "
73+ "accessing the collected data:\n "
74+ f"{ warning_message } " ,
75+ RuntimeWarning ,
5676 )
5777 return self ._collected_data
5878
@@ -61,17 +81,14 @@ def collected_data(self):
6181 def input_pts (self ):
6282 """
6383 Return a dictionary mapping condition names to their corresponding
64- input points.
84+ input points. If some domains are not sampled, they will not be returned
85+ and the corresponding condition will be empty.
6586
6687 :return: The input points of the problem.
6788 :rtype: dict
6889 """
6990 to_return = {}
70- if self ._collected_data is None :
71- raise RuntimeError (
72- "You have to call collect_data() before accessing the data."
73- )
74- for cond_name , data in self ._collected_data .items ():
91+ for cond_name , data in self .collected_data .items ():
7592 to_return [cond_name ] = data ["input" ]
7693 return to_return
7794
@@ -303,22 +320,18 @@ def collect_data(self):
303320 Aggregate data from the problem's conditions into a single dictionary.
304321 """
305322 data = {}
306- # check if all domains are discretised
307- if not self .are_all_domains_discretised :
308- raise RuntimeError (
309- "All domains must be discretised before aggregating data."
310- )
311323 # Iterate over the conditions and collect data
312324 for condition_name in self .conditions :
313325 condition = self .conditions [condition_name ]
314326 # Check if the condition has an domain attribute
315327 if hasattr (condition , "domain" ):
316- # Store the discretisation points
317- samples = self .discretised_domains [condition .domain ]
318- data [condition_name ] = {
319- "input" : samples ,
320- "equation" : condition .equation ,
321- }
328+ # Only store the discretisation points if the domain is in the dictionary
329+ if condition .domain in self .discretised_domains :
330+ samples = self .discretised_domains [condition .domain ]
331+ data [condition_name ] = {
332+ "input" : samples ,
333+ "equation" : condition .equation ,
334+ }
322335 else :
323336 # If the condition does not have a domain attribute, store
324337 # the input and target points
0 commit comments