Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 36 additions & 59 deletions backends/cuda-ref/ceed-cuda-ref-operator.c
Original file line number Diff line number Diff line change
Expand Up @@ -750,15 +750,19 @@ static int CeedOperatorSetupAtPoints_Cuda(CeedOperator op) {
//------------------------------------------------------------------------------
static inline int CeedOperatorInputBasisAtPoints_Cuda(CeedOperatorField op_input_field, CeedQFunctionField qf_input_field, CeedInt input_field,
CeedVector in_vec, CeedVector active_e_vec, CeedInt num_elem, const CeedInt *num_points,
const bool skip_active, CeedOperator_Cuda *impl) {
const bool skip_active, const bool skip_passive, CeedOperator_Cuda *impl) {
bool is_active = false;
CeedEvalMode eval_mode;
CeedVector l_vec, e_vec = impl->e_vecs_in[input_field], q_vec = impl->q_vecs_in[input_field];

// Skip active input
CeedCallBackend(CeedOperatorFieldGetVector(op_input_field, &l_vec));
is_active = l_vec == CEED_VECTOR_ACTIVE;
if (is_active && skip_active) return CEED_ERROR_SUCCESS;
if (skip_active && is_active) return CEED_ERROR_SUCCESS;
if (skip_passive && !is_active) {
CeedCallBackend(CeedVectorDestroy(&l_vec));
return CEED_ERROR_SUCCESS;
}
if (is_active) {
l_vec = in_vec;
if (!e_vec) e_vec = active_e_vec;
Expand Down Expand Up @@ -842,7 +846,7 @@ static int CeedOperatorApplyAddAtPoints_Cuda(CeedOperator op, CeedVector in_vec,
CeedCallBackend(
CeedOperatorInputRestrict_Cuda(op_input_fields[field], qf_input_fields[field], field, in_vec, active_e_vec, false, impl, request));
CeedCallBackend(CeedOperatorInputBasisAtPoints_Cuda(op_input_fields[field], qf_input_fields[field], field, in_vec, active_e_vec, num_elem,
num_points, false, impl));
num_points, false, false, impl));
}

// Output pointers, as necessary
Expand Down Expand Up @@ -1845,19 +1849,8 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
// Process inputs
for (CeedInt i = 0; i < num_input_fields; i++) {
CeedCallBackend(CeedOperatorInputRestrict_Cuda(op_input_fields[i], qf_input_fields[i], i, NULL, NULL, true, impl, request));
CeedCallBackend(CeedOperatorInputBasisAtPoints_Cuda(op_input_fields[i], qf_input_fields[i], i, NULL, NULL, num_elem, num_points, true, impl));
}

// Clear active input Qvecs
for (CeedInt i = 0; i < num_input_fields; i++) {
bool is_active = false;
CeedVector l_vec;

CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &l_vec));
is_active = l_vec == CEED_VECTOR_ACTIVE;
CeedCallBackend(CeedVectorDestroy(&l_vec));
if (!is_active) continue;
CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0));
CeedCallBackend(
CeedOperatorInputBasisAtPoints_Cuda(op_input_fields[i], qf_input_fields[i], i, NULL, NULL, num_elem, num_points, true, false, impl));
}

// Output pointers, as necessary
Expand All @@ -1876,19 +1869,19 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
// Loop over active fields
for (CeedInt i = 0; i < num_input_fields; i++) {
bool is_active = false, is_active_at_points = true;
CeedInt elem_size = 1, num_comp_active = 1, e_vec_size = 0;
CeedInt elem_size = 1, num_comp_active = 1, e_vec_size = 0, field_in = impl->input_field_order[i];
CeedRestrictionType rstr_type;
CeedVector l_vec;
CeedElemRestriction elem_rstr;

// -- Skip non-active input
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &l_vec));
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[field_in], &l_vec));
is_active = l_vec == CEED_VECTOR_ACTIVE;
CeedCallBackend(CeedVectorDestroy(&l_vec));
if (!is_active) continue;
if (!is_active || impl->skip_rstr_in[field_in]) continue;

// -- Get active restriction type
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr));
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[field_in], &elem_rstr));
CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type));
is_active_at_points = rstr_type == CEED_RESTRICTION_POINTS;
if (!is_active_at_points) CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
Expand All @@ -1897,16 +1890,9 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));

