Skip to content

Commit dde0c85

Browse files
committed
gpu/cuda/ref: mixed field assembly support
1 parent a238442 commit dde0c85

4 files changed

Lines changed: 382 additions & 16 deletions

File tree

backends/cuda-ref/ceed-cuda-ref-operator.c

Lines changed: 261 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,29 @@ static int CeedOperatorDestroy_Cuda(CeedOperator op) {
9595
CeedCallCuda(ceed, cudaFree(impl->asmb->d_B_out));
9696
CeedCallBackend(CeedDestroy(&ceed));
9797
}
98+
9899
CeedCallBackend(CeedFree(&impl->asmb));
99100

101+
if (impl->asmb_blocks) {
102+
Ceed ceed;
103+
104+
CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
105+
for (CeedInt i = 0; i < impl->num_blocks_in; i++) {
106+
for (CeedInt j = 0; j < impl->num_blocks_out; j++) {
107+
CeedOperatorAssemble_Cuda *asmb = impl->asmb_blocks[i * impl->num_blocks_out + j];
108+
109+
if (asmb) {
110+
CeedCallCuda(ceed, cuModuleUnload(asmb->module));
111+
CeedCallCuda(ceed, cudaFree(asmb->d_B_in));
112+
CeedCallCuda(ceed, cudaFree(asmb->d_B_out));
113+
}
114+
CeedCallBackend(CeedFree(&asmb));
115+
}
116+
}
117+
CeedCallBackend(CeedFree(&impl->asmb_blocks));
118+
CeedCallBackend(CeedDestroy(&ceed));
119+
}
120+
100121
CeedCallBackend(CeedFree(&impl));
101122
return CEED_ERROR_SUCCESS;
102123
}
@@ -1052,9 +1073,7 @@ static inline int CeedOperatorLinearAssembleQFunctionCore_Cuda(CeedOperator op,
10521073
CeedInt strides[3] = {1, num_elem * Q, Q}; /* *NOPAD* */
10531074

10541075
// Create output restriction
1055-
CeedCallBackend(CeedElemRestrictionCreateStrided(ceed_parent, num_elem, Q, num_active_in * num_active_out,
1056-
(CeedSize)num_active_in * (CeedSize)num_active_out * (CeedSize)num_elem * (CeedSize)Q, strides,
1057-
rstr));
1076+
CeedCallBackend(CeedElemRestrictionCreateStrided(ceed_parent, num_elem, Q, num_active_in * num_active_out, l_size, strides, rstr));
10581077
// Create assembled vector
10591078
CeedCallBackend(CeedVectorCreate(ceed_parent, l_size, assembled));
10601079
}
@@ -1504,6 +1523,133 @@ static int CeedOperatorLinearAssembleAddPointBlockDiagonal_Cuda(CeedOperator op,
15041523
return CEED_ERROR_SUCCESS;
15051524
}
15061525

