Skip to content

Commit 5388e7c

Browse files
committed
gpu/hip/ref: add support for mixed basis assembly
1 parent 49c8696 commit 5388e7c

5 files changed

Lines changed: 412 additions & 16 deletions

File tree

backends/hip-ref/ceed-hip-ref-operator.c

Lines changed: 266 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,27 @@ static int CeedOperatorDestroy_Hip(CeedOperator op) {
9494
CeedCallHip(ceed, hipFree(impl->asmb->d_B_out));
9595
CeedCallBackend(CeedDestroy(&ceed));
9696
}
97+
9798
CeedCallBackend(CeedFree(&impl->asmb));
99+
if (impl->asmb_blocks) {
100+
Ceed ceed;
101+
102+
CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
103+
for (CeedInt i = 0; i < impl->num_blocks_in; i++) {
104+
for (CeedInt j = 0; j < impl->num_blocks_out; j++) {
105+
CeedOperatorAssemble_Hip *asmb = impl->asmb_blocks[i * impl->num_blocks_out + j];
106+
107+
if (asmb) {
108+
CeedCallHip(ceed, hipModuleUnload(asmb->module));
109+
CeedCallHip(ceed, hipFree(asmb->d_B_in));
110+
CeedCallHip(ceed, hipFree(asmb->d_B_out));
111+
}
112+
CeedCallBackend(CeedFree(&asmb));
113+
}
114+
}
115+
CeedCallBackend(CeedFree(&impl->asmb_blocks));
116+
CeedCallBackend(CeedDestroy(&ceed));
117+
}
98118

99119
CeedCallBackend(CeedFree(&impl));
100120
return CEED_ERROR_SUCCESS;
@@ -1049,9 +1069,7 @@ static inline int CeedOperatorLinearAssembleQFunctionCore_Hip(CeedOperator op, b
10491069
CeedInt strides[3] = {1, num_elem * Q, Q}; /* *NOPAD* */
10501070

10511071
// Create output restriction
1052-
CeedCallBackend(CeedElemRestrictionCreateStrided(ceed_parent, num_elem, Q, num_active_in * num_active_out,
1053-
(CeedSize)num_active_in * (CeedSize)num_active_out * (CeedSize)num_elem * (CeedSize)Q, strides,
1054-
rstr));
1072+
CeedCallBackend(CeedElemRestrictionCreateStrided(ceed_parent, num_elem, Q, num_active_in * num_active_out, l_size, strides, rstr));
10551073
// Create assembled vector
10561074
CeedCallBackend(CeedVectorCreate(ceed_parent, l_size, assembled));
10571075
}
@@ -1501,6 +1519,139 @@ static int CeedOperatorLinearAssembleAddPointBlockDiagonal_Hip(CeedOperator op,
15011519
return CEED_ERROR_SUCCESS;
15021520
}
15031521

