@@ -329,6 +329,8 @@ def generate_proxy_coefficient_packing(self):
329329 pw_array = L .ArrayDecl (pw , sizes = int (proxy_coeff_offset [- 1 ]))
330330
331331 for i , (proxy_coeff , expr_name ) in enumerate (self .ir .sub_expressions ):
332+ declarations = []
333+
332334 # Get active coefficients
333335 active_coefficient_offsets = self .ir .proxy_coefficient_offsets [i : i + 2 ]
334336 active_coefficients = self .ir .coefficients_in_proxy [
@@ -337,12 +339,17 @@ def generate_proxy_coefficient_packing(self):
337339 sub_coefficient_sizes = [
338340 active_coefficient .ufl_element ().dim for active_coefficient in active_coefficients
339341 ]
342+
340343 sub_coefficient_offsets = [
341344 self .ir .expression .coefficient_offsets [coeff ] for coeff in active_coefficients
342345 ]
346+ sub_coeff_pos = np .zeros (len (active_coefficients ) + 1 , dtype = np .int32 )
347+ sub_coeff_pos [1 :] = np .cumsum (sub_coefficient_sizes )
343348 # Declare array that holds subset of coefficients
344349 sub_coeff = L .Symbol (f"sub_coeff_{ i } " , dtype = L .DataType .SCALAR )
345350 sub_coeff_array = L .ArrayDecl (sub_coeff , sizes = int (np .sum (sub_coefficient_sizes )))
351+ declarations .append (sub_coeff_array )
352+
346353 pi = L .Symbol ("pi" , dtype = L .DataType .INT )
347354 # Pack coefficiets into contiguous array for expression evaluation
348355 coeff_loops = []
@@ -354,7 +361,7 @@ def generate_proxy_coefficient_packing(self):
354361 sub_coefficient_sizes [j ],
355362 [
356363 L .Assign (
357- sub_coeff [pi ],
364+ sub_coeff [sub_coeff_pos [ j ] + pi ],
358365 self .backend .symbols .coefficients [sub_coefficient_offsets [j ] + pi ],
359366 )
360367 ],
@@ -367,6 +374,7 @@ def generate_proxy_coefficient_packing(self):
367374 proxy_coefficient = L .ArrayDecl (
368375 pz_at_itg_points , sizes = int (np .prod (self .ir .proxy_pack_shape [i ]))
369376 )
377+ declarations .append (proxy_coefficient )
370378 # NOTE: Need to do something similar for constants, currently we just pass them in
371379 custom_data = L .Symbol ("custom_data" , dtype = L .DataType .SCALAR )
372380 func_call = L .CallOp (
@@ -383,24 +391,44 @@ def generate_proxy_coefficient_packing(self):
383391 )
384392 decl = L .Statement (func_call )
385393 # Compute matvec between tabulated expression and interpolation matrix
386- im = proxy_coeff .ufl_element ().basix_element .interpolation_matrix
387- im_table = self .declare_table (f"proxy_im_{ i } " , im )[0 ]
388- im_rows = im .shape [0 ]
389- im_cols = im .shape [1 ]
390- pj = L .Symbol ("pj" , dtype = L .DataType .INT )
394+ identity_assign = False
395+ if isinstance (proxy_coeff .operator , ufl .Interpolate ):
396+ be = proxy_coeff .ufl_element ().basix_element
397+ identity_assign = be .interpolation_is_identity
398+ if not identity_assign :
399+ im = be .interpolation_matrix
400+ else :
401+ raise NotImplementedError (
402+ "Only proxy coefficients for Interpolate supported at the moment"
403+ )
404+
405+ num_dofs = self .ir .proxy_coefficient_sizes [i ]
391406 assign_start = proxy_coeff_offset [i ]
392- inner_assign_loop = L .ForRange (
393- pj ,
394- 0 ,
395- im_cols ,
396- [L .Assign (pw [assign_start + pj ], pz_at_itg_points [pj ] * im_table .symbol [pi ][pj ])],
397- )
398- assign_loop = L .ForRange (pi , 0 , im_rows , [inner_assign_loop ])
407+ if identity_assign :
408+ inner_assign_loop = L .Assign (pw [assign_start + pi ], pz_at_itg_points [pi ])
409+ else :
410+ assert im .shape [0 ] == num_dofs
411+ im_table = self .declare_table (f"proxy_im_{ i } " , im )[0 ]
412+ declarations .append (im_table )
413+ num_quadrature_points = im .shape [1 ]
414+ pj = L .Symbol ("pj" , dtype = L .DataType .INT )
415+ inner_assign_loop = L .ForRange (
416+ pj ,
417+ 0 ,
418+ num_quadrature_points ,
419+ [
420+ L .AssignAdd (
421+ pw [assign_start + pi ], pz_at_itg_points [pj ] * im_table .symbol [pi ][pj ]
422+ )
423+ ],
424+ )
425+ init_pw = L .Assign (pw [assign_start + pi ], 0 )
426+ assign_loop = L .ForRange (pi , 0 , num_dofs , [init_pw , inner_assign_loop ])
399427 intermediates += [
400428 L .Section (
401429 f"Packing { i } th proxy coefficient" ,
402430 statements = [coeff_loops , decl ],
403- declarations = [ proxy_coefficient , sub_coeff_array , im_table ] ,
431+ declarations = declarations ,
404432 input = [],
405433 output = [],
406434 )
0 commit comments