Skip to content

Commit 18d1f9b

Browse files
committed
api: lif tensor component kwargs init to its own method
1 parent 658064e commit 18d1f9b

1 file changed

Lines changed: 28 additions & 24 deletions

File tree

devito/types/tensor.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,28 @@ def __init_finalize__(self, *args, **kwargs):
7979
inds, _ = Function.__indices_setup__(grid=grid, dimensions=dimensions)
8080
self._space_dimensions = inds
8181

82+
@classmethod
83+
def _component_kwargs(cls, *inds, **kwargs):
84+
"""
85+
Get the kwargs for a single component
86+
from the kwargs of the TensorFunction.
87+
"""
88+
kw = {}
89+
for k, v in kwargs.items():
90+
if isinstance(v, MatrixBase):
91+
if len(inds) > 1:
92+
kw[k] = v[inds[0], inds[1]]
93+
else:
94+
kw[k] = v[inds[0]]
95+
elif isinstance(v, (list, tuple)):
96+
if len(inds) > 1:
97+
kw[k] = v[inds[0]][inds[1]]
98+
else:
99+
kw[k] = v[inds[0]]
100+
else:
101+
kw[k] = v
102+
return kw
103+
82104
@classmethod
83105
def __subfunc_setup__(cls, *args, **kwargs):
84106
"""
@@ -108,20 +130,10 @@ def __subfunc_setup__(cls, *args, **kwargs):
108130
start = i if (symm or diag) else 0
109131
stop = i + 1 if diag else len(dims)
110132
for j in range(start, stop):
111-
staggj = (stagg[i][j] if stagg is not None
112-
else (NODE if i == j else (d, dims[j])))
113-
# Setup kwargs for subfunction
114-
# Through rebuilding or user input, the kwargs could be
115-
# Tensors as well from a per-component property
116133
sub_kwargs = {'name': f"{name}_{d.name}{dims[j].name}",
117-
'staggered': staggj}
118-
for k, v in kwargs.items():
119-
if isinstance(v, MatrixBase):
120-
sub_kwargs[k] = v[i, j]
121-
elif isinstance(v, (list, tuple)):
122-
sub_kwargs[k] = v[i][j]
123-
else:
124-
sub_kwargs[k] = v
134+
'staggered': (stagg[i][j] if stagg is not None
135+
else (NODE if i == j else (d, dims[j])))}
136+
sub_kwargs.update(cls._component_kwargs(i, j, **kwargs))
125137
funcs2[j] = cls._sub_type(**sub_kwargs)
126138
funcs.append(funcs2)
127139

@@ -335,17 +347,9 @@ def __subfunc_setup__(cls, *args, **kwargs):
335347
stagg = kwargs.get("staggered", None)
336348
name = kwargs.get("name")
337349
for i, d in enumerate(dims):
338-
kwargs["name"] = "%s_%s" % (name, d.name)
339-
kwargs["staggered"] = stagg[i] if stagg is not None else d
340-
# Setup kwargs for subfunction
341-
# Through rebuilding or user input, the kwargs could be
342-
# Tensors as well from a per-component property
343-
sub_kwargs = {}
344-
for k, v in kwargs.items():
345-
if isinstance(v, (list, tuple, MatrixBase)):
346-
sub_kwargs[k] = v[i]
347-
else:
348-
sub_kwargs[k] = v
350+
sub_kwargs = {'name': f"{name}_{d.name}",
351+
'staggered': stagg[i] if stagg is not None else d}
352+
sub_kwargs.update(cls._component_kwargs(i, **kwargs))
349353
funcs.append(cls._sub_type(**sub_kwargs))
350354

351355
return funcs

0 commit comments

Comments
 (0)