Skip to content

Commit d5db852

Browse files
committed
api: lif tensor component kwargs init to its own method
1 parent 658064e commit d5db852

1 file changed

Lines changed: 25 additions & 24 deletions

File tree

devito/types/tensor.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,25 @@ def __init_finalize__(self, *args, **kwargs):
7979
inds, _ = Function.__indices_setup__(grid=grid, dimensions=dimensions)
8080
self._space_dimensions = inds
8181

82+
@classmethod
83+
def _component_kwargs(cls, *inds, **kwargs):
84+
"""
85+
Get the kwargs for a single component
86+
from the kwargs of the TensorFunction.
87+
"""
88+
kw = {}
89+
for k, v in kwargs.items():
90+
if isinstance(v, MatrixBase):
91+
kw[k] = v[*inds]
92+
elif isinstance(v, (list, tuple)):
93+
if len(inds) > 1:
94+
kw[k] = v[inds[0]][inds[1]]
95+
else:
96+
kw[k] = v[inds[0]]
97+
else:
98+
kw[k] = v
99+
return kw
100+
82101
@classmethod
83102
def __subfunc_setup__(cls, *args, **kwargs):
84103
"""
@@ -108,20 +127,10 @@ def __subfunc_setup__(cls, *args, **kwargs):
108127
start = i if (symm or diag) else 0
109128
stop = i + 1 if diag else len(dims)
110129
for j in range(start, stop):
111-
staggj = (stagg[i][j] if stagg is not None
112-
else (NODE if i == j else (d, dims[j])))
113-
# Setup kwargs for subfunction
114-
# Through rebuilding or user input, the kwargs could be
115-
# Tensors as well from a per-component property
116130
sub_kwargs = {'name': f"{name}_{d.name}{dims[j].name}",
117-
'staggered': staggj}
118-
for k, v in kwargs.items():
119-
if isinstance(v, MatrixBase):
120-
sub_kwargs[k] = v[i, j]
121-
elif isinstance(v, (list, tuple)):
122-
sub_kwargs[k] = v[i][j]
123-
else:
124-
sub_kwargs[k] = v
131+
'staggered': (stagg[i][j] if stagg is not None
132+
else (NODE if i == j else (d, dims[j])))}
133+
sub_kwargs.update(cls._component_kwargs(i, j, **kwargs))
125134
funcs2[j] = cls._sub_type(**sub_kwargs)
126135
funcs.append(funcs2)
127136

@@ -335,17 +344,9 @@ def __subfunc_setup__(cls, *args, **kwargs):
335344
stagg = kwargs.get("staggered", None)
336345
name = kwargs.get("name")
337346
for i, d in enumerate(dims):
338-
kwargs["name"] = "%s_%s" % (name, d.name)
339-
kwargs["staggered"] = stagg[i] if stagg is not None else d
340-
# Setup kwargs for subfunction
341-
# Through rebuilding or user input, the kwargs could be
342-
# Tensors as well from a per-component property
343-
sub_kwargs = {}
344-
for k, v in kwargs.items():
345-
if isinstance(v, (list, tuple, MatrixBase)):
346-
sub_kwargs[k] = v[i]
347-
else:
348-
sub_kwargs[k] = v
347+
sub_kwargs = {'name': f"{name}_{d.name}",
348+
'staggered': stagg[i] if stagg is not None else d}
349+
sub_kwargs.update(cls._component_kwargs(i, **kwargs))
349350
funcs.append(cls._sub_type(**sub_kwargs))
350351

351352
return funcs

0 commit comments

Comments
 (0)