1526+
//------------------------------------------------------------------------------
1527+
// Single Operator Assembly Setup
1528+
//------------------------------------------------------------------------------
1529+
static int CeedOperatorAssembleSingleBlockSetup_Cuda(CeedOperator op, CeedInt active_input, CeedInt active_output, CeedInt use_ceedsize_idx) {
1530+
Ceed ceed;
1531+
Ceed_Cuda *cuda_data;
1532+
CeedInt num_input_fields, num_output_fields, num_eval_modes_in = 0, num_eval_modes_out = 0;
1533+
CeedInt elem_size_in, num_qpts_in = 0, num_comp_in, elem_size_out, num_qpts_out, num_comp_out;
1534+
CeedSize num_output_components;
1535+
const CeedScalar *h_B_in, *h_B_out;
1536+
CeedElemRestriction rstr_in = NULL, rstr_out = NULL;
1537+
CeedBasis basis_in = NULL, basis_out = NULL;
1538+
CeedOperatorField *input_fields, *output_fields;
1539+
CeedOperator_Cuda *impl;
1540+
char *eval_mode_offsets_in_str, *eval_mode_offsets_out_str;
1541+
1542+
CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
1543+
CeedCallBackend(CeedOperatorGetData(op, &impl));
1544+
1545+
// Get intput and output fields
1546+
CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &input_fields, &num_output_fields, &output_fields));
1547+
1548+
{
1549+
CeedInt num_active_bases_in, *t_num_eval_modes_in, num_active_bases_out, *t_num_eval_modes_out;
1550+
CeedSize **eval_modes_offsets_in, **eval_modes_offsets_out;
1551+
CeedBasis *active_bases_in, *active_bases_out;
1552+
CeedElemRestriction *active_rstrs_in, *active_rstrs_out;
1553+
const CeedScalar **B_mats_in, **B_mats_out;
1554+
CeedOperatorAssemblyData data;
1555+
1556+
CeedCall(CeedOperatorGetOperatorAssemblyData(op, &data));
1557+
CeedCall(CeedOperatorAssemblyDataGetEvalModes(data, &num_active_bases_in, &t_num_eval_modes_in, NULL, &eval_modes_offsets_in,
1558+
&num_active_bases_out, &t_num_eval_modes_out, NULL, &eval_modes_offsets_out,
1559+
&num_output_components));
1560+
// Number of elem restrictions is the same as the number of bases
1561+
CeedCall(CeedOperatorAssemblyDataGetElemRestrictions(data, NULL, &active_rstrs_in, NULL, &active_rstrs_out));
1562+
CeedCall(CeedOperatorAssemblyDataGetBases(data, NULL, &active_bases_in, &B_mats_in, NULL, &active_bases_out, &B_mats_out));
1563+
1564+
num_eval_modes_in = t_num_eval_modes_in[active_input];
1565+
num_eval_modes_out = t_num_eval_modes_out[active_output];
1566+
CeedCheck(num_eval_modes_in > 0 && num_eval_modes_out > 0, ceed, CEED_ERROR_UNSUPPORTED, "Cannot assemble operator without inputs/outputs");
1567+
1568+
if (!impl->asmb_blocks) {
1569+
CeedCallBackend(CeedCalloc(num_active_bases_in * num_active_bases_out, &impl->asmb_blocks));
1570+
impl->num_blocks_in = num_active_bases_in;
1571+
impl->num_blocks_out = num_active_bases_out;
1572+
}
1573+
1574+
rstr_in = active_rstrs_in[active_input];
1575+
basis_in = active_bases_in[active_input];
1576+
CeedCallBackend(CeedBuildArrayConstantSize_Cuda(ceed, "EVAL_MODE_OFFSETS_IN", num_eval_modes_in, eval_modes_offsets_in[active_input],
1577+
&eval_mode_offsets_in_str));
1578+
h_B_in = B_mats_in[active_input];
1579+
rstr_out = active_rstrs_out[active_output];
1580+
basis_out = active_bases_out[active_output];
1581+
CeedCallBackend(CeedBuildArrayConstantSize_Cuda(ceed, "EVAL_MODE_OFFSETS_OUT", num_eval_modes_out, eval_modes_offsets_out[active_output],
1582+
&eval_mode_offsets_out_str));
1583+
h_B_out = B_mats_out[active_output];
1584+
}
1585+
1586+
CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_in, &elem_size_in));
1587+
if (basis_in == CEED_BASIS_NONE) num_qpts_in = elem_size_in;
1588+
else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts_in));
1589+
1590+
CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_out, &elem_size_out));
1591+
if (basis_out == CEED_BASIS_NONE) num_qpts_out = elem_size_out;
1592+
else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_out, &num_qpts_out));
1593+
CeedCheck(num_qpts_in == num_qpts_out, ceed, CEED_ERROR_UNSUPPORTED,
1594+
"Active input and output bases must have the same number of quadrature points");
1595+
1596+
CeedCallBackend(CeedCalloc(1, &impl->asmb_blocks[active_input * impl->num_blocks_out + active_output]));
1597+
CeedOperatorAssemble_Cuda *asmb = impl->asmb_blocks[active_input * impl->num_blocks_out + active_output];
1598+
asmb->elems_per_block = 1;
1599+
asmb->block_size_x = elem_size_in;
1600+
asmb->block_size_y = elem_size_out;
1601+
1602+
CeedCallBackend(CeedGetData(ceed, &cuda_data));
1603+
bool fallback = asmb->block_size_x * asmb->block_size_y * asmb->elems_per_block > cuda_data->device_prop.maxThreadsPerBlock;
1604+
1605+
if (fallback) {
1606+
// Use fallback kernel with 1D threadblock
1607+
asmb->block_size_y = 1;
1608+
}
1609+
1610+
// Compile kernels
1611+
const char assembly_kernel_source[] = "// Full assembly source\n#include <ceed/jit-source/cuda/cuda-ref-operator-assemble-block.h>\n";
1612+
const CeedSize len = strlen(assembly_kernel_source) + strlen(eval_mode_offsets_in_str) + strlen(eval_mode_offsets_out_str) + 3;
1613+
char *source;
1614+
1615+
CeedCallBackend(CeedCalloc(len, &source));
1616+
strcat(source, eval_mode_offsets_in_str);
1617+
strcat(source, "\n");
1618+
strcat(source, eval_mode_offsets_out_str);
1619+
strcat(source, "\n");
1620+
strcat(source, assembly_kernel_source);
1621+
1622+
CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr_in, &num_comp_in));
1623+
CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr_out, &num_comp_out));
1624+
CeedCallBackend(CeedCompile_Cuda(ceed, source, &asmb->module, 11, "NUM_EVAL_MODES_IN", num_eval_modes_in, "NUM_EVAL_MODES_OUT", num_eval_modes_out,
1625+
"NUM_COMP_IN", num_comp_in, "NUM_COMP_OUT", num_comp_out, "TOTAL_NUM_COMP_OUT", num_output_components,
1626+
"NUM_NODES_IN", elem_size_in, "NUM_NODES_OUT", elem_size_out, "NUM_QPTS", num_qpts_in, "BLOCK_SIZE",
1627+
asmb->block_size_x * asmb->block_size_y * asmb->elems_per_block, "BLOCK_SIZE_Y", asmb->block_size_y,
1628+
"USE_CEEDSIZE", use_ceedsize_idx));
1629+
CeedCallBackend(CeedGetKernel_Cuda(ceed, asmb->module, "LinearAssembleBlock", &asmb->LinearAssemble));
1630+
1631+
// Load into B_in, in order that they will be used in eval_modes_in
1632+
{
1633+
const CeedInt in_bytes = elem_size_in * num_qpts_in * num_eval_modes_in * sizeof(CeedScalar);
1634+
1635+
CeedCallCuda(ceed, cudaMalloc((void **)&asmb->d_B_in, in_bytes));
1636+
CeedCallCuda(ceed, cudaMemcpy(asmb->d_B_in, h_B_in, in_bytes, cudaMemcpyHostToDevice));
1637+
}
1638+
1639+
// Load into B_out, in order that they will be used in eval_modes_out
1640+
{
1641+
const CeedInt out_bytes = elem_size_out * num_qpts_out * num_eval_modes_out * sizeof(CeedScalar);
1642+
1643+
CeedCallCuda(ceed, cudaMalloc((void **)&asmb->d_B_out, out_bytes));
1644+
CeedCallCuda(ceed, cudaMemcpy(asmb->d_B_out, h_B_out, out_bytes, cudaMemcpyHostToDevice));
1645+
}
1646+
CeedCallBackend(CeedFree(&eval_mode_offsets_in_str));
1647+
CeedCallBackend(CeedFree(&eval_mode_offsets_out_str));
1648+
CeedCallBackend(CeedFree(&source));
1649+
CeedCallBackend(CeedDestroy(&ceed));
1650+
return CEED_ERROR_SUCCESS;
1651+
}
1652+
15071653
//------------------------------------------------------------------------------
15081654
// Single Operator Assembly Setup
15091655
//------------------------------------------------------------------------------
@@ -1705,6 +1851,117 @@ static int CeedOperatorAssembleSingleSetup_Cuda(CeedOperator op, CeedInt use_cee
17051851
return CEED_ERROR_SUCCESS;
17061852
}
17071853

