Skip to content

Commit f05046e

Browse files
committed
op - minor at points diagonal improvements
1 parent 85bbdf9 commit f05046e

1 file changed

Lines changed: 32 additions & 61 deletions

File tree

backends/ref/ceed-ref-operator.c

Lines changed: 32 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -886,8 +886,8 @@ static int CeedOperatorSetupAtPoints_Ref(CeedOperator op) {
886886
//------------------------------------------------------------------------------
887887
static inline int CeedOperatorInputBasisAtPoints_Ref(CeedInt e, CeedInt num_points_offset, CeedInt num_points, CeedQFunctionField *qf_input_fields,
888888
CeedOperatorField *op_input_fields, CeedInt num_input_fields, CeedVector in_vec,
889-
CeedVector point_coords_elem, bool skip_active, CeedScalar *e_data[2 * CEED_FIELD_MAX],
890-
CeedOperator_Ref *impl, CeedRequest *request) {
889+
CeedVector point_coords_elem, bool skip_active, bool skip_passive,
890+
CeedScalar *e_data[2 * CEED_FIELD_MAX], CeedOperator_Ref *impl, CeedRequest *request) {
891891
for (CeedInt i = 0; i < num_input_fields; i++) {
892892
bool is_active;
893893
CeedInt elem_size, size, num_comp;
@@ -902,14 +902,15 @@ static inline int CeedOperatorInputBasisAtPoints_Ref(CeedInt e, CeedInt num_poin
902902
is_active = vec == CEED_VECTOR_ACTIVE;
903903
CeedCallBackend(CeedVectorDestroy(&vec));
904904
if (skip_active && is_active) continue;
905+
if (skip_passive && !is_active) continue;
905906

906907
// Get elem_size, eval_mode, size
907908
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr));
908909
CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type));
909910
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
910911
CeedCallBackend(CeedQFunctionFieldGetSize(qf_input_fields[i], &size));
911912
// Restrict block active input
912-
if (is_active && !impl->skip_rstr_in[i]) {
913+
if (is_active && !impl->skip_rstr_in[i] && !skip_passive) {
913914
if (rstr_type == CEED_RESTRICTION_POINTS) {
914915
CeedCallBackend(CeedElemRestrictionApplyAtPointsInElement(elem_rstr, e, CEED_NOTRANSPOSE, in_vec, impl->e_vecs_in[i], request));
915916
} else {
@@ -952,7 +953,7 @@ static inline int CeedOperatorInputBasisAtPoints_Ref(CeedInt e, CeedInt num_poin
952953
static inline int CeedOperatorOutputBasisAtPoints_Ref(CeedInt e, CeedInt num_points_offset, CeedInt num_points, CeedQFunctionField *qf_output_fields,
953954
CeedOperatorField *op_output_fields, CeedInt num_input_fields, CeedInt num_output_fields,
954955
bool *apply_add_basis, bool *skip_rstr, CeedOperator op, CeedVector out_vec,
955-
CeedVector point_coords_elem, CeedOperator_Ref *impl, CeedRequest *request) {
956+
CeedVector point_coords_elem, bool skip_passive, CeedOperator_Ref *impl, CeedRequest *request) {
956957
for (CeedInt i = 0; i < num_output_fields; i++) {
957958
bool is_active;
958959
CeedRestrictionType rstr_type;
@@ -961,6 +962,12 @@ static inline int CeedOperatorOutputBasisAtPoints_Ref(CeedInt e, CeedInt num_poi
961962
CeedElemRestriction elem_rstr;
962963
CeedBasis basis;
963964

965+
// Skip active input
966+
CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
967+
is_active = vec == CEED_VECTOR_ACTIVE;
968+
CeedCallBackend(CeedVectorDestroy(&vec));
969+
if (skip_passive && !is_active) continue;
970+
964971
// Get elem_size, eval_mode, size
965972
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr));
966973
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
@@ -989,15 +996,14 @@ static inline int CeedOperatorOutputBasisAtPoints_Ref(CeedInt e, CeedInt num_poi
989996
}
990997
}
991998
// Restrict output block
992-
if (skip_rstr[i]) {
999+
if (skip_rstr[i] || skip_passive) {
9931000
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
9941001
continue;
9951002
}
9961003

9971004
// Get output vector
9981005
CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type));
9991006
CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
1000-
is_active = vec == CEED_VECTOR_ACTIVE;
10011007
if (is_active) vec = out_vec;
10021008
// Restrict
10031009
if (rstr_type == CEED_RESTRICTION_POINTS) {
@@ -1049,7 +1055,7 @@ static int CeedOperatorApplyAddAtPoints_Ref(CeedOperator op, CeedVector in_vec,
10491055

10501056
// Input basis apply
10511057
CeedCallBackend(CeedOperatorInputBasisAtPoints_Ref(e, num_points_offset, num_points, qf_input_fields, op_input_fields, num_input_fields, in_vec,
1052-
impl->point_coords_elem, false, e_data, impl, request));
1058+
impl->point_coords_elem, false, false, e_data, impl, request));
10531059

10541060
// Q function
10551061
if (!impl->is_identity_qf) {
@@ -1059,7 +1065,7 @@ static int CeedOperatorApplyAddAtPoints_Ref(CeedOperator op, CeedVector in_vec,
10591065
// Output basis apply and restriction
10601066
CeedCallBackend(CeedOperatorOutputBasisAtPoints_Ref(e, num_points_offset, num_points, qf_output_fields, op_output_fields, num_input_fields,
10611067
num_output_fields, impl->apply_add_basis_out, impl->skip_rstr_out, op, out_vec,
1062-
impl->point_coords_elem, impl, request));
1068+
impl->point_coords_elem, false, impl, request));
10631069

10641070
num_points_offset += num_points;
10651071
}
@@ -1202,7 +1208,7 @@ static inline int CeedOperatorLinearAssembleQFunctionAtPointsCore_Ref(CeedOperat
12021208

12031209
// Input basis apply
12041210
CeedCallBackend(CeedOperatorInputBasisAtPoints_Ref(e, num_points_offset, num_points, qf_input_fields, op_input_fields, num_input_fields, NULL,
1205-
impl->point_coords_elem, true, e_data_full, impl, request));
1211+
impl->point_coords_elem, true, false, e_data_full, impl, request));
12061212

12071213
// Assemble QFunction
12081214
for (CeedInt i = 0; i < num_input_fields; i++) {
@@ -1360,16 +1366,16 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Ref(CeedOperator op, Ce
13601366
CeedCallBackend(CeedVectorSetValue(out_vec, 0.0));
13611367
}
13621368

1363-
// Clear input Qvecs
1369+
// Clear input Evecs
13641370
for (CeedInt i = 0; i < num_input_fields; i++) {
13651371
bool is_active;
13661372
CeedVector vec;
13671373

13681374
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
13691375
is_active = vec == CEED_VECTOR_ACTIVE;
13701376
CeedCallBackend(CeedVectorDestroy(&vec));
1371-
if (!is_active) continue;
1372-
CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0));
1377+
if (!is_active || impl->skip_rstr_in[i]) continue;
1378+
CeedCallBackend(CeedVectorSetValue(impl->e_vecs_in[i], 0.0));
13731379
}
13741380

