Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 207 additions & 0 deletions backends/ref/ceed-ref-operator.c
Original file line number Diff line number Diff line change
Expand Up @@ -1524,6 +1524,212 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Ref(CeedOperator op, Ce
return CEED_ERROR_SUCCESS;
}

//------------------------------------------------------------------------------
// Assemble Operator AtPoints
//------------------------------------------------------------------------------
static int CeedSingleOperatorAssembleAtPoints_Ref(CeedOperator op, CeedInt offset, CeedVector values) {
CeedInt num_points_offset = 0, num_input_fields, num_output_fields, num_elem, num_comp_active = 1;
CeedScalar *e_data[2 * CEED_FIELD_MAX] = {0}, *assembled;
Ceed ceed;
CeedVector point_coords = NULL, in_vec, out_vec;
CeedElemRestriction rstr_points = NULL;
CeedQFunctionField *qf_input_fields, *qf_output_fields;
CeedQFunction qf;
CeedOperatorField *op_input_fields, *op_output_fields;
CeedOperator_Ref *impl;

CeedCallBackend(CeedOperatorGetData(op, &impl));
CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));

// Setup
CeedCallBackend(CeedOperatorSetupAtPoints_Ref(op));

// Ceed
{
Ceed ceed_parent;

CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
CeedCallBackend(CeedGetParent(ceed, &ceed_parent));
CeedCallBackend(CeedReferenceCopy(ceed_parent, &ceed));
CeedCallBackend(CeedDestroy(&ceed_parent));
}

// Point coordinates
CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, &point_coords));

// Input and output vectors
{
CeedSize input_size, output_size;

CeedCallBackend(CeedOperatorGetActiveVectorLengths(op, &input_size, &output_size));
CeedCallBackend(CeedVectorCreate(ceed, input_size, &in_vec));
CeedCallBackend(CeedVectorCreate(ceed, output_size, &out_vec));
CeedCallBackend(CeedVectorSetValue(out_vec, 0.0));
}

// Assembled array
CeedCallBackend(CeedVectorGetArray(values, CEED_MEM_HOST, &assembled));

// Clear input Evecs
for (CeedInt i = 0; i < num_input_fields; i++) {
bool is_active;
CeedVector vec;

CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
is_active = vec == CEED_VECTOR_ACTIVE;
CeedCallBackend(CeedVectorDestroy(&vec));
if (!is_active || impl->skip_rstr_in[i]) continue;
CeedCallBackend(CeedVectorSetValue(impl->e_vecs_in[i], 0.0));
}

// Input Evecs and Restriction
CeedCallBackend(CeedOperatorSetupInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, NULL, true, e_data, impl, CEED_REQUEST_IMMEDIATE));

// Loop through elements
for (CeedInt e = 0; e < num_elem; e++) {
CeedInt num_points, e_vec_size = 0;

// Setup points for element
CeedCallBackend(
CeedElemRestrictionApplyAtPointsInElement(rstr_points, e, CEED_NOTRANSPOSE, point_coords, impl->point_coords_elem, CEED_REQUEST_IMMEDIATE));
CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points));

// Input basis apply for non-active bases
CeedCallBackend(CeedOperatorInputBasisAtPoints_Ref(e, num_points_offset, num_points, qf_input_fields, op_input_fields, num_input_fields, in_vec,
impl->point_coords_elem, true, false, e_data, impl, CEED_REQUEST_IMMEDIATE));

// Loop over points on element
for (CeedInt i = 0; i < num_input_fields; i++) {
bool is_active_at_points = true, is_active;
CeedInt elem_size_active = 1;
CeedRestrictionType rstr_type;
CeedVector vec;
CeedElemRestriction elem_rstr;

// -- Skip non-active input
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
is_active = vec == CEED_VECTOR_ACTIVE;
CeedCallBackend(CeedVectorDestroy(&vec));
if (!is_active || impl->skip_rstr_in[i]) continue;

// -- Get active restriction type
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr));
CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type));
is_active_at_points = rstr_type == CEED_RESTRICTION_POINTS;
if (!is_active_at_points) CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size_active));
else elem_size_active = num_points;
CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp_active));
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));

e_vec_size = elem_size_active * num_comp_active;
for (CeedInt s = 0; s < e_vec_size; s++) {
const CeedInt comp_in = s / elem_size_active;
const CeedInt node_in = s % elem_size_active;

// -- Update unit vector
{
CeedScalar *array;

if (s == 0) CeedCallBackend(CeedVectorSetValue(impl->e_vecs_in[i], 0.0));
CeedCallBackend(CeedVectorGetArray(impl->e_vecs_in[i], CEED_MEM_HOST, &array));
array[s] = 1.0;
if (s > 0) array[s - 1] = 0.0;
CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs_in[i], &array));
}
// Input basis apply for active bases
CeedCallBackend(CeedOperatorInputBasisAtPoints_Ref(e, num_points_offset, num_points, qf_input_fields, op_input_fields, num_input_fields,
in_vec, impl->point_coords_elem, false, true, e_data, impl, CEED_REQUEST_IMMEDIATE));

