@@ -79,6 +79,25 @@ def __init_finalize__(self, *args, **kwargs):
7979 inds , _ = Function .__indices_setup__ (grid = grid , dimensions = dimensions )
8080 self ._space_dimensions = inds
8181
82+ @classmethod
83+ def _component_kwargs (cls , * inds , ** kwargs ):
84+ """
85+ Get the kwargs for a single component
86+ from the kwargs of the TensorFunction.
87+ """
88+ kw = {}
89+ for k , v in kwargs .items ():
90+ if isinstance (v , MatrixBase ):
91+ kw [k ] = v [* inds ]
92+ elif isinstance (v , (list , tuple )):
93+ if len (inds ) > 1 :
94+ kw [k ] = v [inds [0 ]][inds [1 ]]
95+ else :
96+ kw [k ] = v [inds [0 ]]
97+ else :
98+ kw [k ] = v
99+ return kw
100+
82101 @classmethod
83102 def __subfunc_setup__ (cls , * args , ** kwargs ):
84103 """
@@ -108,20 +127,10 @@ def __subfunc_setup__(cls, *args, **kwargs):
108127 start = i if (symm or diag ) else 0
109128 stop = i + 1 if diag else len (dims )
110129 for j in range (start , stop ):
111- staggj = (stagg [i ][j ] if stagg is not None
112- else (NODE if i == j else (d , dims [j ])))
113- # Setup kwargs for subfunction
114- # Through rebuilding or user input, the kwargs could be
115- # Tensors as well from a per-component property
116130 sub_kwargs = {'name' : f"{ name } _{ d .name } { dims [j ].name } " ,
117- 'staggered' : staggj }
118- for k , v in kwargs .items ():
119- if isinstance (v , MatrixBase ):
120- sub_kwargs [k ] = v [i , j ]
121- elif isinstance (v , (list , tuple )):
122- sub_kwargs [k ] = v [i ][j ]
123- else :
124- sub_kwargs [k ] = v
131+ 'staggered' : (stagg [i ][j ] if stagg is not None
132+ else (NODE if i == j else (d , dims [j ])))}
133+ sub_kwargs .update (cls ._component_kwargs (i , j , ** kwargs ))
125134 funcs2 [j ] = cls ._sub_type (** sub_kwargs )
126135 funcs .append (funcs2 )
127136
@@ -335,17 +344,9 @@ def __subfunc_setup__(cls, *args, **kwargs):
335344 stagg = kwargs .get ("staggered" , None )
336345 name = kwargs .get ("name" )
337346 for i , d in enumerate (dims ):
338- kwargs ["name" ] = "%s_%s" % (name , d .name )
339- kwargs ["staggered" ] = stagg [i ] if stagg is not None else d
340- # Setup kwargs for subfunction
341- # Through rebuilding or user input, the kwargs could be
342- # Tensors as well from a per-component property
343- sub_kwargs = {}
344- for k , v in kwargs .items ():
345- if isinstance (v , (list , tuple , MatrixBase )):
346- sub_kwargs [k ] = v [i ]
347- else :
348- sub_kwargs [k ] = v
347+ sub_kwargs = {'name' : f"{ name } _{ d .name } " ,
348+ 'staggered' : stagg [i ] if stagg is not None else d }
349+ sub_kwargs .update (cls ._component_kwargs (i , ** kwargs ))
349350 funcs .append (cls ._sub_type (** sub_kwargs ))
350351
351352 return funcs
0 commit comments