Skip to content

Commit 47bf03b

Browse files
committed
Conditions refactoring (#758)
1 parent 8292542 commit 47bf03b

File tree

3 files changed

+28
-0
lines changed

3 files changed

+28
-0
lines changed

pina/_src/condition/condition_base.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@ def __init__(self, **kwargs):
3434
"""
3535
super().__init__()
3636
self.data = self.store_data(**kwargs)
37+
<<<<<<< HEAD
3738
self.has_custom_dataloader_fn = False
39+
=======
40+
>>>>>>> 4ca4993a (Conditions refactoring (#758))
3841

3942
@property
4043
def problem(self):
@@ -87,8 +90,12 @@ def automatic_batching_collate_fn(cls, batch):
8790
if not batch:
8891
return {}
8992
instance_class = batch[0].__class__
93+
<<<<<<< HEAD
9094
batch = instance_class.create_batch(batch)
9195
return batch
96+
=======
97+
return instance_class.create_batch(batch)
98+
>>>>>>> 4ca4993a (Conditions refactoring (#758))
9299

93100
@staticmethod
94101
def collate_fn(batch, condition):
@@ -106,11 +113,15 @@ def collate_fn(batch, condition):
106113
return data
107114

108115
def create_dataloader(
116+
<<<<<<< HEAD
109117
self,
110118
dataset,
111119
batch_size,
112120
automatic_batching,
113121
**kwargs,
122+
=======
123+
self, dataset, batch_size, shuffle, automatic_batching
124+
>>>>>>> 4ca4993a (Conditions refactoring (#758))
114125
):
115126
"""
116127
Create a DataLoader for the condition.
@@ -121,14 +132,23 @@ def create_dataloader(
121132
:rtype: torch.utils.data.DataLoader
122133
"""
123134
if batch_size == len(dataset):
135+
<<<<<<< HEAD
124136
return DummyDataloader(dataset)
125137
return DataLoader(
126138
dataset=dataset,
139+
=======
140+
pass # will be updated in the near future
141+
return DataLoader(
142+
dataset=dataset,
143+
batch_size=batch_size,
144+
shuffle=shuffle,
145+
>>>>>>> 4ca4993a (Conditions refactoring (#758))
127146
collate_fn=(
128147
partial(self.collate_fn, condition=self)
129148
if not automatic_batching
130149
else self.automatic_batching_collate_fn
131150
),
151+
<<<<<<< HEAD
132152
batch_size=batch_size,
133153
**kwargs,
134154
)
@@ -146,3 +166,6 @@ def switch_dataloader_fn(self, create_dataloader_fn):
146166
# the new function
147167
self.has_custom_dataloader_fn = True
148168
self.create_dataloader = create_dataloader_fn
169+
=======
170+
)
171+
>>>>>>> 4ca4993a (Conditions refactoring (#758))

pina/_src/condition/data_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def create_batch(items):
120120
else torch.stack
121121
)
122122
batch_data[k] = batch_fn(vals)
123+
batch_data[k] = batch_fn(vals, dim=0)
123124
else:
124125
batch_data[k] = sample
125126
return batch_data

pina/_src/condition/domain_equation_condition.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@ class DomainEquationCondition(ConditionBase):
3131
# Available slots
3232
__fields__ = ["domain", "equation"]
3333

34+
<<<<<<< HEAD
3435
_avail_domain_cls = (DomainInterface, str)
36+
=======
37+
_avail_domain_cls = DomainInterface
38+
>>>>>>> 4ca4993a (Conditions refactoring (#758))
3539
_avail_equation_cls = EquationInterface
3640

3741
def __new__(cls, domain, equation):

0 commit comments

Comments
 (0)