1854+
//------------------------------------------------------------------------------
1855+
// Assemble matrix data for one block of a COO matrix of assembled operator.
1856+
// The sparsity pattern is set by CeedOperatorLinearAssembleSymbolic.
1857+
//------------------------------------------------------------------------------
1858+
static int CeedOperatorAssembleSingleBlock_Cuda(CeedOperator op, CeedInt offset, CeedInt active_input, CeedInt active_output, CeedVector values) {
1859+
Ceed ceed;
1860+
CeedSize values_length = 0, assembled_qf_length = 0;
1861+
CeedInt use_ceedsize_idx = 0, num_elem_in, num_elem_out, elem_size_in, elem_size_out;
1862+
CeedScalar *values_array;
1863+
const CeedScalar *assembled_qf_array;
1864+
CeedVector assembled_qf = NULL;
1865+
CeedElemRestriction assembled_rstr = NULL, rstr_in, rstr_out;
1866+
CeedRestrictionType rstr_type_in, rstr_type_out;
1867+
const bool *orients_in = NULL, *orients_out = NULL;
1868+
const CeedInt8 *curl_orients_in = NULL, *curl_orients_out = NULL;
1869+
CeedOperator_Cuda *impl;
1870+
1871+
CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
1872+
CeedCallBackend(CeedOperatorGetData(op, &impl));
1873+
1874+
// Assemble QFunction
1875+
CeedCallBackend(CeedOperatorLinearAssembleQFunctionBuildOrUpdate(op, &assembled_qf, &assembled_rstr, CEED_REQUEST_IMMEDIATE));
1876+
CeedCallBackend(CeedElemRestrictionDestroy(&assembled_rstr));
1877+
CeedCallBackend(CeedVectorGetArrayRead(assembled_qf, CEED_MEM_DEVICE, &assembled_qf_array));
1878+
1879+
CeedCallBackend(CeedVectorGetLength(values, &values_length));
1880+
CeedCallBackend(CeedVectorGetLength(assembled_qf, &assembled_qf_length));
1881+
if ((values_length > INT_MAX) || (assembled_qf_length > INT_MAX)) use_ceedsize_idx = 1;
1882+
1883+
// Setup
1884+
if (!impl->asmb_blocks || (impl->asmb_blocks && !impl->asmb_blocks[active_input * impl->num_blocks_out + active_output])) {
1885+
CeedCallBackend(CeedOperatorAssembleSingleBlockSetup_Cuda(op, active_input, active_output, use_ceedsize_idx));
1886+
}
1887+
CeedOperatorAssemble_Cuda *asmb = impl->asmb_blocks[active_input * impl->num_blocks_out + active_output];
1888+
1889+
assert(asmb != NULL);
1890+
1891+
// Assemble element operator
1892+
CeedCallBackend(CeedVectorGetArray(values, CEED_MEM_DEVICE, &values_array));
1893+
values_array += offset;
1894+
1895+
CeedElemRestriction *active_rstrs_in, *active_rstrs_out;
1896+
CeedOperatorAssemblyData data;
1897+
1898+
CeedCall(CeedOperatorGetOperatorAssemblyData(op, &data));
1899+
CeedCall(CeedOperatorAssemblyDataGetElemRestrictions(data, NULL, &active_rstrs_in, NULL, &active_rstrs_out));
1900+
1901+
rstr_in = active_rstrs_in[active_input];
1902+
rstr_out = active_rstrs_out[active_output];
1903+
CeedCallBackend(CeedElemRestrictionGetNumElements(rstr_in, &num_elem_in));
1904+
CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_in, &elem_size_in));
1905+
1906+
CeedCallBackend(CeedElemRestrictionGetType(rstr_in, &rstr_type_in));
1907+
if (rstr_type_in == CEED_RESTRICTION_ORIENTED) {
1908+
CeedCallBackend(CeedElemRestrictionGetOrientations(rstr_in, CEED_MEM_DEVICE, &orients_in));
1909+
} else if (rstr_type_in == CEED_RESTRICTION_CURL_ORIENTED) {
1910+
CeedCallBackend(CeedElemRestrictionGetCurlOrientations(rstr_in, CEED_MEM_DEVICE, &curl_orients_in));
1911+
}
1912+
1913+
if (rstr_in != rstr_out) {
1914+
CeedCallBackend(CeedElemRestrictionGetNumElements(rstr_out, &num_elem_out));
1915+
CeedCheck(num_elem_in == num_elem_out, ceed, CEED_ERROR_UNSUPPORTED,
1916+
"Active input and output operator restrictions must have the same number of elements");
1917+
CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_out, &elem_size_out));
1918+
1919+
CeedCallBackend(CeedElemRestrictionGetType(rstr_out, &rstr_type_out));
1920+
if (rstr_type_out == CEED_RESTRICTION_ORIENTED) {
1921+
CeedCallBackend(CeedElemRestrictionGetOrientations(rstr_out, CEED_MEM_DEVICE, &orients_out));
1922+
} else if (rstr_type_out == CEED_RESTRICTION_CURL_ORIENTED) {
1923+
CeedCallBackend(CeedElemRestrictionGetCurlOrientations(rstr_out, CEED_MEM_DEVICE, &curl_orients_out));
1924+
}
1925+
} else {
1926+
elem_size_out = elem_size_in;
1927+
orients_out = orients_in;
1928+
curl_orients_out = curl_orients_in;
1929+
}
1930+
1931+
// Compute B^T D B
1932+
CeedInt shared_mem =
1933+
((curl_orients_in || curl_orients_out ? elem_size_in * elem_size_out : 0) + (curl_orients_in ? elem_size_in * asmb->block_size_y : 0)) *
1934+
sizeof(CeedScalar);
1935+
CeedInt grid = CeedDivUpInt(num_elem_in, asmb->elems_per_block);
1936+
void *args[] = {(void *)&num_elem_in, &asmb->d_B_in, &asmb->d_B_out, &orients_in, &curl_orients_in,
1937+
&orients_out, &curl_orients_out, &assembled_qf_array, &values_array};
1938+
1939+
CeedCallBackend(CeedRunKernelDimShared_Cuda(ceed, asmb->LinearAssemble, NULL, grid, asmb->block_size_x, asmb->block_size_y, asmb->elems_per_block,
1940+
shared_mem, args));
1941+
CeedCallCuda(ceed, cudaDeviceSynchronize());
1942+
1943+
// Restore arrays
1944+
CeedCallBackend(CeedVectorRestoreArray(values, &values_array));
1945+
CeedCallBackend(CeedVectorRestoreArrayRead(assembled_qf, &assembled_qf_array));
1946+
1947+
// Cleanup
1948+
CeedCallBackend(CeedVectorDestroy(&assembled_qf));
1949+
if (rstr_type_in == CEED_RESTRICTION_ORIENTED) {
1950+
CeedCallBackend(CeedElemRestrictionRestoreOrientations(rstr_in, &orients_in));
1951+
} else if (rstr_type_in == CEED_RESTRICTION_CURL_ORIENTED) {
1952+
CeedCallBackend(CeedElemRestrictionRestoreCurlOrientations(rstr_in, &curl_orients_in));
1953+
}
1954+
if (rstr_in != rstr_out) {
1955+
if (rstr_type_out == CEED_RESTRICTION_ORIENTED) {
1956+
CeedCallBackend(CeedElemRestrictionRestoreOrientations(rstr_out, &orients_out));
1957+
} else if (rstr_type_out == CEED_RESTRICTION_CURL_ORIENTED) {
1958+
CeedCallBackend(CeedElemRestrictionRestoreCurlOrientations(rstr_out, &curl_orients_out));
1959+
}
1960+
}
1961+
CeedCallBackend(CeedDestroy(&ceed));
1962+
return CEED_ERROR_SUCCESS;
1963+
}
1964+
17081965
//------------------------------------------------------------------------------
17091966
// Assemble matrix data for COO matrix of assembled operator.
17101967
// The sparsity pattern is set by CeedOperatorLinearAssembleSymbolic.
@@ -2090,6 +2347,7 @@ int CeedOperatorCreate_Cuda(CeedOperator op) {
20902347
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddPointBlockDiagonal",
20912348
CeedOperatorLinearAssembleAddPointBlockDiagonal_Cuda));
20922349
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleSingle", CeedOperatorAssembleSingle_Cuda));
2350+
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleSingleBlock", CeedOperatorAssembleSingleBlock_Cuda));
20932351
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Cuda));
20942352
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Cuda));
20952353
CeedCallBackend(CeedDestroy(&ceed));

