Skip to content

Commit 4101ee3

Browse files
committed
op - gpu minor at points diagonal improvements
1 parent 48fdef1 commit 4101ee3

2 files changed

Lines changed: 68 additions & 118 deletions

File tree

backends/cuda-ref/ceed-cuda-ref-operator.c

Lines changed: 33 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -750,15 +750,16 @@ static int CeedOperatorSetupAtPoints_Cuda(CeedOperator op) {
750750
//------------------------------------------------------------------------------
751751
static inline int CeedOperatorInputBasisAtPoints_Cuda(CeedOperatorField op_input_field, CeedQFunctionField qf_input_field, CeedInt input_field,
752752
CeedVector in_vec, CeedVector active_e_vec, CeedInt num_elem, const CeedInt *num_points,
753-
const bool skip_active, CeedOperator_Cuda *impl) {
753+
const bool skip_active, const bool skip_passive, CeedOperator_Cuda *impl) {
754754
bool is_active = false;
755755
CeedEvalMode eval_mode;
756756
CeedVector l_vec, e_vec = impl->e_vecs_in[input_field], q_vec = impl->q_vecs_in[input_field];
757757

758758
// Skip active input
759759
CeedCallBackend(CeedOperatorFieldGetVector(op_input_field, &l_vec));
760760
is_active = l_vec == CEED_VECTOR_ACTIVE;
761-
if (is_active && skip_active) return CEED_ERROR_SUCCESS;
761+
if (skip_active && is_active) return CEED_ERROR_SUCCESS;
762+
if (skip_passive && !is_active) return CEED_ERROR_SUCCESS;
762763
if (is_active) {
763764
l_vec = in_vec;
764765
if (!e_vec) e_vec = active_e_vec;
@@ -842,7 +843,7 @@ static int CeedOperatorApplyAddAtPoints_Cuda(CeedOperator op, CeedVector in_vec,
842843
CeedCallBackend(
843844
CeedOperatorInputRestrict_Cuda(op_input_fields[field], qf_input_fields[field], field, in_vec, active_e_vec, false, impl, request));
844845
CeedCallBackend(CeedOperatorInputBasisAtPoints_Cuda(op_input_fields[field], qf_input_fields[field], field, in_vec, active_e_vec, num_elem,
845-
num_points, false, impl));
846+
num_points, false, false, impl));
846847
}
847848

848849
// Output pointers, as necessary
@@ -1845,19 +1846,8 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
18451846
// Process inputs
18461847
for (CeedInt i = 0; i < num_input_fields; i++) {
18471848
CeedCallBackend(CeedOperatorInputRestrict_Cuda(op_input_fields[i], qf_input_fields[i], i, NULL, NULL, true, impl, request));
1848-
CeedCallBackend(CeedOperatorInputBasisAtPoints_Cuda(op_input_fields[i], qf_input_fields[i], i, NULL, NULL, num_elem, num_points, true, impl));
1849-
}
1850-
1851-
// Clear active input Qvecs
1852-
for (CeedInt i = 0; i < num_input_fields; i++) {
1853-
bool is_active = false;
1854-
CeedVector l_vec;
1855-
1856-
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &l_vec));
1857-
is_active = l_vec == CEED_VECTOR_ACTIVE;
1858-
CeedCallBackend(CeedVectorDestroy(&l_vec));
1859-
if (!is_active) continue;
1860-
CeedCallBackend(CeedVectorSetValue(impl->q_vecs_in[i], 0.0));
1849+
CeedCallBackend(
1850+
CeedOperatorInputBasisAtPoints_Cuda(op_input_fields[i], qf_input_fields[i], i, NULL, NULL, num_elem, num_points, true, false, impl));
18611851
}
18621852

18631853
// Output pointers, as necessary
@@ -1876,19 +1866,19 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
18761866
// Loop over active fields
18771867
for (CeedInt i = 0; i < num_input_fields; i++) {
18781868
bool is_active = false, is_active_at_points = true;
1879-
CeedInt elem_size = 1, num_comp_active = 1, e_vec_size = 0;
1869+
CeedInt elem_size = 1, num_comp_active = 1, e_vec_size = 0, field_in = impl->input_field_order[i];
18801870
CeedRestrictionType rstr_type;
18811871
CeedVector l_vec;
18821872
CeedElemRestriction elem_rstr;
18831873

18841874
// -- Skip non-active input
1885-
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &l_vec));
1875+
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[field_in], &l_vec));
18861876
is_active = l_vec == CEED_VECTOR_ACTIVE;
18871877
CeedCallBackend(CeedVectorDestroy(&l_vec));
1888-
if (!is_active) continue;
1878+
if (!is_active || impl->skip_rstr_in[field_in]) continue;
18891879

18901880
// -- Get active restriction type
1891-
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr));
1881+
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[field_in], &elem_rstr));
18921882
CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type));
18931883
is_active_at_points = rstr_type == CEED_RESTRICTION_POINTS;
18941884
if (!is_active_at_points) CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
@@ -1897,16 +1887,9 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
18971887
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
18981888

18991889
e_vec_size = elem_size * num_comp_active;
1890+
CeedCallBackend(CeedVectorSetValue(active_e_vec_in, 0.0));
19001891
for (CeedInt s = 0; s < e_vec_size; s++) {
1901-
bool is_active = false;
1902-
CeedEvalMode eval_mode;
1903-
CeedVector l_vec, q_vec = impl->q_vecs_in[i];
1904-
1905-
// Skip non-active input
1906-
CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &l_vec));
1907-
is_active = l_vec == CEED_VECTOR_ACTIVE;
1908-
CeedCallBackend(CeedVectorDestroy(&l_vec));
1909-
if (!is_active) continue;
1892+
CeedVector q_vec = impl->q_vecs_in[field_in];
19101893