13751381
// Input Evecs and Restriction
@@ -1385,7 +1391,7 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Ref(CeedOperator op, Ce
13851391

13861392
// Input basis apply for non-active bases
13871393
CeedCallBackend(CeedOperatorInputBasisAtPoints_Ref(e, num_points_offset, num_points, qf_input_fields, op_input_fields, num_input_fields, in_vec,
1388-
impl->point_coords_elem, true, e_data, impl, request));
1394+
impl->point_coords_elem, true, false, e_data, impl, request));
13891395

13901396
// Loop over points on element
13911397
for (CeedInt i = 0; i < num_input_fields; i++) {
@@ -1399,7 +1405,7 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Ref(CeedOperator op, Ce
13991405
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
14001406
is_active = vec == CEED_VECTOR_ACTIVE;
14011407
CeedCallBackend(CeedVectorDestroy(&vec));
1402-
if (!is_active) continue;
1408+
if (!is_active || impl->skip_rstr_in[i]) continue;
14031409

14041410
// -- Get active restriction type
14051411
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr));
@@ -1412,37 +1418,18 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Ref(CeedOperator op, Ce
14121418

14131419
e_vec_size = elem_size_active * num_comp_active;
14141420
for (CeedInt s = 0; s < e_vec_size; s++) {
1415-
CeedEvalMode eval_mode;
1416-
CeedBasis basis;
1417-
14181421
// -- Update unit vector
14191422
{
14201423
CeedScalar *array;
14211424

1422-
if (s == 0) CeedCallBackend(CeedVectorSetValue(impl->e_vecs_in[i], 0.0));
14231425
CeedCallBackend(CeedVectorGetArray(impl->e_vecs_in[i], CEED_MEM_HOST, &array));
14241426
array[s] = 1.0;
14251427
if (s > 0) array[s - 1] = 0.0;
14261428
CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs_in[i], &array));
14271429
}
1428-
// -- Basis action
1429-
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
1430-
switch (eval_mode) {
1431-
case CEED_EVAL_NONE:
1432-
break;
1433-
// Note - these basis eval modes require FEM fields
1434-
case CEED_EVAL_INTERP:
1435-
case CEED_EVAL_GRAD:
1436-
case CEED_EVAL_DIV:
1437-
case CEED_EVAL_CURL:
1438-
CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis));
1439-
CeedCallBackend(CeedBasisApplyAtPoints(basis, 1, &num_points, CEED_NOTRANSPOSE, eval_mode, impl->point_coords_elem, impl->e_vecs_in[i],
1440-
impl->q_vecs_in[i]));
1441-
CeedCallBackend(CeedBasisDestroy(&basis));
1442-
break;
1443-
case CEED_EVAL_WEIGHT:
1444-
break; // No action
1445-
}
1430+
// Input basis apply for active bases
1431+
CeedCallBackend(CeedOperatorInputBasisAtPoints_Ref(e, num_points_offset, num_points, qf_input_fields, op_input_fields, num_input_fields,
1432+
in_vec, impl->point_coords_elem, false, true, e_data, impl, request));
14461433