e_vec_size = elem_size * num_comp_active;
CeedCallBackend(CeedVectorSetValue(active_e_vec_in, 0.0));
for (CeedInt s = 0; s < e_vec_size; s++) {
bool is_active = false;
CeedEvalMode eval_mode;
CeedVector l_vec, q_vec = impl->q_vecs_in[i];

// Skip non-active input
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &l_vec));
is_active = l_vec == CEED_VECTOR_ACTIVE;
CeedCallBackend(CeedVectorDestroy(&l_vec));
if (!is_active) continue;
CeedVector q_vec = impl->q_vecs_in[field_in];

// Update unit vector
{
Expand All @@ -1915,8 +1901,7 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
CeedSize start = node * 1 + comp * (elem_size * num_elem);
CeedSize stop = (comp + 1) * (elem_size * num_elem);

if (s == 0) CeedCallBackend(CeedVectorSetValue(active_e_vec_in, 0.0));
else CeedCallBackend(CeedVectorSetValueStrided(active_e_vec_in, start, stop, elem_size, 0.0));
if (s != 0) CeedCallBackend(CeedVectorSetValueStrided(active_e_vec_in, start, stop, elem_size, 0.0));

node = s % elem_size, comp = s / elem_size;
start = node * 1 + comp * (elem_size * num_elem);
Expand All @@ -1925,29 +1910,11 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
}

// Basis action
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
switch (eval_mode) {
case CEED_EVAL_NONE: {
const CeedScalar *e_vec_array;

CeedCallBackend(CeedVectorGetArrayRead(active_e_vec_in, CEED_MEM_DEVICE, &e_vec_array));
CeedCallBackend(CeedVectorSetArray(q_vec, CEED_MEM_DEVICE, CEED_USE_POINTER, (CeedScalar *)e_vec_array));
break;
}
case CEED_EVAL_INTERP:
case CEED_EVAL_GRAD:
case CEED_EVAL_DIV:
case CEED_EVAL_CURL: {
CeedBasis basis;
for (CeedInt j = 0; j < num_input_fields; j++) {
CeedInt field = impl->input_field_order[j];

CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis));
CeedCallBackend(
CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_NOTRANSPOSE, eval_mode, impl->point_coords_elem, active_e_vec_in, q_vec));
CeedCallBackend(CeedBasisDestroy(&basis));
break;
}
case CEED_EVAL_WEIGHT:
break; // No action
CeedCallBackend(CeedOperatorInputBasisAtPoints_Cuda(op_input_fields[field], qf_input_fields[field], field, NULL, active_e_vec_in, num_elem,
num_points, false, true, impl));
}

// Q function
Expand All @@ -1957,20 +1924,21 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
for (CeedInt j = 0; j < num_output_fields; j++) {
bool is_active = false;
CeedInt elem_size = 0;
CeedInt field_out = impl->output_field_order[j];
CeedRestrictionType rstr_type;
CeedEvalMode eval_mode;
CeedVector l_vec, e_vec = impl->e_vecs_out[j], q_vec = impl->q_vecs_out[j];
CeedVector l_vec, e_vec = impl->e_vecs_out[field_out], q_vec = impl->q_vecs_out[field_out];
CeedElemRestriction elem_rstr;

// ---- Skip non-active output
CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[j], &l_vec));
CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[field_out], &l_vec));
is_active = l_vec == CEED_VECTOR_ACTIVE;
CeedCallBackend(CeedVectorDestroy(&l_vec));
if (!is_active) continue;
if (!e_vec) e_vec = active_e_vec_out;

// ---- Check if elem size matches
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[j], &elem_rstr));
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[field_out], &elem_rstr));
CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type));
if (is_active_at_points && rstr_type != CEED_RESTRICTION_POINTS) continue;
if (rstr_type == CEED_RESTRICTION_POINTS) {
Expand All @@ -1986,7 +1954,7 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
}

// Basis action
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[j], &eval_mode));
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[field_out], &eval_mode));
switch (eval_mode) {
case CEED_EVAL_NONE: {
CeedScalar *e_vec_array;
Expand All @@ -2001,8 +1969,13 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
case CEED_EVAL_CURL: {
CeedBasis basis;

CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[j], &basis));
CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, q_vec, e_vec));
CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[field_out], &basis));
if (impl->apply_add_basis_out[field_out]) {
CeedCallBackend(
CeedBasisApplyAddAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, q_vec, e_vec));
} else {
CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, q_vec, e_vec));
}
CeedCallBackend(CeedBasisDestroy(&basis));
break;
}
Expand All @@ -2014,6 +1987,10 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
}

// Mask output e-vec
if (impl->skip_rstr_out[field_out]) {
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
continue;
}
CeedCallBackend(CeedVectorPointwiseMult(e_vec, active_e_vec_in, e_vec));

// Restrict
Expand Down
Loading
Loading