Skip to content

Commit 45e1399

Browse files
committed
pc - CPU support for AtPoints assembly
1 parent 48fdef1 commit 45e1399

6 files changed

Lines changed: 705 additions & 1 deletion

File tree

backends/ref/ceed-ref-operator.c

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,6 +1524,212 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Ref(CeedOperator op, Ce
15241524
return CEED_ERROR_SUCCESS;
15251525
}
15261526

1527+
//------------------------------------------------------------------------------
1528+
// Assemble Operator AtPoints
1529+
//------------------------------------------------------------------------------
1530+
static int CeedSingleOperatorAssembleAtPoints_Ref(CeedOperator op, CeedInt offset, CeedVector values) {
1531+
CeedInt num_points_offset = 0, num_input_fields, num_output_fields, num_elem, num_comp_active = 1;
1532+
CeedScalar *e_data[2 * CEED_FIELD_MAX] = {0}, *assembled;
1533+
Ceed ceed;
1534+
CeedVector point_coords = NULL, in_vec, out_vec;
1535+
CeedElemRestriction rstr_points = NULL;
1536+
CeedQFunctionField *qf_input_fields, *qf_output_fields;
1537+
CeedQFunction qf;
1538+
CeedOperatorField *op_input_fields, *op_output_fields;
1539+
CeedOperator_Ref *impl;
1540+
1541+
CeedCallBackend(CeedOperatorGetData(op, &impl));
1542+
CeedCallBackend(CeedOperatorGetNumElements(op, &num_elem));
1543+
CeedCallBackend(CeedOperatorGetQFunction(op, &qf));
1544+
CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &op_input_fields, &num_output_fields, &op_output_fields));
1545+
CeedCallBackend(CeedQFunctionGetFields(qf, NULL, &qf_input_fields, NULL, &qf_output_fields));
1546+
1547+
// Setup
1548+
CeedCallBackend(CeedOperatorSetupAtPoints_Ref(op));
1549+
1550+
// Ceed
1551+
{
1552+
Ceed ceed_parent;
1553+
1554+
CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
1555+
CeedCallBackend(CeedGetParent(ceed, &ceed_parent));
1556+
CeedCallBackend(CeedReferenceCopy(ceed_parent, &ceed));
1557+
CeedCallBackend(CeedDestroy(&ceed_parent));
1558+
}
1559+
1560+
// Point coordinates
1561+
CeedCallBackend(CeedOperatorAtPointsGetPoints(op, &rstr_points, &point_coords));
1562+
1563+
// Input and output vectors
1564+
{
1565+
CeedSize input_size, output_size;
1566+
1567+
CeedCallBackend(CeedOperatorGetActiveVectorLengths(op, &input_size, &output_size));
1568+
CeedCallBackend(CeedVectorCreate(ceed, input_size, &in_vec));
1569+
CeedCallBackend(CeedVectorCreate(ceed, output_size, &out_vec));
1570+
CeedCallBackend(CeedVectorSetValue(out_vec, 0.0));
1571+
}
1572+
1573+
// Assembled array
1574+
CeedCallBackend(CeedVectorGetArray(values, CEED_MEM_HOST, &assembled));
1575+
1576+
// Clear input Evecs
1577+
for (CeedInt i = 0; i < num_input_fields; i++) {
1578+
bool is_active;
1579+
CeedVector vec;
1580+
1581+
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
1582+
is_active = vec == CEED_VECTOR_ACTIVE;
1583+
CeedCallBackend(CeedVectorDestroy(&vec));
1584+
if (!is_active || impl->skip_rstr_in[i]) continue;
1585+
CeedCallBackend(CeedVectorSetValue(impl->e_vecs_in[i], 0.0));
1586+
}
1587+
1588+
// Input Evecs and Restriction
1589+
CeedCallBackend(CeedOperatorSetupInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, NULL, true, e_data, impl, CEED_REQUEST_IMMEDIATE));
1590+
1591+
// Loop through elements
1592+
for (CeedInt e = 0; e < num_elem; e++) {
1593+
CeedInt num_points, e_vec_size = 0;
1594+
1595+
// Setup points for element
1596+
CeedCallBackend(
1597+
CeedElemRestrictionApplyAtPointsInElement(rstr_points, e, CEED_NOTRANSPOSE, point_coords, impl->point_coords_elem, CEED_REQUEST_IMMEDIATE));
1598+
CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(rstr_points, e, &num_points));
1599+
1600+
// Input basis apply for non-active bases
1601+
CeedCallBackend(CeedOperatorInputBasisAtPoints_Ref(e, num_points_offset, num_points, qf_input_fields, op_input_fields, num_input_fields, in_vec,
1602+
impl->point_coords_elem, true, false, e_data, impl, CEED_REQUEST_IMMEDIATE));
1603+
1604+
// Loop over points on element
1605+
for (CeedInt i = 0; i < num_input_fields; i++) {
1606+
bool is_active_at_points = true, is_active;
1607+
CeedInt elem_size_active = 1;
1608+
CeedRestrictionType rstr_type;
1609+
CeedVector vec;
1610+
CeedElemRestriction elem_rstr;
1611+
1612+
// -- Skip non-active input
1613+
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec));
1614+
is_active = vec == CEED_VECTOR_ACTIVE;
1615+
CeedCallBackend(CeedVectorDestroy(&vec));
1616+
if (!is_active || impl->skip_rstr_in[i]) continue;
1617+
1618+
// -- Get active restriction type
1619+
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr));
1620+
CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type));
1621+
is_active_at_points = rstr_type == CEED_RESTRICTION_POINTS;
1622+
if (!is_active_at_points) CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size_active));
1623+
else elem_size_active = num_points;
1624+
CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp_active));
1625+
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
1626+
1627+
e_vec_size = elem_size_active * num_comp_active;
1628+
for (CeedInt s = 0; s < e_vec_size; s++) {
1629+
const CeedInt comp_in = s / elem_size_active;
1630+
const CeedInt node_in = s % elem_size_active;
1631+
1632+
// -- Update unit vector
1633+
{
1634+
CeedScalar *array;
1635+
1636+
if (s == 0) CeedCallBackend(CeedVectorSetValue(impl->e_vecs_in[i], 0.0));
1637+
CeedCallBackend(CeedVectorGetArray(impl->e_vecs_in[i], CEED_MEM_HOST, &array));
1638+
array[s] = 1.0;
1639+
if (s > 0) array[s - 1] = 0.0;
1640+
CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs_in[i], &array));
1641+
}
1642+
// Input basis apply for active bases
1643+
CeedCallBackend(CeedOperatorInputBasisAtPoints_Ref(e, num_points_offset, num_points, qf_input_fields, op_input_fields, num_input_fields,
1644+
in_vec, impl->point_coords_elem, false, true, e_data, impl, CEED_REQUEST_IMMEDIATE));
1645+
1646+
// -- Q function
1647+
if (!impl->is_identity_qf) {
1648+
CeedCallBackend(CeedQFunctionApply(qf, num_points, impl->q_vecs_in, impl->q_vecs_out));
1649+
}
1650+
1651+
// -- Output basis apply and restriction
1652+
CeedCallBackend(CeedOperatorOutputBasisAtPoints_Ref(e, num_points_offset, num_points, qf_output_fields, op_output_fields, num_input_fields,
1653+
num_output_fields, impl->apply_add_basis_out, impl->skip_rstr_out, op, out_vec,
1654+
impl->point_coords_elem, true, impl, CEED_REQUEST_IMMEDIATE));
1655+
1656+
// -- Build element matrix
1657+
for (CeedInt j = 0; j < num_output_fields; j++) {
1658+
bool is_active;
1659+
CeedInt elem_size = 0;
1660+
CeedRestrictionType rstr_type;
1661+
CeedVector vec;
1662+
CeedElemRestriction elem_rstr;
1663+
1664+
// ---- Skip non-active output
1665+
CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[j], &vec));
1666+
is_active = vec == CEED_VECTOR_ACTIVE;
1667+
CeedCallBackend(CeedVectorDestroy(&vec));
1668+
if (!is_active || impl->skip_rstr_out[j]) continue;
1669+
1670+
// ---- Check if elem size matches
1671+
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[j], &elem_rstr));
1672+
CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type));
1673+
if (is_active_at_points && rstr_type != CEED_RESTRICTION_POINTS) {
1674+
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
1675+
continue;
1676+
}
1677+
if (rstr_type == CEED_RESTRICTION_POINTS) {
1678+
CeedCallBackend(CeedElemRestrictionGetNumPointsInElement(elem_rstr, e, &elem_size));
1679+
} else {
1680+
CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
1681+
}
1682+
{
1683+
CeedInt num_comp = 0;
1684+
1685+
CeedCallBackend(CeedElemRestrictionGetNumComponents(elem_rstr, &num_comp));
1686+
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
1687+
if (e_vec_size != num_comp * elem_size) continue;
1688+
}
1689+
// ---- Copy output
1690+
{
1691+
const CeedScalar *output;
1692+
1693+
CeedCallBackend(CeedVectorGetArrayRead(impl->e_vecs_out[j], CEED_MEM_HOST, &output));
1694+
for (CeedInt k = 0; k < e_vec_size; k++) {
1695+
const CeedInt comp_out = k / elem_size_active;
1696+
const CeedInt node_out = k % elem_size_active;
1697+
1698+
assembled[offset + e * e_vec_size * e_vec_size + (comp_in * num_comp_active + comp_out) * elem_size_active * elem_size_active +
1699+
node_out * elem_size_active + node_in] = output[k];
1700+
}
1701+
CeedCallBackend(CeedVectorRestoreArrayRead(impl->e_vecs_out[j], &output));
1702+
}
1703+
}
1704+
// -- Reset unit vector
1705+
if (s == e_vec_size - 1) {
1706+
CeedScalar *array;
1707+
1708+
CeedCallBackend(CeedVectorGetArray(impl->e_vecs_in[i], CEED_MEM_HOST, &array));
1709+
array[s] = 0.0;
1710+
CeedCallBackend(CeedVectorRestoreArray(impl->e_vecs_in[i], &array));
1711+
}
1712+
}
1713+
}
1714+
num_points_offset += num_points;
1715+
}
1716+
1717+
// Restore input arrays
1718+
CeedCallBackend(CeedOperatorRestoreInputs_Ref(num_input_fields, qf_input_fields, op_input_fields, true, e_data, impl));
1719+
1720+
// Restore assembled values
1721+
CeedCallBackend(CeedVectorRestoreArray(values, &assembled));
1722+
1723+
// Cleanup
1724+
CeedCallBackend(CeedDestroy(&ceed));
1725+
CeedCallBackend(CeedVectorDestroy(&in_vec));
1726+
CeedCallBackend(CeedVectorDestroy(&out_vec));
1727+
CeedCallBackend(CeedVectorDestroy(&point_coords));
1728+
CeedCallBackend(CeedElemRestrictionDestroy(&rstr_points));
1729+
CeedCallBackend(CeedQFunctionDestroy(&qf));
1730+
return CEED_ERROR_SUCCESS;
1731+
}
1732+
15271733
//------------------------------------------------------------------------------
15281734
// Operator Destroy
15291735
//------------------------------------------------------------------------------
@@ -1592,6 +1798,7 @@ int CeedOperatorCreateAtPoints_Ref(CeedOperator op) {
15921798
CeedCallBackend(
15931799
CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleQFunctionUpdate", CeedOperatorLinearAssembleQFunctionAtPointsUpdate_Ref));
15941800
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddDiagonal", CeedOperatorLinearAssembleAddDiagonalAtPoints_Ref));
1801+
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleSingle", CeedSingleOperatorAssembleAtPoints_Ref));
15951802
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAddAtPoints_Ref));
15961803
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Ref));
15971804
CeedCallBackend(CeedDestroy(&ceed));

interface/ceed-preconditioning.c

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ static int CeedSingleOperatorAssembleSymbolic(CeedOperator op, CeedInt offset, C
567567
@ref Developer
568568
**/
569569
static int CeedSingleOperatorAssemble(CeedOperator op, CeedInt offset, CeedVector values) {
570-
bool is_composite;
570+
bool is_composite, is_at_points;
571571

572572
CeedCall(CeedOperatorIsComposite(op, &is_composite));
573573
CeedCheck(!is_composite, CeedOperatorReturnCeed(op), CEED_ERROR_UNSUPPORTED, "Composite operator not supported");
@@ -595,6 +595,10 @@ static int CeedSingleOperatorAssemble(CeedOperator op, CeedInt offset, CeedVecto
595595
}
596596
}
597597

598+
CeedCall(CeedOperatorIsAtPoints(op, &is_at_points));
599+
CeedCheck(!is_at_points, CeedOperatorReturnCeed(op), CEED_ERROR_UNSUPPORTED,
600+
"Backend does not implement CeedOperatorLinearAssemble for AtPoints operator");
601+
598602
// Assemble QFunction
599603
CeedInt layout_qf[3];
600604
const CeedScalar *assembled_qf_array;

0 commit comments

Comments
 (0)