Skip to content

Commit 523ab01

Browse files
authored
Merge pull request #1807 from CEED/jeremy/extra-ops-trim
op - minor at points diagonal improvements
2 parents 85bbdf9 + 5cde1db commit 523ab01

3 files changed

Lines changed: 108 additions & 180 deletions

File tree

backends/cuda-ref/ceed-cuda-ref-operator.c

Lines changed: 36 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -750,15 +750,19 @@ static int CeedOperatorSetupAtPoints_Cuda(CeedOperator op) {
750750
//------------------------------------------------------------------------------
751751
static inline int CeedOperatorInputBasisAtPoints_Cuda(CeedOperatorField op_input_field, CeedQFunctionField qf_input_field, CeedInt input_field,
752752
CeedVector in_vec, CeedVector active_e_vec, CeedInt num_elem, const CeedInt *num_points,
753-
const bool skip_active, CeedOperator_Cuda *impl) {
753+
const bool skip_active, const bool skip_passive, CeedOperator_Cuda *impl) {
754754
bool is_active = false;
755755
CeedEvalMode eval_mode;
756756
CeedVector l_vec, e_vec = impl->e_vecs_in[input_field], q_vec = impl->q_vecs_in[input_field];
757757

758758
// Skip active input
759759
CeedCallBackend(CeedOperatorFieldGetVector(op_input_field, &l_vec));
760760
is_active = l_vec == CEED_VECTOR_ACTIVE;
761-
if (is_active && skip_active) return CEED_ERROR_SUCCESS;
761+
if (skip_active && is_active) return CEED_ERROR_SUCCESS;
762+
if (skip_passive && !is_active) {
763+
CeedCallBackend(CeedVectorDestroy(&l_vec));
764+
return CEED_ERROR_SUCCESS;
765+
}
762766
if (is_active) {
763767
l_vec = in_vec;
764768
if (!e_vec) e_vec = active_e_vec;
@@ -842,7 +846,7 @@ static int CeedOperatorApplyAddAtPoints_Cuda(CeedOperator op, CeedVector in_vec,
842846
CeedCallBackend(
843847
CeedOperatorInputRestrict_Cuda(op_input_fields[field], qf_input_fields[field], field, in_vec, active_e_vec, false, impl, request));
844848
CeedCallBackend(CeedOperatorInputBasisAtPoints_Cuda(op_input_fields[field], qf_input_fields[field], field, in_vec, active_e_vec, num_elem,
845-
num_points, false, impl));
849+
num_points, false, false, impl));
846850
}
847851

848852
// Output pointers, as necessary
@@ -1845,19 +1849,8 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
18451849
// Process inputs
18461850
for (CeedInt i = 0; i < num_input_fields; i++) {
18471851
CeedCallBackend(CeedOperatorInputRestrict_Cuda(op_input_fields[i], qf_input_fields[i], i, NULL, NULL, true, impl, request));
1848-
CeedCallBackend(CeedOperatorInputBasisAtPoints_Cuda(op_input_fields[i], qf_input_fields[i], i, NULL, NULL, num_elem, num_points, true, impl));
1849-
}
1850-
1851-
// Clear active input Qvecs
1852-
for (CeedInt i = 0; i < num_input_fields; i++) {
1853-
bool is_active = false;
1854-
CeedVector l_vec;
1855-
1856-
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &l_vec));
1857-
is_active = l_vec == CEED_VECTOR_ACTIVE;
1858-
CeedCallBackend(CeedVectorDestroy(&l_vec));
1859-
if (!is_active) continue;
1860-
CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0));
1852+
CeedCallBackend(
1853+
CeedOperatorInputBasisAtPoints_Cuda(op_input_fields[i], qf_input_fields[i], i, NULL, NULL, num_elem, num_points, true, false, impl));
18611854
}
18621855

18631856
// Output pointers, as necessary
@@ -1876,19 +1869,19 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
18761869
// Loop over active fields
18771870
for (CeedInt i = 0; i < num_input_fields; i++) {
18781871
bool is_active = false, is_active_at_points = true;
1879-
CeedInt elem_size = 1, num_comp_active = 1, e_vec_size = 0;
1872+
CeedInt elem_size = 1, num_comp_active = 1, e_vec_size = 0, field_in = impl->input_field_order[i];
18801873
CeedRestrictionType rstr_type;
18811874
CeedVector l_vec;
18821875
CeedElemRestriction elem_rstr;
18831876

18841877
// -- Skip non-active input
1885-
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &l_vec));
1878+
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[field_in], &l_vec));
18861879
is_active = l_vec == CEED_VECTOR_ACTIVE;
18871880
CeedCallBackend(CeedVectorDestroy(&l_vec));
1888-
if (!is_active) continue;
1881+
if (!is_active || impl->skip_rstr_in[field_in]) continue;
18891882

18901883
// -- Get active restriction type
1891-
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr));
1884+
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[field_in], &elem_rstr));
18921885
CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type));
18931886
is_active_at_points = rstr_type == CEED_RESTRICTION_POINTS;
18941887
if (!is_active_at_points) CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
@@ -1897,16 +1890,9 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
18971890
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
18981891

18991892
e_vec_size = elem_size * num_comp_active;
1893+
CeedCallBackend(CeedVectorSetValue(active_e_vec_in, 0.0));
19001894
for (CeedInt s = 0; s < e_vec_size; s++) {
1901-
bool is_active = false;
1902-
CeedEvalMode eval_mode;
1903-
CeedVector l_vec, q_vec = impl->q_vecs_in[i];
1904-
1905-
// Skip non-active input
1906-
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &l_vec));
1907-
is_active = l_vec == CEED_VECTOR_ACTIVE;
1908-
CeedCallBackend(CeedVectorDestroy(&l_vec));
1909-
if (!is_active) continue;
1895+
CeedVector q_vec = impl->q_vecs_in[field_in];
19101896