19111894
// Update unit vector
19121895
{
@@ -1915,8 +1898,7 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
19151898
CeedSize start = node * 1 + comp * (elem_size * num_elem);
19161899
CeedSize stop = (comp + 1) * (elem_size * num_elem);
19171900

1918-
if (s == 0) CeedCallBackend(CeedVectorSetValue(active_e_vec_in, 0.0));
1919-
else CeedCallBackend(CeedVectorSetValueStrided(active_e_vec_in, start, stop, elem_size, 0.0));
1901+
if (s != 0) CeedCallBackend(CeedVectorSetValueStrided(active_e_vec_in, start, stop, elem_size, 0.0));
19201902

19211903
node = s % elem_size, comp = s / elem_size;
19221904
start = node * 1 + comp * (elem_size * num_elem);
@@ -1925,29 +1907,11 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
19251907
}
19261908

19271909
// Basis action
1928-
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_input_fields[i], &eval_mode));
1929-
switch (eval_mode) {
1930-
case CEED_EVAL_NONE: {
1931-
const CeedScalar *e_vec_array;
1932-
1933-
CeedCallBackend(CeedVectorGetArrayRead(active_e_vec_in, CEED_MEM_DEVICE, &e_vec_array));
1934-
CeedCallBackend(CeedVectorSetArray(q_vec, CEED_MEM_DEVICE, CEED_USE_POINTER, (CeedScalar *)e_vec_array));
1935-
break;
1936-
}
1937-
case CEED_EVAL_INTERP:
1938-
case CEED_EVAL_GRAD:
1939-
case CEED_EVAL_DIV:
1940-
case CEED_EVAL_CURL: {
1941-
CeedBasis basis;
1910+
for (CeedInt j = 0; j < num_input_fields; j++) {
1911+
CeedInt field = impl->input_field_order[j];
19421912

1943-
CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis));
1944-
CeedCallBackend(
1945-
CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_NOTRANSPOSE, eval_mode, impl->point_coords_elem, active_e_vec_in, q_vec));
1946-
CeedCallBackend(CeedBasisDestroy(&basis));
1947-
break;
1948-
}
1949-
case CEED_EVAL_WEIGHT:
1950-
break; // No action
1913+
CeedCallBackend(CeedOperatorInputBasisAtPoints_Cuda(op_input_fields[field], qf_input_fields[field], field, NULL, active_e_vec_in, num_elem,
1914+
num_points, false, true, impl));
19511915
}
19521916

19531917
// Q function
@@ -1957,20 +1921,21 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
19571921
for (CeedInt j = 0; j < num_output_fields; j++) {
19581922
bool is_active = false;
19591923
CeedInt elem_size = 0;
1924+
CeedInt field_out = impl->output_field_order[j];
19601925
CeedRestrictionType rstr_type;
19611926
CeedEvalMode eval_mode;
1962-
CeedVector l_vec, e_vec = impl->e_vecs_out[j], q_vec = impl->q_vecs_out[j];
1927+
CeedVector l_vec, e_vec = impl->e_vecs_out[field_out], q_vec = impl->q_vecs_out[field_out];
19631928
CeedElemRestriction elem_rstr;
19641929

19651930
// ---- Skip non-active output
1966-
CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[j], &l_vec));
1931+
CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[field_out], &l_vec));
19671932
is_active = l_vec == CEED_VECTOR_ACTIVE;
19681933
CeedCallBackend(CeedVectorDestroy(&l_vec));
19691934
if (!is_active) continue;
19701935
if (!e_vec) e_vec = active_e_vec_out;
19711936

19721937
// ---- Check if elem size matches
1973-
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[j], &elem_rstr));
1938+
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[field_out], &elem_rstr));
19741939
CeedCallBackend(CeedElemRestrictionGetType(elem_rstr, &rstr_type));
19751940
if (is_active_at_points && rstr_type != CEED_RESTRICTION_POINTS) continue;
19761941
if (rstr_type == CEED_RESTRICTION_POINTS) {
@@ -1986,7 +1951,7 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
19861951
}
19871952

19881953
// Basis action
1989-
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[j], &eval_mode));
1954+
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[field_out], &eval_mode));
19901955
switch (eval_mode) {
19911956
case CEED_EVAL_NONE: {
19921957
CeedScalar *e_vec_array;
@@ -2001,8 +1966,13 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
20011966
case CEED_EVAL_CURL: {
20021967
CeedBasis basis;
20031968

2004-
CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[j], &basis));
2005-
CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, q_vec, e_vec));
1969+
CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[field_out], &basis));
1970+
if (impl->apply_add_basis_out[field_out]) {
1971+
CeedCallBackend(
1972+
CeedBasisApplyAddAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, q_vec, e_vec));
1973+
} else {
1974+
CeedCallBackend(CeedBasisApplyAtPoints(basis, num_elem, num_points, CEED_TRANSPOSE, eval_mode, impl->point_coords_elem, q_vec, e_vec));
1975+
}
20061976
CeedCallBackend(CeedBasisDestroy(&basis));
20071977
break;
20081978
}
@@ -2014,6 +1984,10 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
20141984
}
20151985

20161986
// Mask output e-vec
1987+
if (impl->skip_rstr_out[field_out]) {
1988+
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
1989+
continue;
1990+
}
20171991
CeedCallBackend(CeedVectorPointwiseMult(e_vec, active_e_vec_in, e_vec));
20181992

20191993
// Restrict

0 commit comments

Comments
 (0)