1522+
//------------------------------------------------------------------------------
1523+
// Single Operator Block Assembly Setup
1524+
//------------------------------------------------------------------------------
1525+
static int CeedOperatorAssembleSingleBlockSetup_Hip(CeedOperator op, CeedInt active_input, CeedInt active_output, CeedInt use_ceedsize_idx) {
1526+
Ceed ceed;
1527+
Ceed_Hip *hip_data;
1528+
CeedInt num_input_fields, num_output_fields, num_eval_modes_in = 0, num_eval_modes_out = 0;
1529+
CeedInt elem_size_in, num_qpts_in = 0, num_comp_in, elem_size_out, num_qpts_out, num_comp_out;
1530+
CeedSize num_output_components;
1531+
const CeedScalar *h_B_in, *h_B_out;
1532+
CeedElemRestriction rstr_in = NULL, rstr_out = NULL;
1533+
CeedBasis basis_in = NULL, basis_out = NULL;
1534+
CeedOperatorField *input_fields, *output_fields;
1535+
CeedOperator_Hip *impl;
1536+
CeedOperatorAssemblyData data;
1537+
1538+
CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
1539+
CeedCallBackend(CeedOperatorGetData(op, &impl));
1540+
CeedCall(CeedOperatorGetOperatorAssemblyData(op, &data));
1541+
1542+
// Get intput and output fields
1543+
CeedCallBackend(CeedOperatorGetFields(op, &num_input_fields, &input_fields, &num_output_fields, &output_fields));
1544+
1545+
{
1546+
CeedInt num_active_bases_in, *t_num_eval_modes_in, num_active_bases_out, *t_num_eval_modes_out;
1547+
CeedBasis *active_bases_in, *active_bases_out;
1548+
CeedElemRestriction *active_rstrs_in, *active_rstrs_out;
1549+
const CeedScalar **B_mats_in, **B_mats_out;
1550+
1551+
CeedCall(CeedOperatorAssemblyDataGetEvalModes(data, &num_active_bases_in, &t_num_eval_modes_in, NULL, NULL, &num_active_bases_out,
1552+
&t_num_eval_modes_out, NULL, NULL, &num_output_components));
1553+
// Number of elem restrictions is the same as the number of bases
1554+
CeedCall(CeedOperatorAssemblyDataGetElemRestrictions(data, NULL, &active_rstrs_in, NULL, &active_rstrs_out));
1555+
CeedCall(CeedOperatorAssemblyDataGetBases(data, NULL, &active_bases_in, &B_mats_in, NULL, &active_bases_out, &B_mats_out));
1556+
1557+
num_eval_modes_in = t_num_eval_modes_in[active_input];
1558+
num_eval_modes_out = t_num_eval_modes_out[active_output];
1559+
CeedCheck(num_eval_modes_in > 0 && num_eval_modes_out > 0, ceed, CEED_ERROR_UNSUPPORTED, "Cannot assemble operator without inputs/outputs");
1560+
1561+
if (!impl->asmb_blocks) {
1562+
CeedCallBackend(CeedCalloc(num_active_bases_in * num_active_bases_out, &impl->asmb_blocks));
1563+
impl->num_blocks_in = num_active_bases_in;
1564+
impl->num_blocks_out = num_active_bases_out;
1565+
}
1566+
1567+
rstr_in = active_rstrs_in[active_input];
1568+
basis_in = active_bases_in[active_input];
1569+
h_B_in = B_mats_in[active_input];
1570+
rstr_out = active_rstrs_out[active_output];
1571+
basis_out = active_bases_out[active_output];
1572+
h_B_out = B_mats_out[active_output];
1573+
}
1574+
1575+
CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_in, &elem_size_in));
1576+
if (basis_in == CEED_BASIS_NONE) num_qpts_in = elem_size_in;
1577+
else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_in, &num_qpts_in));
1578+
1579+
CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_out, &elem_size_out));
1580+
if (basis_out == CEED_BASIS_NONE) num_qpts_out = elem_size_out;
1581+
else CeedCallBackend(CeedBasisGetNumQuadraturePoints(basis_out, &num_qpts_out));
1582+
CeedCheck(num_qpts_in == num_qpts_out, ceed, CEED_ERROR_UNSUPPORTED,
1583+
"Active input and output bases must have the same number of quadrature points");
1584+
1585+
CeedCallBackend(CeedCalloc(1, &impl->asmb_blocks[active_input * impl->num_blocks_out + active_output]));
1586+
CeedOperatorAssemble_Hip *asmb = impl->asmb_blocks[active_input * impl->num_blocks_out + active_output];
1587+
asmb->elems_per_block = 1;
1588+
asmb->block_size_x = elem_size_in;
1589+
asmb->block_size_y = elem_size_out;
1590+
1591+
CeedCallBackend(CeedGetData(ceed, &hip_data));
1592+
bool fallback = asmb->block_size_x * asmb->block_size_y * asmb->elems_per_block > hip_data->device_prop.maxThreadsPerBlock;
1593+
1594+
if (fallback) {
1595+
// Use fallback kernel with 1D threadblock
1596+
asmb->block_size_y = 1;
1597+
}
1598+
1599+
// Compile kernels
1600+
char *source;
1601+
{
1602+
CeedInt num_eval_modes_in, *t_num_eval_modes_in, num_eval_modes_out, *t_num_eval_modes_out;
1603+
CeedSize **eval_modes_offsets_in, **eval_modes_offsets_out;
1604+
char *eval_mode_offsets_in_str, *eval_mode_offsets_out_str;
1605+
1606+
CeedCall(CeedOperatorAssemblyDataGetEvalModes(data, NULL, &t_num_eval_modes_in, NULL, &eval_modes_offsets_in, NULL, &t_num_eval_modes_out, NULL,
1607+
&eval_modes_offsets_out, NULL));
1608+
num_eval_modes_in = t_num_eval_modes_in[active_input];
1609+
num_eval_modes_out = t_num_eval_modes_out[active_output];
1610+
1611+
CeedCallBackend(CeedBuildArrayConstantSize_Hip(ceed, "EVAL_MODE_OFFSETS_IN", num_eval_modes_in, eval_modes_offsets_in[active_input],
1612+
&eval_mode_offsets_in_str));
1613+
CeedCallBackend(CeedBuildArrayConstantSize_Hip(ceed, "EVAL_MODE_OFFSETS_OUT", num_eval_modes_out, eval_modes_offsets_out[active_output],
1614+
&eval_mode_offsets_out_str));
1615+
1616+
const char assembly_kernel_source[] = "// Full assembly source\n#include <ceed/jit-source/hip/hip-ref-operator-assemble-block.h>\n";
1617+
const CeedSize len = strlen(assembly_kernel_source) + strlen(eval_mode_offsets_in_str) + strlen(eval_mode_offsets_out_str) + 3;
1618+
1619+
CeedCallBackend(CeedCalloc(len, &source));
1620+
snprintf(source, len, "%s\n%s\n%s", eval_mode_offsets_in_str, eval_mode_offsets_out_str, assembly_kernel_source);
1621+
1622+
CeedCallBackend(CeedFree(&eval_mode_offsets_in_str));
1623+
CeedCallBackend(CeedFree(&eval_mode_offsets_out_str));
1624+
}
1625+
1626+
CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr_in, &num_comp_in));
1627+
CeedCallBackend(CeedElemRestrictionGetNumComponents(rstr_out, &num_comp_out));
1628+
CeedCallBackend(CeedCompile_Hip(ceed, source, &asmb->module, 11, "NUM_EVAL_MODES_IN", num_eval_modes_in, "NUM_EVAL_MODES_OUT", num_eval_modes_out,
1629+
"NUM_COMP_IN", num_comp_in, "NUM_COMP_OUT", num_comp_out, "TOTAL_NUM_COMP_OUT", num_output_components,
1630+
"NUM_NODES_IN", elem_size_in, "NUM_NODES_OUT", elem_size_out, "NUM_QPTS", num_qpts_in, "BLOCK_SIZE",
1631+
asmb->block_size_x * asmb->block_size_y * asmb->elems_per_block, "BLOCK_SIZE_Y", asmb->block_size_y, "USE_CEEDSIZE",
1632+
use_ceedsize_idx));
1633+
CeedCallBackend(CeedGetKernel_Hip(ceed, asmb->module, "LinearAssembleBlock", &asmb->LinearAssemble));
1634+
1635+
// Load into B_in, in order that they will be used in eval_modes_in
1636+
{
1637+
const CeedInt in_bytes = elem_size_in * num_qpts_in * num_eval_modes_in * sizeof(CeedScalar);
1638+
1639+
CeedCallHip(ceed, hipMalloc((void **)&asmb->d_B_in, in_bytes));
1640+
CeedCallHip(ceed, hipMemcpy(asmb->d_B_in, h_B_in, in_bytes, hipMemcpyHostToDevice));
1641+
}
1642+
1643+
// Load into B_out, in order that they will be used in eval_modes_out
1644+
{
1645+
const CeedInt out_bytes = elem_size_out * num_qpts_out * num_eval_modes_out * sizeof(CeedScalar);
1646+
1647+
CeedCallHip(ceed, hipMalloc((void **)&asmb->d_B_out, out_bytes));
1648+
CeedCallHip(ceed, hipMemcpy(asmb->d_B_out, h_B_out, out_bytes, hipMemcpyHostToDevice));
1649+
}
1650+
CeedCallBackend(CeedFree(&source));
1651+
CeedCallBackend(CeedDestroy(&ceed));
1652+
return CEED_ERROR_SUCCESS;
1653+
}
1654+
15041655
//------------------------------------------------------------------------------
15051656
// Single Operator Assembly Setup
15061657
//------------------------------------------------------------------------------
@@ -1702,6 +1853,117 @@ static int CeedOperatorAssembleSingleSetup_Hip(CeedOperator op, CeedInt use_ceed
17021853
return CEED_ERROR_SUCCESS;
17031854
}
17041855