backends/cuda-ref/ceed-cuda-ref.h

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -132,19 +132,21 @@ typedef struct {
132132
} CeedOperatorAssemble_Cuda;
133133

134134
typedef struct {
135-
bool *skip_rstr_in, *skip_rstr_out, *apply_add_basis_out;
136-
uint64_t *input_states, points_state; // State tracking for passive inputs
137-
CeedVector *e_vecs_in, *e_vecs_out;
138-
CeedVector *q_vecs_in, *q_vecs_out;
139-
CeedInt num_inputs, num_outputs;
140-
CeedInt num_active_in, num_active_out;
141-
CeedInt *input_field_order, *output_field_order;
142-
CeedSize max_active_e_vec_len;
143-
CeedInt max_num_points;
144-
CeedInt *num_points;
145-
CeedVector *qf_active_in, point_coords_elem;
146-
CeedOperatorDiag_Cuda *diag;
147-
CeedOperatorAssemble_Cuda *asmb;
135+
bool *skip_rstr_in, *skip_rstr_out, *apply_add_basis_out;
136+
uint64_t *input_states, points_state; // State tracking for passive inputs
137+
CeedVector *e_vecs_in, *e_vecs_out;
138+
CeedVector *q_vecs_in, *q_vecs_out;
139+
CeedInt num_inputs, num_outputs;
140+
CeedInt num_active_in, num_active_out;
141+
CeedInt *input_field_order, *output_field_order;
142+
CeedSize max_active_e_vec_len;
143+
CeedInt max_num_points;
144+
CeedInt *num_points;
145+
CeedVector *qf_active_in, point_coords_elem;
146+
CeedOperatorDiag_Cuda *diag;
147+
CeedOperatorAssemble_Cuda *asmb;
148+
CeedOperatorAssemble_Cuda **asmb_blocks;
149+
CeedInt num_blocks_in, num_blocks_out;
148150
} CeedOperator_Cuda;
149151

150152
CEED_INTERN int CeedGetCublasHandle_Cuda(Ceed ceed, cublasHandle_t *handle);

doc/sphinx/source/releasenotes.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Specifically, directories set with `CeedAddJitSourceRoot(ceed, "foo/bar")` will
3030
- Added support to code generation backends `/gpu/cuda/gen` and `/gpu/hip/gen` for operators with both tensor and non-tensor bases.
3131
- Add `CeedGetGitVersion()` to access the Git commit and dirty state of the repository at build time.
3232
- Add `CeedGetBuildConfiguration()` to access compilers, flags, and related information about the build environment.
33+
- Add support for full `CeedOperator` assembly for operators with multiple active fields with different bases for CPU backends and `/gpu/cuda/ref` and `/gpu/hip/gen` backends.
3334

3435
### Examples
3536

0 commit comments

Comments
 (0)