Skip to content

Commit 5e54f60

Browse files
authored
Merge branch 'master' into multirom
2 parents 2a92a10 + 3121318 commit 5e54f60

22 files changed

Lines changed: 2180 additions & 573 deletions

.github/workflows/testing_pr.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
fail-fast: false
1414
matrix:
1515
os: [windows-latest, macos-latest, ubuntu-latest]
16-
python-version: [3.7, 3.8]
16+
python-version: [3.8, 3.9, 3.10, 3.11]
1717

1818
steps:
1919
- uses: actions/checkout@v2

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,6 @@ target/
6666

6767
#Ipython Notebook
6868
.ipynb_checkpoints
69+
70+
#revieweing and package modernization
71+
venv_ezyrb/

docs/source/conf.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,22 @@
1818
import sphinx
1919
from sphinx.errors import VersionRequirementError
2020
import sphinx_rtd_theme
21+
import time
22+
import importlib.metadata
23+
2124

2225
# If extensions (or modules to document with autodoc) are in another directory,
2326
# add these directories to sys.path here. If the directory is relative to the
2427
# documentation root, use os.path.abspath to make it absolute, like shown here.
2528
sys.path.insert(0, os.path.abspath('../..'))
26-
import ezyrb.meta as meta
29+
30+
31+
# -- Project infirmation --------
32+
_DISTRIBUTION_METADATA = importlib.metadata.metadata("ezyrb")
33+
project = _DISTRIBUTION_METADATA["Name"]
34+
copyright = f'2016-{time.strftime("%Y")}, EZyRB contributors'
35+
author = _DISTRIBUTION_METADATA["Author"]
36+
2737

2838
# -- General configuration ------------------------------------------------
2939

@@ -69,10 +79,6 @@
6979
# The master toctree document.
7080
master_doc = 'index'
7181

72-
# General information about the project.
73-
project = meta.__project__
74-
copyright = meta.__copyright__
75-
author = meta.__author__
7682

7783
# autoclass
7884
autoclass_content = 'both'
@@ -82,9 +88,9 @@
8288
# built documents.
8389
#
8490
# The short X.Y version.
85-
version = meta.__version__
86-
# The full version, including alpha/beta/rc tags.
91+
version = _DISTRIBUTION_METADATA["Version"]
8792
release = version
93+
# The full version, including alpha/beta/rc tags.
8894

8995
# The language for content autogenerated by Sphinx. Refer to documentation
9096
# for a list of supported languages.
@@ -287,7 +293,7 @@
287293
# One entry per manual page. List of tuples
288294
# (source start file, name, description, authors, manual section).
289295
man_pages = [
290-
(master_doc, meta.__title__, u'EZyRB Documentation',
296+
(master_doc, 'ezyrb' , u'EZyRB Documentation',
291297
[author], 1)
292298
]
293299

ezyrb/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
'MultiReducedOrderModel'
88
]
99

10-
from .meta import *
1110
from .database import Database
1211
from .snapshot import Snapshot
1312
from .parameter import Parameter

ezyrb/approximation/ann.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ def __init__(self, layers, function, stop_training, loss=None,
6060
if not isinstance(stop_training, list):
6161
stop_training = [stop_training]
6262

63+
if torch.cuda.is_available(): # Check if GPU is available
64+
print("Using cuda device")
65+
torch.cuda.empty_cache()
66+
self.use_cuda = True
67+
else:
68+
self.use_cuda = False
69+
6370
self.layers = layers
6471
self.function = function
6572
self.loss = loss
@@ -156,13 +163,19 @@ def fit(self, points, values):
156163
"""
157164

158165
self._build_model(points, values)
166+
167+
if self.use_cuda:
168+
self.model = self.model.cuda()
169+
points = self._convert_numpy_to_torch(points).cuda()
170+
values = self._convert_numpy_to_torch(values).cuda()
171+
else:
172+
points = self._convert_numpy_to_torch(points)
173+
values = self._convert_numpy_to_torch(values)
174+
159175
optimizer = self.optimizer(
160176
self.model.parameters(),
161177
lr=self.lr, weight_decay=self.l2_regularization)
162178

163-
points = self._convert_numpy_to_torch(points)
164-
values = self._convert_numpy_to_torch(values)
165-
166179
n_epoch = 1
167180
flag = True
168181
while flag:
@@ -201,6 +214,12 @@ def predict(self, new_point):
201214
:return: the predicted values via the ANN.
202215
:rtype: numpy.ndarray
203216
"""
204-
new_point = self._convert_numpy_to_torch(np.array(new_point))
205-
y_new = self.model(new_point)
206-
return self._convert_torch_to_numpy(y_new)
217+
if self.use_cuda :
218+
new_point = self._convert_numpy_to_torch(new_point).cuda()
219+
new_point = self._convert_numpy_to_torch(
220+
np.array(new_point.cpu())).cuda()
221+
y_new = self._convert_torch_to_numpy(self.model(new_point).cpu())
222+
else:
223+
new_point = self._convert_numpy_to_torch(np.array(new_point))
224+
y_new = self._convert_torch_to_numpy(self.model(new_point))
225+
return y_new

ezyrb/meta.py

Lines changed: 0 additions & 16 deletions
This file was deleted.

ezyrb/plugin/automatic_shift.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
import torch
55

6+
from ezyrb import Database, Snapshot, Parameter
67
from .plugin import Plugin
78

89

@@ -152,16 +153,25 @@ def fit_preprocessing(self, rom):
152153
snap.values.reshape(-1, 1))
153154

154155
snap.values = self.interpolator.predict(
155-
reference_snapshot.space.reshape(-1, 1)).flatten()
156+
reference_snapshot.space.reshape(-1, 1)).flatten() # reconstructing shifted snapshots in physical space
156157

157158
def predict_postprocessing(self, rom):
158159

159160
ref_space = self.reference_snapshot.space
161+
db = Database()
160162

161163
for param, snap in rom.predict_full_database._pairs:
162164
input_shift = np.hstack([
163165
ref_space.reshape(-1, 1),
164166
np.ones(shape=(ref_space.shape[0], 1))*param.values])
165167
shift = self.shift_network.predict(input_shift)
166-
snap.space = ref_space + shift.flatten()
168+
snap.space = ref_space + shift.flatten() # shifted space transports to correct physical frame
167169
snap.space = snap.space.flatten()
170+
171+
self.interpolator.fit(snap.space, snap.values.reshape(-1, 1))
172+
snap.values = self.interpolator.predict(ref_space) # reconstruct snapshot in physical space
173+
174+
snaps = Snapshot(values = snap.values, space = ref_space)
175+
db.add(Parameter(param.values), snaps)
176+
177+
rom._full_database = db

ezyrb/plugin/scaler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(self, scaler, mode, target) -> None:
2626
self.scaler = scaler
2727
self.mode = mode
2828
self.target = target
29-
29+
3030
@property
3131
def target(self):
3232
"""
@@ -141,5 +141,5 @@ def rom_postprocessing(self, rom):
141141
db.parameters_matrix,
142142
self.scaler.inverse_transform(self._select_matrix(db)),
143143
)
144-
144+
145145
rom._reduced_database = new_db

0 commit comments

Comments
 (0)