1856+
//------------------------------------------------------------------------------
1857+
// Assemble matrix data for one block of a COO matrix of assembled operator.
1858+
// The sparsity pattern is set by CeedOperatorLinearAssembleSymbolic.
1859+
//------------------------------------------------------------------------------
1860+
static int CeedOperatorAssembleSingleBlock_Hip(CeedOperator op, CeedInt offset, CeedInt active_input, CeedInt active_output, CeedVector values) {
1861+
Ceed ceed;
1862+
CeedSize values_length = 0, assembled_qf_length = 0;
1863+
CeedInt use_ceedsize_idx = 0, num_elem_in, num_elem_out, elem_size_in, elem_size_out;
1864+
CeedScalar *values_array;
1865+
const CeedScalar *assembled_qf_array;
1866+
CeedVector assembled_qf = NULL;
1867+
CeedElemRestriction assembled_rstr = NULL, rstr_in, rstr_out;
1868+
CeedRestrictionType rstr_type_in, rstr_type_out;
1869+
const bool *orients_in = NULL, *orients_out = NULL;
1870+
const CeedInt8 *curl_orients_in = NULL, *curl_orients_out = NULL;
1871+
CeedOperator_Hip *impl;
1872+
1873+
CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
1874+
CeedCallBackend(CeedOperatorGetData(op, &impl));
1875+
1876+
// Assemble QFunction
1877+
CeedCallBackend(CeedOperatorLinearAssembleQFunctionBuildOrUpdate(op, &assembled_qf, &assembled_rstr, CEED_REQUEST_IMMEDIATE));
1878+
CeedCallBackend(CeedElemRestrictionDestroy(&assembled_rstr));
1879+
CeedCallBackend(CeedVectorGetArrayRead(assembled_qf, CEED_MEM_DEVICE, &assembled_qf_array));
1880+
1881+
CeedCallBackend(CeedVectorGetLength(values, &values_length));
1882+
CeedCallBackend(CeedVectorGetLength(assembled_qf, &assembled_qf_length));
1883+
if ((values_length > INT_MAX) || (assembled_qf_length > INT_MAX)) use_ceedsize_idx = 1;
1884+
1885+
// Setup
1886+
if (!impl->asmb_blocks || (impl->asmb_blocks && !impl->asmb_blocks[active_input * impl->num_blocks_out + active_output])) {
1887+
CeedCallBackend(CeedOperatorAssembleSingleBlockSetup_Hip(op, active_input, active_output, use_ceedsize_idx));
1888+
}
1889+
assert(impl->asmb_blocks && impl->asmb_blocks[active_input * impl->num_blocks_out + active_output]);
1890+
CeedOperatorAssemble_Hip *asmb = impl->asmb_blocks[active_input * impl->num_blocks_out + active_output];
1891+
1892+
assert(asmb != NULL);
1893+
1894+
// Assemble element operator
1895+
CeedCallBackend(CeedVectorGetArray(values, CEED_MEM_DEVICE, &values_array));
1896+
values_array += offset;
1897+
1898+
CeedElemRestriction *active_rstrs_in, *active_rstrs_out;
1899+
CeedOperatorAssemblyData data;
1900+
1901+
CeedCall(CeedOperatorGetOperatorAssemblyData(op, &data));
1902+
CeedCall(CeedOperatorAssemblyDataGetElemRestrictions(data, NULL, &active_rstrs_in, NULL, &active_rstrs_out));
1903+
1904+
rstr_in = active_rstrs_in[active_input];
1905+
rstr_out = active_rstrs_out[active_output];
1906+
CeedCallBackend(CeedElemRestrictionGetNumElements(rstr_in, &num_elem_in));
1907+
CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_in, &elem_size_in));
1908+
1909+
CeedCallBackend(CeedElemRestrictionGetType(rstr_in, &rstr_type_in));
1910+
if (rstr_type_in == CEED_RESTRICTION_ORIENTED) {
1911+
CeedCallBackend(CeedElemRestrictionGetOrientations(rstr_in, CEED_MEM_DEVICE, &orients_in));
1912+
} else if (rstr_type_in == CEED_RESTRICTION_CURL_ORIENTED) {
1913+
CeedCallBackend(CeedElemRestrictionGetCurlOrientations(rstr_in, CEED_MEM_DEVICE, &curl_orients_in));
1914+
}
1915+
1916+
if (rstr_in != rstr_out) {
1917+
CeedCallBackend(CeedElemRestrictionGetNumElements(rstr_out, &num_elem_out));
1918+
CeedCheck(num_elem_in == num_elem_out, ceed, CEED_ERROR_UNSUPPORTED,
1919+
"Active input and output operator restrictions must have the same number of elements");
1920+
CeedCallBackend(CeedElemRestrictionGetElementSize(rstr_out, &elem_size_out));
1921+
1922+
CeedCallBackend(CeedElemRestrictionGetType(rstr_out, &rstr_type_out));
1923+
if (rstr_type_out == CEED_RESTRICTION_ORIENTED) {
1924+
CeedCallBackend(CeedElemRestrictionGetOrientations(rstr_out, CEED_MEM_DEVICE, &orients_out));
1925+
} else if (rstr_type_out == CEED_RESTRICTION_CURL_ORIENTED) {
1926+
CeedCallBackend(CeedElemRestrictionGetCurlOrientations(rstr_out, CEED_MEM_DEVICE, &curl_orients_out));
1927+
}
1928+
} else {
1929+
elem_size_out = elem_size_in;
1930+
orients_out = orients_in;
1931+
curl_orients_out = curl_orients_in;
1932+
}
1933+
1934+
// Compute B^T D B
1935+
CeedInt shared_mem =
1936+
((curl_orients_in || curl_orients_out ? elem_size_in * elem_size_out : 0) + (curl_orients_in ? elem_size_in * asmb->block_size_y : 0)) *
1937+
sizeof(CeedScalar);
1938+
CeedInt grid = CeedDivUpInt(num_elem_in, asmb->elems_per_block);
1939+
void *args[] = {(void *)&num_elem_in, &asmb->d_B_in, &asmb->d_B_out, &orients_in, &curl_orients_in,
1940+
&orients_out, &curl_orients_out, &assembled_qf_array, &values_array};
1941+
1942+
CeedCallBackend(CeedRunKernelDimShared_Hip(ceed, asmb->LinearAssemble, NULL, grid, asmb->block_size_x, asmb->block_size_y, asmb->elems_per_block,
1943+
shared_mem, args));
1944+
1945+
// Restore arrays
1946+
CeedCallBackend(CeedVectorRestoreArray(values, &values_array));
1947+
CeedCallBackend(CeedVectorRestoreArrayRead(assembled_qf, &assembled_qf_array));
1948+
1949+
// Cleanup
1950+
CeedCallBackend(CeedVectorDestroy(&assembled_qf));
1951+
if (rstr_type_in == CEED_RESTRICTION_ORIENTED) {
1952+
CeedCallBackend(CeedElemRestrictionRestoreOrientations(rstr_in, &orients_in));
1953+
} else if (rstr_type_in == CEED_RESTRICTION_CURL_ORIENTED) {
1954+
CeedCallBackend(CeedElemRestrictionRestoreCurlOrientations(rstr_in, &curl_orients_in));
1955+
}
1956+
if (rstr_in != rstr_out) {
1957+
if (rstr_type_out == CEED_RESTRICTION_ORIENTED) {
1958+
CeedCallBackend(CeedElemRestrictionRestoreOrientations(rstr_out, &orients_out));
1959+
} else if (rstr_type_out == CEED_RESTRICTION_CURL_ORIENTED) {
1960+
CeedCallBackend(CeedElemRestrictionRestoreCurlOrientations(rstr_out, &curl_orients_out));
1961+
}
1962+
}
1963+
CeedCallBackend(CeedDestroy(&ceed));
1964+
return CEED_ERROR_SUCCESS;
1965+
}
1966+
17051967
//------------------------------------------------------------------------------
17061968
// Assemble matrix data for COO matrix of assembled operator.
17071969
// The sparsity pattern is set by CeedOperatorLinearAssembleSymbolic.
@@ -2089,6 +2351,7 @@ int CeedOperatorCreate_Hip(CeedOperator op) {
20892351
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleAddPointBlockDiagonal",
20902352
CeedOperatorLinearAssembleAddPointBlockDiagonal_Hip));
20912353
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleSingle", CeedOperatorAssembleSingle_Hip));
2354+
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "LinearAssembleSingleBlock", CeedOperatorAssembleSingleBlock_Hip));
20922355
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "ApplyAdd", CeedOperatorApplyAdd_Hip));
20932356
CeedCallBackend(CeedSetBackendFunction(ceed, "Operator", op, "Destroy", CeedOperatorDestroy_Hip));
20942357
CeedCallBackend(CeedDestroy(&ceed));