14471434
// -- Q function
14481435
if (!impl->is_identity_qf) {
@@ -1452,23 +1439,22 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Ref(CeedOperator op, Ce
14521439
// -- Output basis apply and restriction
14531440
CeedCallBackend(CeedOperatorOutputBasisAtPoints_Ref(e, num_points_offset, num_points, qf_output_fields, op_output_fields, num_input_fields,
14541441
num_output_fields, impl->apply_add_basis_out, impl->skip_rstr_out, op, out_vec,
1455-
impl->point_coords_elem, impl, request));
1442+
impl->point_coords_elem, true, impl, request));
14561443

14571444
// -- Grab diagonal value
14581445
for (CeedInt j = 0; j < num_output_fields; j++) {
14591446
bool is_active;
14601447
CeedInt elem_size = 0;
14611448
CeedRestrictionType rstr_type;
1462-
CeedEvalMode eval_mode;
14631449
CeedVector vec;
14641450
CeedElemRestriction elem_rstr;
1465-
CeedBasis basis;
14661451

14671452
// ---- Skip non-active output
14681453
CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[j], &vec));
14691454
is_active = vec == CEED_VECTOR_ACTIVE;
14701455
CeedCallBackend(CeedVectorDestroy(&vec));
14711456
if (!is_active) continue;
1457+
if (impl->skip_rstr_out[j]) continue;
14721458

14731459
// ---- Check if elem size matches
14741460
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[j], &elem_rstr));
@@ -1491,27 +1477,6 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Ref(CeedOperator op, Ce
14911477
continue;
14921478
}
14931479
}
1494-
1495-
// ---- Basis action
1496-
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[j], &eval_mode));
1497-
switch (eval_mode) {
1498-
case CEED_EVAL_NONE:
1499-
break; // No action
1500-
case CEED_EVAL_INTERP:
1501-
case CEED_EVAL_GRAD:
1502-
case CEED_EVAL_DIV:
1503-
case CEED_EVAL_CURL:
1504-
CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[j], &basis));
1505-
CeedCallBackend(CeedBasisApplyAtPoints(basis, 1, &num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, impl->q_vecs_out[j],
1506-
impl->e_vecs_out[j]));
1507-
CeedCallBackend(CeedBasisDestroy(&basis));
1508-
break;
1509-
// LCOV_EXCL_START
1510-
case CEED_EVAL_WEIGHT: {
1511-
return CeedError(CeedOperatorReturnCeed(op), CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
1512-
// LCOV_EXCL_STOP
1513-
}
1514-
}
15151480
// ---- Update output vector
15161481
{
15171482
CeedScalar *array, current_value = 0.0;
@@ -1533,7 +1498,13 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Ref(CeedOperator op, Ce
15331498
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
15341499
}
15351500
// -- Reset unit vector
1536-
if (s == e_vec_size - 1) CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0));
1501+
if (s == e_vec_size - 1) {
1502+
CeedScalar *array;
1503+
1504+
CeedCallBackend(CeedVectorGetArray(impl->e_vecs_in[i], CEED_MEM_HOST, &array));
1505+
array[s] = 0.0;
1506+
CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs_in[i], &array));
1507+
}
15371508
}
15381509
}
15391510
num_points_offset += num_points;

0 commit comments

Comments
 (0)