19111897
// Update unit vector
19121898
{
@@ -1915,8 +1901,7 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
19151901
CeedSize start = node * 1 + comp * (elem_size * num_elem);
19161902
CeedSize stop = (comp + 1) * (elem_size * num_elem);
19171903

1918-
if (s == 0) CeedCallBackend(CeedVectorSetValue(active_e_vec_in, 0.0));
1919-
else CeedCallBackend(CeedVectorSetValueStrided(active_e_vec_in, start, stop, elem_size, 0.0));
1904+
if (s != 0) CeedCallBackend(CeedVectorSetValueStrided(active_e_vec_in, start, stop, elem_size, 0.0));
19201905

19211906
node = s % elem_size, comp = s / elem_size;
19221907
start = node * 1 + comp * (elem_size * num_elem);
@@ -1925,29 +1910,11 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
19251910
}
19261911

19271912
// Basis action
1928-
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
1929-
switch (eval_mode) {
1930-
case CEED_EVAL_NONE: {
1931-
const CeedScalar *e_vec_array;
1932-
1933-
CeedCallBackend(CeedVectorGetArrayRead(active_e_vec_in, CEED_MEM_DEVICE, &e_vec_array));
1934-
CeedCallBackend(CeedVectorSetArray(q_vec, CEED_MEM_DEVICE, CEED_USE_POINTER, (CeedScalar *)e_vec_array));
1935-
break;
1936-
}
1937-
case CEED_EVAL_INTERP:
1938-
case CEED_EVAL_GRAD:
1939-
case CEED_EVAL_DIV:
1940-
case CEED_EVAL_CURL: {
1941-
CeedBasis basis;
1913+
for (CeedInt j = 0; j < num_input_fields; j++) {
1914+
CeedInt field = impl->input_field_order[j];
19421915

1943-
CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis));
1944-
CeedCallBackend(
1945-
CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_NOTRANSPOSE, eval_mode, impl->point_coords_elem, active_e_vec_in, q_vec));
1946-
CeedCallBackend(CeedBasisDestroy(&basis));
1947-
break;
1948-
}
1949-
case CEED_EVAL_WEIGHT:
1950-
break; // No action
1916+
CeedCallBackend(CeedOperatorInputBasisAtPoints_Cuda(op_input_fields[field], qf_input_fields[field], field, NULL, active_e_vec_in, num_elem,
1917+
num_points, false, true, impl));
19511918
}
19521919

19531920
// Q function
@@ -1957,20 +1924,21 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
19571924
for (CeedInt j = 0; j < num_output_fields; j++) {
19581925
bool is_active = false;
19591926
CeedInt elem_size = 0;
1927+
CeedInt field_out = impl->output_field_order[j];
19601928
CeedRestrictionType rstr_type;
19611929
CeedEvalMode eval_mode;
1962-
CeedVector l_vec, e_vec = impl->e_vecs_out[j], q_vec = impl->q_vecs_out[j];
1930+
CeedVector l_vec, e_vec = impl->e_vecs_out[field_out], q_vec = impl->q_vecs_out[field_out];
19631931
CeedElemRestriction elem_rstr;
19641932

19651933
// ---- Skip non-active output
1966-
CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[j], &l_vec));
1934+
CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[field_out], &l_vec));
19671935
is_active = l_vec == CEED_VECTOR_ACTIVE;
19681936
CeedCallBackend(CeedVectorDestroy(&l_vec));
19691937
if (!is_active) continue;
19701938
if (!e_vec) e_vec = active_e_vec_out;
19711939

19721940
// ---- Check if elem size matches
1973-
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[j], &elem_rstr));
1941+
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[field_out], &elem_rstr));
19741942
CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type));
19751943
if (is_active_at_points && rstr_type != CEED_RESTRICTION_POINTS) continue;
19761944
if (rstr_type == CEED_RESTRICTION_POINTS) {
@@ -1986,7 +1954,7 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
19861954
}
19871955

19881956
// Basis action
1989-
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[j], &eval_mode));
1957+
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[field_out], &eval_mode));
19901958
switch (eval_mode) {
19911959
case CEED_EVAL_NONE: {
19921960
CeedScalar *e_vec_array;
@@ -2001,8 +1969,13 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
20011969
case CEED_EVAL_CURL: {
20021970
CeedBasis basis;
20031971

2004-
CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[j], &basis));
2005-
CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, q_vec, e_vec));
1972+
CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[field_out], &basis));
1973+
if (impl->apply_add_basis_out[field_out]) {
1974+
CeedCallBackend(
1975+
CeedBasisApplyAddAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, q_vec, e_vec));
1976+
} else {
1977+
CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, q_vec, e_vec));
1978+
}
20061979
CeedCallBackend(CeedBasisDestroy(&basis));
20071980
break;
20081981
}
@@ -2014,6 +1987,10 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
20141987
}
20151988

20161989
// Mask output e-vec
1990+
if (impl->skip_rstr_out[field_out]) {
1991+
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
1992+
continue;
1993+
}
20171994
CeedCallBackend(CeedVectorPointwiseMult(e_vec, active_e_vec_in, e_vec));
20181995

20191996
// Restrict

0 commit comments

Comments
 (0)