@@ -94,7 +94,27 @@ static int CeedOperatorDestroy_Hip(CeedOperator op) {
9494 CeedCallHip (ceed , hipFree (impl -> asmb -> d_B_out ));
9595 CeedCallBackend (CeedDestroy (& ceed ));
9696 }
97+
9798 CeedCallBackend (CeedFree (& impl -> asmb ));
99+ if (impl -> asmb_blocks ) {
100+ Ceed ceed ;
101+
102+ CeedCallBackend (CeedOperatorGetCeed (op , & ceed ));
103+ for (CeedInt i = 0 ; i < impl -> num_blocks_in ; i ++ ) {
104+ for (CeedInt j = 0 ; j < impl -> num_blocks_out ; j ++ ) {
105+ CeedOperatorAssemble_Hip * asmb = impl -> asmb_blocks [i * impl -> num_blocks_out + j ];
106+
107+ if (asmb ) {
108+ CeedCallHip (ceed , hipModuleUnload (asmb -> module ));
109+ CeedCallHip (ceed , hipFree (asmb -> d_B_in ));
110+ CeedCallHip (ceed , hipFree (asmb -> d_B_out ));
111+ }
112+ CeedCallBackend (CeedFree (& asmb ));
113+ }
114+ }
115+ CeedCallBackend (CeedFree (& impl -> asmb_blocks ));
116+ CeedCallBackend (CeedDestroy (& ceed ));
117+ }
98118
99119 CeedCallBackend (CeedFree (& impl ));
100120 return CEED_ERROR_SUCCESS ;
@@ -1049,9 +1069,7 @@ static inline int CeedOperatorLinearAssembleQFunctionCore_Hip(CeedOperator op, b
10491069 CeedInt strides [3 ] = {1 , num_elem * Q , Q }; /* *NOPAD* */
10501070
10511071 // Create output restriction
1052- CeedCallBackend (CeedElemRestrictionCreateStrided (ceed_parent , num_elem , Q , num_active_in * num_active_out ,
1053- (CeedSize )num_active_in * (CeedSize )num_active_out * (CeedSize )num_elem * (CeedSize )Q , strides ,
1054- rstr ));
1072+ CeedCallBackend (CeedElemRestrictionCreateStrided (ceed_parent , num_elem , Q , num_active_in * num_active_out , l_size , strides , rstr ));
10551073 // Create assembled vector
10561074 CeedCallBackend (CeedVectorCreate (ceed_parent , l_size , assembled ));
10571075 }
@@ -1501,6 +1519,139 @@ static int CeedOperatorLinearAssembleAddPointBlockDiagonal_Hip(CeedOperator op,
15011519 return CEED_ERROR_SUCCESS ;
15021520}
15031521
1522+ //------------------------------------------------------------------------------
1523+ // Single Operator Block Assembly Setup
1524+ //------------------------------------------------------------------------------
1525+ static int CeedOperatorAssembleSingleBlockSetup_Hip (CeedOperator op , CeedInt active_input , CeedInt active_output , CeedInt use_ceedsize_idx ) {
1526+ Ceed ceed ;
1527+ Ceed_Hip * hip_data ;
1528+ CeedInt num_input_fields , num_output_fields , num_eval_modes_in = 0 , num_eval_modes_out = 0 ;
1529+ CeedInt elem_size_in , num_qpts_in = 0 , num_comp_in , elem_size_out , num_qpts_out , num_comp_out ;
1530+ CeedSize num_output_components ;
1531+ const CeedScalar * h_B_in , * h_B_out ;
1532+ CeedElemRestriction rstr_in = NULL , rstr_out = NULL ;
1533+ CeedBasis basis_in = NULL , basis_out = NULL ;
1534+ CeedOperatorField * input_fields , * output_fields ;
1535+ CeedOperator_Hip * impl ;
1536+ CeedOperatorAssemblyData data ;
1537+
1538+ CeedCallBackend (CeedOperatorGetCeed (op , & ceed ));
1539+ CeedCallBackend (CeedOperatorGetData (op , & impl ));
1540+ CeedCall (CeedOperatorGetOperatorAssemblyData (op , & data ));
1541+
1542+ // Get intput and output fields
1543+ CeedCallBackend (CeedOperatorGetFields (op , & num_input_fields , & input_fields , & num_output_fields , & output_fields ));
1544+
1545+ {
1546+ CeedInt num_active_bases_in , * t_num_eval_modes_in , num_active_bases_out , * t_num_eval_modes_out ;
1547+ CeedBasis * active_bases_in , * active_bases_out ;
1548+ CeedElemRestriction * active_rstrs_in , * active_rstrs_out ;
1549+ const CeedScalar * * B_mats_in , * * B_mats_out ;
1550+
1551+ CeedCall (CeedOperatorAssemblyDataGetEvalModes (data , & num_active_bases_in , & t_num_eval_modes_in , NULL , NULL , & num_active_bases_out ,
1552+ & t_num_eval_modes_out , NULL , NULL , & num_output_components ));
1553+ // Number of elem restrictions is the same as the number of bases
1554+ CeedCall (CeedOperatorAssemblyDataGetElemRestrictions (data , NULL , & active_rstrs_in , NULL , & active_rstrs_out ));
1555+ CeedCall (CeedOperatorAssemblyDataGetBases (data , NULL , & active_bases_in , & B_mats_in , NULL , & active_bases_out , & B_mats_out ));
1556+
1557+ num_eval_modes_in = t_num_eval_modes_in [active_input ];
1558+ num_eval_modes_out = t_num_eval_modes_out [active_output ];
1559+ CeedCheck (num_eval_modes_in > 0 && num_eval_modes_out > 0 , ceed , CEED_ERROR_UNSUPPORTED , "Cannot assemble operator without inputs/outputs" );
1560+
1561+ if (!impl -> asmb_blocks ) {
1562+ CeedCallBackend (CeedCalloc (num_active_bases_in * num_active_bases_out , & impl -> asmb_blocks ));
1563+ impl -> num_blocks_in = num_active_bases_in ;
1564+ impl -> num_blocks_out = num_active_bases_out ;
1565+ }
1566+
1567+ rstr_in = active_rstrs_in [active_input ];
1568+ basis_in = active_bases_in [active_input ];
1569+ h_B_in = B_mats_in [active_input ];
1570+ rstr_out = active_rstrs_out [active_output ];
1571+ basis_out = active_bases_out [active_output ];
1572+ h_B_out = B_mats_out [active_output ];
1573+ }
1574+
1575+ CeedCallBackend (CeedElemRestrictionGetElementSize (rstr_in , & elem_size_in ));
1576+ if (basis_in == CEED_BASIS_NONE ) num_qpts_in = elem_size_in ;
1577+ else CeedCallBackend (CeedBasisGetNumQuadraturePoints (basis_in , & num_qpts_in ));
1578+
1579+ CeedCallBackend (CeedElemRestrictionGetElementSize (rstr_out , & elem_size_out ));
1580+ if (basis_out == CEED_BASIS_NONE ) num_qpts_out = elem_size_out ;
1581+ else CeedCallBackend (CeedBasisGetNumQuadraturePoints (basis_out , & num_qpts_out ));
1582+ CeedCheck (num_qpts_in == num_qpts_out , ceed , CEED_ERROR_UNSUPPORTED ,
1583+ "Active input and output bases must have the same number of quadrature points" );
1584+
1585+ CeedCallBackend (CeedCalloc (1 , & impl -> asmb_blocks [active_input * impl -> num_blocks_out + active_output ]));
1586+ CeedOperatorAssemble_Hip * asmb = impl -> asmb_blocks [active_input * impl -> num_blocks_out + active_output ];
1587+ asmb -> elems_per_block = 1 ;
1588+ asmb -> block_size_x = elem_size_in ;
1589+ asmb -> block_size_y = elem_size_out ;
1590+
1591+ CeedCallBackend (CeedGetData (ceed , & hip_data ));
1592+ bool fallback = asmb -> block_size_x * asmb -> block_size_y * asmb -> elems_per_block > hip_data -> device_prop .maxThreadsPerBlock ;
1593+
1594+ if (fallback ) {
1595+ // Use fallback kernel with 1D threadblock
1596+ asmb -> block_size_y = 1 ;
1597+ }
1598+
1599+ // Compile kernels
1600+ char * source ;
1601+ {
1602+ CeedInt num_eval_modes_in , * t_num_eval_modes_in , num_eval_modes_out , * t_num_eval_modes_out ;
1603+ CeedSize * * eval_modes_offsets_in , * * eval_modes_offsets_out ;
1604+ char * eval_mode_offsets_in_str , * eval_mode_offsets_out_str ;
1605+
1606+ CeedCall (CeedOperatorAssemblyDataGetEvalModes (data , NULL , & t_num_eval_modes_in , NULL , & eval_modes_offsets_in , NULL , & t_num_eval_modes_out , NULL ,
1607+ & eval_modes_offsets_out , NULL ));
1608+ num_eval_modes_in = t_num_eval_modes_in [active_input ];
1609+ num_eval_modes_out = t_num_eval_modes_out [active_output ];
1610+
1611+ CeedCallBackend (CeedBuildArrayConstantSize_Hip (ceed , "EVAL_MODE_OFFSETS_IN" , num_eval_modes_in , eval_modes_offsets_in [active_input ],
1612+ & eval_mode_offsets_in_str ));
1613+ CeedCallBackend (CeedBuildArrayConstantSize_Hip (ceed , "EVAL_MODE_OFFSETS_OUT" , num_eval_modes_out , eval_modes_offsets_out [active_output ],
1614+ & eval_mode_offsets_out_str ));
1615+
1616+ const char assembly_kernel_source [] = "// Full assembly source\n#include <ceed/jit-source/hip/hip-ref-operator-assemble-block.h>\n" ;
1617+ const CeedSize len = strlen (assembly_kernel_source ) + strlen (eval_mode_offsets_in_str ) + strlen (eval_mode_offsets_out_str ) + 3 ;
1618+
1619+ CeedCallBackend (CeedCalloc (len , & source ));
1620+ snprintf (source , len , "%s\n%s\n%s" , eval_mode_offsets_in_str , eval_mode_offsets_out_str , assembly_kernel_source );
1621+
1622+ CeedCallBackend (CeedFree (& eval_mode_offsets_in_str ));
1623+ CeedCallBackend (CeedFree (& eval_mode_offsets_out_str ));
1624+ }
1625+
1626+ CeedCallBackend (CeedElemRestrictionGetNumComponents (rstr_in , & num_comp_in ));
1627+ CeedCallBackend (CeedElemRestrictionGetNumComponents (rstr_out , & num_comp_out ));
1628+ CeedCallBackend (CeedCompile_Hip (ceed , source , & asmb -> module , 11 , "NUM_EVAL_MODES_IN" , num_eval_modes_in , "NUM_EVAL_MODES_OUT" , num_eval_modes_out ,
1629+ "NUM_COMP_IN" , num_comp_in , "NUM_COMP_OUT" , num_comp_out , "TOTAL_NUM_COMP_OUT" , num_output_components ,
1630+ "NUM_NODES_IN" , elem_size_in , "NUM_NODES_OUT" , elem_size_out , "NUM_QPTS" , num_qpts_in , "BLOCK_SIZE" ,
1631+ asmb -> block_size_x * asmb -> block_size_y * asmb -> elems_per_block , "BLOCK_SIZE_Y" , asmb -> block_size_y , "USE_CEEDSIZE" ,
1632+ use_ceedsize_idx ));
1633+ CeedCallBackend (CeedGetKernel_Hip (ceed , asmb -> module , "LinearAssembleBlock" , & asmb -> LinearAssemble ));
1634+
1635+ // Load into B_in, in order that they will be used in eval_modes_in
1636+ {
1637+ const CeedInt in_bytes = elem_size_in * num_qpts_in * num_eval_modes_in * sizeof (CeedScalar );
1638+
1639+ CeedCallHip (ceed , hipMalloc ((void * * )& asmb -> d_B_in , in_bytes ));
1640+ CeedCallHip (ceed , hipMemcpy (asmb -> d_B_in , h_B_in , in_bytes , hipMemcpyHostToDevice ));
1641+ }
1642+
1643+ // Load into B_out, in order that they will be used in eval_modes_out
1644+ {
1645+ const CeedInt out_bytes = elem_size_out * num_qpts_out * num_eval_modes_out * sizeof (CeedScalar );
1646+
1647+ CeedCallHip (ceed , hipMalloc ((void * * )& asmb -> d_B_out , out_bytes ));
1648+ CeedCallHip (ceed , hipMemcpy (asmb -> d_B_out , h_B_out , out_bytes , hipMemcpyHostToDevice ));
1649+ }
1650+ CeedCallBackend (CeedFree (& source ));
1651+ CeedCallBackend (CeedDestroy (& ceed ));
1652+ return CEED_ERROR_SUCCESS ;
1653+ }
1654+
15041655//------------------------------------------------------------------------------
15051656// Single Operator Assembly Setup
15061657//------------------------------------------------------------------------------
@@ -1702,6 +1853,117 @@ static int CeedOperatorAssembleSingleSetup_Hip(CeedOperator op, CeedInt use_ceed
17021853 return CEED_ERROR_SUCCESS ;
17031854}
17041855
1856+ //------------------------------------------------------------------------------
1857+ // Assemble matrix data for one block of a COO matrix of assembled operator.
1858+ // The sparsity pattern is set by CeedOperatorLinearAssembleSymbolic.
1859+ //------------------------------------------------------------------------------
1860+ static int CeedOperatorAssembleSingleBlock_Hip (CeedOperator op , CeedInt offset , CeedInt active_input , CeedInt active_output , CeedVector values ) {
1861+ Ceed ceed ;
1862+ CeedSize values_length = 0 , assembled_qf_length = 0 ;
1863+ CeedInt use_ceedsize_idx = 0 , num_elem_in , num_elem_out , elem_size_in , elem_size_out ;
1864+ CeedScalar * values_array ;
1865+ const CeedScalar * assembled_qf_array ;
1866+ CeedVector assembled_qf = NULL ;
1867+ CeedElemRestriction assembled_rstr = NULL , rstr_in , rstr_out ;
1868+ CeedRestrictionType rstr_type_in , rstr_type_out ;
1869+ const bool * orients_in = NULL , * orients_out = NULL ;
1870+ const CeedInt8 * curl_orients_in = NULL , * curl_orients_out = NULL ;
1871+ CeedOperator_Hip * impl ;
1872+
1873+ CeedCallBackend (CeedOperatorGetCeed (op , & ceed ));
1874+ CeedCallBackend (CeedOperatorGetData (op , & impl ));
1875+
1876+ // Assemble QFunction
1877+ CeedCallBackend (CeedOperatorLinearAssembleQFunctionBuildOrUpdate (op , & assembled_qf , & assembled_rstr , CEED_REQUEST_IMMEDIATE ));
1878+ CeedCallBackend (CeedElemRestrictionDestroy (& assembled_rstr ));
1879+ CeedCallBackend (CeedVectorGetArrayRead (assembled_qf , CEED_MEM_DEVICE , & assembled_qf_array ));
1880+
1881+ CeedCallBackend (CeedVectorGetLength (values , & values_length ));
1882+ CeedCallBackend (CeedVectorGetLength (assembled_qf , & assembled_qf_length ));
1883+ if ((values_length > INT_MAX ) || (assembled_qf_length > INT_MAX )) use_ceedsize_idx = 1 ;
1884+
1885+ // Setup
1886+ if (!impl -> asmb_blocks || (impl -> asmb_blocks && !impl -> asmb_blocks [active_input * impl -> num_blocks_out + active_output ])) {
1887+ CeedCallBackend (CeedOperatorAssembleSingleBlockSetup_Hip (op , active_input , active_output , use_ceedsize_idx ));
1888+ }
1889+ assert (impl -> asmb_blocks && impl -> asmb_blocks [active_input * impl -> num_blocks_out + active_output ]);
1890+ CeedOperatorAssemble_Hip * asmb = impl -> asmb_blocks [active_input * impl -> num_blocks_out + active_output ];
1891+
1892+ assert (asmb != NULL );
1893+
1894+ // Assemble element operator
1895+ CeedCallBackend (CeedVectorGetArray (values , CEED_MEM_DEVICE , & values_array ));
1896+ values_array += offset ;
1897+
1898+ CeedElemRestriction * active_rstrs_in , * active_rstrs_out ;
1899+ CeedOperatorAssemblyData data ;
1900+
1901+ CeedCall (CeedOperatorGetOperatorAssemblyData (op , & data ));
1902+ CeedCall (CeedOperatorAssemblyDataGetElemRestrictions (data , NULL , & active_rstrs_in , NULL , & active_rstrs_out ));
1903+
1904+ rstr_in = active_rstrs_in [active_input ];
1905+ rstr_out = active_rstrs_out [active_output ];
1906+ CeedCallBackend (CeedElemRestrictionGetNumElements (rstr_in , & num_elem_in ));
1907+ CeedCallBackend (CeedElemRestrictionGetElementSize (rstr_in , & elem_size_in ));
1908+
1909+ CeedCallBackend (CeedElemRestrictionGetType (rstr_in , & rstr_type_in ));
1910+ if (rstr_type_in == CEED_RESTRICTION_ORIENTED ) {
1911+ CeedCallBackend (CeedElemRestrictionGetOrientations (rstr_in , CEED_MEM_DEVICE , & orients_in ));
1912+ } else if (rstr_type_in == CEED_RESTRICTION_CURL_ORIENTED ) {
1913+ CeedCallBackend (CeedElemRestrictionGetCurlOrientations (rstr_in , CEED_MEM_DEVICE , & curl_orients_in ));
1914+ }
1915+
1916+ if (rstr_in != rstr_out ) {
1917+ CeedCallBackend (CeedElemRestrictionGetNumElements (rstr_out , & num_elem_out ));
1918+ CeedCheck (num_elem_in == num_elem_out , ceed , CEED_ERROR_UNSUPPORTED ,
1919+ "Active input and output operator restrictions must have the same number of elements" );
1920+ CeedCallBackend (CeedElemRestrictionGetElementSize (rstr_out , & elem_size_out ));
1921+
1922+ CeedCallBackend (CeedElemRestrictionGetType (rstr_out , & rstr_type_out ));
1923+ if (rstr_type_out == CEED_RESTRICTION_ORIENTED ) {
1924+ CeedCallBackend (CeedElemRestrictionGetOrientations (rstr_out , CEED_MEM_DEVICE , & orients_out ));
1925+ } else if (rstr_type_out == CEED_RESTRICTION_CURL_ORIENTED ) {
1926+ CeedCallBackend (CeedElemRestrictionGetCurlOrientations (rstr_out , CEED_MEM_DEVICE , & curl_orients_out ));
1927+ }
1928+ } else {
1929+ elem_size_out = elem_size_in ;
1930+ orients_out = orients_in ;
1931+ curl_orients_out = curl_orients_in ;
1932+ }
1933+
1934+ // Compute B^T D B
1935+ CeedInt shared_mem =
1936+ ((curl_orients_in || curl_orients_out ? elem_size_in * elem_size_out : 0 ) + (curl_orients_in ? elem_size_in * asmb -> block_size_y : 0 )) *
1937+ sizeof (CeedScalar );
1938+ CeedInt grid = CeedDivUpInt (num_elem_in , asmb -> elems_per_block );
1939+ void * args [] = {(void * )& num_elem_in , & asmb -> d_B_in , & asmb -> d_B_out , & orients_in , & curl_orients_in ,
1940+ & orients_out , & curl_orients_out , & assembled_qf_array , & values_array };
1941+
1942+ CeedCallBackend (CeedRunKernelDimShared_Hip (ceed , asmb -> LinearAssemble , NULL , grid , asmb -> block_size_x , asmb -> block_size_y , asmb -> elems_per_block ,
1943+ shared_mem , args ));
1944+
1945+ // Restore arrays
1946+ CeedCallBackend (CeedVectorRestoreArray (values , & values_array ));
1947+ CeedCallBackend (CeedVectorRestoreArrayRead (assembled_qf , & assembled_qf_array ));
1948+
1949+ // Cleanup
1950+ CeedCallBackend (CeedVectorDestroy (& assembled_qf ));
1951+ if (rstr_type_in == CEED_RESTRICTION_ORIENTED ) {
1952+ CeedCallBackend (CeedElemRestrictionRestoreOrientations (rstr_in , & orients_in ));
1953+ } else if (rstr_type_in == CEED_RESTRICTION_CURL_ORIENTED ) {
1954+ CeedCallBackend (CeedElemRestrictionRestoreCurlOrientations (rstr_in , & curl_orients_in ));
1955+ }
1956+ if (rstr_in != rstr_out ) {
1957+ if (rstr_type_out == CEED_RESTRICTION_ORIENTED ) {
1958+ CeedCallBackend (CeedElemRestrictionRestoreOrientations (rstr_out , & orients_out ));
1959+ } else if (rstr_type_out == CEED_RESTRICTION_CURL_ORIENTED ) {
1960+ CeedCallBackend (CeedElemRestrictionRestoreCurlOrientations (rstr_out , & curl_orients_out ));
1961+ }
1962+ }
1963+ CeedCallBackend (CeedDestroy (& ceed ));
1964+ return CEED_ERROR_SUCCESS ;
1965+ }
1966+
17051967//------------------------------------------------------------------------------
17061968// Assemble matrix data for COO matrix of assembled operator.
17071969// The sparsity pattern is set by CeedOperatorLinearAssembleSymbolic.
@@ -2089,6 +2351,7 @@ int CeedOperatorCreate_Hip(CeedOperator op) {
20892351 CeedCallBackend (CeedSetBackendFunction (ceed , "Operator" , op , "LinearAssembleAddPointBlockDiagonal" ,
20902352 CeedOperatorLinearAssembleAddPointBlockDiagonal_Hip ));
20912353 CeedCallBackend (CeedSetBackendFunction (ceed , "Operator" , op , "LinearAssembleSingle" , CeedOperatorAssembleSingle_Hip ));
2354+ CeedCallBackend (CeedSetBackendFunction (ceed , "Operator" , op , "LinearAssembleSingleBlock" , CeedOperatorAssembleSingleBlock_Hip ));
20922355 CeedCallBackend (CeedSetBackendFunction (ceed , "Operator" , op , "ApplyAdd" , CeedOperatorApplyAdd_Hip ));
20932356 CeedCallBackend (CeedSetBackendFunction (ceed , "Operator" , op , "Destroy" , CeedOperatorDestroy_Hip ));
20942357 CeedCallBackend (CeedDestroy (& ceed ));
0 commit comments