// -- Q function
if (!impl->is_identity_qf) {
CeedCallBackend(CeedQFunctionApply(qf, num_points, impl->q_vecs_in, impl->q_vecs_out));
}

// -- Output basis apply and restriction
CeedCallBackend(CeedOperatorOutputBasisAtPoints_Ref(e, num_points_offset, num_points, qf_output_fields, op_output_fields, num_input_fields,
num_output_fields, impl->apply_add_basis_out, impl->skip_rstr_out, op, out_vec,
impl->point_coords_elem, true, impl, CEED_REQUEST_IMMEDIATE));

// -- Build element matrix
for (CeedInt j = 0; j < num_output_fields; j++) {
bool is_active;
CeedInt elem_size = 0;
CeedRestrictionType rstr_type;
CeedVector vec;
CeedElemRestriction elem_rstr;

// ---- Skip non-active output
CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[j], &vec));
is_active = vec == CEED_VECTOR_ACTIVE;
CeedCallBackend(CeedVectorDestroy(&vec));
if (!is_active || impl->skip_rstr_out[j]) continue;

// ---- Check if elem size matches
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[j], &elem_rstr));
CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type));
if (is_active_at_points && rstr_type != CEED_RESTRICTION_POINTS) {
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
continue;
}
if (rstr_type == CEED_RESTRICTION_POINTS) {
CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(elem_rstr, e, &elem_size));
} else {
CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
}
{
CeedInt num_comp = 0;

CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp));
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
if (e_vec_size != num_comp * elem_size) continue;
}
// ---- Copy output
{
const CeedScalar *output;

CeedCallBackend(CeedVectorGetArrayRead(impl->e_vecs_out[j], CEED_MEM_HOST, &output));
for (CeedInt k = 0; k < e_vec_size; k++) {
const CeedInt comp_out = k / elem_size_active;
const CeedInt node_out = k % elem_size_active;

assembled[offset + e * e_vec_size * e_vec_size + (comp_in * num_comp_active + comp_out) * elem_size_active * elem_size_active +
node_out * elem_size_active + node_in] = output[k];
}
CeedCallBackend(CeedVectorRestoreArrayRead(impl->e_vecs_out[j], &output));
}
}
// -- Reset unit vector
if (s == e_vec_size - 1) {
CeedScalar *array;

CeedCallBackend(CeedVectorGetArray(impl->e_vecs_in[i], CEED_MEM_HOST, &array));
array[s] = 0.0;
CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs_in[i], &array));
}
}
}
num_points_offset += num_points;
}

// Restore input arrays
CeedCallBackend(CeedOperatorRestoreInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, true, e_data, impl));

// Restore assembled values
CeedCallBackend(CeedVectorRestoreArray(values, &assembled));

// Cleanup
CeedCallBackend(CeedDestroy(&ceed));
CeedCallBackend(CeedVectorDestroy(&in_vec));
CeedCallBackend(CeedVectorDestroy(&out_vec));
CeedCallBackend(CeedVectorDestroy(&point_coords));
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points));
CeedCallBackend(CeedQFunctionDestroy(&qf));
return CEED_ERROR_SUCCESS;
}

//------------------------------------------------------------------------------
// Operator Destroy
//------------------------------------------------------------------------------
Expand Down Expand Up @@ -1592,6 +1798,7 @@ int CeedOperatorCreateAtPoints_Ref(CeedOperator op) {
CeedCallBackend(
CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunctionUpdate", CeedOperatorLinearAssembleQFunctionAtPointsUpdate_Ref));
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddDiagonal", CeedOperatorLinearAssembleAddDiagonalAtPoints_Ref));
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleSingle", CeedSingleOperatorAssembleAtPoints_Ref));
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAddAtPoints_Ref));
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Ref));
CeedCallBackend(CeedDestroy(&ceed));
Expand Down
6 changes: 5 additions & 1 deletion interface/ceed-preconditioning.c
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ static int CeedSingleOperatorAssembleSymbolic(CeedOperator op, CeedInt offset, C
@ref Developer
**/
static int CeedSingleOperatorAssemble(CeedOperator op, CeedInt offset, CeedVector values) {
bool is_composite;
bool is_composite, is_at_points;

CeedCall(CeedOperatorIsComposite(op, &is_composite));
CeedCheck(!is_composite, CeedOperatorReturnCeed(op), CEED_ERROR_UNSUPPORTED, "Composite operator not supported");
Expand Down Expand Up @@ -595,6 +595,10 @@ static int CeedSingleOperatorAssemble(CeedOperator op, CeedInt offset, CeedVecto
}
}

CeedCall(CeedOperatorIsAtPoints(op, &is_at_points));
CeedCheck(!is_at_points, CeedOperatorReturnCeed(op), CEED_ERROR_UNSUPPORTED,
"Backend does not implement CeedOperatorLinearAssemble for AtPoints operator");

// Assemble QFunction
CeedInt layout_qf[3];
const CeedScalar *assembled_qf_array;
Expand Down
Loading