@@ -95,8 +95,29 @@ static int CeedOperatorDestroy_Cuda(CeedOperator op) {
9595 CeedCallCuda (ceed , cudaFree (impl -> asmb -> d_B_out ));
9696 CeedCallBackend (CeedDestroy (& ceed ));
9797 }
98+
9899 CeedCallBackend (CeedFree (& impl -> asmb ));
99100
101+ if (impl -> asmb_blocks ) {
102+ Ceed ceed ;
103+
104+ CeedCallBackend (CeedOperatorGetCeed (op , & ceed ));
105+ for (CeedInt i = 0 ; i < impl -> num_blocks_in ; i ++ ) {
106+ for (CeedInt j = 0 ; j < impl -> num_blocks_out ; j ++ ) {
107+ CeedOperatorAssemble_Cuda * asmb = impl -> asmb_blocks [i * impl -> num_blocks_out + j ];
108+
109+ if (asmb ) {
110+ CeedCallCuda (ceed , cuModuleUnload (asmb -> module ));
111+ CeedCallCuda (ceed , cudaFree (asmb -> d_B_in ));
112+ CeedCallCuda (ceed , cudaFree (asmb -> d_B_out ));
113+ }
114+ CeedCallBackend (CeedFree (& asmb ));
115+ }
116+ }
117+ CeedCallBackend (CeedFree (& impl -> asmb_blocks ));
118+ CeedCallBackend (CeedDestroy (& ceed ));
119+ }
120+
100121 CeedCallBackend (CeedFree (& impl ));
101122 return CEED_ERROR_SUCCESS ;
102123}
@@ -1052,9 +1073,7 @@ static inline int CeedOperatorLinearAssembleQFunctionCore_Cuda(CeedOperator op,
10521073 CeedInt strides [3 ] = {1 , num_elem * Q , Q }; /* *NOPAD* */
10531074
10541075 // Create output restriction
1055- CeedCallBackend (CeedElemRestrictionCreateStrided (ceed_parent , num_elem , Q , num_active_in * num_active_out ,
1056- (CeedSize )num_active_in * (CeedSize )num_active_out * (CeedSize )num_elem * (CeedSize )Q , strides ,
1057- rstr ));
1076+ CeedCallBackend (CeedElemRestrictionCreateStrided (ceed_parent , num_elem , Q , num_active_in * num_active_out , l_size , strides , rstr ));
10581077 // Create assembled vector
10591078 CeedCallBackend (CeedVectorCreate (ceed_parent , l_size , assembled ));
10601079 }
@@ -1504,6 +1523,133 @@ static int CeedOperatorLinearAssembleAddPointBlockDiagonal_Cuda(CeedOperator op,
15041523 return CEED_ERROR_SUCCESS ;
15051524}
15061525
1526+ //------------------------------------------------------------------------------
1527+ // Single Operator Assembly Setup
1528+ //------------------------------------------------------------------------------
1529+ static int CeedOperatorAssembleSingleBlockSetup_Cuda (CeedOperator op , CeedInt active_input , CeedInt active_output , CeedInt use_ceedsize_idx ) {
1530+ Ceed ceed ;
1531+ Ceed_Cuda * cuda_data ;
1532+ CeedInt num_input_fields , num_output_fields , num_eval_modes_in = 0 , num_eval_modes_out = 0 ;
1533+ CeedInt elem_size_in , num_qpts_in = 0 , num_comp_in , elem_size_out , num_qpts_out , num_comp_out ;
1534+ CeedSize num_output_components ;
1535+ const CeedScalar * h_B_in , * h_B_out ;
1536+ CeedElemRestriction rstr_in = NULL , rstr_out = NULL ;
1537+ CeedBasis basis_in = NULL , basis_out = NULL ;
1538+ CeedOperatorField * input_fields , * output_fields ;
1539+ CeedOperator_Cuda * impl ;
1540+ char * eval_mode_offsets_in_str , * eval_mode_offsets_out_str ;
1541+
1542+ CeedCallBackend (CeedOperatorGetCeed (op , & ceed ));
1543+ CeedCallBackend (CeedOperatorGetData (op , & impl ));
1544+
1545+ // Get intput and output fields
1546+ CeedCallBackend (CeedOperatorGetFields (op , & num_input_fields , & input_fields , & num_output_fields , & output_fields ));
1547+
1548+ {
1549+ CeedInt num_active_bases_in , * t_num_eval_modes_in , num_active_bases_out , * t_num_eval_modes_out ;
1550+ CeedSize * * eval_modes_offsets_in , * * eval_modes_offsets_out ;
1551+ CeedBasis * active_bases_in , * active_bases_out ;
1552+ CeedElemRestriction * active_rstrs_in , * active_rstrs_out ;
1553+ const CeedScalar * * B_mats_in , * * B_mats_out ;
1554+ CeedOperatorAssemblyData data ;
1555+
1556+ CeedCall (CeedOperatorGetOperatorAssemblyData (op , & data ));
1557+ CeedCall (CeedOperatorAssemblyDataGetEvalModes (data , & num_active_bases_in , & t_num_eval_modes_in , NULL , & eval_modes_offsets_in ,
1558+ & num_active_bases_out , & t_num_eval_modes_out , NULL , & eval_modes_offsets_out ,
1559+ & num_output_components ));
1560+ // Number of elem restrictions is the same as the number of bases
1561+ CeedCall (CeedOperatorAssemblyDataGetElemRestrictions (data , NULL , & active_rstrs_in , NULL , & active_rstrs_out ));
1562+ CeedCall (CeedOperatorAssemblyDataGetBases (data , NULL , & active_bases_in , & B_mats_in , NULL , & active_bases_out , & B_mats_out ));
1563+
1564+ num_eval_modes_in = t_num_eval_modes_in [active_input ];
1565+ num_eval_modes_out = t_num_eval_modes_out [active_output ];
1566+ CeedCheck (num_eval_modes_in > 0 && num_eval_modes_out > 0 , ceed , CEED_ERROR_UNSUPPORTED , "Cannot assemble operator without inputs/outputs" );
1567+
1568+ if (!impl -> asmb_blocks ) {
1569+ CeedCallBackend (CeedCalloc (num_active_bases_in * num_active_bases_out , & impl -> asmb_blocks ));
1570+ impl -> num_blocks_in = num_active_bases_in ;
1571+ impl -> num_blocks_out = num_active_bases_out ;
1572+ }
1573+
1574+ rstr_in = active_rstrs_in [active_input ];
1575+ basis_in = active_bases_in [active_input ];
1576+ CeedCallBackend (CeedBuildArrayConstantSize_Cuda (ceed , "EVAL_MODE_OFFSETS_IN" , num_eval_modes_in , eval_modes_offsets_in [active_input ],
1577+ & eval_mode_offsets_in_str ));
1578+ h_B_in = B_mats_in [active_input ];
1579+ rstr_out = active_rstrs_out [active_output ];
1580+ basis_out = active_bases_out [active_output ];
1581+ CeedCallBackend (CeedBuildArrayConstantSize_Cuda (ceed , "EVAL_MODE_OFFSETS_OUT" , num_eval_modes_out , eval_modes_offsets_out [active_output ],
1582+ & eval_mode_offsets_out_str ));
1583+ h_B_out = B_mats_out [active_output ];
1584+ }
1585+
1586+ CeedCallBackend (CeedElemRestrictionGetElementSize (rstr_in , & elem_size_in ));
1587+ if (basis_in == CEED_BASIS_NONE ) num_qpts_in = elem_size_in ;
1588+ else CeedCallBackend (CeedBasisGetNumQuadraturePoints (basis_in , & num_qpts_in ));
1589+
1590+ CeedCallBackend (CeedElemRestrictionGetElementSize (rstr_out , & elem_size_out ));
1591+ if (basis_out == CEED_BASIS_NONE ) num_qpts_out = elem_size_out ;
1592+ else CeedCallBackend (CeedBasisGetNumQuadraturePoints (basis_out , & num_qpts_out ));
1593+ CeedCheck (num_qpts_in == num_qpts_out , ceed , CEED_ERROR_UNSUPPORTED ,
1594+ "Active input and output bases must have the same number of quadrature points" );
1595+
1596+ CeedCallBackend (CeedCalloc (1 , & impl -> asmb_blocks [active_input * impl -> num_blocks_out + active_output ]));
1597+ CeedOperatorAssemble_Cuda * asmb = impl -> asmb_blocks [active_input * impl -> num_blocks_out + active_output ];
1598+ asmb -> elems_per_block = 1 ;
1599+ asmb -> block_size_x = elem_size_in ;
1600+ asmb -> block_size_y = elem_size_out ;
1601+
1602+ CeedCallBackend (CeedGetData (ceed , & cuda_data ));
1603+ bool fallback = asmb -> block_size_x * asmb -> block_size_y * asmb -> elems_per_block > cuda_data -> device_prop .maxThreadsPerBlock ;
1604+
1605+ if (fallback ) {
1606+ // Use fallback kernel with 1D threadblock
1607+ asmb -> block_size_y = 1 ;
1608+ }
1609+
1610+ // Compile kernels
1611+ const char assembly_kernel_source [] = "// Full assembly source\n#include <ceed/jit-source/cuda/cuda-ref-operator-assemble-block.h>\n" ;
1612+ const CeedSize len = strlen (assembly_kernel_source ) + strlen (eval_mode_offsets_in_str ) + strlen (eval_mode_offsets_out_str ) + 3 ;
1613+ char * source ;
1614+
1615+ CeedCallBackend (CeedCalloc (len , & source ));
1616+ strcat (source , eval_mode_offsets_in_str );
1617+ strcat (source , "\n" );
1618+ strcat (source , eval_mode_offsets_out_str );
1619+ strcat (source , "\n" );
1620+ strcat (source , assembly_kernel_source );
1621+
1622+ CeedCallBackend (CeedElemRestrictionGetNumComponents (rstr_in , & num_comp_in ));
1623+ CeedCallBackend (CeedElemRestrictionGetNumComponents (rstr_out , & num_comp_out ));
1624+ CeedCallBackend (CeedCompile_Cuda (ceed , source , & asmb -> module , 11 , "NUM_EVAL_MODES_IN" , num_eval_modes_in , "NUM_EVAL_MODES_OUT" , num_eval_modes_out ,
1625+ "NUM_COMP_IN" , num_comp_in , "NUM_COMP_OUT" , num_comp_out , "TOTAL_NUM_COMP_OUT" , num_output_components ,
1626+ "NUM_NODES_IN" , elem_size_in , "NUM_NODES_OUT" , elem_size_out , "NUM_QPTS" , num_qpts_in , "BLOCK_SIZE" ,
1627+ asmb -> block_size_x * asmb -> block_size_y * asmb -> elems_per_block , "BLOCK_SIZE_Y" , asmb -> block_size_y ,
1628+ "USE_CEEDSIZE" , use_ceedsize_idx ));
1629+ CeedCallBackend (CeedGetKernel_Cuda (ceed , asmb -> module , "LinearAssembleBlock" , & asmb -> LinearAssemble ));
1630+
1631+ // Load into B_in, in order that they will be used in eval_modes_in
1632+ {
1633+ const CeedInt in_bytes = elem_size_in * num_qpts_in * num_eval_modes_in * sizeof (CeedScalar );
1634+
1635+ CeedCallCuda (ceed , cudaMalloc ((void * * )& asmb -> d_B_in , in_bytes ));
1636+ CeedCallCuda (ceed , cudaMemcpy (asmb -> d_B_in , h_B_in , in_bytes , cudaMemcpyHostToDevice ));
1637+ }
1638+
1639+ // Load into B_out, in order that they will be used in eval_modes_out
1640+ {
1641+ const CeedInt out_bytes = elem_size_out * num_qpts_out * num_eval_modes_out * sizeof (CeedScalar );
1642+
1643+ CeedCallCuda (ceed , cudaMalloc ((void * * )& asmb -> d_B_out , out_bytes ));
1644+ CeedCallCuda (ceed , cudaMemcpy (asmb -> d_B_out , h_B_out , out_bytes , cudaMemcpyHostToDevice ));
1645+ }
1646+ CeedCallBackend (CeedFree (& eval_mode_offsets_in_str ));
1647+ CeedCallBackend (CeedFree (& eval_mode_offsets_out_str ));
1648+ CeedCallBackend (CeedFree (& source ));
1649+ CeedCallBackend (CeedDestroy (& ceed ));
1650+ return CEED_ERROR_SUCCESS ;
1651+ }
1652+
15071653//------------------------------------------------------------------------------
15081654// Single Operator Assembly Setup
15091655//------------------------------------------------------------------------------
@@ -1705,6 +1851,117 @@ static int CeedOperatorAssembleSingleSetup_Cuda(CeedOperator op, CeedInt use_cee
17051851 return CEED_ERROR_SUCCESS ;
17061852}
17071853
1854+ //------------------------------------------------------------------------------
1855+ // Assemble matrix data for one block of a COO matrix of assembled operator.
1856+ // The sparsity pattern is set by CeedOperatorLinearAssembleSymbolic.
1857+ //------------------------------------------------------------------------------
1858+ static int CeedOperatorAssembleSingleBlock_Cuda (CeedOperator op , CeedInt offset , CeedInt active_input , CeedInt active_output , CeedVector values ) {
1859+ Ceed ceed ;
1860+ CeedSize values_length = 0 , assembled_qf_length = 0 ;
1861+ CeedInt use_ceedsize_idx = 0 , num_elem_in , num_elem_out , elem_size_in , elem_size_out ;
1862+ CeedScalar * values_array ;
1863+ const CeedScalar * assembled_qf_array ;
1864+ CeedVector assembled_qf = NULL ;
1865+ CeedElemRestriction assembled_rstr = NULL , rstr_in , rstr_out ;
1866+ CeedRestrictionType rstr_type_in , rstr_type_out ;
1867+ const bool * orients_in = NULL , * orients_out = NULL ;
1868+ const CeedInt8 * curl_orients_in = NULL , * curl_orients_out = NULL ;
1869+ CeedOperator_Cuda * impl ;
1870+
1871+ CeedCallBackend (CeedOperatorGetCeed (op , & ceed ));
1872+ CeedCallBackend (CeedOperatorGetData (op , & impl ));
1873+
1874+ // Assemble QFunction
1875+ CeedCallBackend (CeedOperatorLinearAssembleQFunctionBuildOrUpdate (op , & assembled_qf , & assembled_rstr , CEED_REQUEST_IMMEDIATE ));
1876+ CeedCallBackend (CeedElemRestrictionDestroy (& assembled_rstr ));
1877+ CeedCallBackend (CeedVectorGetArrayRead (assembled_qf , CEED_MEM_DEVICE , & assembled_qf_array ));
1878+
1879+ CeedCallBackend (CeedVectorGetLength (values , & values_length ));
1880+ CeedCallBackend (CeedVectorGetLength (assembled_qf , & assembled_qf_length ));
1881+ if ((values_length > INT_MAX ) || (assembled_qf_length > INT_MAX )) use_ceedsize_idx = 1 ;
1882+
1883+ // Setup
1884+ if (!impl -> asmb_blocks || (impl -> asmb_blocks && !impl -> asmb_blocks [active_input * impl -> num_blocks_out + active_output ])) {
1885+ CeedCallBackend (CeedOperatorAssembleSingleBlockSetup_Cuda (op , active_input , active_output , use_ceedsize_idx ));
1886+ }
1887+ CeedOperatorAssemble_Cuda * asmb = impl -> asmb_blocks [active_input * impl -> num_blocks_out + active_output ];
1888+
1889+ assert (asmb != NULL );
1890+
1891+ // Assemble element operator
1892+ CeedCallBackend (CeedVectorGetArray (values , CEED_MEM_DEVICE , & values_array ));
1893+ values_array += offset ;
1894+
1895+ CeedElemRestriction * active_rstrs_in , * active_rstrs_out ;
1896+ CeedOperatorAssemblyData data ;
1897+
1898+ CeedCall (CeedOperatorGetOperatorAssemblyData (op , & data ));
1899+ CeedCall (CeedOperatorAssemblyDataGetElemRestrictions (data , NULL , & active_rstrs_in , NULL , & active_rstrs_out ));
1900+
1901+ rstr_in = active_rstrs_in [active_input ];
1902+ rstr_out = active_rstrs_out [active_output ];
1903+ CeedCallBackend (CeedElemRestrictionGetNumElements (rstr_in , & num_elem_in ));
1904+ CeedCallBackend (CeedElemRestrictionGetElementSize (rstr_in , & elem_size_in ));
1905+
1906+ CeedCallBackend (CeedElemRestrictionGetType (rstr_in , & rstr_type_in ));
1907+ if (rstr_type_in == CEED_RESTRICTION_ORIENTED ) {
1908+ CeedCallBackend (CeedElemRestrictionGetOrientations (rstr_in , CEED_MEM_DEVICE , & orients_in ));
1909+ } else if (rstr_type_in == CEED_RESTRICTION_CURL_ORIENTED ) {
1910+ CeedCallBackend (CeedElemRestrictionGetCurlOrientations (rstr_in , CEED_MEM_DEVICE , & curl_orients_in ));
1911+ }
1912+
1913+ if (rstr_in != rstr_out ) {
1914+ CeedCallBackend (CeedElemRestrictionGetNumElements (rstr_out , & num_elem_out ));
1915+ CeedCheck (num_elem_in == num_elem_out , ceed , CEED_ERROR_UNSUPPORTED ,
1916+ "Active input and output operator restrictions must have the same number of elements" );
1917+ CeedCallBackend (CeedElemRestrictionGetElementSize (rstr_out , & elem_size_out ));
1918+
1919+ CeedCallBackend (CeedElemRestrictionGetType (rstr_out , & rstr_type_out ));
1920+ if (rstr_type_out == CEED_RESTRICTION_ORIENTED ) {
1921+ CeedCallBackend (CeedElemRestrictionGetOrientations (rstr_out , CEED_MEM_DEVICE , & orients_out ));
1922+ } else if (rstr_type_out == CEED_RESTRICTION_CURL_ORIENTED ) {
1923+ CeedCallBackend (CeedElemRestrictionGetCurlOrientations (rstr_out , CEED_MEM_DEVICE , & curl_orients_out ));
1924+ }
1925+ } else {
1926+ elem_size_out = elem_size_in ;
1927+ orients_out = orients_in ;
1928+ curl_orients_out = curl_orients_in ;
1929+ }
1930+
1931+ // Compute B^T D B
1932+ CeedInt shared_mem =
1933+ ((curl_orients_in || curl_orients_out ? elem_size_in * elem_size_out : 0 ) + (curl_orients_in ? elem_size_in * asmb -> block_size_y : 0 )) *
1934+ sizeof (CeedScalar );
1935+ CeedInt grid = CeedDivUpInt (num_elem_in , asmb -> elems_per_block );
1936+ void * args [] = {(void * )& num_elem_in , & asmb -> d_B_in , & asmb -> d_B_out , & orients_in , & curl_orients_in ,
1937+ & orients_out , & curl_orients_out , & assembled_qf_array , & values_array };
1938+
1939+ CeedCallBackend (CeedRunKernelDimShared_Cuda (ceed , asmb -> LinearAssemble , NULL , grid , asmb -> block_size_x , asmb -> block_size_y , asmb -> elems_per_block ,
1940+ shared_mem , args ));
1941+ CeedCallCuda (ceed , cudaDeviceSynchronize ());
1942+
1943+ // Restore arrays
1944+ CeedCallBackend (CeedVectorRestoreArray (values , & values_array ));
1945+ CeedCallBackend (CeedVectorRestoreArrayRead (assembled_qf , & assembled_qf_array ));
1946+
1947+ // Cleanup
1948+ CeedCallBackend (CeedVectorDestroy (& assembled_qf ));
1949+ if (rstr_type_in == CEED_RESTRICTION_ORIENTED ) {
1950+ CeedCallBackend (CeedElemRestrictionRestoreOrientations (rstr_in , & orients_in ));
1951+ } else if (rstr_type_in == CEED_RESTRICTION_CURL_ORIENTED ) {
1952+ CeedCallBackend (CeedElemRestrictionRestoreCurlOrientations (rstr_in , & curl_orients_in ));
1953+ }
1954+ if (rstr_in != rstr_out ) {
1955+ if (rstr_type_out == CEED_RESTRICTION_ORIENTED ) {
1956+ CeedCallBackend (CeedElemRestrictionRestoreOrientations (rstr_out , & orients_out ));
1957+ } else if (rstr_type_out == CEED_RESTRICTION_CURL_ORIENTED ) {
1958+ CeedCallBackend (CeedElemRestrictionRestoreCurlOrientations (rstr_out , & curl_orients_out ));
1959+ }
1960+ }
1961+ CeedCallBackend (CeedDestroy (& ceed ));
1962+ return CEED_ERROR_SUCCESS ;
1963+ }
1964+
17081965//------------------------------------------------------------------------------
17091966// Assemble matrix data for COO matrix of assembled operator.
17101967// The sparsity pattern is set by CeedOperatorLinearAssembleSymbolic.
@@ -2090,6 +2347,7 @@ int CeedOperatorCreate_Cuda(CeedOperator op) {
20902347 CeedCallBackend (CeedSetBackendFunction (ceed , "Operator" , op , "LinearAssembleAddPointBlockDiagonal" ,
20912348 CeedOperatorLinearAssembleAddPointBlockDiagonal_Cuda ));
20922349 CeedCallBackend (CeedSetBackendFunction (ceed , "Operator" , op , "LinearAssembleSingle" , CeedOperatorAssembleSingle_Cuda ));
2350+ CeedCallBackend (CeedSetBackendFunction (ceed , "Operator" , op , "LinearAssembleSingleBlock" , CeedOperatorAssembleSingleBlock_Cuda ));
20932351 CeedCallBackend (CeedSetBackendFunction (ceed , "Operator" , op , "ApplyAdd" , CeedOperatorApplyAdd_Cuda ));
20942352 CeedCallBackend (CeedSetBackendFunction (ceed , "Operator" , op , "Destroy" , CeedOperatorDestroy_Cuda ));
20952353 CeedCallBackend (CeedDestroy (& ceed ));
0 commit comments