backends/hip-ref/ceed-hip-ref.h

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -137,19 +137,21 @@ typedef struct {
137137
} CeedOperatorAssemble_Hip;
138138

139139
typedef struct {
140-
bool *skip_rstr_in, *skip_rstr_out, *apply_add_basis_out;
141-
uint64_t *input_states, points_state; // State tracking for passive inputs
142-
CeedVector *e_vecs_in, *e_vecs_out;
143-
CeedVector *q_vecs_in, *q_vecs_out;
144-
CeedInt num_inputs, num_outputs;
145-
CeedInt num_active_in, num_active_out;
146-
CeedInt *input_field_order, *output_field_order;
147-
CeedSize max_active_e_vec_len;
148-
CeedInt max_num_points;
149-
CeedInt *num_points;
150-
CeedVector *qf_active_in, point_coords_elem;
151-
CeedOperatorDiag_Hip *diag;
152-
CeedOperatorAssemble_Hip *asmb;
140+
bool *skip_rstr_in, *skip_rstr_out, *apply_add_basis_out;
141+
uint64_t *input_states, points_state; // State tracking for passive inputs
142+
CeedVector *e_vecs_in, *e_vecs_out;
143+
CeedVector *q_vecs_in, *q_vecs_out;
144+
CeedInt num_inputs, num_outputs;
145+
CeedInt num_active_in, num_active_out;
146+
CeedInt *input_field_order, *output_field_order;
147+
CeedSize max_active_e_vec_len;
148+
CeedInt max_num_points;
149+
CeedInt *num_points;
150+
CeedVector *qf_active_in, point_coords_elem;
151+
CeedOperatorDiag_Hip *diag;
152+
CeedOperatorAssemble_Hip *asmb;
153+
CeedOperatorAssemble_Hip **asmb_blocks;
154+
CeedInt num_blocks_in, num_blocks_out;
153155
} CeedOperator_Hip;
154156

