Skip to content

Commit e21715a

Browse files
committed
Remove collector
1 parent 6b355b4 commit e21715a

5 files changed

Lines changed: 69 additions & 296 deletions

File tree

pina/collector.py

Lines changed: 0 additions & 129 deletions
This file was deleted.

pina/data/data_module.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from torch.utils.data.distributed import DistributedSampler
1313
from ..label_tensor import LabelTensor
1414
from .dataset import PinaDatasetFactory, PinaTensorDataset
15-
from ..collector import Collector
15+
16+
# from ..collector import Collector
1617

1718

1819
class DummyDataloader:
@@ -330,9 +331,10 @@ def __init__(
330331
self.pin_memory = pin_memory
331332

332333
# Collect data
333-
collector = Collector(problem)
334-
collector.store_fixed_data()
335-
collector.store_sample_domains()
334+
# collector = Collector(problem)
335+
# collector.store_fixed_data()
336+
# collector.store_sample_domains()
337+
problem.aggregate_data()
336338

337339
# Check if the splits are correct
338340
self._check_slit_sizes(train_size, test_size, val_size)
@@ -361,7 +363,7 @@ def __init__(
361363
# raises NotImplementedError
362364
self.val_dataloader = super().val_dataloader
363365

364-
self.collector_splits = self._create_splits(collector, splits_dict)
366+
self.data_splits = self._create_splits(problem.data, splits_dict)
365367
self.transfer_batch_to_device = self._transfer_batch_to_device
366368

367369
def setup(self, stage=None):
@@ -376,23 +378,23 @@ def setup(self, stage=None):
376378
"""
377379
if stage == "fit" or stage is None:
378380
self.train_dataset = PinaDatasetFactory(
379-
self.collector_splits["train"],
381+
self.data_splits["train"],
380382
max_conditions_lengths=self.find_max_conditions_lengths(
381383
"train"
382384
),
383385
automatic_batching=self.automatic_batching,
384386
)
385-
if "val" in self.collector_splits.keys():
387+
if "val" in self.data_splits.keys():
386388
self.val_dataset = PinaDatasetFactory(
387-
self.collector_splits["val"],
389+
self.data_splits["val"],
388390
max_conditions_lengths=self.find_max_conditions_lengths(
389391
"val"
390392
),
391393
automatic_batching=self.automatic_batching,
392394
)
393395
elif stage == "test":
394396
self.test_dataset = PinaDatasetFactory(
395-
self.collector_splits["test"],
397+
self.data_splits["test"],
396398
max_conditions_lengths=self.find_max_conditions_lengths("test"),
397399
automatic_batching=self.automatic_batching,
398400
)
@@ -473,7 +475,7 @@ def _apply_shuffle(condition_dict, len_data):
473475
for (
474476
condition_name,
475477
condition_dict,
476-
) in collector.data_collections.items():
478+
) in collector.items():
477479
len_data = len(condition_dict["input"])
478480
if self.shuffle:
479481
_apply_shuffle(condition_dict, len_data)
@@ -540,7 +542,7 @@ def find_max_conditions_lengths(self, split):
540542
"""
541543

542544
max_conditions_lengths = {}
543-
for k, v in self.collector_splits[split].items():
545+
for k, v in self.data_splits[split].items():
544546
if self.batch_size is None:
545547
max_conditions_lengths[k] = len(v["input"])
546548
elif self.repeat:

pina/problem/abstract_problem.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,11 @@ def __init__(self):
2323
Initialization of the :class:`AbstractProblem` class.
2424
"""
2525
self._discretised_domains = {}
26-
# create collector to manage problem data
2726

2827
# create hook conditions <-> problems
2928
for condition_name in self.conditions:
3029
self.conditions[condition_name].problem = self
3130

32-
self._batching_dimension = 0
33-
3431
# Store in domains dict all the domains object directly passed to
3532
# ConditionInterface. Done for back compatibility with PINA <0.2
3633
if not hasattr(self, "domains"):
@@ -41,24 +38,7 @@ def __init__(self):
4138
self.domains[cond_name] = cond.domain
4239
cond.domain = cond_name
4340

44-
@property
45-
def batching_dimension(self):
46-
"""
47-
Get batching dimension.
48-
49-
:return: The batching dimension.
50-
:rtype: int
51-
"""
52-
return self._batching_dimension
53-
54-
@batching_dimension.setter
55-
def batching_dimension(self, value):
56-
"""
57-
Set the batching dimension.
58-
59-
:param int value: The batching dimension.
60-
"""
61-
self._batching_dimension = value
41+
self.data = None
6242

6343
# back compatibility 0.1
6444
@property
@@ -300,3 +280,26 @@ def add_points(self, new_points_dict):
300280
self.discretised_domains[k] = LabelTensor.vstack(
301281
[self.discretised_domains[k], v]
302282
)
283+
284+
def aggregate_data(self):
285+
"""
286+
Aggregate data from the problem's conditions into a single dictionary.
287+
"""
288+
self.data = {}
289+
if not self.are_all_domains_discretised:
290+
raise RuntimeError(
291+
"All domains must be discretised before aggregating data."
292+
)
293+
for condition_name in self.conditions:
294+
condition = self.conditions[condition_name]
295+
if hasattr(condition, "domain"):
296+
samples = self.discretised_domains[condition.domain]
297+
298+
self.data[condition_name] = {
299+
"input": samples,
300+
"equation": condition.equation,
301+
}
302+
else:
303+
keys = condition.__slots__
304+
values = [getattr(condition, name) for name in keys]
305+
self.data[condition_name] = dict(zip(keys, values))

0 commit comments

Comments
 (0)