44import requests
55import torch
66from io import BytesIO
7- from requests .exceptions import RequestException
87from ... import Condition
98from ... import LabelTensor
109from ...operator import laplacian
1110from ...domain import CartesianDomain
1211from ...equation import Equation , FixedValue
1312from ...problem import SpatialProblem , InverseProblem
14- from ...utils import custom_warning_format
13+ from ...utils import custom_warning_format , check_positive_integer
1514
1615warnings .formatwarning = custom_warning_format
1716warnings .filterwarnings ("always" , category = ResourceWarning )
1817
1918
20- def _load_tensor_from_url (url , labels ):
19+ def _load_tensor_from_url (url , labels , timeout = 10 ):
2120 """
2221 Downloads a tensor file from a URL and wraps it in a LabelTensor.
2322
@@ -28,21 +27,24 @@ def _load_tensor_from_url(url, labels):
2827
2928 :param str url: URL to the remote `.pth` tensor file.
3029 :param list[str] | tuple[str] labels: Labels for the resulting LabelTensor.
30+ :param int timeout: Timeout for the request in seconds.
3131 :return: A LabelTensor object if successful, otherwise None.
3232 :rtype: LabelTensor | None
3333 """
34+ # Try to download the tensor file from the given URL
3435 try :
35- response = requests .get (url )
36+ response = requests .get (url , timeout = timeout )
3637 response .raise_for_status ()
3738 tensor = torch .load (
3839 BytesIO (response .content ), weights_only = False
3940 ).tensor .detach ()
4041 return LabelTensor (tensor , labels )
41- except RequestException as e :
42- print (
43- "Could not download data for 'InversePoisson2DSquareProblem' "
44- f"from '{ url } '. "
45- f"Reason: { e } . Skipping data loading." ,
42+
43+ # If the request fails, issue a warning and return None
44+ except requests .exceptions .RequestException as e :
45+ warnings .warn (
46+ f"Could not download data for 'InversePoisson2DSquareProblem' "
47+ f"from '{ url } '. Reason: { e } . Skipping data loading." ,
4648 ResourceWarning ,
4749 )
4850 return None
@@ -66,19 +68,6 @@ def laplace_equation(input_, output_, params_):
6668 return delta_u - force_term
6769
6870
69- # loading data
70- input_url = (
71- "https://github.com/mathLab/PINA/raw/refs/heads/master"
72- "/tutorials/tutorial7/data/pts_0.5_0.5"
73- )
74- output_url = (
75- "https://github.com/mathLab/PINA/raw/refs/heads/master"
76- "/tutorials/tutorial7/data/pinn_solution_0.5_0.5"
77- )
78- input_data = _load_tensor_from_url (input_url , ["x" , "y" , "mu1" , "mu2" ])
79- output_data = _load_tensor_from_url (output_url , ["u" ])
80-
81-
8271class InversePoisson2DSquareProblem (SpatialProblem , InverseProblem ):
8372 r"""
8473 Implementation of the inverse 2-dimensional Poisson problem in the square
@@ -113,5 +102,43 @@ class InversePoisson2DSquareProblem(SpatialProblem, InverseProblem):
113102 "D" : Condition (domain = "D" , equation = Equation (laplace_equation )),
114103 }
115104
116- if input_data is not None and input_data is not None :
117- conditions ["data" ] = Condition (input = input_data , target = output_data )
105+ def __init__ (self , load = True , n_data = None ):
106+ """
107+ Initialization of the :class:`InversePoisson2DSquareProblem`.
108+
109+ :param bool load: If True, it attempts to load data from remote URLs.
110+ Set to False to skip data loading (e.g., if no internet connection).
111+ :param int n_data: Number of data points to use for "data" condition.
112+ If `None`, all available data points are used. Default is `None`.
113+ :raises ValueError: If `n_data` is not a positive integer.
114+ """
115+ super ().__init__ ()
116+
117+ # Load data if requested
118+ if load :
119+
120+ # Set the number of data points to use for "data" condition
121+ n_data = n_data or 2500
122+ check_positive_integer (n_data )
123+
124+ # Define URLs for input and output data
125+ input_url = (
126+ "https://github.com/mathLab/PINA/raw/refs/heads/master"
127+ "/tutorials/tutorial7/data/pts_0.5_0.5"
128+ )
129+ output_url = (
130+ "https://github.com/mathLab/PINA/raw/refs/heads/master"
131+ "/tutorials/tutorial7/data/pinn_solution_0.5_0.5"
132+ )
133+
134+ # Define input and output data
135+ input_data = _load_tensor_from_url (
136+ input_url , ["x" , "y" , "mu1" , "mu2" ]
137+ )
138+ output_data = _load_tensor_from_url (output_url , ["u" ])
139+
140+ # Add the "data" condition
141+ if input_data is not None and output_data is not None :
142+ self .conditions ["data" ] = Condition (
143+ input = input_data [:n_data ], target = output_data [:n_data ]
144+ )
0 commit comments