155157
CEED_INTERN int CeedGetHipblasHandle_Hip(Ceed ceed, hipblasHandle_t *handle);

backends/hip/ceed-hip-compile.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <ceed/backend.h>
1212
#include <ceed/jit-source/hip/hip-chipstar.h>
1313
#include <ceed/jit-tools.h>
14+
#include <iomanip>
1415
#include <stdarg.h>
1516
#include <string.h>
1617
#include <hip/hiprtc.h>
@@ -200,6 +201,29 @@ static int CeedCompileCore_Hip(Ceed ceed, const char *source, const bool throw_e
200201
return CEED_ERROR_SUCCESS;
201202
}
202203

204+
template <typename ArrayT>
205+
struct CeedArrayView {
206+
const ArrayT *array;
207+
CeedInt size;
208+
209+
CeedArrayView(const ArrayT *array_, CeedInt size_) : array(array_), size(size_) {}
210+
};
211+
212+
template <typename OStream, typename ArrayT>
213+
OStream &operator<<(OStream &ostream, const CeedArrayView<ArrayT> &view) {
214+
ostream << "{";
215+
for (CeedInt i = 0; i < view.size; i++) ostream << std::setprecision(17) << view.array[i] << (i == view.size - 1 ? "}" : ", ");
216+
return ostream;
217+
}
218+
219+
int CeedBuildArrayConstantSize_Hip(Ceed ceed, const char *name, CeedInt length, const CeedSize *array, char **line) {
220+
std::ostringstream code;
221+
222+
code << "constexpr CeedSize " << name << "[" << length << "] = " << CeedArrayView<CeedSize>(array, length) << ";";
223+
CeedCallBackend(CeedStringAllocCopy(code.str().c_str(), line));
224+
return CEED_ERROR_SUCCESS;
225+
}
226+
203227
int CeedCompile_Hip(Ceed ceed, const char *source, hipModule_t *module, const CeedInt num_defines, ...) {
204228
bool is_compile_good = true;
205229
va_list args;

backends/hip/ceed-hip-compile.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
static inline CeedInt CeedDivUpInt(CeedInt numerator, CeedInt denominator) { return (numerator + denominator - 1) / denominator; }
1414

15+
CEED_INTERN int CeedBuildArrayConstantSize_Hip(Ceed ceed, const char *name, CeedInt length, const CeedSize *array, char **line);
16+
1517
CEED_INTERN int CeedCompile_Hip(Ceed ceed, const char *source, hipModule_t *module, const CeedInt num_defines, ...);
1618
CEED_INTERN int CeedTryCompile_Hip(Ceed ceed, const char *source, bool *is_compile_good, hipModule_t *module, const CeedInt num_defines, ...);
1719

0 commit comments

Comments
 (0)