@@ -750,15 +750,16 @@ static int CeedOperatorSetupAtPoints_Cuda(CeedOperator op) {
750750//------------------------------------------------------------------------------
751751static inline int CeedOperatorInputBasisAtPoints_Cuda (CeedOperatorField op_input_field , CeedQFunctionField qf_input_field , CeedInt input_field ,
752752 CeedVector in_vec , CeedVector active_e_vec , CeedInt num_elem , const CeedInt * num_points ,
753- const bool skip_active , CeedOperator_Cuda * impl ) {
753+ const bool skip_active , const bool skip_passive , CeedOperator_Cuda * impl ) {
754754 bool is_active = false;
755755 CeedEvalMode eval_mode ;
756756 CeedVector l_vec , e_vec = impl -> e_vecs_in [input_field ], q_vec = impl -> q_vecs_in [input_field ];
757757
758758 // Skip active input
759759 CeedCallBackend (CeedOperatorFieldGetVector (op_input_field , & l_vec ));
760760 is_active = l_vec == CEED_VECTOR_ACTIVE ;
761- if (is_active && skip_active ) return CEED_ERROR_SUCCESS ;
761+ if (skip_active && is_active ) return CEED_ERROR_SUCCESS ;
762+ if (skip_passive && !is_active ) return CEED_ERROR_SUCCESS ;
762763 if (is_active ) {
763764 l_vec = in_vec ;
764765 if (!e_vec ) e_vec = active_e_vec ;
@@ -842,7 +843,7 @@ static int CeedOperatorApplyAddAtPoints_Cuda(CeedOperator op, CeedVector in_vec,
842843 CeedCallBackend (
843844 CeedOperatorInputRestrict_Cuda (op_input_fields [field ], qf_input_fields [field ], field , in_vec , active_e_vec , false, impl , request ));
844845 CeedCallBackend (CeedOperatorInputBasisAtPoints_Cuda (op_input_fields [field ], qf_input_fields [field ], field , in_vec , active_e_vec , num_elem ,
845- num_points , false, impl ));
846+ num_points , false, false, impl ));
846847 }
847848
848849 // Output pointers, as necessary
@@ -1845,19 +1846,8 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
18451846 // Process inputs
18461847 for (CeedInt i = 0 ; i < num_input_fields ; i ++ ) {
18471848 CeedCallBackend (CeedOperatorInputRestrict_Cuda (op_input_fields [i ], qf_input_fields [i ], i , NULL , NULL , true, impl , request ));
1848- CeedCallBackend (CeedOperatorInputBasisAtPoints_Cuda (op_input_fields [i ], qf_input_fields [i ], i , NULL , NULL , num_elem , num_points , true, impl ));
1849- }
1850-
1851- // Clear active input Qvecs
1852- for (CeedInt i = 0 ; i < num_input_fields ; i ++ ) {
1853- bool is_active = false;
1854- CeedVector l_vec ;
1855-
1856- CeedCallBackend (CeedOperatorFieldGetVector (op_input_fields [i ], & l_vec ));
1857- is_active = l_vec == CEED_VECTOR_ACTIVE ;
1858- CeedCallBackend (CeedVectorDestroy (& l_vec ));
1859- if (!is_active ) continue ;
1860- CeedCallBackend (CeedVectorSetValue (impl -> q_vecs_in [i ], 0.0 ));
1849+ CeedCallBackend (
1850+ CeedOperatorInputBasisAtPoints_Cuda (op_input_fields [i ], qf_input_fields [i ], i , NULL , NULL , num_elem , num_points , true, false, impl ));
18611851 }
18621852
18631853 // Output pointers, as necessary
@@ -1876,19 +1866,19 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
18761866 // Loop over active fields
18771867 for (CeedInt i = 0 ; i < num_input_fields ; i ++ ) {
18781868 bool is_active = false, is_active_at_points = true;
1879- CeedInt elem_size = 1 , num_comp_active = 1 , e_vec_size = 0 ;
1869+ CeedInt elem_size = 1 , num_comp_active = 1 , e_vec_size = 0 , field_in = impl -> input_field_order [ i ] ;
18801870 CeedRestrictionType rstr_type ;
18811871 CeedVector l_vec ;
18821872 CeedElemRestriction elem_rstr ;
18831873
18841874 // -- Skip non-active input
1885- CeedCallBackend (CeedOperatorFieldGetVector (op_input_fields [i ], & l_vec ));
1875+ CeedCallBackend (CeedOperatorFieldGetVector (op_input_fields [field_in ], & l_vec ));
18861876 is_active = l_vec == CEED_VECTOR_ACTIVE ;
18871877 CeedCallBackend (CeedVectorDestroy (& l_vec ));
1888- if (!is_active ) continue ;
1878+ if (!is_active || impl -> skip_rstr_in [ field_in ] ) continue ;
18891879
18901880 // -- Get active restriction type
1891- CeedCallBackend (CeedOperatorFieldGetElemRestriction (op_input_fields [i ], & elem_rstr ));
1881+ CeedCallBackend (CeedOperatorFieldGetElemRestriction (op_input_fields [field_in ], & elem_rstr ));
18921882 CeedCallBackend (CeedElemRestrictionGetType (elem_rstr , & rstr_type ));
18931883 is_active_at_points = rstr_type == CEED_RESTRICTION_POINTS ;
18941884 if (!is_active_at_points ) CeedCallBackend (CeedElemRestrictionGetElementSize (elem_rstr , & elem_size ));
@@ -1897,16 +1887,9 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
18971887 CeedCallBackend (CeedElemRestrictionDestroy (& elem_rstr ));
18981888
18991889 e_vec_size = elem_size * num_comp_active ;
1890+ CeedCallBackend (CeedVectorSetValue (active_e_vec_in , 0.0 ));
19001891 for (CeedInt s = 0 ; s < e_vec_size ; s ++ ) {
1901- bool is_active = false;
1902- CeedEvalMode eval_mode ;
1903- CeedVector l_vec , q_vec = impl -> q_vecs_in [i ];
1904-
1905- // Skip non-active input
1906- CeedCallBackend (CeedOperatorFieldGetVector (op_input_fields [i ], & l_vec ));
1907- is_active = l_vec == CEED_VECTOR_ACTIVE ;
1908- CeedCallBackend (CeedVectorDestroy (& l_vec ));
1909- if (!is_active ) continue ;
1892+ CeedVector q_vec = impl -> q_vecs_in [field_in ];
19101893
19111894 // Update unit vector
19121895 {
@@ -1915,8 +1898,7 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
19151898 CeedSize start = node * 1 + comp * (elem_size * num_elem );
19161899 CeedSize stop = (comp + 1 ) * (elem_size * num_elem );
19171900
1918- if (s == 0 ) CeedCallBackend (CeedVectorSetValue (active_e_vec_in , 0.0 ));
1919- else CeedCallBackend (CeedVectorSetValueStrided (active_e_vec_in , start , stop , elem_size , 0.0 ));
1901+ if (s != 0 ) CeedCallBackend (CeedVectorSetValueStrided (active_e_vec_in , start , stop , elem_size , 0.0 ));
19201902
19211903 node = s % elem_size , comp = s / elem_size ;
19221904 start = node * 1 + comp * (elem_size * num_elem );
@@ -1925,29 +1907,11 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
19251907 }
19261908
19271909 // Basis action
1928- CeedCallBackend (CeedQFunctionFieldGetEvalMode (qf_input_fields [i ], & eval_mode ));
1929- switch (eval_mode ) {
1930- case CEED_EVAL_NONE : {
1931- const CeedScalar * e_vec_array ;
1932-
1933- CeedCallBackend (CeedVectorGetArrayRead (active_e_vec_in , CEED_MEM_DEVICE , & e_vec_array ));
1934- CeedCallBackend (CeedVectorSetArray (q_vec , CEED_MEM_DEVICE , CEED_USE_POINTER , (CeedScalar * )e_vec_array ));
1935- break ;
1936- }
1937- case CEED_EVAL_INTERP :
1938- case CEED_EVAL_GRAD :
1939- case CEED_EVAL_DIV :
1940- case CEED_EVAL_CURL : {
1941- CeedBasis basis ;
1910+ for (CeedInt j = 0 ; j < num_input_fields ; j ++ ) {
1911+ CeedInt field = impl -> input_field_order [j ];
19421912
1943- CeedCallBackend (CeedOperatorFieldGetBasis (op_input_fields [i ], & basis ));
1944- CeedCallBackend (
1945- CeedBasisApplyAtPoints (basis , num_elem , num_points , CEED_NOTRANSPOSE , eval_mode , impl -> point_coords_elem , active_e_vec_in , q_vec ));
1946- CeedCallBackend (CeedBasisDestroy (& basis ));
1947- break ;
1948- }
1949- case CEED_EVAL_WEIGHT :
1950- break ; // No action
1913+ CeedCallBackend (CeedOperatorInputBasisAtPoints_Cuda (op_input_fields [field ], qf_input_fields [field ], field , NULL , active_e_vec_in , num_elem ,
1914+ num_points , false, true, impl ));
19511915 }
19521916
19531917 // Q function
@@ -1957,20 +1921,21 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
19571921 for (CeedInt j = 0 ; j < num_output_fields ; j ++ ) {
19581922 bool is_active = false;
19591923 CeedInt elem_size = 0 ;
1924+ CeedInt field_out = impl -> output_field_order [j ];
19601925 CeedRestrictionType rstr_type ;
19611926 CeedEvalMode eval_mode ;
1962- CeedVector l_vec , e_vec = impl -> e_vecs_out [j ], q_vec = impl -> q_vecs_out [j ];
1927+ CeedVector l_vec , e_vec = impl -> e_vecs_out [field_out ], q_vec = impl -> q_vecs_out [field_out ];
19631928 CeedElemRestriction elem_rstr ;
19641929
19651930 // ---- Skip non-active output
1966- CeedCallBackend (CeedOperatorFieldGetVector (op_output_fields [j ], & l_vec ));
1931+ CeedCallBackend (CeedOperatorFieldGetVector (op_output_fields [field_out ], & l_vec ));
19671932 is_active = l_vec == CEED_VECTOR_ACTIVE ;
19681933 CeedCallBackend (CeedVectorDestroy (& l_vec ));
19691934 if (!is_active ) continue ;
19701935 if (!e_vec ) e_vec = active_e_vec_out ;
19711936
19721937 // ---- Check if elem size matches
1973- CeedCallBackend (CeedOperatorFieldGetElemRestriction (op_output_fields [j ], & elem_rstr ));
1938+ CeedCallBackend (CeedOperatorFieldGetElemRestriction (op_output_fields [field_out ], & elem_rstr ));
19741939 CeedCallBackend (CeedElemRestrictionGetType (elem_rstr , & rstr_type ));
19751940 if (is_active_at_points && rstr_type != CEED_RESTRICTION_POINTS ) continue ;
19761941 if (rstr_type == CEED_RESTRICTION_POINTS ) {
@@ -1986,7 +1951,7 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
19861951 }
19871952
19881953 // Basis action
1989- CeedCallBackend (CeedQFunctionFieldGetEvalMode (qf_output_fields [j ], & eval_mode ));
1954+ CeedCallBackend (CeedQFunctionFieldGetEvalMode (qf_output_fields [field_out ], & eval_mode ));
19901955 switch (eval_mode ) {
19911956 case CEED_EVAL_NONE : {
19921957 CeedScalar * e_vec_array ;
@@ -2001,8 +1966,13 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
20011966 case CEED_EVAL_CURL : {
20021967 CeedBasis basis ;
20031968
2004- CeedCallBackend (CeedOperatorFieldGetBasis (op_output_fields [j ], & basis ));
2005- CeedCallBackend (CeedBasisApplyAtPoints (basis , num_elem , num_points , CEED_TRANSPOSE , eval_mode , impl -> point_coords_elem , q_vec , e_vec ));
1969+ CeedCallBackend (CeedOperatorFieldGetBasis (op_output_fields [field_out ], & basis ));
1970+ if (impl -> apply_add_basis_out [field_out ]) {
1971+ CeedCallBackend (
1972+ CeedBasisApplyAddAtPoints (basis , num_elem , num_points , CEED_TRANSPOSE , eval_mode , impl -> point_coords_elem , q_vec , e_vec ));
1973+ } else {
1974+ CeedCallBackend (CeedBasisApplyAtPoints (basis , num_elem , num_points , CEED_TRANSPOSE , eval_mode , impl -> point_coords_elem , q_vec , e_vec ));
1975+ }
20061976 CeedCallBackend (CeedBasisDestroy (& basis ));
20071977 break ;
20081978 }
@@ -2014,6 +1984,10 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
20141984 }
20151985
20161986 // Mask output e-vec
1987+ if (impl -> skip_rstr_out [field_out ]) {
1988+ CeedCallBackend (CeedElemRestrictionDestroy (& elem_rstr ));
1989+ continue ;
1990+ }
20171991 CeedCallBackend (CeedVectorPointwiseMult (e_vec , active_e_vec_in , e_vec ));
20181992
20191993 // Restrict
0 commit comments